import csv from datetime import datetime, timedelta from io import BytesIO from zipfile import ZipFile import requests from django.core.management import BaseCommand from sncfgtfs.models import Agency, Calendar, CalendarDate, FeedInfo, Route, Stop, StopTime, Transfer, Trip class Command(BaseCommand): help = "Update the SNCF GTFS database." GTFS_FEEDS = { "TGV": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export_gtfs_voyages.zip", "IC": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-intercites-gtfs-last.zip", "TER": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip", "TN": "https://eu.ftp.opendatasoft.com/sncf/gtfs/transilien-gtfs.zip", } def add_arguments(self, parser): parser.add_argument('--bulk_size', type=int, default=1000, help="Number of objects to create in bulk.") parser.add_argument('--dry-run', action='store_true', help="Do not update the database, only print what would be done.") parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.") def handle(self, *args, **options): bulk_size = options['bulk_size'] dry_run = options['dry_run'] force = options['force'] if dry_run: self.stdout.write(self.style.WARNING("Dry run mode activated.")) if not FeedInfo.objects.exists(): last_update_date = "1970-01-01" else: last_update_date = FeedInfo.objects.get().version for url in self.GTFS_FEEDS.values(): last_modified = requests.head(url).headers["Last-Modified"] last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") if last_modified.date().isoformat() > last_update_date: break else: if not force: self.stdout.write(self.style.WARNING("Database already up-to-date.")) return self.stdout.write("Updating database...") all_trips = [] for transport_type, feed_url in self.GTFS_FEEDS.items(): self.stdout.write(f"Downloading {transport_type} GTFS feed...") with ZipFile(BytesIO(requests.get(feed_url).content)) as zipfile: agencies = [] for agency_dict in csv.DictReader(zipfile.read("agency.txt").decode().splitlines()): agency_dict: dict agency = Agency( id=agency_dict['agency_id'], name=agency_dict['agency_name'], url=agency_dict['agency_url'], timezone=agency_dict['agency_timezone'], lang=agency_dict['agency_lang'], phone=agency_dict.get('agency_phone', ""), email=agency_dict.get('agency_email', ""), ) agencies.append(agency) if agencies and not dry_run: Agency.objects.bulk_create(agencies, update_conflicts=True, update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email'], unique_fields=['id']) agencies.clear() stops = [] for stop_dict in csv.DictReader(zipfile.read("stops.txt").decode().splitlines()): stop_dict: dict stop = Stop( id=stop_dict["stop_id"], name=stop_dict['stop_name'], desc=stop_dict['stop_desc'], lat=stop_dict['stop_lat'], lon=stop_dict['stop_lon'], zone_id=stop_dict['zone_id'], url=stop_dict['stop_url'], location_type=stop_dict['location_type'], parent_station_id=stop_dict['parent_station'] or None if last_update_date != "1970-01-01" or transport_type != "TN" else None, timezone=stop_dict.get('stop_timezone', ""), wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0), level_id=stop_dict.get('level_id', ""), platform_code=stop_dict.get('platform_code', ""), ) stops.append(stop) if len(stops) >= bulk_size and not dry_run: Stop.objects.bulk_create(stops, update_conflicts=True, update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url', 'location_type', 'parent_station_id', 'timezone', 'wheelchair_boarding', 'level_id', 'platform_code'], unique_fields=['id']) stops.clear() if stops and not dry_run: Stop.objects.bulk_create(stops, update_conflicts=True, update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url', 'location_type', 'parent_station_id', 'timezone', 'wheelchair_boarding', 'level_id', 'platform_code'], unique_fields=['id']) stops.clear() routes = [] for route_dict in csv.DictReader(zipfile.read("routes.txt").decode().splitlines()): route_dict: dict route = Route( id=route_dict['route_id'], agency_id=route_dict['agency_id'], short_name=route_dict['route_short_name'], long_name=route_dict['route_long_name'], desc=route_dict['route_desc'], type=route_dict['route_type'], url=route_dict['route_url'], color=route_dict['route_color'], text_color=route_dict['route_text_color'], transport_type=transport_type, ) routes.append(route) if len(routes) >= bulk_size and not dry_run: Route.objects.bulk_create(routes, update_conflicts=True, update_fields=['agency_id', 'short_name', 'long_name', 'desc', 'type', 'url', 'color', 'text_color', 'transport_type'], unique_fields=['id']) routes.clear() if routes and not dry_run: Route.objects.bulk_create(routes, update_conflicts=True, update_fields=['agency_id', 'short_name', 'long_name', 'desc', 'type', 'url', 'color', 'text_color'], unique_fields=['id']) routes.clear() calendar_ids = [] if "calendar.txt" in zipfile.namelist(): calendars = [] for calendar_dict in csv.DictReader(zipfile.read("calendar.txt").decode().splitlines()): calendar_dict: dict calendar = Calendar( id=f"{transport_type}-{calendar_dict['service_id']}", monday=calendar_dict['monday'], tuesday=calendar_dict['tuesday'], wednesday=calendar_dict['wednesday'], thursday=calendar_dict['thursday'], friday=calendar_dict['friday'], saturday=calendar_dict['saturday'], sunday=calendar_dict['sunday'], start_date=calendar_dict['start_date'], end_date=calendar_dict['end_date'], transport_type=transport_type, ) calendars.append(calendar) calendar_ids.append(calendar.id) if len(calendars) >= bulk_size and not dry_run: Calendar.objects.bulk_create(calendars, update_conflicts=True, update_fields=['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday', 'start_date', 'end_date', 'transport_type'], unique_fields=['id']) calendars.clear() if calendars and not dry_run: Calendar.objects.bulk_create(calendars, update_conflicts=True, update_fields=['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday', 'start_date', 'end_date', 'transport_type'], unique_fields=['id']) calendars.clear() calendars = [] calendar_dates = [] for calendar_date_dict in csv.DictReader(zipfile.read("calendar_dates.txt").decode().splitlines()): calendar_date_dict: dict calendar_date = CalendarDate( id=f"{transport_type}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}", service_id=f"{transport_type}-{calendar_date_dict['service_id']}", date=calendar_date_dict['date'], exception_type=calendar_date_dict['exception_type'], ) calendar_dates.append(calendar_date) if calendar_date.service_id not in calendar_ids: calendar = Calendar( id=f"{transport_type}-{calendar_date_dict['service_id']}", monday=False, tuesday=False, wednesday=False, thursday=False, friday=False, saturday=False, sunday=False, start_date=calendar_date_dict['date'], end_date=calendar_date_dict['date'], transport_type=transport_type, ) calendars.append(calendar) if len(calendar_dates) >= bulk_size and not dry_run: Calendar.objects.bulk_create(calendars, update_conflicts=True, update_fields=['end_date'], unique_fields=['id']) CalendarDate.objects.bulk_create(calendar_dates, update_conflicts=True, update_fields=['service_id', 'date', 'exception_type'], unique_fields=['id']) calendars.clear() calendar_dates.clear() if calendar_dates and not dry_run: Calendar.objects.bulk_create(calendars, update_conflicts=True, update_fields=['end_date'], unique_fields=['id']) CalendarDate.objects.bulk_create(calendar_dates, update_conflicts=True, update_fields=['service_id', 'date', 'exception_type'], unique_fields=['id']) calendars.clear() calendar_dates.clear() trips = [] for trip_dict in csv.DictReader(zipfile.read("trips.txt").decode().splitlines()): trip_dict: dict trip_id = trip_dict['trip_id'] if transport_type != "TN": trip_id, last_update = trip_id.split(':', 1) last_update = datetime.fromisoformat(last_update) else: last_update = None trip = Trip( id=trip_id, route_id=trip_dict['route_id'], service_id=f"{transport_type}-{trip_dict['service_id']}", headsign=trip_dict['trip_headsign'], short_name=trip_dict.get('trip_short_name', ""), direction_id=trip_dict['direction_id'] or None, block_id=trip_dict['block_id'], shape_id=trip_dict['shape_id'], wheelchair_accessible=trip_dict.get('wheelchair_accessible', None), bikes_allowed=trip_dict.get('bikes_allowed', None), last_update=last_update, ) trips.append(trip) if len(trips) >= bulk_size and not dry_run: Trip.objects.bulk_create(trips, update_conflicts=True, update_fields=['route_id', 'service_id', 'headsign', 'short_name', 'direction_id', 'block_id', 'shape_id', 'wheelchair_accessible', 'bikes_allowed'], unique_fields=['id']) trips.clear() if trips and not dry_run: Trip.objects.bulk_create(trips, update_conflicts=True, update_fields=['route_id', 'service_id', 'headsign', 'short_name', 'direction_id', 'block_id', 'shape_id', 'wheelchair_accessible', 'bikes_allowed'], unique_fields=['id']) trips.clear() all_trips.extend(trips) stop_times = [] for stop_time_dict in csv.DictReader(zipfile.read("stop_times.txt").decode().splitlines()): stop_time_dict: dict trip_id = stop_time_dict['trip_id'] if transport_type != "TN": trip_id = trip_id.split(':', 1)[0] arr_time = stop_time_dict['arrival_time'] arr_time = int(arr_time[:2]) * 3600 + int(arr_time[3:5]) * 60 + int(arr_time[6:]) dep_time = stop_time_dict['departure_time'] dep_time = int(dep_time[:2]) * 3600 + int(dep_time[3:5]) * 60 + int(dep_time[6:]) st = StopTime( id=f"{stop_time_dict['trip_id']}-{stop_time_dict['stop_id']}", trip_id=trip_id, arrival_time=timedelta(seconds=arr_time), departure_time=timedelta(seconds=dep_time), stop_id=stop_time_dict['stop_id'], stop_sequence=stop_time_dict['stop_sequence'], stop_headsign=stop_time_dict['stop_headsign'], pickup_type=stop_time_dict['pickup_type'], drop_off_type=stop_time_dict['drop_off_type'], timepoint=stop_time_dict.get('timepoint', None), ) stop_times.append(st) if len(stop_times) >= bulk_size and not dry_run: StopTime.objects.bulk_create(stop_times, update_conflicts=True, update_fields=['stop_id', 'arrival_time', 'departure_time', 'stop_headsign', 'pickup_type', 'drop_off_type', 'timepoint'], unique_fields=['id']) stop_times.clear() if stop_times and not dry_run: StopTime.objects.bulk_create(stop_times, update_conflicts=True, update_fields=['stop_id', 'arrival_time', 'departure_time', 'stop_headsign', 'pickup_type', 'drop_off_type', 'timepoint'], unique_fields=['id']) stop_times.clear() transfers = [] for transfer_dict in csv.DictReader(zipfile.read("transfers.txt").decode().splitlines()): transfer_dict: dict transfer = Transfer( id=f"{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}", from_stop_id=transfer_dict['from_stop_id'], to_stop_id=transfer_dict['to_stop_id'], transfer_type=transfer_dict['transfer_type'], min_transfer_time=transfer_dict['min_transfer_time'], ) transfers.append(transfer) if len(transfers) >= bulk_size and not dry_run: Transfer.objects.bulk_create(transfers, update_conflicts=True, update_fields=['transfer_type', 'min_transfer_time'], unique_fields=['id']) transfers.clear() if transfers and not dry_run: Transfer.objects.bulk_create(transfers, update_conflicts=True, update_fields=['transfer_type', 'min_transfer_time'], unique_fields=['id']) transfers.clear() if "feed_info.txt" in zipfile.namelist() and not dry_run: for feed_info_dict in csv.DictReader(zipfile.read("feed_info.txt").decode().splitlines()): feed_info_dict: dict FeedInfo.objects.update_or_create( publisher_name=feed_info_dict['feed_publisher_name'], defaults=dict( publisher_url=feed_info_dict['feed_publisher_url'], lang=feed_info_dict['feed_lang'], start_date=feed_info_dict['feed_start_date'], end_date=feed_info_dict['feed_end_date'], version=feed_info_dict['feed_version'], ) )