From 0890460ba76f15b19973792e977e1e3a353c139e Mon Sep 17 00:00:00 2001 From: Emmy D'Anello Date: Sat, 27 Jan 2024 17:13:42 +0100 Subject: [PATCH] New API viewsets for next departures and next arrivals --- sncf/api/views.py | 131 ++++++++++++++++++++++++++++++++++++++++++++++ sncf/settings.py | 6 ++- sncf/urls.py | 4 +- 3 files changed, 138 insertions(+), 3 deletions(-) diff --git a/sncf/api/views.py b/sncf/api/views.py index 64a18f6..425fb0f 100644 --- a/sncf/api/views.py +++ b/sncf/api/views.py @@ -1,3 +1,7 @@ +from datetime import datetime, timedelta, date + +from django.db.models import F, Q, OuterRef, Subquery +from django_filters.rest_framework import DjangoFilterBackend from rest_framework import viewsets from sncf.api.serializers import AgencySerializer, StopSerializer, RouteSerializer, TripSerializer, \ @@ -10,43 +14,170 @@ from sncfgtfs.models import Agency, Stop, Route, Trip, StopTime, Calendar, Calen class AgencyViewSet(viewsets.ReadOnlyModelViewSet): queryset = Agency.objects.all() serializer_class = AgencySerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = '__all__' class StopViewSet(viewsets.ReadOnlyModelViewSet): queryset = Stop.objects.all() serializer_class = StopSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = '__all__' class RouteViewSet(viewsets.ReadOnlyModelViewSet): queryset = Route.objects.all() serializer_class = RouteSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = '__all__' class TripViewSet(viewsets.ReadOnlyModelViewSet): queryset = Trip.objects.all() serializer_class = TripSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = '__all__' class StopTimeViewSet(viewsets.ReadOnlyModelViewSet): queryset = StopTime.objects.order_by('id').all() serializer_class = StopTimeSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = '__all__' class CalendarViewSet(viewsets.ReadOnlyModelViewSet): queryset = Calendar.objects.all() serializer_class = CalendarSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = '__all__' class CalendarDateViewSet(viewsets.ReadOnlyModelViewSet): queryset = CalendarDate.objects.all() serializer_class = CalendarDateSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = '__all__' class TransferViewSet(viewsets.ReadOnlyModelViewSet): queryset = Transfer.objects.all() serializer_class = TransferSerializer + filter_backends = [DjangoFilterBackend] class FeedInfoViewSet(viewsets.ReadOnlyModelViewSet): queryset = FeedInfo.objects.all() serializer_class = FeedInfoSerializer + filter_backends = [DjangoFilterBackend] + filterset_fields = '__all__' + + +class NextDeparturesViewSet(viewsets.ReadOnlyModelViewSet): + queryset = StopTime.objects.none() + serializer_class = StopTimeSerializer + filter_backends = [DjangoFilterBackend] + + def get_queryset(self): + now = datetime.now() + + stop_id = self.request.query_params.get('stop_id', None) + stop_name = self.request.query_params.get('stop_name', None) + query_date = date.fromisoformat(self.request.query_params.get('date', now.date().isoformat())) + query_time = self.request.query_params.get('time', now.time().isoformat(timespec='seconds')) + query_time = timedelta(seconds=int(query_time[:2]) * 3600 + int(query_time[3:5]) * 60 + int(query_time[6:])) + + yesterday = query_date - timedelta(days=1) + time_yesterday = query_time + timedelta(days=1) + tomorrow = query_date + timedelta(days=1) + + stop_filter = Q(stop__location_type=0) + if stop_id: + stop = Stop.objects.get(id=stop_id) + stops = Stop.objects.filter(Q(id=stop_id) + | Q(parent_station=stop_id) + | Q(parent_station=stop.parent_station_id)).values_list('id', flat=True) + stop_filter = Q(stop__in=stops) + elif stop_name: + stops = Stop.objects.filter(name__iexact=stop_name).values_list('id', flat=True) + stop_filter = Q(stop__in=stops) + + def calendar_filter(d: date): + return Q(trip__service_id__in=CalendarDate.objects.filter(date=d, exception_type=1) + .values_list('service_id')) \ + | Q(trip__service_id__in=Calendar.objects.filter( + start_date__lte=d, + end_date__gte=d, + **{f"{d:%A}".lower(): True}) + .filter(~Q(id__in=CalendarDate.objects.filter(date=d, exception_type=2) + .values_list('service_id', flat=True))) + .values_list('id')) + + qs_today = StopTime.objects.filter(stop_filter) \ + .filter(Q(departure_time__gte=query_time, pickup_type=0), calendar_filter(query_date)) \ + .annotate(departure_time_24h=F('departure_time')) + + qs_yesterday = StopTime.objects.filter(stop_filter) \ + .filter(Q(departure_time__gte=time_yesterday, pickup_type=0), calendar_filter(yesterday)) \ + .annotate(departure_time_24h=F('departure_time') - timedelta(days=1)) + + qs_tomorrow = StopTime.objects.filter(stop_filter) \ + .filter(Q(departure_time__gte=timedelta(0), pickup_type=0), calendar_filter(tomorrow)) \ + .annotate(departure_time_24h=F('departure_time') + timedelta(days=1)) + + return qs_today.union(qs_yesterday).union(qs_tomorrow).order_by("departure_time_24h").all() + + +class NextArrivalsViewSet(viewsets.ReadOnlyModelViewSet): + queryset = StopTime.objects.none() + serializer_class = StopTimeSerializer + filter_backends = [DjangoFilterBackend] + + def get_queryset(self): + now = datetime.now() + + stop_id = self.request.query_params.get('stop_id', None) + stop_name = self.request.query_params.get('stop_name', None) + query_date = date.fromisoformat(self.request.query_params.get('date', now.date().isoformat())) + query_time = self.request.query_params.get('time', now.time().isoformat(timespec='seconds')) + query_time = timedelta(seconds=int(query_time[:2]) * 3600 + int(query_time[3:5]) * 60 + int(query_time[6:])) + + yesterday = query_date - timedelta(days=1) + time_yesterday = query_time + timedelta(days=1) + tomorrow = query_date + timedelta(days=1) + + stop_filter = Q(stop__location_type=0) + if stop_id: + stop = Stop.objects.get(id=stop_id) + stops = Stop.objects.filter(Q(id=stop_id) + | Q(parent_station=stop_id) + | Q(parent_station=stop.parent_station_id)).values_list('id', flat=True) + stop_filter = Q(stop__in=stops) + elif stop_name: + stops = Stop.objects.filter(name__iexact=stop_name).values_list('id', flat=True) + stop_filter = Q(stop__in=stops) + + def calendar_filter(d: date): + return Q(trip__service_id__in=CalendarDate.objects.filter(date=d, exception_type=1) + .values_list('service_id')) \ + | Q(trip__service_id__in=Calendar.objects.filter( + start_date__lte=d, + end_date__gte=d, + **{f"{d:%A}".lower(): True}) + .filter(~Q(id__in=CalendarDate.objects.filter(date=d, exception_type=2) + .values_list('service_id', flat=True))) + .values_list('id')) + + qs_today = StopTime.objects.filter(stop_filter) \ + .filter(Q(arrival_time__gte=query_time, drop_off_type=0), calendar_filter(query_date)) \ + .annotate(arrival_time_24h=F('arrival_time')) + + qs_yesterday = StopTime.objects.filter(stop_filter) \ + .filter(Q(arrival_time__gte=time_yesterday, drop_off_type=0), calendar_filter(yesterday)) \ + .annotate(arrival_time_24h=F('arrival_time') - timedelta(days=1)) + + qs_tomorrow = StopTime.objects.filter(stop_filter) \ + .filter(Q(arrival_time__gte=timedelta(0), drop_off_type=0), calendar_filter(tomorrow)) \ + .annotate(arrival_time_24h=F('arrival_time') + timedelta(days=1)) + + return qs_today.union(qs_yesterday).union(qs_tomorrow).order_by("arrival_time_24h").all() diff --git a/sncf/settings.py b/sncf/settings.py index 739ee31..5f392f2 100644 --- a/sncf/settings.py +++ b/sncf/settings.py @@ -38,6 +38,7 @@ INSTALLED_APPS = [ "django.contrib.messages", "django.contrib.staticfiles", + "django_filters", "rest_framework", "sncf.api", @@ -128,6 +129,7 @@ STATIC_URL = "static/" DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" REST_FRAMEWORK = { - 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', - 'PAGE_SIZE': 20 + 'DEFAULT_FILTER_BACKENDS': ['django_filters.rest_framework.DjangoFilterBackend'], + 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.LimitOffsetPagination', + 'PAGE_SIZE': 20, } diff --git a/sncf/urls.py b/sncf/urls.py index 4032ff6..e4f27e4 100644 --- a/sncf/urls.py +++ b/sncf/urls.py @@ -19,7 +19,7 @@ from django.urls import path, include from rest_framework import routers from sncf.api.views import AgencyViewSet, StopViewSet, RouteViewSet, TripViewSet, StopTimeViewSet, \ - CalendarViewSet, CalendarDateViewSet, TransferViewSet, FeedInfoViewSet + CalendarViewSet, CalendarDateViewSet, TransferViewSet, FeedInfoViewSet, NextDeparturesViewSet, NextArrivalsViewSet router = routers.DefaultRouter() router.register("gtfs/agency", AgencyViewSet) @@ -31,6 +31,8 @@ router.register("gtfs/calendar", CalendarViewSet) router.register("gtfs/calendar_date", CalendarDateViewSet) router.register("gtfs/transfer", TransferViewSet) router.register("gtfs/feed_info", FeedInfoViewSet) +router.register("station/next_departures", NextDeparturesViewSet, basename="next_departures") +router.register("station/next_arrivals", NextArrivalsViewSet, basename="next_arrivals") urlpatterns = [ path("admin/", admin.site.urls, name="admin"),