from datetime import timedelta, datetime, date, time from zoneinfo import ZoneInfo import requests from django.core.management import BaseCommand from django.db.models import Q, Max from sncfgtfs.gtfs_realtime_pb2 import FeedMessage from sncfgtfs.models import Calendar, CalendarDate, StopTime, StopTimeUpdate, Trip, TripUpdate, Stop class Command(BaseCommand): help = "Update the SNCF GTFS Realtime database." GTFS_RT_FEEDS = { "TGV": "https://proxy.transport.data.gouv.fr/resource/sncf-tgv-gtfs-rt-trip-updates", "IC": "https://proxy.transport.data.gouv.fr/resource/sncf-ic-gtfs-rt-trip-updates", "TER": "https://proxy.transport.data.gouv.fr/resource/sncf-ter-gtfs-rt-trip-updates", } def add_arguments(self, parser): pass def handle(self, *args, **options): for feed_type, feed_url in self.GTFS_RT_FEEDS.items(): self.stdout.write(f"Updating {feed_type} feed...") feed_message = FeedMessage() feed_message.ParseFromString(requests.get(feed_url).content) stop_times_updates = [] for entity in feed_message.entity: if entity.HasField("trip_update"): trip_update = entity.trip_update trip_id = trip_update.trip.trip_id if feed_type in ["TGV", "IC", "TER"]: trip_id = trip_id.split(":", 1)[0] start_date = date(year=int(trip_update.trip.start_date[:4]), month=int(trip_update.trip.start_date[4:6]), day=int(trip_update.trip.start_date[6:])) start_dt = datetime.combine(start_date, time(0), tzinfo=ZoneInfo("Europe/Paris")) if trip_update.trip.schedule_relationship == 1: headsign = trip_id[5:-1] trip_qs = Trip.objects.all() trip_ids = trip_qs.values_list('id', flat=True) first_stop_queryset = StopTime.objects.filter( stop__parent_station_id=trip_update.stop_time_update[0].stop_id, ).values('trip_id') last_stop_queryset = StopTime.objects.filter( stop__parent_station_id=trip_update.stop_time_update[-1].stop_id, ).values('trip_id') trip_ids = trip_ids.intersection(first_stop_queryset).intersection(last_stop_queryset) for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update): stop_id = stop_time_update.stop_id st_queryset = StopTime.objects.filter(stop__parent_station_id=stop_id) if stop_sequence == 0: st_queryset = st_queryset.filter(stop_sequence=0) trip_ids_restrict = trip_ids.intersection(st_queryset.values('trip_id')) if trip_ids_restrict: trip_ids = trip_ids_restrict else: stop = Stop.objects.get(id=stop_id) self.stdout.write(self.style.WARNING(f"Warning: No trip is found passing by stop " f"{stop.name} ({stop_id})")) trip_ids = set(trip_ids) route_ids = set(Trip.objects.filter(id__in=trip_ids).values_list('route_id', flat=True)) self.stdout.write(f"{len(route_ids)} routes found on trip for new train {headsign}") if not route_ids: self.stdout.write(f"Route not found for trip {trip_id}.") continue elif len(route_ids) > 1: self.stdout.write(f"Multiple routes found for trip {trip_id}.") self.stdout.write(", ".join(route_ids)) route_id = route_ids.pop() Calendar.objects.update_or_create( id=f"{feed_type}-new-{headsign}", defaults={ "transport_type": feed_type, "monday": False, "tuesday": False, "wednesday": False, "thursday": False, "friday": False, "saturday": False, "sunday": False, "start_date": start_date, "end_date": start_date, } ) CalendarDate.objects.update_or_create( id=f"{feed_type}-{headsign}-{trip_update.trip.start_date}", defaults={ "service_id": f"{feed_type}-new-{headsign}", "date": trip_update.trip.start_date, "exception_type": 1, } ) Trip.objects.update_or_create( id=trip_id, defaults={ "route_id": route_id, "service_id": f"{feed_type}-new-{headsign}", "headsign": headsign, "direction_id": trip_update.trip.direction_id, } ) sample_trip = Trip.objects.filter(id__in=trip_ids, route_id=route_id).first() for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update): stop_id = stop_time_update.stop_id stop = Stop.objects.get(id=stop_id) if stop.location_type == 1: if not StopTime.objects.filter(trip_id=trip_id).exists(): print(trip_id, sample_trip.id) stop = StopTime.objects.get(trip_id=sample_trip.id, stop__parent_station_id=stop_id).stop elif StopTime.objects.filter(trip_id=trip_id, stop__parent_station_id=stop_id).exists(): stop = StopTime.objects.get(trip_id=trip_id, stop__parent_station_id=stop_id).stop else: stop = next(s for s in Stop.objects.filter(parent_station_id=stop_id).all() for s2 in StopTime.objects.filter(trip_id=trip_id).all() if s.stop_type in s2.stop.stop_type or s2.stop.stop_type in s.stop_type) stop_id = stop.id arr_time = datetime.fromtimestamp(stop_time_update.arrival.time, tz=ZoneInfo("Europe/Paris")) - start_dt dep_time = datetime.fromtimestamp(stop_time_update.departure.time, tz=ZoneInfo("Europe/Paris")) - start_dt pickup_type = 0 if stop_time_update.departure.time and stop_sequence > 0 else 1 drop_off_type = 0 if stop_time_update.arrival.time \ and stop_sequence < len(trip_update.stop_time_update) - 1 else 1 StopTime.objects.update_or_create( id=f"{trip_id}-{stop_id}", trip_id=trip_id, defaults={ "stop_id": stop_id, "stop_sequence": stop_sequence, "arrival_time": arr_time, "departure_time": dep_time, "pickup_type": pickup_type, "drop_off_type": drop_off_type, } ) if not Trip.objects.filter(id=trip_id).exists(): self.stdout.write(f"Trip {trip_id} does not exist in the GTFS feed.") continue tu, _created = TripUpdate.objects.update_or_create( trip_id=trip_id, start_date=trip_update.trip.start_date, start_time=trip_update.trip.start_time, defaults=dict( schedule_relationship=trip_update.trip.schedule_relationship, ) ) for stop_sequence, stop_time_update in enumerate(trip_update.stop_time_update): stop_id = stop_time_update.stop_id if stop_id.startswith('StopArea:'): if StopTime.objects.filter(trip_id=trip_id, stop__parent_station_id=stop_id).exists(): stop = StopTime.objects.get(trip_id=trip_id, stop__parent_station_id=stop_id).stop else: stop = next(s for s in Stop.objects.filter(parent_station_id=stop_id).all() for s2 in StopTime.objects.filter(trip_id=trip_id).all() if s.stop_type in s2.stop.stop_type or s2.stop.stop_type in s.stop_type) st, _created = StopTime.objects.update_or_create( id=f"{trip_id}-{stop.id}", trip_id=trip_id, stop_id=stop.id, defaults={ "stop_sequence": stop_sequence, "arrival_time": datetime.fromtimestamp(stop_time_update.arrival.time, tz=ZoneInfo("Europe/Paris")) - start_dt, "departure_time": datetime.fromtimestamp(stop_time_update.departure.time, tz=ZoneInfo("Europe/Paris")) - start_dt, "pickup_type": 0 if stop_time_update.departure.time else 1, "drop_off_type": 0 if stop_time_update.arrival.time else 1, } ) elif stop_time_update.schedule_relationship == 1: st = StopTime.objects.get(Q(stop=stop_id) | Q(stop__parent_station_id=stop_id), trip_id=trip_id) if st.pickup_type != 1 or st.drop_off_type != 1: st.pickup_type = 1 st.drop_off_type = 1 st.save() else: qs = StopTime.objects.filter(Q(stop=stop_id) | Q(stop__parent_station_id=stop_id), trip_id=trip_id) if qs.count() == 1: st = qs.first() else: st = qs.get(stop_sequence=stop_sequence) if st.stop_sequence != stop_sequence: st.stop_sequence = stop_sequence st.save() st_update = StopTimeUpdate( trip_update=tu, stop_time=st, arrival_delay=timedelta(seconds=stop_time_update.arrival.delay), arrival_time=datetime.fromtimestamp(stop_time_update.arrival.time, tz=ZoneInfo("Europe/Paris")), departure_delay=timedelta(seconds=stop_time_update.departure.delay), departure_time=datetime.fromtimestamp(stop_time_update.departure.time, tz=ZoneInfo("Europe/Paris")), schedule_relationship=stop_time_update.schedule_relationship or 0, ) stop_times_updates.append(st_update) else: self.stdout.write(str(entity)) StopTimeUpdate.objects.bulk_create(stop_times_updates, update_conflicts=True, update_fields=['arrival_delay', 'arrival_time', 'departure_delay', 'departure_time'], unique_fields=['trip_update', 'stop_time'])