import csv from datetime import datetime, timedelta from io import BytesIO from zipfile import ZipFile import requests from django.core.management import BaseCommand from sncfgtfs.models import Agency, Calendar, CalendarDate, FeedInfo, Route, Stop, StopTime, Transfer, Trip class Command(BaseCommand): help = "Update the SNCF GTFS database." GTFS_FEEDS = { "TGV": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export_gtfs_voyages.zip", "IC": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-intercites-gtfs-last.zip", "TER": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip", "TN": "https://eu.ftp.opendatasoft.com/sncf/gtfs/transilien-gtfs.zip", } def add_arguments(self, parser): parser.add_argument('--bulk_size', type=int, default=1000, help='Number of objects to create in bulk.') def handle(self, *args, **options): bulk_size = options['bulk_size'] if not FeedInfo.objects.exists(): last_update_date = "1970-01-01" else: last_update_date = FeedInfo.objects.get().feed_version for url in self.GTFS_FEEDS.values(): last_modified = requests.head(url).headers["Last-Modified"] last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") if last_modified.date().isoformat() > last_update_date: break else: self.stdout.write(self.style.WARNING("Database already up-to-date.")) return self.stdout.write("Updating database...") for transport_type, feed_url in self.GTFS_FEEDS.items(): self.stdout.write(f"Downloading {transport_type} GTFS feed...") with ZipFile(BytesIO(requests.get(feed_url).content)) as zipfile: agencies = [] for agency_dict in csv.DictReader(zipfile.read("agency.txt").decode().splitlines()): agency_dict: dict agency = Agency( id=agency_dict['agency_id'], name=agency_dict['agency_name'], url=agency_dict['agency_url'], timezone=agency_dict['agency_timezone'], lang=agency_dict['agency_lang'], phone=agency_dict.get('agency_phone', ""), email=agency_dict.get('agency_email', ""), ) agencies.append(agency) if agencies: Agency.objects.bulk_create(agencies, update_conflicts=True, update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email'], unique_fields=['id']) agencies.clear() stops = [] for stop_dict in csv.DictReader(zipfile.read("stops.txt").decode().splitlines()): stop_dict: dict stop = Stop( id=stop_dict["stop_id"], name=stop_dict['stop_name'], desc=stop_dict['stop_desc'], lat=stop_dict['stop_lat'], lon=stop_dict['stop_lon'], zone_id=stop_dict['zone_id'], url=stop_dict['stop_url'], location_type=stop_dict['location_type'], parent_station_id=stop_dict['parent_station'] or None if last_update_date != "1970-01-01" or transport_type != "TN" else None, timezone=stop_dict.get('stop_timezone', ""), wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0), level_id=stop_dict.get('level_id', ""), platform_code=stop_dict.get('platform_code', ""), ) stops.append(stop) if len(stops) >= bulk_size: Stop.objects.bulk_create(stops, update_conflicts=True, update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url', 'location_type', 'parent_station_id', 'timezone', 'wheelchair_boarding', 'level_id', 'platform_code'], unique_fields=['id']) stops.clear() if stops: Stop.objects.bulk_create(stops, update_conflicts=True, update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url', 'location_type', 'parent_station_id', 'timezone', 'wheelchair_boarding', 'level_id', 'platform_code'], unique_fields=['id']) stops.clear() routes = [] for route_dict in csv.DictReader(zipfile.read("routes.txt").decode().splitlines()): route_dict: dict route = Route( id=route_dict['route_id'], agency_id=route_dict['agency_id'], short_name=route_dict['route_short_name'], long_name=route_dict['route_long_name'], desc=route_dict['route_desc'], type=route_dict['route_type'], url=route_dict['route_url'], color=route_dict['route_color'], text_color=route_dict['route_text_color'], ) routes.append(route) if len(routes) >= bulk_size: Route.objects.bulk_create(routes, update_conflicts=True, update_fields=['agency_id', 'short_name', 'long_name', 'desc', 'type', 'url', 'color', 'text_color'], unique_fields=['id']) routes.clear() if routes: Route.objects.bulk_create(routes, update_conflicts=True, update_fields=['agency_id', 'short_name', 'long_name', 'desc', 'type', 'url', 'color', 'text_color'], unique_fields=['id']) routes.clear() calendar_ids = [] if "calendar.txt" in zipfile.namelist(): calendars = [] for calendar_dict in csv.DictReader(zipfile.read("calendar.txt").decode().splitlines()): calendar_dict: dict calendar = Calendar( id=f"{transport_type}-{calendar_dict['service_id']}", monday=calendar_dict['monday'], tuesday=calendar_dict['tuesday'], wednesday=calendar_dict['wednesday'], thursday=calendar_dict['thursday'], friday=calendar_dict['friday'], saturday=calendar_dict['saturday'], sunday=calendar_dict['sunday'], start_date=calendar_dict['start_date'], end_date=calendar_dict['end_date'], transport_type=transport_type, ) calendars.append(calendar) calendar_ids.append(calendar.id) if len(calendars) >= bulk_size: Calendar.objects.bulk_create(calendars, update_conflicts=True, update_fields=['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday', 'start_date', 'end_date', 'transport_type'], unique_fields=['id']) calendars.clear() if calendars: Calendar.objects.bulk_create(calendars, update_conflicts=True, update_fields=['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday', 'start_date', 'end_date', 'transport_type'], unique_fields=['id']) calendars.clear() calendars = [] calendar_dates = [] for calendar_date_dict in csv.DictReader(zipfile.read("calendar_dates.txt").decode().splitlines()): calendar_date_dict: dict calendar_date = CalendarDate( id=f"{transport_type}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}", service_id=f"{transport_type}-{calendar_date_dict['service_id']}", date=calendar_date_dict['date'], exception_type=calendar_date_dict['exception_type'], ) calendar_dates.append(calendar_date) if calendar_date.service_id not in calendar_ids: calendar = Calendar( id=f"{transport_type}-{calendar_date_dict['service_id']}", monday=False, tuesday=False, wednesday=False, thursday=False, friday=False, saturday=False, sunday=False, start_date=calendar_date_dict['date'], end_date=calendar_date_dict['date'], transport_type=transport_type, ) calendars.append(calendar) if len(calendar_dates) >= bulk_size: Calendar.objects.bulk_create(calendars, update_conflicts=True, update_fields=['end_date'], unique_fields=['id']) CalendarDate.objects.bulk_create(calendar_dates, update_conflicts=True, update_fields=['service_id', 'date', 'exception_type'], unique_fields=['id']) calendars.clear() calendar_dates.clear() if calendar_dates: Calendar.objects.bulk_create(calendars, update_conflicts=True, update_fields=['end_date'], unique_fields=['id']) CalendarDate.objects.bulk_create(calendar_dates, update_conflicts=True, update_fields=['service_id', 'date', 'exception_type'], unique_fields=['id']) calendars.clear() calendar_dates.clear() trips = [] for trip_dict in csv.DictReader(zipfile.read("trips.txt").decode().splitlines()): trip_dict: dict trip = Trip( id=trip_dict['trip_id'], route_id=trip_dict['route_id'], service_id=f"{transport_type}-{trip_dict['service_id']}", headsign=trip_dict['trip_headsign'], short_name=trip_dict.get('trip_short_name', ""), direction_id=trip_dict['direction_id'] or None, block_id=trip_dict['block_id'], shape_id=trip_dict['shape_id'], wheelchair_accessible=trip_dict.get('wheelchair_accessible', None), bikes_allowed=trip_dict.get('bikes_allowed', None), ) trips.append(trip) if len(trips) >= bulk_size: Trip.objects.bulk_create(trips, update_conflicts=True, update_fields=['route_id', 'service_id', 'headsign', 'short_name', 'direction_id', 'block_id', 'shape_id', 'wheelchair_accessible', 'bikes_allowed'], unique_fields=['id']) trips.clear() if trips: Trip.objects.bulk_create(trips, update_conflicts=True, update_fields=['route_id', 'service_id', 'headsign', 'short_name', 'direction_id', 'block_id', 'shape_id', 'wheelchair_accessible', 'bikes_allowed'], unique_fields=['id']) trips.clear() stop_times = [] for stop_time_dict in csv.DictReader(zipfile.read("stop_times.txt").decode().splitlines()): stop_time_dict: dict arr_time = stop_time_dict['arrival_time'] arr_time = int(arr_time[:2]) * 3600 + int(arr_time[3:5]) * 60 + int(arr_time[6:]) dep_time = stop_time_dict['departure_time'] dep_time = int(dep_time[:2]) * 3600 + int(dep_time[3:5]) * 60 + int(dep_time[6:]) st = StopTime( id=f"{stop_time_dict['trip_id']}-{stop_time_dict['stop_sequence']}", trip_id=stop_time_dict['trip_id'], arrival_time=timedelta(seconds=arr_time), departure_time=timedelta(seconds=dep_time), stop_id=stop_time_dict['stop_id'], stop_sequence=stop_time_dict['stop_sequence'], stop_headsign=stop_time_dict['stop_headsign'], pickup_type=stop_time_dict['pickup_type'], drop_off_type=stop_time_dict['drop_off_type'], timepoint=stop_time_dict.get('timepoint', None), ) stop_times.append(st) if len(stop_times) >= bulk_size: StopTime.objects.bulk_create(stop_times, update_conflicts=True, update_fields=['stop_id', 'arrival_time', 'departure_time', 'stop_headsign', 'pickup_type', 'drop_off_type', 'timepoint'], unique_fields=['id']) stop_times.clear() if stop_times: StopTime.objects.bulk_create(stop_times, update_conflicts=True, update_fields=['stop_id', 'arrival_time', 'departure_time', 'stop_headsign', 'pickup_type', 'drop_off_type', 'timepoint'], unique_fields=['id']) stop_times.clear() transfers = [] for transfer_dict in csv.DictReader(zipfile.read("transfers.txt").decode().splitlines()): transfer_dict: dict transfer = Transfer( id=f"{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}", from_stop_id=transfer_dict['from_stop_id'], to_stop_id=transfer_dict['to_stop_id'], transfer_type=transfer_dict['transfer_type'], min_transfer_time=transfer_dict['min_transfer_time'], ) transfers.append(transfer) if len(transfers) >= bulk_size: Transfer.objects.bulk_create(transfers, update_conflicts=True, update_fields=['transfer_type', 'min_transfer_time'], unique_fields=['id']) transfers.clear() if transfers: Transfer.objects.bulk_create(transfers, update_conflicts=True, update_fields=['transfer_type', 'min_transfer_time'], unique_fields=['id']) transfers.clear() if "feed_info.txt" in zipfile.namelist(): for feed_info_dict in csv.DictReader(zipfile.read("feed_info.txt").decode().splitlines()): feed_info_dict: dict FeedInfo.objects.update_or_create( publisher_name=feed_info_dict['feed_publisher_name'], defaults=dict( publisher_url=feed_info_dict['feed_publisher_url'], lang=feed_info_dict['feed_lang'], start_date=feed_info_dict['feed_start_date'], end_date=feed_info_dict['feed_end_date'], version=feed_info_dict['feed_version'], ) )