Compare commits

..

No commits in common. "b85a1b773464c65391100b9f744a9c6095d9192c" and "7ed092410891c0d949807f400d336b50237b188c" have entirely different histories.

5 changed files with 342 additions and 388 deletions

View File

@ -1,8 +1,7 @@
Django>=5.0.4,<6.0 Django>=5.0,<6.0
django-cors-headers~=4.3.1 django-cors-headers
django-extensions~=3.2.3 django-filter~=23.5
django-filter~=24.2
djangorestframework~=3.14.0 djangorestframework~=3.14.0
protobuf~=5.26.1 protobuf
requests~=2.31.0 requests~=2.31.0
tqdm~=4.66.4 tqdm

View File

@ -1,8 +1,6 @@
import csv import csv
import os.path
import tempfile
from datetime import datetime, timedelta from datetime import datetime, timedelta
from time import time from io import BytesIO, TextIOWrapper
from zipfile import ZipFile from zipfile import ZipFile
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
@ -11,7 +9,7 @@ from django.core.management import BaseCommand
from tqdm import tqdm from tqdm import tqdm
from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, \ from trainvel.gtfs.models import Agency, Calendar, CalendarDate, FeedInfo, GTFSFeed, Route, Stop, StopTime, \
Transfer, Trip, PickupType, TripUpdate Transfer, Trip, PickupType
class Command(BaseCommand): class Command(BaseCommand):
@ -19,7 +17,7 @@ class Command(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode") parser.add_argument('--debug', '-d', action='store_true', help="Activate debug mode")
parser.add_argument('--bulk_size', '-b', type=int, default=10000, help="Number of objects to create in bulk.") parser.add_argument('--bulk_size', type=int, default=1000, help="Number of objects to create in bulk.")
parser.add_argument('--dry-run', action='store_true', parser.add_argument('--dry-run', action='store_true',
help="Do not update the database, only print what would be done.") help="Do not update the database, only print what would be done.")
parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.") parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.")
@ -32,6 +30,8 @@ class Command(BaseCommand):
self.stdout.write("Updating database...") self.stdout.write("Updating database...")
for gtfs_feed in GTFSFeed.objects.all(): for gtfs_feed in GTFSFeed.objects.all():
gtfs_code = gtfs_feed.code
if not force: if not force:
# Check if the source file was updated # Check if the source file was updated
resp = requests.head(gtfs_feed.feed_url, allow_redirects=True) resp = requests.head(gtfs_feed.feed_url, allow_redirects=True)
@ -51,37 +51,14 @@ class Command(BaseCommand):
self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...") self.stdout.write(f"Downloading GTFS feed for {gtfs_feed}...")
resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True) resp = requests.get(gtfs_feed.feed_url, allow_redirects=True, stream=True)
with ZipFile(BytesIO(resp.content)) as zipfile:
with tempfile.TemporaryFile(suffix=".zip") as file:
for chunk in resp.iter_content(chunk_size=128):
file.write(chunk)
file.seek(0)
with tempfile.TemporaryDirectory() as tmp_dir:
with ZipFile(file) as zipfile:
zipfile.extractall(tmp_dir)
self.parse_gtfs(tmp_dir, gtfs_feed, bulk_size, dry_run, verbosity)
if 'ETag' in resp.headers:
gtfs_feed.etag = resp.headers['ETag']
gtfs_feed.save()
if 'Last-Modified' in resp.headers:
last_modified = resp.headers['Last-Modified']
gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
gtfs_feed.save()
def parse_gtfs(self, zip_dir: str, gtfs_feed: GTFSFeed, bulk_size: int, dry_run: bool, verbosity: int):
gtfs_code = gtfs_feed.code
def read_csv(filename): def read_csv(filename):
with open(os.path.join(zip_dir, filename), 'r') as f: with zipfile.open(filename, 'r') as zf:
reader = csv.DictReader(f) with TextIOWrapper(zf, encoding='utf-8') as wrapper:
reader = csv.DictReader(wrapper)
reader.fieldnames = [field.replace('\ufeff', '').strip() reader.fieldnames = [field.replace('\ufeff', '').strip()
for field in reader.fieldnames] for field in reader.fieldnames]
iterator = tqdm(reader, desc=filename, unit=' rows') if verbosity >= 2 else reader for row in tqdm(reader, desc=filename, unit=' rows'):
for row in iterator:
yield {k.strip(): v.strip() for k, v in row.items()} yield {k.strip(): v.strip() for k, v in row.items()}
agencies = [] agencies = []
@ -95,7 +72,7 @@ class Command(BaseCommand):
lang=agency_dict.get('agency_lang', "fr"), lang=agency_dict.get('agency_lang', "fr"),
phone=agency_dict.get('agency_phone', ""), phone=agency_dict.get('agency_phone', ""),
email=agency_dict.get('agency_email', ""), email=agency_dict.get('agency_email', ""),
gtfs_feed_id=gtfs_code, gtfs_feed=gtfs_feed,
) )
agencies.append(agency) agencies.append(agency)
if agencies and not dry_run: if agencies and not dry_run:
@ -129,7 +106,7 @@ class Command(BaseCommand):
wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0), wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0),
level_id=stop_dict.get('level_id', ""), level_id=stop_dict.get('level_id', ""),
platform_code=stop_dict.get('platform_code', ""), platform_code=stop_dict.get('platform_code', ""),
gtfs_feed_id=gtfs_code, gtfs_feed=gtfs_feed,
) )
stops.append(stop) stops.append(stop)
@ -150,7 +127,7 @@ class Command(BaseCommand):
route_id = route_dict['route_id'] route_id = route_dict['route_id']
route_id = f"{gtfs_code}-{route_id}" route_id = f"{gtfs_code}-{route_id}"
# Agency is optional there is only one # Agency is optional there is only one
agency_id = route_dict.get('agency_id', "") or Agency.objects.get(gtfs_feed_id=gtfs_code) agency_id = route_dict.get('agency_id', "") or Agency.objects.get(gtfs_feed=gtfs_feed)
route = Route( route = Route(
id=route_id, id=route_id,
agency_id=f"{gtfs_code}-{agency_id}", agency_id=f"{gtfs_code}-{agency_id}",
@ -161,7 +138,7 @@ class Command(BaseCommand):
url=route_dict.get('route_url', ""), url=route_dict.get('route_url', ""),
color=route_dict.get('route_color', ""), color=route_dict.get('route_color', ""),
text_color=route_dict.get('route_text_color', ""), text_color=route_dict.get('route_text_color', ""),
gtfs_feed_id=gtfs_code, gtfs_feed=gtfs_feed,
) )
routes.append(route) routes.append(route)
@ -182,22 +159,9 @@ class Command(BaseCommand):
unique_fields=['id']) unique_fields=['id'])
routes.clear() routes.clear()
start_time = 0 Calendar.objects.filter(gtfs_feed=gtfs_feed).delete()
if verbosity >= 1:
self.stdout.write("Deleting old calendars, trips and stop times…")
start_time = time()
TripUpdate.objects.filter(trip__gtfs_feed_id=gtfs_code).delete()
StopTime.objects.filter(trip__gtfs_feed_id=gtfs_code)._raw_delete(StopTime.objects.db)
Trip.objects.filter(gtfs_feed_id=gtfs_code)._raw_delete(Trip.objects.db)
Calendar.objects.filter(gtfs_feed_id=gtfs_code).delete()
if verbosity >= 1:
end = time()
self.stdout.write(f"Done in {end - start_time:.2f} s")
calendars = {} calendars = {}
if os.path.exists(os.path.join(zip_dir, "calendar.txt")): if "calendar.txt" in zipfile.namelist():
for calendar_dict in read_csv("calendar.txt"): for calendar_dict in read_csv("calendar.txt"):
calendar_dict: dict calendar_dict: dict
calendar = Calendar( calendar = Calendar(
@ -211,7 +175,7 @@ class Command(BaseCommand):
sunday=calendar_dict['sunday'], sunday=calendar_dict['sunday'],
start_date=calendar_dict['start_date'], start_date=calendar_dict['start_date'],
end_date=calendar_dict['end_date'], end_date=calendar_dict['end_date'],
gtfs_feed_id=gtfs_code, gtfs_feed=gtfs_feed,
) )
calendars[calendar.id] = calendar calendars[calendar.id] = calendar
@ -254,7 +218,7 @@ class Command(BaseCommand):
sunday=False, sunday=False,
start_date=calendar_date_dict['date'], start_date=calendar_date_dict['date'],
end_date=calendar_date_dict['date'], end_date=calendar_date_dict['date'],
gtfs_feed_id=gtfs_code, gtfs_feed=gtfs_feed,
) )
calendars[calendar.id] = calendar calendars[calendar.id] = calendar
else: else:
@ -279,7 +243,6 @@ class Command(BaseCommand):
calendar_dates.clear() calendar_dates.clear()
trips = [] trips = []
# start_time = time()
for trip_dict in read_csv("trips.txt"): for trip_dict in read_csv("trips.txt"):
trip_dict: dict trip_dict: dict
trip_id = trip_dict['trip_id'] trip_id = trip_dict['trip_id']
@ -297,26 +260,28 @@ class Command(BaseCommand):
shape_id=trip_dict.get('shape_id', ""), shape_id=trip_dict.get('shape_id', ""),
wheelchair_accessible=trip_dict.get('wheelchair_accessible', None), wheelchair_accessible=trip_dict.get('wheelchair_accessible', None),
bikes_allowed=trip_dict.get('bikes_allowed', None), bikes_allowed=trip_dict.get('bikes_allowed', None),
gtfs_feed_id=gtfs_code, gtfs_feed=gtfs_feed,
) )
trips.append(trip) trips.append(trip)
if len(trips) >= bulk_size and not dry_run: if len(trips) >= bulk_size and not dry_run:
# now = time() Trip.objects.bulk_create(trips,
# print(f"Elapsed time: {now - start_time:.3f}s, " update_conflicts=True,
# f"{1000 * (now - start_time) / len(trips):.2f}ms per iteration") update_fields=['route_id', 'service_id', 'headsign', 'short_name',
# start_time = now 'direction_id', 'block_id', 'shape_id',
Trip.objects.bulk_create(trips) 'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'],
# now = time() unique_fields=['id'])
# print(f"Elapsed time: {now - start_time:.3f}s to save")
# start_time = now
trips.clear() trips.clear()
if trips and not dry_run: if trips and not dry_run:
Trip.objects.bulk_create(trips) Trip.objects.bulk_create(trips,
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed', 'gtfs_feed'],
unique_fields=['id'])
trips.clear() trips.clear()
stop_times = [] stop_times = []
# start_time = time()
for stop_time_dict in read_csv("stop_times.txt"): for stop_time_dict in read_csv("stop_times.txt"):
stop_time_dict: dict stop_time_dict: dict
@ -358,21 +323,23 @@ class Command(BaseCommand):
stop_times.append(st) stop_times.append(st)
if len(stop_times) >= bulk_size and not dry_run: if len(stop_times) >= bulk_size and not dry_run:
# now = time() StopTime.objects.bulk_create(stop_times,
# print(f"Elapsed time: {now - start_time:.3f}s, " update_conflicts=True,
# f"{1000 * (now - start_time) / len(stop_times):.2f}ms per iteration") update_fields=['stop_id', 'arrival_time', 'departure_time',
# start_time = now 'stop_headsign', 'pickup_type',
StopTime.objects.bulk_create(stop_times) 'drop_off_type', 'timepoint'],
# now = time() unique_fields=['id'])
# print(f"Elapsed time: {now - start_time:.3f}s to save")
# start_time = now
stop_times.clear() stop_times.clear()
if stop_times and not dry_run: if stop_times and not dry_run:
StopTime.objects.bulk_create(stop_times) StopTime.objects.bulk_create(stop_times,
update_conflicts=True,
update_fields=['stop_id', 'arrival_time', 'departure_time',
'stop_headsign', 'pickup_type',
'drop_off_type', 'timepoint'],
unique_fields=['id'])
stop_times.clear() stop_times.clear()
if os.path.exists(os.path.join(zip_dir, "transfers.txt")): if "transfers.txt" in zipfile.namelist():
transfers = [] transfers = []
for transfer_dict in read_csv("transfers.txt"): for transfer_dict in read_csv("transfers.txt"):
transfer_dict: dict transfer_dict: dict
@ -382,7 +349,7 @@ class Command(BaseCommand):
to_stop_id = f"{gtfs_code}-{to_stop_id}" to_stop_id = f"{gtfs_code}-{to_stop_id}"
transfer = Transfer( transfer = Transfer(
id=f"{gtfs_code}-{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}", id=f"{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}",
from_stop_id=from_stop_id, from_stop_id=from_stop_id,
to_stop_id=to_stop_id, to_stop_id=to_stop_id,
transfer_type=transfer_dict['transfer_type'], transfer_type=transfer_dict['transfer_type'],
@ -391,19 +358,25 @@ class Command(BaseCommand):
transfers.append(transfer) transfers.append(transfer)
if len(transfers) >= bulk_size and not dry_run: if len(transfers) >= bulk_size and not dry_run:
Transfer.objects.bulk_create(transfers) Transfer.objects.bulk_create(transfers,
update_conflicts=True,
update_fields=['transfer_type', 'min_transfer_time'],
unique_fields=['id'])
transfers.clear() transfers.clear()
if transfers and not dry_run: if transfers and not dry_run:
Transfer.objects.bulk_create(transfers) Transfer.objects.bulk_create(transfers,
update_conflicts=True,
update_fields=['transfer_type', 'min_transfer_time'],
unique_fields=['id'])
transfers.clear() transfers.clear()
if os.path.exists(os.path.join(zip_dir, "feed_info.txt")) and not dry_run: if "feed_info.txt" in zipfile.namelist() and not dry_run:
for feed_info_dict in read_csv("feed_info.txt"): for feed_info_dict in read_csv("feed_info.txt"):
feed_info_dict: dict feed_info_dict: dict
FeedInfo.objects.update_or_create( FeedInfo.objects.update_or_create(
publisher_name=feed_info_dict['feed_publisher_name'], publisher_name=feed_info_dict['feed_publisher_name'],
gtfs_feed_id=gtfs_code, gtfs_feed=gtfs_feed,
defaults=dict( defaults=dict(
publisher_url=feed_info_dict['feed_publisher_url'], publisher_url=feed_info_dict['feed_publisher_url'],
lang=feed_info_dict['feed_lang'], lang=feed_info_dict['feed_lang'],
@ -412,3 +385,12 @@ class Command(BaseCommand):
version=feed_info_dict.get('feed_version', 1), version=feed_info_dict.get('feed_version', 1),
) )
) )
if 'ETag' in resp.headers:
gtfs_feed.etag = resp.headers['ETag']
gtfs_feed.save()
if 'Last-Modified' in resp.headers:
last_modified = resp.headers['Last-Modified']
gtfs_feed.last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z") \
.replace(tzinfo=ZoneInfo(last_modified.split(' ')[-1]))
gtfs_feed.save()

View File

@ -1,26 +0,0 @@
# Generated by Django 5.0.6 on 2024-05-12 09:31
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("gtfs", "0001_initial"),
]
operations = [
migrations.AlterField(
model_name="stop",
name="parent_station",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="children",
to="gtfs.stop",
verbose_name="Parent station",
),
),
]

View File

@ -288,7 +288,7 @@ class Stop(models.Model):
parent_station = models.ForeignKey( parent_station = models.ForeignKey(
to="Stop", to="Stop",
on_delete=models.SET_NULL, on_delete=models.PROTECT,
verbose_name=_("Parent station"), verbose_name=_("Parent station"),
related_name="children", related_name="children",
blank=True, blank=True,

View File

@ -41,7 +41,6 @@ INSTALLED_APPS = [
"django.contrib.staticfiles", "django.contrib.staticfiles",
"corsheaders", "corsheaders",
"django_extensions",
"django_filters", "django_filters",
"rest_framework", "rest_framework",