diff --git a/trainvel/gtfs/management/commands/update_trainvel_gtfs.py b/trainvel/gtfs/management/commands/update_trainvel_gtfs.py index 6b81f62..880e8ff 100644 --- a/trainvel/gtfs/management/commands/update_trainvel_gtfs.py +++ b/trainvel/gtfs/management/commands/update_trainvel_gtfs.py @@ -1,6 +1,8 @@ import csv +import os.path +import tempfile from datetime import datetime, timedelta -from io import BytesIO, TextIOWrapper +from time import time from zipfile import ZipFile from zoneinfo import ZoneInfo @@ -9,7 +11,7 @@ from django.core.management import BaseCommand from tqdm import tqdm from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, \ - Transfer, Trip, PickupType + Transfer, Trip, PickupType, TripUpdate class Command(BaseCommand): @@ -17,7 +19,7 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode") - parser.add_argument('--bulk_size', type=int, default=1000, help="Number of objects to create in bulk.") + parser.add_argument('--bulk_size', '-b', type=int, default=10000, help="Number of objects to create in bulk.") parser.add_argument('--dry-run', action='store_true', help="Do not update the database, only print what would be done.") parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.") @@ -30,8 +32,6 @@ class Command(BaseCommand): self.stdout.write("Updating database...") for gtfs_feed in GTFSFeed.objects.all(): - gtfs_code = gtfs_feed.code - if not force: # Check if the source file was updated resp = requests.head(gtfs_feed.feed_url, allow_redirects=True) @@ -51,340 +51,17 @@ class Command(BaseCommand): self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...") resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True) - with ZipFile(BytesIO(resp.content)) as zipfile: - def read_csv(filename): - with zipfile.open(filename, 'r') as zf: - with TextIOWrapper(zf, encoding='utf-8') as wrapper: - reader = csv.DictReader(wrapper) - reader.fieldnames = [field.replace('\ufeff', '').strip() - for field in reader.fieldnames] - for row in tqdm(reader, desc=filename, unit=' rows'): - yield {k.strip(): v.strip() for k, v in row.items()} - agencies = [] - for agency_dict in read_csv("agency.txt"): - agency_dict: dict - agency = Agency( - id=f"{gtfs_code}-{agency_dict['agency_id']}", - name=agency_dict['agency_name'], - url=agency_dict['agency_url'], - timezone=agency_dict['agency_timezone'], - lang=agency_dict.get('agency_lang', "fr"), - phone=agency_dict.get('agency_phone', ""), - email=agency_dict.get('agency_email', ""), - gtfs_feed=gtfs_feed, - ) - agencies.append(agency) - if agencies and not dry_run: - Agency.objects.bulk_create(agencies, - update_conflicts=True, - update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email', - 'gtfs_feed'], - unique_fields=['id']) - agencies.clear() + with tempfile.TemporaryFile(suffix=".zip") as file: + for chunk in resp.iter_content(chunk_size=128): + file.write(chunk) + file.seek(0) - stops = [] - for stop_dict in read_csv("stops.txt"): - stop_dict: dict - stop_id = stop_dict['stop_id'] - stop_id = f"{gtfs_code}-{stop_id}" + with tempfile.TemporaryDirectory() as tmp_dir: + with ZipFile(file) as zipfile: + zipfile.extractall(tmp_dir) - parent_station_id = stop_dict.get('parent_station', None) - parent_station_id = f"{gtfs_code}-{parent_station_id}" if parent_station_id else None - - stop = Stop( - id=stop_id, - name=stop_dict['stop_name'], - desc=stop_dict.get('stop_desc', ""), - lat=stop_dict['stop_lat'], - lon=stop_dict['stop_lon'], - zone_id=stop_dict.get('zone_id', ""), - url=stop_dict.get('stop_url', ""), - location_type=stop_dict.get('location_type', 0) or 0, - parent_station_id=parent_station_id, - timezone=stop_dict.get('stop_timezone', ""), - wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0), - level_id=stop_dict.get('level_id', ""), - platform_code=stop_dict.get('platform_code', ""), - gtfs_feed=gtfs_feed, - ) - stops.append(stop) - - if stops and not dry_run: - Stop.objects.bulk_create(stops, - batch_size=bulk_size, - update_conflicts=True, - update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url', - 'location_type', 'parent_station_id', 'timezone', - 'wheelchair_boarding', 'level_id', 'platform_code', - 'gtfs_feed'], - unique_fields=['id']) - stops.clear() - - routes = [] - for route_dict in read_csv("routes.txt"): - route_dict: dict - route_id = route_dict['route_id'] - route_id = f"{gtfs_code}-{route_id}" - # Agency is optional there is only one - agency_id = route_dict.get('agency_id', "") or Agency.objects.get(gtfs_feed=gtfs_feed) - route = Route( - id=route_id, - agency_id=f"{gtfs_code}-{agency_id}", - short_name=route_dict['route_short_name'], - long_name=route_dict['route_long_name'], - desc=route_dict.get('route_desc', ""), - type=route_dict['route_type'], - url=route_dict.get('route_url', ""), - color=route_dict.get('route_color', ""), - text_color=route_dict.get('route_text_color', ""), - gtfs_feed=gtfs_feed, - ) - routes.append(route) - - if len(routes) >= bulk_size and not dry_run: - Route.objects.bulk_create(routes, - update_conflicts=True, - update_fields=['agency_id', 'short_name', 'long_name', 'desc', - 'type', 'url', 'color', 'text_color', - 'gtfs_feed'], - unique_fields=['id']) - routes.clear() - if routes and not dry_run: - Route.objects.bulk_create(routes, - update_conflicts=True, - update_fields=['agency_id', 'short_name', 'long_name', 'desc', - 'type', 'url', 'color', 'text_color', - 'gtfs_feed'], - unique_fields=['id']) - routes.clear() - - Calendar.objects.filter(gtfs_feed=gtfs_feed).delete() - calendars = {} - if "calendar.txt" in zipfile.namelist(): - for calendar_dict in read_csv("calendar.txt"): - calendar_dict: dict - calendar = Calendar( - id=f"{gtfs_code}-{calendar_dict['service_id']}", - monday=calendar_dict['monday'], - tuesday=calendar_dict['tuesday'], - wednesday=calendar_dict['wednesday'], - thursday=calendar_dict['thursday'], - friday=calendar_dict['friday'], - saturday=calendar_dict['saturday'], - sunday=calendar_dict['sunday'], - start_date=calendar_dict['start_date'], - end_date=calendar_dict['end_date'], - gtfs_feed=gtfs_feed, - ) - calendars[calendar.id] = calendar - - if len(calendars) >= bulk_size and not dry_run: - Calendar.objects.bulk_create(calendars.values(), - update_conflicts=True, - update_fields=['monday', 'tuesday', 'wednesday', 'thursday', - 'friday', 'saturday', 'sunday', 'start_date', - 'end_date', 'gtfs_feed'], - unique_fields=['id']) - calendars.clear() - if calendars and not dry_run: - Calendar.objects.bulk_create(calendars.values(), update_conflicts=True, - update_fields=['monday', 'tuesday', 'wednesday', 'thursday', - 'friday', 'saturday', 'sunday', 'start_date', - 'end_date', 'gtfs_feed'], - unique_fields=['id']) - calendars.clear() - - calendar_dates = [] - for calendar_date_dict in read_csv("calendar_dates.txt"): - calendar_date_dict: dict - calendar_date = CalendarDate( - id=f"{gtfs_code}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}", - service_id=f"{gtfs_code}-{calendar_date_dict['service_id']}", - date=calendar_date_dict['date'], - exception_type=calendar_date_dict['exception_type'], - ) - calendar_dates.append(calendar_date) - - if calendar_date.service_id not in calendars: - calendar = Calendar( - id=f"{gtfs_code}-{calendar_date_dict['service_id']}", - monday=False, - tuesday=False, - wednesday=False, - thursday=False, - friday=False, - saturday=False, - sunday=False, - start_date=calendar_date_dict['date'], - end_date=calendar_date_dict['date'], - gtfs_feed=gtfs_feed, - ) - calendars[calendar.id] = calendar - else: - calendar = calendars[f"{gtfs_code}-{calendar_date_dict['service_id']}"] - if calendar.start_date > calendar_date.date: - calendar.start_date = calendar_date.date - if calendar.end_date < calendar_date.date: - calendar.end_date = calendar_date.date - - if calendar_dates and not dry_run: - Calendar.objects.bulk_create(calendars.values(), - batch_size=bulk_size, - update_conflicts=True, - update_fields=['start_date', 'end_date', 'gtfs_feed'], - unique_fields=['id']) - CalendarDate.objects.bulk_create(calendar_dates, - batch_size=bulk_size, - update_conflicts=True, - update_fields=['service_id', 'date', 'exception_type'], - unique_fields=['id']) - calendars.clear() - calendar_dates.clear() - - trips = [] - for trip_dict in read_csv("trips.txt"): - trip_dict: dict - trip_id = trip_dict['trip_id'] - route_id = trip_dict['route_id'] - trip_id = f"{gtfs_code}-{trip_id}" - route_id = f"{gtfs_code}-{route_id}" - trip = Trip( - id=trip_id, - route_id=route_id, - service_id=f"{gtfs_code}-{trip_dict['service_id']}", - headsign=trip_dict.get('trip_headsign', ""), - short_name=trip_dict.get('trip_short_name', ""), - direction_id=trip_dict.get('direction_id', None) or None, - block_id=trip_dict.get('block_id', ""), - shape_id=trip_dict.get('shape_id', ""), - wheelchair_accessible=trip_dict.get('wheelchair_accessible', None), - bikes_allowed=trip_dict.get('bikes_allowed', None), - gtfs_feed=gtfs_feed, - ) - trips.append(trip) - - if len(trips) >= bulk_size and not dry_run: - Trip.objects.bulk_create(trips, - update_conflicts=True, - update_fields=['route_id', 'service_id', 'headsign', 'short_name', - 'direction_id', 'block_id', 'shape_id', - 'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'], - unique_fields=['id']) - trips.clear() - if trips and not dry_run: - Trip.objects.bulk_create(trips, - update_conflicts=True, - update_fields=['route_id', 'service_id', 'headsign', 'short_name', - 'direction_id', 'block_id', 'shape_id', - 'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'], - unique_fields=['id']) - trips.clear() - - stop_times = [] - for stop_time_dict in read_csv("stop_times.txt"): - stop_time_dict: dict - - stop_id = stop_time_dict['stop_id'] - stop_id = f"{gtfs_code}-{stop_id}" - - trip_id = stop_time_dict['trip_id'] - trip_id = f"{gtfs_code}-{trip_id}" - - arr_time = stop_time_dict['arrival_time'] - arr_h, arr_m, arr_s = map(int, arr_time.split(':')) - arr_time = arr_h * 3600 + arr_m * 60 + arr_s - dep_time = stop_time_dict['departure_time'] - dep_h, dep_m, dep_s = map(int, dep_time.split(':')) - dep_time = dep_h * 3600 + dep_m * 60 + dep_s - - pickup_type = stop_time_dict.get('pickup_type', PickupType.REGULAR) - drop_off_type = stop_time_dict.get('drop_off_type', PickupType.REGULAR) - # if stop_time_dict['stop_sequence'] == "1": - # # First stop - # drop_off_type = PickupType.NONE - # elif arr_time == dep_time: - # # Last stop - # pickup_type = PickupType.NONE - - st = StopTime( - id=f"{gtfs_code}-{stop_time_dict['trip_id']}-{stop_time_dict['stop_id']}" - f"-{stop_time_dict['departure_time']}", - trip_id=trip_id, - arrival_time=timedelta(seconds=arr_time), - departure_time=timedelta(seconds=dep_time), - stop_id=stop_id, - stop_sequence=stop_time_dict['stop_sequence'], - stop_headsign=stop_time_dict.get('stop_headsign', ""), - pickup_type=pickup_type, - drop_off_type=drop_off_type, - timepoint=stop_time_dict.get('timepoint', None), - ) - stop_times.append(st) - - if len(stop_times) >= bulk_size and not dry_run: - StopTime.objects.bulk_create(stop_times, - update_conflicts=True, - update_fields=['stop_id', 'arrival_time', 'departure_time', - 'stop_headsign', 'pickup_type', - 'drop_off_type', 'timepoint'], - unique_fields=['id']) - stop_times.clear() - if stop_times and not dry_run: - StopTime.objects.bulk_create(stop_times, - update_conflicts=True, - update_fields=['stop_id', 'arrival_time', 'departure_time', - 'stop_headsign', 'pickup_type', - 'drop_off_type', 'timepoint'], - unique_fields=['id']) - stop_times.clear() - - if "transfers.txt" in zipfile.namelist(): - transfers = [] - for transfer_dict in read_csv("transfers.txt"): - transfer_dict: dict - from_stop_id = transfer_dict['from_stop_id'] - to_stop_id = transfer_dict['to_stop_id'] - from_stop_id = f"{gtfs_code}-{from_stop_id}" - to_stop_id = f"{gtfs_code}-{to_stop_id}" - - transfer = Transfer( - id=f"{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}", - from_stop_id=from_stop_id, - to_stop_id=to_stop_id, - transfer_type=transfer_dict['transfer_type'], - min_transfer_time=transfer_dict.get('min_transfer_time', 0) or 0, - ) - transfers.append(transfer) - - if len(transfers) >= bulk_size and not dry_run: - Transfer.objects.bulk_create(transfers, - update_conflicts=True, - update_fields=['transfer_type', 'min_transfer_time'], - unique_fields=['id']) - transfers.clear() - - if transfers and not dry_run: - Transfer.objects.bulk_create(transfers, - update_conflicts=True, - update_fields=['transfer_type', 'min_transfer_time'], - unique_fields=['id']) - transfers.clear() - - if "feed_info.txt" in zipfile.namelist() and not dry_run: - for feed_info_dict in read_csv("feed_info.txt"): - feed_info_dict: dict - FeedInfo.objects.update_or_create( - publisher_name=feed_info_dict['feed_publisher_name'], - gtfs_feed=gtfs_feed, - defaults=dict( - publisher_url=feed_info_dict['feed_publisher_url'], - lang=feed_info_dict['feed_lang'], - start_date=feed_info_dict.get('feed_start_date', datetime.now().date()), - end_date=feed_info_dict.get('feed_end_date', datetime.now().date()), - version=feed_info_dict.get('feed_version', 1), - ) - ) + self.parse_gtfs(tmp_dir, gtfs_feed, bulk_size, dry_run, verbosity) if 'ETag' in resp.headers: gtfs_feed.etag = resp.headers['ETag'] @@ -394,3 +71,344 @@ class Command(BaseCommand): gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \ .replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1])) gtfs_feed.save() + + def parse_gtfs(self, zip_dir: str, gtfs_feed: GTFSFeed, bulk_size: int, dry_run: bool, verbosity: int): + gtfs_code = gtfs_feed.code + + def read_csv(filename): + with open(os.path.join(zip_dir, filename), 'r') as f: + reader = csv.DictReader(f) + reader.fieldnames = [field.replace('\ufeff', '').strip() + for field in reader.fieldnames] + iterator = tqdm(reader, desc=filename, unit=' rows') if verbosity >= 2 else reader + for row in iterator: + yield {k.strip(): v.strip() for k, v in row.items()} + + agencies = [] + for agency_dict in read_csv("agency.txt"): + agency_dict: dict + agency = Agency( + id=f"{gtfs_code}-{agency_dict['agency_id']}", + name=agency_dict['agency_name'], + url=agency_dict['agency_url'], + timezone=agency_dict['agency_timezone'], + lang=agency_dict.get('agency_lang', "fr"), + phone=agency_dict.get('agency_phone', ""), + email=agency_dict.get('agency_email', ""), + gtfs_feed_id=gtfs_code, + ) + agencies.append(agency) + if agencies and not dry_run: + Agency.objects.bulk_create(agencies, + update_conflicts=True, + update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email', + 'gtfs_feed'], + unique_fields=['id']) + agencies.clear() + + stops = [] + for stop_dict in read_csv("stops.txt"): + stop_dict: dict + stop_id = stop_dict['stop_id'] + stop_id = f"{gtfs_code}-{stop_id}" + + parent_station_id = stop_dict.get('parent_station', None) + parent_station_id = f"{gtfs_code}-{parent_station_id}" if parent_station_id else None + + stop = Stop( + id=stop_id, + name=stop_dict['stop_name'], + desc=stop_dict.get('stop_desc', ""), + lat=stop_dict['stop_lat'], + lon=stop_dict['stop_lon'], + zone_id=stop_dict.get('zone_id', ""), + url=stop_dict.get('stop_url', ""), + location_type=stop_dict.get('location_type', 0) or 0, + parent_station_id=parent_station_id, + timezone=stop_dict.get('stop_timezone', ""), + wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0), + level_id=stop_dict.get('level_id', ""), + platform_code=stop_dict.get('platform_code', ""), + gtfs_feed_id=gtfs_code, + ) + stops.append(stop) + + if stops and not dry_run: + Stop.objects.bulk_create(stops, + batch_size=bulk_size, + update_conflicts=True, + update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url', + 'location_type', 'parent_station_id', 'timezone', + 'wheelchair_boarding', 'level_id', 'platform_code', + 'gtfs_feed'], + unique_fields=['id']) + stops.clear() + + routes = [] + for route_dict in read_csv("routes.txt"): + route_dict: dict + route_id = route_dict['route_id'] + route_id = f"{gtfs_code}-{route_id}" + # Agency is optional there is only one + agency_id = route_dict.get('agency_id', "") or Agency.objects.get(gtfs_feed_id=gtfs_code) + route = Route( + id=route_id, + agency_id=f"{gtfs_code}-{agency_id}", + short_name=route_dict['route_short_name'], + long_name=route_dict['route_long_name'], + desc=route_dict.get('route_desc', ""), + type=route_dict['route_type'], + url=route_dict.get('route_url', ""), + color=route_dict.get('route_color', ""), + text_color=route_dict.get('route_text_color', ""), + gtfs_feed_id=gtfs_code, + ) + routes.append(route) + + if len(routes) >= bulk_size and not dry_run: + Route.objects.bulk_create(routes, + update_conflicts=True, + update_fields=['agency_id', 'short_name', 'long_name', 'desc', + 'type', 'url', 'color', 'text_color', + 'gtfs_feed'], + unique_fields=['id']) + routes.clear() + if routes and not dry_run: + Route.objects.bulk_create(routes, + update_conflicts=True, + update_fields=['agency_id', 'short_name', 'long_name', 'desc', + 'type', 'url', 'color', 'text_color', + 'gtfs_feed'], + unique_fields=['id']) + routes.clear() + + start_time = 0 + if verbosity >= 1: + self.stdout.write("Deleting old calendars, trips and stop times…") + start_time = time() + + TripUpdate.objects.filter(trip__gtfs_feed_id=gtfs_code).delete() + StopTime.objects.filter(trip__gtfs_feed_id=gtfs_code)._raw_delete(StopTime.objects.db) + Trip.objects.filter(gtfs_feed_id=gtfs_code)._raw_delete(Trip.objects.db) + Calendar.objects.filter(gtfs_feed_id=gtfs_code).delete() + + if verbosity >= 1: + end = time() + self.stdout.write(f"Done in {end - start_time:.2f} s") + + calendars = {} + if os.path.exists(os.path.join(zip_dir, "calendar.txt")): + for calendar_dict in read_csv("calendar.txt"): + calendar_dict: dict + calendar = Calendar( + id=f"{gtfs_code}-{calendar_dict['service_id']}", + monday=calendar_dict['monday'], + tuesday=calendar_dict['tuesday'], + wednesday=calendar_dict['wednesday'], + thursday=calendar_dict['thursday'], + friday=calendar_dict['friday'], + saturday=calendar_dict['saturday'], + sunday=calendar_dict['sunday'], + start_date=calendar_dict['start_date'], + end_date=calendar_dict['end_date'], + gtfs_feed_id=gtfs_code, + ) + calendars[calendar.id] = calendar + + if len(calendars) >= bulk_size and not dry_run: + Calendar.objects.bulk_create(calendars.values(), + update_conflicts=True, + update_fields=['monday', 'tuesday', 'wednesday', 'thursday', + 'friday', 'saturday', 'sunday', 'start_date', + 'end_date', 'gtfs_feed'], + unique_fields=['id']) + calendars.clear() + if calendars and not dry_run: + Calendar.objects.bulk_create(calendars.values(), update_conflicts=True, + update_fields=['monday', 'tuesday', 'wednesday', 'thursday', + 'friday', 'saturday', 'sunday', 'start_date', + 'end_date', 'gtfs_feed'], + unique_fields=['id']) + calendars.clear() + + calendar_dates = [] + for calendar_date_dict in read_csv("calendar_dates.txt"): + calendar_date_dict: dict + calendar_date = CalendarDate( + id=f"{gtfs_code}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}", + service_id=f"{gtfs_code}-{calendar_date_dict['service_id']}", + date=calendar_date_dict['date'], + exception_type=calendar_date_dict['exception_type'], + ) + calendar_dates.append(calendar_date) + + if calendar_date.service_id not in calendars: + calendar = Calendar( + id=f"{gtfs_code}-{calendar_date_dict['service_id']}", + monday=False, + tuesday=False, + wednesday=False, + thursday=False, + friday=False, + saturday=False, + sunday=False, + start_date=calendar_date_dict['date'], + end_date=calendar_date_dict['date'], + gtfs_feed_id=gtfs_code, + ) + calendars[calendar.id] = calendar + else: + calendar = calendars[f"{gtfs_code}-{calendar_date_dict['service_id']}"] + if calendar.start_date > calendar_date.date: + calendar.start_date = calendar_date.date + if calendar.end_date < calendar_date.date: + calendar.end_date = calendar_date.date + + if calendar_dates and not dry_run: + Calendar.objects.bulk_create(calendars.values(), + batch_size=bulk_size, + update_conflicts=True, + update_fields=['start_date', 'end_date', 'gtfs_feed'], + unique_fields=['id']) + CalendarDate.objects.bulk_create(calendar_dates, + batch_size=bulk_size, + update_conflicts=True, + update_fields=['service_id', 'date', 'exception_type'], + unique_fields=['id']) + calendars.clear() + calendar_dates.clear() + + trips = [] + # start_time = time() + for trip_dict in read_csv("trips.txt"): + trip_dict: dict + trip_id = trip_dict['trip_id'] + route_id = trip_dict['route_id'] + trip_id = f"{gtfs_code}-{trip_id}" + route_id = f"{gtfs_code}-{route_id}" + trip = Trip( + id=trip_id, + route_id=route_id, + service_id=f"{gtfs_code}-{trip_dict['service_id']}", + headsign=trip_dict.get('trip_headsign', ""), + short_name=trip_dict.get('trip_short_name', ""), + direction_id=trip_dict.get('direction_id', None) or None, + block_id=trip_dict.get('block_id', ""), + shape_id=trip_dict.get('shape_id', ""), + wheelchair_accessible=trip_dict.get('wheelchair_accessible', None), + bikes_allowed=trip_dict.get('bikes_allowed', None), + gtfs_feed_id=gtfs_code, + ) + trips.append(trip) + + if len(trips) >= bulk_size and not dry_run: + # now = time() + # print(f"Elapsed time: {now - start_time:.3f}s, " + # f"{1000 * (now - start_time) / len(trips):.2f}ms per iteration") + # start_time = now + Trip.objects.bulk_create(trips) + # now = time() + # print(f"Elapsed time: {now - start_time:.3f}s to save") + # start_time = now + trips.clear() + if trips and not dry_run: + Trip.objects.bulk_create(trips) + trips.clear() + + stop_times = [] + # start_time = time() + for stop_time_dict in read_csv("stop_times.txt"): + stop_time_dict: dict + + stop_id = stop_time_dict['stop_id'] + stop_id = f"{gtfs_code}-{stop_id}" + + trip_id = stop_time_dict['trip_id'] + trip_id = f"{gtfs_code}-{trip_id}" + + arr_time = stop_time_dict['arrival_time'] + arr_h, arr_m, arr_s = map(int, arr_time.split(':')) + arr_time = arr_h * 3600 + arr_m * 60 + arr_s + dep_time = stop_time_dict['departure_time'] + dep_h, dep_m, dep_s = map(int, dep_time.split(':')) + dep_time = dep_h * 3600 + dep_m * 60 + dep_s + + pickup_type = stop_time_dict.get('pickup_type', PickupType.REGULAR) + drop_off_type = stop_time_dict.get('drop_off_type', PickupType.REGULAR) + # if stop_time_dict['stop_sequence'] == "1": + # # First stop + # drop_off_type = PickupType.NONE + # elif arr_time == dep_time: + # # Last stop + # pickup_type = PickupType.NONE + + st = StopTime( + id=f"{gtfs_code}-{stop_time_dict['trip_id']}-{stop_time_dict['stop_id']}" + f"-{stop_time_dict['departure_time']}", + trip_id=trip_id, + arrival_time=timedelta(seconds=arr_time), + departure_time=timedelta(seconds=dep_time), + stop_id=stop_id, + stop_sequence=stop_time_dict['stop_sequence'], + stop_headsign=stop_time_dict.get('stop_headsign', ""), + pickup_type=pickup_type, + drop_off_type=drop_off_type, + timepoint=stop_time_dict.get('timepoint', None), + ) + stop_times.append(st) + + if len(stop_times) >= bulk_size and not dry_run: + # now = time() + # print(f"Elapsed time: {now - start_time:.3f}s, " + # f"{1000 * (now - start_time) / len(stop_times):.2f}ms per iteration") + # start_time = now + StopTime.objects.bulk_create(stop_times) + # now = time() + # print(f"Elapsed time: {now - start_time:.3f}s to save") + # start_time = now + stop_times.clear() + + if stop_times and not dry_run: + StopTime.objects.bulk_create(stop_times) + stop_times.clear() + + if os.path.exists(os.path.join(zip_dir, "transfers.txt")): + transfers = [] + for transfer_dict in read_csv("transfers.txt"): + transfer_dict: dict + from_stop_id = transfer_dict['from_stop_id'] + to_stop_id = transfer_dict['to_stop_id'] + from_stop_id = f"{gtfs_code}-{from_stop_id}" + to_stop_id = f"{gtfs_code}-{to_stop_id}" + + transfer = Transfer( + id=f"{gtfs_code}-{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}", + from_stop_id=from_stop_id, + to_stop_id=to_stop_id, + transfer_type=transfer_dict['transfer_type'], + min_transfer_time=transfer_dict.get('min_transfer_time', 0) or 0, + ) + transfers.append(transfer) + + if len(transfers) >= bulk_size and not dry_run: + Transfer.objects.bulk_create(transfers) + transfers.clear() + + if transfers and not dry_run: + Transfer.objects.bulk_create(transfers) + transfers.clear() + + if os.path.exists(os.path.join(zip_dir, "feed_info.txt")) and not dry_run: + for feed_info_dict in read_csv("feed_info.txt"): + feed_info_dict: dict + FeedInfo.objects.update_or_create( + publisher_name=feed_info_dict['feed_publisher_name'], + gtfs_feed_id=gtfs_code, + defaults=dict( + publisher_url=feed_info_dict['feed_publisher_url'], + lang=feed_info_dict['feed_lang'], + start_date=feed_info_dict.get('feed_start_date', datetime.now().date()), + end_date=feed_info_dict.get('feed_end_date', datetime.now().date()), + version=feed_info_dict.get('feed_version', 1), + ) + ) diff --git a/trainvel/gtfs/migrations/0002_alter_stop_parent_station.py b/trainvel/gtfs/migrations/0002_alter_stop_parent_station.py new file mode 100644 index 0000000..57d8156 --- /dev/null +++ b/trainvel/gtfs/migrations/0002_alter_stop_parent_station.py @@ -0,0 +1,26 @@ +# Generated by Django 5.0.6 on 2024-05-12 09:31 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("gtfs", "0001_initial"), + ] + + operations = [ + migrations.AlterField( + model_name="stop", + name="parent_station", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="children", + to="gtfs.stop", + verbose_name="Parent station", + ), + ), + ] diff --git a/trainvel/gtfs/models.py b/trainvel/gtfs/models.py index 69e892a..afd465c 100644 --- a/trainvel/gtfs/models.py +++ b/trainvel/gtfs/models.py @@ -288,7 +288,7 @@ class Stop(models.Model): parent_station = models.ForeignKey( to="Stop", - on_delete=models.PROTECT, + on_delete=models.SET_NULL, verbose_name=_("Parent station"), related_name="children", blank=True,