From 86d274ac846b80f4ad4f55ea690e3c1f66832024 Mon Sep 17 00:00:00 2001 From: Emmy D'Anello Date: Sat, 11 May 2024 23:18:57 +0200 Subject: [PATCH] Optimize CSV processing --- .../commands/update_trainvel_gtfs.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/trainvel/gtfs/management/commands/update_trainvel_gtfs.py b/trainvel/gtfs/management/commands/update_trainvel_gtfs.py index b653940..6b81f62 100644 --- a/trainvel/gtfs/management/commands/update_trainvel_gtfs.py +++ b/trainvel/gtfs/management/commands/update_trainvel_gtfs.py @@ -1,6 +1,6 @@ import csv from datetime import datetime, timedelta -from io import BytesIO +from io import BytesIO, TextIOWrapper from zipfile import ZipFile from zoneinfo import ZoneInfo @@ -8,8 +8,8 @@ import requests from django.core.management import BaseCommand from tqdm import tqdm -from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, Transfer, Trip, \ - PickupType +from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, \ + Transfer, Trip, PickupType class Command(BaseCommand): @@ -52,16 +52,18 @@ class Command(BaseCommand): self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...") resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True) with ZipFile(BytesIO(resp.content)) as zipfile: - def read_file(filename): - lines = zipfile.read(filename).decode().replace('\ufeff', '').splitlines() - return [line.strip() for line in lines] + def read_csv(filename): + with zipfile.open(filename, 'r') as zf: + with TextIOWrapper(zf, encoding='utf-8') as wrapper: + reader = csv.DictReader(wrapper) + reader.fieldnames = [field.replace('\ufeff', '').strip() + for field in reader.fieldnames] + for row in tqdm(reader, desc=filename, unit=' rows'): + yield {k.strip(): v.strip() for k, v in row.items()} agencies = [] - for agency_dict in csv.DictReader(read_file("agency.txt")): + for agency_dict in read_csv("agency.txt"): agency_dict: dict - # if gtfs_code == "FR-EUROSTAR" \ - # and agency_dict['agency_id'] != 'ES' and agency_dict['agency_id'] != 'ER': - # continue agency = Agency( id=f"{gtfs_code}-{agency_dict['agency_id']}", name=agency_dict['agency_name'], @@ -82,7 +84,7 @@ class Command(BaseCommand): agencies.clear() stops = [] - for stop_dict in csv.DictReader(tqdm(read_file("stops.txt"), desc="Stops")): + for stop_dict in read_csv("stops.txt"): stop_dict: dict stop_id = stop_dict['stop_id'] stop_id = f"{gtfs_code}-{stop_id}" @@ -120,7 +122,7 @@ class Command(BaseCommand): stops.clear() routes = [] - for route_dict in csv.DictReader(tqdm(read_file("routes.txt"), desc="Routes")): + for route_dict in read_csv("routes.txt"): route_dict: dict route_id = route_dict['route_id'] route_id = f"{gtfs_code}-{route_id}" @@ -160,7 +162,7 @@ class Command(BaseCommand): Calendar.objects.filter(gtfs_feed=gtfs_feed).delete() calendars = {} if "calendar.txt" in zipfile.namelist(): - for calendar_dict in csv.DictReader(tqdm(read_file("calendar.txt"), desc="Calendars")): + for calendar_dict in read_csv("calendar.txt"): calendar_dict: dict calendar = Calendar( id=f"{gtfs_code}-{calendar_dict['service_id']}", @@ -194,7 +196,7 @@ class Command(BaseCommand): calendars.clear() calendar_dates = [] - for calendar_date_dict in csv.DictReader(tqdm(read_file("calendar_dates.txt"), desc="Calendar dates")): + for calendar_date_dict in read_csv("calendar_dates.txt"): calendar_date_dict: dict calendar_date = CalendarDate( id=f"{gtfs_code}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}", @@ -241,7 +243,7 @@ class Command(BaseCommand): calendar_dates.clear() trips = [] - for trip_dict in csv.DictReader(tqdm(read_file("trips.txt"), desc="Trips")): + for trip_dict in read_csv("trips.txt"): trip_dict: dict trip_id = trip_dict['trip_id'] route_id = trip_dict['route_id'] @@ -280,7 +282,7 @@ class Command(BaseCommand): trips.clear() stop_times = [] - for stop_time_dict in csv.DictReader(tqdm(read_file("stop_times.txt"), desc="Stop times")): + for stop_time_dict in read_csv("stop_times.txt"): stop_time_dict: dict stop_id = stop_time_dict['stop_id'] @@ -339,7 +341,7 @@ class Command(BaseCommand): if "transfers.txt" in zipfile.namelist(): transfers = [] - for transfer_dict in csv.DictReader(tqdm(read_file("transfers.txt"), desc="Transfers")): + for transfer_dict in read_csv("transfers.txt"): transfer_dict: dict from_stop_id = transfer_dict['from_stop_id'] to_stop_id = transfer_dict['to_stop_id'] @@ -370,7 +372,7 @@ class Command(BaseCommand): transfers.clear() if "feed_info.txt" in zipfile.namelist() and not dry_run: - for feed_info_dict in csv.DictReader(tqdm(read_file("feed_info.txt"), desc="Feed info")): + for feed_info_dict in read_csv("feed_info.txt"): feed_info_dict: dict FeedInfo.objects.update_or_create( publisher_name=feed_info_dict['feed_publisher_name'],