Compare commits

..

2 Commits

Author SHA1 Message Date
b85a1b7734
More optimization 2024-05-12 11:52:44 +02:00
eade9e84de
Install django-extensions + update requirements.txt 2024-05-12 10:03:12 +02:00
5 changed files with 388 additions and 342 deletions

View File

@ -1,7 +1,8 @@
Django>=5.0,<6.0 Django>=5.0.4,<6.0
django-cors-headers django-cors-headers~=4.3.1
django-filter~=23.5 django-extensions~=3.2.3
django-filter~=24.2
djangorestframework~=3.14.0 djangorestframework~=3.14.0
protobuf protobuf~=5.26.1
requests~=2.31.0 requests~=2.31.0
tqdm tqdm~=4.66.4

View File

@ -1,6 +1,8 @@
import csv import csv
import os.path
import tempfile
from datetime import datetime, timedelta from datetime import datetime, timedelta
from io import BytesIO, TextIOWrapper from time import time
from zipfile import ZipFile from zipfile import ZipFile
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
@ -9,7 +11,7 @@ from django.core.management import BaseCommand
from tqdm import tqdm from tqdm import tqdm
from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, \ from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, \
Transfer, Trip, PickupType Transfer, Trip, PickupType, TripUpdate
class Command(BaseCommand): class Command(BaseCommand):
@ -17,7 +19,7 @@ class Command(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode") parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode")
parser.add_argument('--bulk_size', type=int, default=1000, help="Number of objects to create in bulk.") parser.add_argument('--bulk_size', '-b', type=int, default=10000, help="Number of objects to create in bulk.")
parser.add_argument('--dry-run', action='store_true', parser.add_argument('--dry-run', action='store_true',
help="Do not update the database, only print what would be done.") help="Do not update the database, only print what would be done.")
parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.") parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.")
@ -30,8 +32,6 @@ class Command(BaseCommand):
self.stdout.write("Updating database...") self.stdout.write("Updating database...")
for gtfs_feed in GTFSFeed.objects.all(): for gtfs_feed in GTFSFeed.objects.all():
gtfs_code = gtfs_feed.code
if not force: if not force:
# Check if the source file was updated # Check if the source file was updated
resp = requests.head(gtfs_feed.feed_url, allow_redirects=True) resp = requests.head(gtfs_feed.feed_url, allow_redirects=True)
@ -51,14 +51,37 @@ class Command(BaseCommand):
self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...") self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...")
resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True) resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True)
with ZipFile(BytesIO(resp.content)) as zipfile:
with tempfile.TemporaryFile(suffix=".zip") as file:
for chunk in resp.iter_content(chunk_size=128):
file.write(chunk)
file.seek(0)
with tempfile.TemporaryDirectory() as tmp_dir:
with ZipFile(file) as zipfile:
zipfile.extractall(tmp_dir)
self.parse_gtfs(tmp_dir, gtfs_feed, bulk_size, dry_run, verbosity)
if 'ETag' in resp.headers:
gtfs_feed.etag = resp.headers['ETag']
gtfs_feed.save()
if 'Last-Modified' in resp.headers:
last_modified = resp.headers['Last-Modified']
gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
gtfs_feed.save()
def parse_gtfs(self, zip_dir: str, gtfs_feed: GTFSFeed, bulk_size: int, dry_run: bool, verbosity: int):
gtfs_code = gtfs_feed.code
def read_csv(filename): def read_csv(filename):
with zipfile.open(filename, 'r') as zf: with open(os.path.join(zip_dir, filename), 'r') as f:
with TextIOWrapper(zf, encoding='utf-8') as wrapper: reader = csv.DictReader(f)
reader = csv.DictReader(wrapper)
reader.fieldnames = [field.replace('\ufeff', '').strip() reader.fieldnames = [field.replace('\ufeff', '').strip()
for field in reader.fieldnames] for field in reader.fieldnames]
for row in tqdm(reader, desc=filename, unit=' rows'): iterator = tqdm(reader, desc=filename, unit=' rows') if verbosity >= 2 else reader
for row in iterator:
yield {k.strip(): v.strip() for k, v in row.items()} yield {k.strip(): v.strip() for k, v in row.items()}
agencies = [] agencies = []
@ -72,7 +95,7 @@ class Command(BaseCommand):
lang=agency_dict.get('agency_lang', "fr"), lang=agency_dict.get('agency_lang', "fr"),
phone=agency_dict.get('agency_phone', ""), phone=agency_dict.get('agency_phone', ""),
email=agency_dict.get('agency_email', ""), email=agency_dict.get('agency_email', ""),
gtfs_feed=gtfs_feed, gtfs_feed_id=gtfs_code,
) )
agencies.append(agency) agencies.append(agency)
if agencies and not dry_run: if agencies and not dry_run:
@ -106,7 +129,7 @@ class Command(BaseCommand):
wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0), wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0),
level_id=stop_dict.get('level_id', ""), level_id=stop_dict.get('level_id', ""),
platform_code=stop_dict.get('platform_code', ""), platform_code=stop_dict.get('platform_code', ""),
gtfs_feed=gtfs_feed, gtfs_feed_id=gtfs_code,
) )
stops.append(stop) stops.append(stop)
@ -127,7 +150,7 @@ class Command(BaseCommand):
route_id = route_dict['route_id'] route_id = route_dict['route_id']
route_id = f"{gtfs_code}-{route_id}" route_id = f"{gtfs_code}-{route_id}"
# Agency is optional there is only one # Agency is optional there is only one
agency_id = route_dict.get('agency_id', "") or Agency.objects.get(gtfs_feed=gtfs_feed) agency_id = route_dict.get('agency_id', "") or Agency.objects.get(gtfs_feed_id=gtfs_code)
route = Route( route = Route(
id=route_id, id=route_id,
agency_id=f"{gtfs_code}-{agency_id}", agency_id=f"{gtfs_code}-{agency_id}",
@ -138,7 +161,7 @@ class Command(BaseCommand):
url=route_dict.get('route_url', ""), url=route_dict.get('route_url', ""),
color=route_dict.get('route_color', ""), color=route_dict.get('route_color', ""),
text_color=route_dict.get('route_text_color', ""), text_color=route_dict.get('route_text_color', ""),
gtfs_feed=gtfs_feed, gtfs_feed_id=gtfs_code,
) )
routes.append(route) routes.append(route)
@ -159,9 +182,22 @@ class Command(BaseCommand):
unique_fields=['id']) unique_fields=['id'])
routes.clear() routes.clear()
Calendar.objects.filter(gtfs_feed=gtfs_feed).delete() start_time = 0
if verbosity >= 1:
self.stdout.write("Deleting old calendars, trips and stop times…")
start_time = time()
TripUpdate.objects.filter(trip__gtfs_feed_id=gtfs_code).delete()
StopTime.objects.filter(trip__gtfs_feed_id=gtfs_code)._raw_delete(StopTime.objects.db)
Trip.objects.filter(gtfs_feed_id=gtfs_code)._raw_delete(Trip.objects.db)
Calendar.objects.filter(gtfs_feed_id=gtfs_code).delete()
if verbosity >= 1:
end = time()
self.stdout.write(f"Done in {end - start_time:.2f} s")
calendars = {} calendars = {}
if "calendar.txt" in zipfile.namelist(): if os.path.exists(os.path.join(zip_dir, "calendar.txt")):
for calendar_dict in read_csv("calendar.txt"): for calendar_dict in read_csv("calendar.txt"):
calendar_dict: dict calendar_dict: dict
calendar = Calendar( calendar = Calendar(
@ -175,7 +211,7 @@ class Command(BaseCommand):
sunday=calendar_dict['sunday'], sunday=calendar_dict['sunday'],
start_date=calendar_dict['start_date'], start_date=calendar_dict['start_date'],
end_date=calendar_dict['end_date'], end_date=calendar_dict['end_date'],
gtfs_feed=gtfs_feed, gtfs_feed_id=gtfs_code,
) )
calendars[calendar.id] = calendar calendars[calendar.id] = calendar
@ -218,7 +254,7 @@ class Command(BaseCommand):
sunday=False, sunday=False,
start_date=calendar_date_dict['date'], start_date=calendar_date_dict['date'],
end_date=calendar_date_dict['date'], end_date=calendar_date_dict['date'],
gtfs_feed=gtfs_feed, gtfs_feed_id=gtfs_code,
) )
calendars[calendar.id] = calendar calendars[calendar.id] = calendar
else: else:
@ -243,6 +279,7 @@ class Command(BaseCommand):
calendar_dates.clear() calendar_dates.clear()
trips = [] trips = []
# start_time = time()
for trip_dict in read_csv("trips.txt"): for trip_dict in read_csv("trips.txt"):
trip_dict: dict trip_dict: dict
trip_id = trip_dict['trip_id'] trip_id = trip_dict['trip_id']
@ -260,28 +297,26 @@ class Command(BaseCommand):
shape_id=trip_dict.get('shape_id', ""), shape_id=trip_dict.get('shape_id', ""),
wheelchair_accessible=trip_dict.get('wheelchair_accessible', None), wheelchair_accessible=trip_dict.get('wheelchair_accessible', None),
bikes_allowed=trip_dict.get('bikes_allowed', None), bikes_allowed=trip_dict.get('bikes_allowed', None),
gtfs_feed=gtfs_feed, gtfs_feed_id=gtfs_code,
) )
trips.append(trip) trips.append(trip)
if len(trips) >= bulk_size and not dry_run: if len(trips) >= bulk_size and not dry_run:
Trip.objects.bulk_create(trips, # now = time()
update_conflicts=True, # print(f"Elapsed time: {now - start_time:.3f}s, "
update_fields=['route_id', 'service_id', 'headsign', 'short_name', # f"{1000 * (now - start_time) / len(trips):.2f}ms per iteration")
'direction_id', 'block_id', 'shape_id', # start_time = now
'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'], Trip.objects.bulk_create(trips)
unique_fields=['id']) # now = time()
# print(f"Elapsed time: {now - start_time:.3f}s to save")
# start_time = now
trips.clear() trips.clear()
if trips and not dry_run: if trips and not dry_run:
Trip.objects.bulk_create(trips, Trip.objects.bulk_create(trips)
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'],
unique_fields=['id'])
trips.clear() trips.clear()
stop_times = [] stop_times = []
# start_time = time()
for stop_time_dict in read_csv("stop_times.txt"): for stop_time_dict in read_csv("stop_times.txt"):
stop_time_dict: dict stop_time_dict: dict
@ -323,23 +358,21 @@ class Command(BaseCommand):
stop_times.append(st) stop_times.append(st)
if len(stop_times) >= bulk_size and not dry_run: if len(stop_times) >= bulk_size and not dry_run:
StopTime.objects.bulk_create(stop_times, # now = time()
update_conflicts=True, # print(f"Elapsed time: {now - start_time:.3f}s, "
update_fields=['stop_id', 'arrival_time', 'departure_time', # f"{1000 * (now - start_time) / len(stop_times):.2f}ms per iteration")
'stop_headsign', 'pickup_type', # start_time = now
'drop_off_type', 'timepoint'], StopTime.objects.bulk_create(stop_times)
unique_fields=['id']) # now = time()
stop_times.clear() # print(f"Elapsed time: {now - start_time:.3f}s to save")
if stop_times and not dry_run: # start_time = now
StopTime.objects.bulk_create(stop_times,
update_conflicts=True,
update_fields=['stop_id', 'arrival_time', 'departure_time',
'stop_headsign', 'pickup_type',
'drop_off_type', 'timepoint'],
unique_fields=['id'])
stop_times.clear() stop_times.clear()
if "transfers.txt" in zipfile.namelist(): if stop_times and not dry_run:
StopTime.objects.bulk_create(stop_times)
stop_times.clear()
if os.path.exists(os.path.join(zip_dir, "transfers.txt")):
transfers = [] transfers = []
for transfer_dict in read_csv("transfers.txt"): for transfer_dict in read_csv("transfers.txt"):
transfer_dict: dict transfer_dict: dict
@ -349,7 +382,7 @@ class Command(BaseCommand):
to_stop_id = f"{gtfs_code}-{to_stop_id}" to_stop_id = f"{gtfs_code}-{to_stop_id}"
transfer = Transfer( transfer = Transfer(
id=f"{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}", id=f"{gtfs_code}-{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}",
from_stop_id=from_stop_id, from_stop_id=from_stop_id,
to_stop_id=to_stop_id, to_stop_id=to_stop_id,
transfer_type=transfer_dict['transfer_type'], transfer_type=transfer_dict['transfer_type'],
@ -358,25 +391,19 @@ class Command(BaseCommand):
transfers.append(transfer) transfers.append(transfer)
if len(transfers) >= bulk_size and not dry_run: if len(transfers) >= bulk_size and not dry_run:
Transfer.objects.bulk_create(transfers, Transfer.objects.bulk_create(transfers)
update_conflicts=True,
update_fields=['transfer_type', 'min_transfer_time'],
unique_fields=['id'])
transfers.clear() transfers.clear()
if transfers and not dry_run: if transfers and not dry_run:
Transfer.objects.bulk_create(transfers, Transfer.objects.bulk_create(transfers)
update_conflicts=True,
update_fields=['transfer_type', 'min_transfer_time'],
unique_fields=['id'])
transfers.clear() transfers.clear()
if "feed_info.txt" in zipfile.namelist() and not dry_run: if os.path.exists(os.path.join(zip_dir, "feed_info.txt")) and not dry_run:
for feed_info_dict in read_csv("feed_info.txt"): for feed_info_dict in read_csv("feed_info.txt"):
feed_info_dict: dict feed_info_dict: dict
FeedInfo.objects.update_or_create( FeedInfo.objects.update_or_create(
publisher_name=feed_info_dict['feed_publisher_name'], publisher_name=feed_info_dict['feed_publisher_name'],
gtfs_feed=gtfs_feed, gtfs_feed_id=gtfs_code,
defaults=dict( defaults=dict(
publisher_url=feed_info_dict['feed_publisher_url'], publisher_url=feed_info_dict['feed_publisher_url'],
lang=feed_info_dict['feed_lang'], lang=feed_info_dict['feed_lang'],
@ -385,12 +412,3 @@ class Command(BaseCommand):
version=feed_info_dict.get('feed_version', 1), version=feed_info_dict.get('feed_version', 1),
) )
) )
if 'ETag' in resp.headers:
gtfs_feed.etag = resp.headers['ETag']
gtfs_feed.save()
if 'Last-Modified' in resp.headers:
last_modified = resp.headers['Last-Modified']
gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
gtfs_feed.save()

View File

@ -0,0 +1,26 @@
# Generated by Django 5.0.6 on 2024-05-12 09:31
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("gtfs", "0001_initial"),
]
operations = [
migrations.AlterField(
model_name="stop",
name="parent_station",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="children",
to="gtfs.stop",
verbose_name="Parent station",
),
),
]

View File

@ -288,7 +288,7 @@ class Stop(models.Model):
parent_station = models.ForeignKey( parent_station = models.ForeignKey(
to="Stop", to="Stop",
on_delete=models.PROTECT, on_delete=models.SET_NULL,
verbose_name=_("Parent station"), verbose_name=_("Parent station"),
related_name="children", related_name="children",
blank=True, blank=True,

View File

@ -41,6 +41,7 @@ INSTALLED_APPS = [
"django.contrib.staticfiles", "django.contrib.staticfiles",
"corsheaders", "corsheaders",
"django_extensions",
"django_filters", "django_filters",
"rest_framework", "rest_framework",