Compare commits

...

2 Commits

Author SHA1 Message Date
b85a1b7734
More optimization 2024-05-12 11:52:44 +02:00
eade9e84de
Install django-extensions + update requirements.txt 2024-05-12 10:03:12 +02:00
5 changed files with 388 additions and 342 deletions

View File

@ -1,7 +1,8 @@
Django>=5.0,<6.0
django-cors-headers
django-filter~=23.5
Django>=5.0.4,<6.0
django-cors-headers~=4.3.1
django-extensions~=3.2.3
django-filter~=24.2
djangorestframework~=3.14.0
protobuf
protobuf~=5.26.1
requests~=2.31.0
tqdm
tqdm~=4.66.4

View File

@ -1,6 +1,8 @@
import csv
import os.path
import tempfile
from datetime import datetime, timedelta
from io import BytesIO, TextIOWrapper
from time import time
from zipfile import ZipFile
from zoneinfo import ZoneInfo
@ -9,7 +11,7 @@ from django.core.management import BaseCommand
from tqdm import tqdm
from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, \
Transfer, Trip, PickupType
Transfer, Trip, PickupType, TripUpdate
class Command(BaseCommand):
@ -17,7 +19,7 @@ class Command(BaseCommand):
def add_arguments(self, parser):
parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode")
parser.add_argument('--bulk_size', type=int, default=1000, help="Number of objects to create in bulk.")
parser.add_argument('--bulk_size', '-b', type=int, default=10000, help="Number of objects to create in bulk.")
parser.add_argument('--dry-run', action='store_true',
help="Do not update the database, only print what would be done.")
parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.")
@ -30,8 +32,6 @@ class Command(BaseCommand):
self.stdout.write("Updating database...")
for gtfs_feed in GTFSFeed.objects.all():
gtfs_code = gtfs_feed.code
if not force:
# Check if the source file was updated
resp = requests.head(gtfs_feed.feed_url, allow_redirects=True)
@ -51,14 +51,37 @@ class Command(BaseCommand):
self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...")
resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True)
with ZipFile(BytesIO(resp.content)) as zipfile:
with tempfile.TemporaryFile(suffix=".zip") as file:
for chunk in resp.iter_content(chunk_size=128):
file.write(chunk)
file.seek(0)
with tempfile.TemporaryDirectory() as tmp_dir:
with ZipFile(file) as zipfile:
zipfile.extractall(tmp_dir)
self.parse_gtfs(tmp_dir, gtfs_feed, bulk_size, dry_run, verbosity)
if 'ETag' in resp.headers:
gtfs_feed.etag = resp.headers['ETag']
gtfs_feed.save()
if 'Last-Modified' in resp.headers:
last_modified = resp.headers['Last-Modified']
gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
gtfs_feed.save()
def parse_gtfs(self, zip_dir: str, gtfs_feed: GTFSFeed, bulk_size: int, dry_run: bool, verbosity: int):
gtfs_code = gtfs_feed.code
def read_csv(filename):
with zipfile.open(filename, 'r') as zf:
with TextIOWrapper(zf, encoding='utf-8') as wrapper:
reader = csv.DictReader(wrapper)
with open(os.path.join(zip_dir, filename), 'r') as f:
reader = csv.DictReader(f)
reader.fieldnames = [field.replace('\ufeff', '').strip()
for field in reader.fieldnames]
for row in tqdm(reader, desc=filename, unit=' rows'):
iterator = tqdm(reader, desc=filename, unit=' rows') if verbosity >= 2 else reader
for row in iterator:
yield {k.strip(): v.strip() for k, v in row.items()}
agencies = []
@ -72,7 +95,7 @@ class Command(BaseCommand):
lang=agency_dict.get('agency_lang', "fr"),
phone=agency_dict.get('agency_phone', ""),
email=agency_dict.get('agency_email', ""),
gtfs_feed=gtfs_feed,
gtfs_feed_id=gtfs_code,
)
agencies.append(agency)
if agencies and not dry_run:
@ -106,7 +129,7 @@ class Command(BaseCommand):
wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0),
level_id=stop_dict.get('level_id', ""),
platform_code=stop_dict.get('platform_code', ""),
gtfs_feed=gtfs_feed,
gtfs_feed_id=gtfs_code,
)
stops.append(stop)
@ -127,7 +150,7 @@ class Command(BaseCommand):
route_id = route_dict['route_id']
route_id = f"{gtfs_code}-{route_id}"
# Agency is optional there is only one
agency_id = route_dict.get('agency_id', "") or Agency.objects.get(gtfs_feed=gtfs_feed)
agency_id = route_dict.get('agency_id', "") or Agency.objects.get(gtfs_feed_id=gtfs_code)
route = Route(
id=route_id,
agency_id=f"{gtfs_code}-{agency_id}",
@ -138,7 +161,7 @@ class Command(BaseCommand):
url=route_dict.get('route_url', ""),
color=route_dict.get('route_color', ""),
text_color=route_dict.get('route_text_color', ""),
gtfs_feed=gtfs_feed,
gtfs_feed_id=gtfs_code,
)
routes.append(route)
@ -159,9 +182,22 @@ class Command(BaseCommand):
unique_fields=['id'])
routes.clear()
Calendar.objects.filter(gtfs_feed=gtfs_feed).delete()
start_time = 0
if verbosity >= 1:
self.stdout.write("Deleting old calendars, trips and stop times…")
start_time = time()
TripUpdate.objects.filter(trip__gtfs_feed_id=gtfs_code).delete()
StopTime.objects.filter(trip__gtfs_feed_id=gtfs_code)._raw_delete(StopTime.objects.db)
Trip.objects.filter(gtfs_feed_id=gtfs_code)._raw_delete(Trip.objects.db)
Calendar.objects.filter(gtfs_feed_id=gtfs_code).delete()
if verbosity >= 1:
end = time()
self.stdout.write(f"Done in {end - start_time:.2f} s")
calendars = {}
if "calendar.txt" in zipfile.namelist():
if os.path.exists(os.path.join(zip_dir, "calendar.txt")):
for calendar_dict in read_csv("calendar.txt"):
calendar_dict: dict
calendar = Calendar(
@ -175,7 +211,7 @@ class Command(BaseCommand):
sunday=calendar_dict['sunday'],
start_date=calendar_dict['start_date'],
end_date=calendar_dict['end_date'],
gtfs_feed=gtfs_feed,
gtfs_feed_id=gtfs_code,
)
calendars[calendar.id] = calendar
@ -218,7 +254,7 @@ class Command(BaseCommand):
sunday=False,
start_date=calendar_date_dict['date'],
end_date=calendar_date_dict['date'],
gtfs_feed=gtfs_feed,
gtfs_feed_id=gtfs_code,
)
calendars[calendar.id] = calendar
else:
@ -243,6 +279,7 @@ class Command(BaseCommand):
calendar_dates.clear()
trips = []
# start_time = time()
for trip_dict in read_csv("trips.txt"):
trip_dict: dict
trip_id = trip_dict['trip_id']
@ -260,28 +297,26 @@ class Command(BaseCommand):
shape_id=trip_dict.get('shape_id', ""),
wheelchair_accessible=trip_dict.get('wheelchair_accessible', None),
bikes_allowed=trip_dict.get('bikes_allowed', None),
gtfs_feed=gtfs_feed,
gtfs_feed_id=gtfs_code,
)
trips.append(trip)
if len(trips) >= bulk_size and not dry_run:
Trip.objects.bulk_create(trips,
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'],
unique_fields=['id'])
# now = time()
# print(f"Elapsed time: {now - start_time:.3f}s, "
# f"{1000 * (now - start_time) / len(trips):.2f}ms per iteration")
# start_time = now
Trip.objects.bulk_create(trips)
# now = time()
# print(f"Elapsed time: {now - start_time:.3f}s to save")
# start_time = now
trips.clear()
if trips and not dry_run:
Trip.objects.bulk_create(trips,
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'],
unique_fields=['id'])
Trip.objects.bulk_create(trips)
trips.clear()
stop_times = []
# start_time = time()
for stop_time_dict in read_csv("stop_times.txt"):
stop_time_dict: dict
@ -323,23 +358,21 @@ class Command(BaseCommand):
stop_times.append(st)
if len(stop_times) >= bulk_size and not dry_run:
StopTime.objects.bulk_create(stop_times,
update_conflicts=True,
update_fields=['stop_id', 'arrival_time', 'departure_time',
'stop_headsign', 'pickup_type',
'drop_off_type', 'timepoint'],
unique_fields=['id'])
stop_times.clear()
if stop_times and not dry_run:
StopTime.objects.bulk_create(stop_times,
update_conflicts=True,
update_fields=['stop_id', 'arrival_time', 'departure_time',
'stop_headsign', 'pickup_type',
'drop_off_type', 'timepoint'],
unique_fields=['id'])
# now = time()
# print(f"Elapsed time: {now - start_time:.3f}s, "
# f"{1000 * (now - start_time) / len(stop_times):.2f}ms per iteration")
# start_time = now
StopTime.objects.bulk_create(stop_times)
# now = time()
# print(f"Elapsed time: {now - start_time:.3f}s to save")
# start_time = now
stop_times.clear()
if "transfers.txt" in zipfile.namelist():
if stop_times and not dry_run:
StopTime.objects.bulk_create(stop_times)
stop_times.clear()
if os.path.exists(os.path.join(zip_dir, "transfers.txt")):
transfers = []
for transfer_dict in read_csv("transfers.txt"):
transfer_dict: dict
@ -349,7 +382,7 @@ class Command(BaseCommand):
to_stop_id = f"{gtfs_code}-{to_stop_id}"
transfer = Transfer(
id=f"{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}",
id=f"{gtfs_code}-{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}",
from_stop_id=from_stop_id,
to_stop_id=to_stop_id,
transfer_type=transfer_dict['transfer_type'],
@ -358,25 +391,19 @@ class Command(BaseCommand):
transfers.append(transfer)
if len(transfers) >= bulk_size and not dry_run:
Transfer.objects.bulk_create(transfers,
update_conflicts=True,
update_fields=['transfer_type', 'min_transfer_time'],
unique_fields=['id'])
Transfer.objects.bulk_create(transfers)
transfers.clear()
if transfers and not dry_run:
Transfer.objects.bulk_create(transfers,
update_conflicts=True,
update_fields=['transfer_type', 'min_transfer_time'],
unique_fields=['id'])
Transfer.objects.bulk_create(transfers)
transfers.clear()
if "feed_info.txt" in zipfile.namelist() and not dry_run:
if os.path.exists(os.path.join(zip_dir, "feed_info.txt")) and not dry_run:
for feed_info_dict in read_csv("feed_info.txt"):
feed_info_dict: dict
FeedInfo.objects.update_or_create(
publisher_name=feed_info_dict['feed_publisher_name'],
gtfs_feed=gtfs_feed,
gtfs_feed_id=gtfs_code,
defaults=dict(
publisher_url=feed_info_dict['feed_publisher_url'],
lang=feed_info_dict['feed_lang'],
@ -385,12 +412,3 @@ class Command(BaseCommand):
version=feed_info_dict.get('feed_version', 1),
)
)
if 'ETag' in resp.headers:
gtfs_feed.etag = resp.headers['ETag']
gtfs_feed.save()
if 'Last-Modified' in resp.headers:
last_modified = resp.headers['Last-Modified']
gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
gtfs_feed.save()

View File

@ -0,0 +1,26 @@
# Generated by Django 5.0.6 on 2024-05-12 09:31
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("gtfs", "0001_initial"),
]
operations = [
migrations.AlterField(
model_name="stop",
name="parent_station",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="children",
to="gtfs.stop",
verbose_name="Parent station",
),
),
]

View File

@ -288,7 +288,7 @@ class Stop(models.Model):
parent_station = models.ForeignKey(
to="Stop",
on_delete=models.PROTECT,
on_delete=models.SET_NULL,
verbose_name=_("Parent station"),
related_name="children",
blank=True,

View File

@ -41,6 +41,7 @@ INSTALLED_APPS = [
"django.contrib.staticfiles",
"corsheaders",
"django_extensions",
"django_filters",
"rest_framework",