import csv from datetime import datetime, timedelta from io import BytesIO from zipfile import ZipFile from zoneinfo import ZoneInfo import requests from django.core.management import BaseCommand from sncfgtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, Transfer, Trip, \ PickupType class Command(BaseCommand): help = "Update the SNCF GTFS database." def add_arguments(self, parser): parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode") parser.add_argument('--bulk_size', type=int, default=1000, help="Number of objects to create in bulk.") parser.add_argument('--dry-run', action='store_true', help="Do not update the database, only print what would be done.") parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.") def handle(self, debug: bool = False, bulk_size: int = 100, dry_run: bool = False, force: bool = False, verbosity: int = 1, *args, **options): if dry_run: self.stdout.write(self.style.WARNING("Dry run mode activated.")) self.stdout.write("Updating database...") for gtfs_feed in GTFSFeed.objects.all(): gtfs_code = gtfs_feed.code if not force: # Check if the source file was updated resp = requests.head(gtfs_feed.feed_url, allow_redirects=True) if 'ETag' in resp.headers and gtfs_feed.etag: if resp.headers['ETag'] == gtfs_feed.etag: if verbosity >= 1: self.stdout.write(self.style.WARNING(f"Database is already up-to-date for {gtfs_feed}.")) continue if 'Last-Modified' in resp.headers and gtfs_feed.last_modified: last_modified = resp.headers['Last-Modified'] last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \ .replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1])) if last_modified <= gtfs_feed.last_modified: if verbosity >= 1: self.stdout.write(self.style.WARNING(f"Database is already up-to-date for {gtfs_feed}.")) continue self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...") resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True) with ZipFile(BytesIO(resp.content)) as zipfile: def read_file(filename): lines = zipfile.read(filename).decode().replace('\ufeff', '').splitlines() return [line.strip() for line in lines] agencies = [] for agency_dict in csv.DictReader(read_file("agency.txt")): agency_dict: dict # if gtfs_code == "FR-EUROSTAR" \ # and agency_dict['agency_id'] != 'ES' and agency_dict['agency_id'] != 'ER': # continue agency = Agency( id=f"{gtfs_code}-{agency_dict['agency_id']}", name=agency_dict['agency_name'], url=agency_dict['agency_url'], timezone=agency_dict['agency_timezone'], lang=agency_dict.get('agency_lang', "fr"), phone=agency_dict.get('agency_phone', ""), email=agency_dict.get('agency_email', ""), gtfs_feed=gtfs_feed, ) agencies.append(agency) if agencies and not dry_run: Agency.objects.bulk_create(agencies, update_conflicts=True, update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email', 'gtfs_feed'], unique_fields=['id']) agencies.clear() stops = [] for stop_dict in csv.DictReader(read_file("stops.txt")): stop_dict: dict stop_id = stop_dict['stop_id'] stop_id = f"{gtfs_code}-{stop_id}" parent_station_id = stop_dict.get('parent_station', None) parent_station_id = f"{gtfs_code}-{parent_station_id}" if parent_station_id else None stop = Stop( id=stop_id, name=stop_dict['stop_name'], desc=stop_dict.get('stop_desc', ""), lat=stop_dict['stop_lat'], lon=stop_dict['stop_lon'], zone_id=stop_dict.get('zone_id', ""), url=stop_dict.get('stop_url', ""), location_type=stop_dict.get('location_type', 0) or 0, parent_station_id=parent_station_id, timezone=stop_dict.get('stop_timezone', ""), wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0), level_id=stop_dict.get('level_id', ""), platform_code=stop_dict.get('platform_code', ""), gtfs_feed=gtfs_feed, ) stops.append(stop) if stops and not dry_run: Stop.objects.bulk_create(stops, batch_size=bulk_size, update_conflicts=True, update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url', 'location_type', 'parent_station_id', 'timezone', 'wheelchair_boarding', 'level_id', 'platform_code', 'gtfs_feed'], unique_fields=['id']) stops.clear() routes = [] for route_dict in csv.DictReader(read_file("routes.txt")): route_dict: dict route_id = route_dict['route_id'] route_id = f"{gtfs_code}-{route_id}" route = Route( id=route_id, agency_id=f"{gtfs_code}-{route_dict['agency_id']}", short_name=route_dict['route_short_name'], long_name=route_dict['route_long_name'], desc=route_dict.get('route_desc', ""), type=route_dict['route_type'], url=route_dict.get('route_url', ""), color=route_dict.get('route_color', ""), text_color=route_dict.get('route_text_color', ""), gtfs_feed=gtfs_feed, ) routes.append(route) if len(routes) >= bulk_size and not dry_run: Route.objects.bulk_create(routes, update_conflicts=True, update_fields=['agency_id', 'short_name', 'long_name', 'desc', 'type', 'url', 'color', 'text_color', 'gtfs_feed'], unique_fields=['id']) routes.clear() if routes and not dry_run: Route.objects.bulk_create(routes, update_conflicts=True, update_fields=['agency_id', 'short_name', 'long_name', 'desc', 'type', 'url', 'color', 'text_color', 'gtfs_feed'], unique_fields=['id']) routes.clear() Calendar.objects.filter(gtfs_feed=gtfs_feed).delete() calendars = {} if "calendar.txt" in zipfile.namelist(): for calendar_dict in csv.DictReader(read_file("calendar.txt")): calendar_dict: dict calendar = Calendar( id=f"{gtfs_code}-{calendar_dict['service_id']}", monday=calendar_dict['monday'], tuesday=calendar_dict['tuesday'], wednesday=calendar_dict['wednesday'], thursday=calendar_dict['thursday'], friday=calendar_dict['friday'], saturday=calendar_dict['saturday'], sunday=calendar_dict['sunday'], start_date=calendar_dict['start_date'], end_date=calendar_dict['end_date'], gtfs_feed=gtfs_feed, ) calendars[calendar.id] = calendar if len(calendars) >= bulk_size and not dry_run: Calendar.objects.bulk_create(calendars.values(), update_conflicts=True, update_fields=['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday', 'start_date', 'end_date', 'gtfs_feed'], unique_fields=['id']) calendars.clear() if calendars and not dry_run: Calendar.objects.bulk_create(calendars.values(), update_conflicts=True, update_fields=['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday', 'start_date', 'end_date', 'gtfs_feed'], unique_fields=['id']) calendars.clear() calendar_dates = [] for calendar_date_dict in csv.DictReader(read_file("calendar_dates.txt")): calendar_date_dict: dict calendar_date = CalendarDate( id=f"{gtfs_code}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}", service_id=f"{gtfs_code}-{calendar_date_dict['service_id']}", date=calendar_date_dict['date'], exception_type=calendar_date_dict['exception_type'], ) calendar_dates.append(calendar_date) if calendar_date.service_id not in calendars: calendar = Calendar( id=f"{gtfs_code}-{calendar_date_dict['service_id']}", monday=False, tuesday=False, wednesday=False, thursday=False, friday=False, saturday=False, sunday=False, start_date=calendar_date_dict['date'], end_date=calendar_date_dict['date'], gtfs_feed=gtfs_feed, ) calendars[calendar.id] = calendar else: calendar = calendars[f"{gtfs_code}-{calendar_date_dict['service_id']}"] if calendar.start_date > calendar_date.date: calendar.start_date = calendar_date.date if calendar.end_date < calendar_date.date: calendar.end_date = calendar_date.date if calendar_dates and not dry_run: Calendar.objects.bulk_create(calendars.values(), batch_size=bulk_size, update_conflicts=True, update_fields=['start_date', 'end_date', 'gtfs_feed'], unique_fields=['id']) CalendarDate.objects.bulk_create(calendar_dates, batch_size=bulk_size, update_conflicts=True, update_fields=['service_id', 'date', 'exception_type'], unique_fields=['id']) calendars.clear() calendar_dates.clear() trips = [] for trip_dict in csv.DictReader(read_file("trips.txt")): trip_dict: dict trip_id = trip_dict['trip_id'] route_id = trip_dict['route_id'] trip_id = f"{gtfs_code}-{trip_id}" route_id = f"{gtfs_code}-{route_id}" trip = Trip( id=trip_id, route_id=route_id, service_id=f"{gtfs_code}-{trip_dict['service_id']}", headsign=trip_dict.get('trip_headsign', ""), short_name=trip_dict.get('trip_short_name', ""), direction_id=trip_dict.get('direction_id', None) or None, block_id=trip_dict.get('block_id', ""), shape_id=trip_dict.get('shape_id', ""), wheelchair_accessible=trip_dict.get('wheelchair_accessible', None), bikes_allowed=trip_dict.get('bikes_allowed', None), gtfs_feed=gtfs_feed, ) trips.append(trip) if len(trips) >= bulk_size and not dry_run: Trip.objects.bulk_create(trips, update_conflicts=True, update_fields=['route_id', 'service_id', 'headsign', 'short_name', 'direction_id', 'block_id', 'shape_id', 'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'], unique_fields=['id']) trips.clear() if trips and not dry_run: Trip.objects.bulk_create(trips, update_conflicts=True, update_fields=['route_id', 'service_id', 'headsign', 'short_name', 'direction_id', 'block_id', 'shape_id', 'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'], unique_fields=['id']) trips.clear() stop_times = [] for stop_time_dict in csv.DictReader(read_file("stop_times.txt")): stop_time_dict: dict stop_id = stop_time_dict['stop_id'] stop_id = f"{gtfs_code}-{stop_id}" trip_id = stop_time_dict['trip_id'] trip_id = f"{gtfs_code}-{trip_id}" arr_time = stop_time_dict['arrival_time'] arr_h, arr_m, arr_s = map(int, arr_time.split(':')) arr_time = arr_h * 3600 + arr_m * 60 + arr_s dep_time = stop_time_dict['departure_time'] dep_h, dep_m, dep_s = map(int, dep_time.split(':')) dep_time = dep_h * 3600 + dep_m * 60 + dep_s pickup_type = stop_time_dict.get('pickup_type', 0) drop_off_type = stop_time_dict.get('drop_off_type', 0) if stop_time_dict['stop_sequence'] == "1": # First stop drop_off_type = PickupType.NONE elif arr_time == dep_time: # Last stop pickup_type = PickupType.NONE st = StopTime( id=f"{gtfs_code}-{stop_time_dict['trip_id']}-{stop_time_dict['stop_id']}" f"-{stop_time_dict['departure_time']}", trip_id=trip_id, arrival_time=timedelta(seconds=arr_time), departure_time=timedelta(seconds=dep_time), stop_id=stop_id, stop_sequence=stop_time_dict['stop_sequence'], stop_headsign=stop_time_dict.get('stop_headsign', ""), pickup_type=pickup_type, drop_off_type=drop_off_type, timepoint=stop_time_dict.get('timepoint', None), ) stop_times.append(st) if len(stop_times) >= bulk_size and not dry_run: StopTime.objects.bulk_create(stop_times, update_conflicts=True, update_fields=['stop_id', 'arrival_time', 'departure_time', 'stop_headsign', 'pickup_type', 'drop_off_type', 'timepoint'], unique_fields=['id']) stop_times.clear() if stop_times and not dry_run: StopTime.objects.bulk_create(stop_times, update_conflicts=True, update_fields=['stop_id', 'arrival_time', 'departure_time', 'stop_headsign', 'pickup_type', 'drop_off_type', 'timepoint'], unique_fields=['id']) stop_times.clear() if "transfers.txt" in zipfile.namelist(): transfers = [] for transfer_dict in csv.DictReader(read_file("transfers.txt")): transfer_dict: dict from_stop_id = transfer_dict['from_stop_id'] to_stop_id = transfer_dict['to_stop_id'] from_stop_id = f"{gtfs_code}-{from_stop_id}" to_stop_id = f"{gtfs_code}-{to_stop_id}" transfer = Transfer( id=f"{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}", from_stop_id=from_stop_id, to_stop_id=to_stop_id, transfer_type=transfer_dict['transfer_type'], min_transfer_time=transfer_dict['min_transfer_time'], ) transfers.append(transfer) if len(transfers) >= bulk_size and not dry_run: Transfer.objects.bulk_create(transfers, update_conflicts=True, update_fields=['transfer_type', 'min_transfer_time'], unique_fields=['id']) transfers.clear() if transfers and not dry_run: Transfer.objects.bulk_create(transfers, update_conflicts=True, update_fields=['transfer_type', 'min_transfer_time'], unique_fields=['id']) transfers.clear() if "feed_info.txt" in zipfile.namelist() and not dry_run: for feed_info_dict in csv.DictReader(read_file("feed_info.txt")): feed_info_dict: dict FeedInfo.objects.update_or_create( publisher_name=feed_info_dict['feed_publisher_name'], gtfs_feed=gtfs_feed, defaults=dict( publisher_url=feed_info_dict['feed_publisher_url'], lang=feed_info_dict['feed_lang'], start_date=feed_info_dict.get('feed_start_date', datetime.now().date()), end_date=feed_info_dict.get('feed_end_date', datetime.now().date()), version=feed_info_dict.get('feed_version', 1), ) ) if 'ETag' in resp.headers: gtfs_feed.etag = resp.headers['ETag'] gtfs_feed.save() if 'Last-Modified' in resp.headers: last_modified = resp.headers['Last-Modified'] gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \ .replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1])) gtfs_feed.save()