trainvel/sncfgtfs/management/commands/update_sncf_gtfs.py

405 lines
22 KiB
Python

import csv
from datetime import datetime, timedelta
from io import BytesIO
from zipfile import ZipFile
import requests
from django.core.management import BaseCommand
from sncfgtfs.models import Agency, Calendar, CalendarDate, FeedInfo, Route, Stop, StopTime, Transfer, Trip
class Command(BaseCommand):
help = "Update the SNCF GTFS database."
GTFS_FEEDS = {
"TGV": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export_gtfs_voyages.zip",
"IC": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-intercites-gtfs-last.zip",
"TER": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip",
"TN": "https://eu.ftp.opendatasoft.com/sncf/gtfs/transilien-gtfs.zip",
# "ES": "https://www.data.gouv.fr/fr/datasets/r/9089b550-696e-4ae0-87b5-40ea55a14292",
# "TI": "https://www.data.gouv.fr/fr/datasets/r/4d1dd21a-b061-47ac-9514-57ffcc09b4a5",
# "RENFE": "https://ssl.renfe.com/gtransit/Fichero_AV_LD/google_transit.zip",
# "OBB": "https://static.oebb.at/open-data/soll-fahrplan-gtfs/GTFS_OP_2024_obb.zip",
}
def add_arguments(self, parser):
parser.add_argument('--bulk_size', type=int, default=1000, help="Number of objects to create in bulk.")
parser.add_argument('--dry-run', action='store_true',
help="Do not update the database, only print what would be done.")
parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.")
def handle(self, *args, **options):
bulk_size = options['bulk_size']
dry_run = options['dry_run']
force = options['force']
if dry_run:
self.stdout.write(self.style.WARNING("Dry run mode activated."))
if not FeedInfo.objects.exists():
last_update_date = "1970-01-01"
else:
last_update_date = FeedInfo.objects.get(publisher_name='SNCF_default').version
for url in self.GTFS_FEEDS.values():
resp = requests.head(url)
if "Last-Modified" not in resp.headers:
continue
last_modified = resp.headers["Last-Modified"]
last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z")
if last_modified.date().isoformat() > last_update_date:
break
else:
if not force:
self.stdout.write(self.style.WARNING("Database already up-to-date."))
return
self.stdout.write("Updating database...")
for transport_type, feed_url in self.GTFS_FEEDS.items():
self.stdout.write(f"Downloading {transport_type} GTFS feed...")
with ZipFile(BytesIO(requests.get(feed_url).content)) as zipfile:
def read_file(filename):
lines = zipfile.read(filename).decode().replace('\ufeff', '').splitlines()
return [line.strip() for line in lines]
agencies = []
for agency_dict in csv.DictReader(read_file("agency.txt")):
agency_dict: dict
if transport_type == "ES" \
and agency_dict['agency_id'] != 'ES' and agency_dict['agency_id'] != 'ER':
continue
agency = Agency(
id=agency_dict['agency_id'],
name=agency_dict['agency_name'],
url=agency_dict['agency_url'],
timezone=agency_dict['agency_timezone'],
lang=agency_dict.get('agency_lang', "fr"),
phone=agency_dict.get('agency_phone', ""),
email=agency_dict.get('agency_email', ""),
)
agencies.append(agency)
if agencies and not dry_run:
Agency.objects.bulk_create(agencies,
update_conflicts=True,
update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email'],
unique_fields=['id'])
agencies.clear()
stops = []
for stop_dict in csv.DictReader(read_file("stops.txt")):
stop_dict: dict
stop_id = stop_dict['stop_id']
if transport_type in ["ES", "TI", "RENFE"]:
stop_id = f"{transport_type}-{stop_id}"
stop = Stop(
id=stop_id,
name=stop_dict['stop_name'],
desc=stop_dict.get('stop_desc', ""),
lat=stop_dict['stop_lat'],
lon=stop_dict['stop_lon'],
zone_id=stop_dict.get('zone_id', ""),
url=stop_dict.get('stop_url', ""),
location_type=stop_dict.get('location_type', 1) or 1,
parent_station_id=stop_dict.get('parent_station', None) or None,
timezone=stop_dict.get('stop_timezone', ""),
wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0),
level_id=stop_dict.get('level_id', ""),
platform_code=stop_dict.get('platform_code', ""),
transport_type=transport_type,
)
stops.append(stop)
if stops and not dry_run:
Stop.objects.bulk_create(stops,
batch_size=bulk_size,
update_conflicts=True,
update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url',
'location_type', 'parent_station_id', 'timezone',
'wheelchair_boarding', 'level_id', 'platform_code',
'transport_type'],
unique_fields=['id'])
stops.clear()
routes = []
for route_dict in csv.DictReader(read_file("routes.txt")):
route_dict: dict
route_id = route_dict['route_id']
if transport_type == "TI":
route_id = f"{transport_type}-{route_id}"
route = Route(
id=route_id,
agency_id=route_dict['agency_id'],
short_name=route_dict['route_short_name'],
long_name=route_dict['route_long_name'],
desc=route_dict.get('route_desc', ""),
type=route_dict['route_type'],
url=route_dict.get('route_url', ""),
color=route_dict.get('route_color', ""),
text_color=route_dict.get('route_text_color', ""),
transport_type=transport_type,
)
routes.append(route)
if len(routes) >= bulk_size and not dry_run:
Route.objects.bulk_create(routes,
update_conflicts=True,
update_fields=['agency_id', 'short_name', 'long_name', 'desc',
'type', 'url', 'color', 'text_color',
'transport_type'],
unique_fields=['id'])
routes.clear()
if routes and not dry_run:
Route.objects.bulk_create(routes,
update_conflicts=True,
update_fields=['agency_id', 'short_name', 'long_name', 'desc',
'type', 'url', 'color', 'text_color',
'transport_type'],
unique_fields=['id'])
routes.clear()
Calendar.objects.filter(transport_type=transport_type).delete()
calendars = {}
if "calendar.txt" in zipfile.namelist():
for calendar_dict in csv.DictReader(read_file("calendar.txt")):
calendar_dict: dict
calendar = Calendar(
id=f"{transport_type}-{calendar_dict['service_id']}",
monday=calendar_dict['monday'],
tuesday=calendar_dict['tuesday'],
wednesday=calendar_dict['wednesday'],
thursday=calendar_dict['thursday'],
friday=calendar_dict['friday'],
saturday=calendar_dict['saturday'],
sunday=calendar_dict['sunday'],
start_date=calendar_dict['start_date'],
end_date=calendar_dict['end_date'],
transport_type=transport_type,
)
calendars[calendar.id] = calendar
if len(calendars) >= bulk_size and not dry_run:
Calendar.objects.bulk_create(calendars,
update_conflicts=True,
update_fields=['monday', 'tuesday', 'wednesday', 'thursday',
'friday', 'saturday', 'sunday', 'start_date',
'end_date', 'transport_type'],
unique_fields=['id'])
calendars.clear()
if calendars and not dry_run:
Calendar.objects.bulk_create(calendars.values(), update_conflicts=True,
update_fields=['monday', 'tuesday', 'wednesday', 'thursday',
'friday', 'saturday', 'sunday', 'start_date',
'end_date', 'transport_type'],
unique_fields=['id'])
calendars.clear()
calendar_dates = []
for calendar_date_dict in csv.DictReader(read_file("calendar_dates.txt")):
calendar_date_dict: dict
calendar_date = CalendarDate(
id=f"{transport_type}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}",
service_id=f"{transport_type}-{calendar_date_dict['service_id']}",
date=calendar_date_dict['date'],
exception_type=calendar_date_dict['exception_type'],
)
calendar_dates.append(calendar_date)
if calendar_date.service_id not in calendars:
calendar = Calendar(
id=f"{transport_type}-{calendar_date_dict['service_id']}",
monday=False,
tuesday=False,
wednesday=False,
thursday=False,
friday=False,
saturday=False,
sunday=False,
start_date=calendar_date_dict['date'],
end_date=calendar_date_dict['date'],
transport_type=transport_type,
)
calendars[calendar.id] = calendar
else:
calendar = calendars[f"{transport_type}-{calendar_date_dict['service_id']}"]
if calendar.start_date > calendar_date.date:
calendar.start_date = calendar_date.date
if calendar.end_date < calendar_date.date:
calendar.end_date = calendar_date.date
if calendar_dates and not dry_run:
Calendar.objects.bulk_create(calendars.values(),
batch_size=bulk_size,
update_conflicts=True,
update_fields=['start_date', 'end_date'],
unique_fields=['id'])
CalendarDate.objects.bulk_create(calendar_dates,
batch_size=bulk_size,
update_conflicts=True,
update_fields=['service_id', 'date', 'exception_type'],
unique_fields=['id'])
calendars.clear()
calendar_dates.clear()
trips = []
for trip_dict in csv.DictReader(read_file("trips.txt")):
trip_dict: dict
trip_id = trip_dict['trip_id']
route_id = trip_dict['route_id']
if transport_type in ["TGV", "IC", "TER"]:
trip_id, last_update = trip_id.split(':', 1)
last_update = datetime.fromisoformat(last_update)
elif transport_type in ["ES", "RENFE"]:
trip_id = f"{transport_type}-{trip_id}"
last_update = None
elif transport_type == "TI":
trip_id = f"{transport_type}-{trip_id}"
route_id = f"{transport_type}-{route_id}"
last_update = None
else:
last_update = None
trip = Trip(
id=trip_id,
route_id=route_id,
service_id=f"{transport_type}-{trip_dict['service_id']}",
headsign=trip_dict.get('trip_headsign', ""),
short_name=trip_dict.get('trip_short_name', ""),
direction_id=trip_dict.get('direction_id', None) or None,
block_id=trip_dict.get('block_id', ""),
shape_id=trip_dict.get('shape_id', ""),
wheelchair_accessible=trip_dict.get('wheelchair_accessible', None),
bikes_allowed=trip_dict.get('bikes_allowed', None),
last_update=last_update,
)
trips.append(trip)
if len(trips) >= bulk_size and not dry_run:
Trip.objects.bulk_create(trips,
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed'],
unique_fields=['id'])
trips.clear()
if trips and not dry_run:
Trip.objects.bulk_create(trips,
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed'],
unique_fields=['id'])
trips.clear()
stop_times = []
for stop_time_dict in csv.DictReader(read_file("stop_times.txt")):
stop_time_dict: dict
stop_id = stop_time_dict['stop_id']
if transport_type in ["ES", "TI", "RENFE"]:
stop_id = f"{transport_type}-{stop_id}"
trip_id = stop_time_dict['trip_id']
if transport_type in ["TGV", "IC", "TER"]:
trip_id = trip_id.split(':', 1)[0]
elif transport_type in ["ES", "TI", "RENFE"]:
trip_id = f"{transport_type}-{trip_id}"
arr_time = stop_time_dict['arrival_time']
arr_h, arr_m, arr_s = map(int, arr_time.split(':'))
arr_time = arr_h * 3600 + arr_m * 60 + arr_s
dep_time = stop_time_dict['departure_time']
dep_h, dep_m, dep_s = map(int, dep_time.split(':'))
dep_time = dep_h * 3600 + dep_m * 60 + dep_s
pickup_type = stop_time_dict.get('pickup_type', 0)
drop_off_type = stop_time_dict.get('drop_off_type', 0)
if transport_type in ["ES", "RENFE", "OBB"]:
if stop_time_dict['stop_sequence'] == "1":
drop_off_type = 1
elif arr_time == dep_time:
pickup_type = 1
elif transport_type == "TI":
if stop_time_dict['stop_sequence'] == "0":
drop_off_type = 1
elif arr_time == dep_time:
pickup_type = 1
st = StopTime(
id=f"{trip_id}-{stop_id}-{stop_time_dict['departure_time']}",
trip_id=trip_id,
arrival_time=timedelta(seconds=arr_time),
departure_time=timedelta(seconds=dep_time),
stop_id=stop_id,
stop_sequence=stop_time_dict['stop_sequence'],
stop_headsign=stop_time_dict.get('stop_headsign', ""),
pickup_type=pickup_type,
drop_off_type=drop_off_type,
timepoint=stop_time_dict.get('timepoint', None),
)
stop_times.append(st)
if len(stop_times) >= bulk_size and not dry_run:
StopTime.objects.bulk_create(stop_times,
update_conflicts=True,
update_fields=['stop_id', 'arrival_time', 'departure_time',
'stop_headsign', 'pickup_type',
'drop_off_type', 'timepoint'],
unique_fields=['id'])
stop_times.clear()
if stop_times and not dry_run:
StopTime.objects.bulk_create(stop_times,
update_conflicts=True,
update_fields=['stop_id', 'arrival_time', 'departure_time',
'stop_headsign', 'pickup_type',
'drop_off_type', 'timepoint'],
unique_fields=['id'])
stop_times.clear()
if "transfers.txt" in zipfile.namelist():
transfers = []
for transfer_dict in csv.DictReader(read_file("transfers.txt")):
transfer_dict: dict
from_stop_id = transfer_dict['from_stop_id']
to_stop_id = transfer_dict['to_stop_id']
if transport_type in ["ES", "RENFE", "OBB"]:
from_stop_id = f"{transport_type}-{from_stop_id}"
to_stop_id = f"{transport_type}-{to_stop_id}"
transfer = Transfer(
id=f"{from_stop_id}-{to_stop_id}",
from_stop_id=transfer_dict['from_stop_id'],
to_stop_id=transfer_dict['to_stop_id'],
transfer_type=transfer_dict['transfer_type'],
min_transfer_time=transfer_dict['min_transfer_time'],
)
transfers.append(transfer)
if len(transfers) >= bulk_size and not dry_run:
Transfer.objects.bulk_create(transfers,
update_conflicts=True,
update_fields=['transfer_type', 'min_transfer_time'],
unique_fields=['id'])
transfers.clear()
if transfers and not dry_run:
Transfer.objects.bulk_create(transfers,
update_conflicts=True,
update_fields=['transfer_type', 'min_transfer_time'],
unique_fields=['id'])
transfers.clear()
if "feed_info.txt" in zipfile.namelist() and not dry_run:
for feed_info_dict in csv.DictReader(read_file("feed_info.txt")):
feed_info_dict: dict
FeedInfo.objects.update_or_create(
publisher_name=feed_info_dict['feed_publisher_name'],
defaults=dict(
publisher_url=feed_info_dict['feed_publisher_url'],
lang=feed_info_dict['feed_lang'],
start_date=feed_info_dict.get('feed_start_date', datetime.now().date()),
end_date=feed_info_dict.get('feed_end_date', datetime.now().date()),
version=feed_info_dict.get('feed_version', 1),
)
)