trainvel/sncfgtfs/management/commands/update_sncf_gtfs.py

361 lines
20 KiB
Python

import csv
from datetime import datetime, timedelta
from io import BytesIO
from zipfile import ZipFile
import requests
from django.core.management import BaseCommand
from sncfgtfs.models import Agency, Calendar, CalendarDate, FeedInfo, Route, Stop, StopTime, Transfer, Trip
class Command(BaseCommand):
help = "Update the SNCF GTFS database."
GTFS_FEEDS = {
"TGV": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export_gtfs_voyages.zip",
"IC": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-intercites-gtfs-last.zip",
"TER": "https://eu.ftp.opendatasoft.com/sncf/gtfs/export-ter-gtfs-last.zip",
"TN": "https://eu.ftp.opendatasoft.com/sncf/gtfs/transilien-gtfs.zip",
}
def add_arguments(self, parser):
parser.add_argument('--bulk_size', type=int, default=1000, help="Number of objects to create in bulk.")
parser.add_argument('--dry-run', action='store_true',
help="Do not update the database, only print what would be done.")
parser.add_argument('--force', '-f', action='store_true', help="Force the update of the database.")
def handle(self, *args, **options):
bulk_size = options['bulk_size']
dry_run = options['dry_run']
force = options['force']
if dry_run:
self.stdout.write(self.style.WARNING("Dry run mode activated."))
if not FeedInfo.objects.exists():
last_update_date = "1970-01-01"
else:
last_update_date = FeedInfo.objects.get().version
for url in self.GTFS_FEEDS.values():
last_modified = requests.head(url).headers["Last-Modified"]
last_modified = datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z")
if last_modified.date().isoformat() > last_update_date:
break
else:
if not force:
self.stdout.write(self.style.WARNING("Database already up-to-date."))
return
self.stdout.write("Updating database...")
all_trips = []
for transport_type, feed_url in self.GTFS_FEEDS.items():
self.stdout.write(f"Downloading {transport_type} GTFS feed...")
with ZipFile(BytesIO(requests.get(feed_url).content)) as zipfile:
agencies = []
for agency_dict in csv.DictReader(zipfile.read("agency.txt").decode().splitlines()):
agency_dict: dict
agency = Agency(
id=agency_dict['agency_id'],
name=agency_dict['agency_name'],
url=agency_dict['agency_url'],
timezone=agency_dict['agency_timezone'],
lang=agency_dict['agency_lang'],
phone=agency_dict.get('agency_phone', ""),
email=agency_dict.get('agency_email', ""),
)
agencies.append(agency)
if agencies and not dry_run:
Agency.objects.bulk_create(agencies,
update_conflicts=True,
update_fields=['name', 'url', 'timezone', 'lang', 'phone', 'email'],
unique_fields=['id'])
agencies.clear()
stops = []
for stop_dict in csv.DictReader(zipfile.read("stops.txt").decode().splitlines()):
stop_dict: dict
stop = Stop(
id=stop_dict["stop_id"],
name=stop_dict['stop_name'],
desc=stop_dict['stop_desc'],
lat=stop_dict['stop_lat'],
lon=stop_dict['stop_lon'],
zone_id=stop_dict['zone_id'],
url=stop_dict['stop_url'],
location_type=stop_dict['location_type'],
parent_station_id=stop_dict['parent_station'] or None
if last_update_date != "1970-01-01" or transport_type != "TN" else None,
timezone=stop_dict.get('stop_timezone', ""),
wheelchair_boarding=stop_dict.get('wheelchair_boarding', 0),
level_id=stop_dict.get('level_id', ""),
platform_code=stop_dict.get('platform_code', ""),
)
stops.append(stop)
if len(stops) >= bulk_size and not dry_run:
Stop.objects.bulk_create(stops,
update_conflicts=True,
update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url',
'location_type', 'parent_station_id', 'timezone',
'wheelchair_boarding', 'level_id', 'platform_code'],
unique_fields=['id'])
stops.clear()
if stops and not dry_run:
Stop.objects.bulk_create(stops,
update_conflicts=True,
update_fields=['name', 'desc', 'lat', 'lon', 'zone_id', 'url',
'location_type', 'parent_station_id', 'timezone',
'wheelchair_boarding', 'level_id', 'platform_code'],
unique_fields=['id'])
stops.clear()
routes = []
for route_dict in csv.DictReader(zipfile.read("routes.txt").decode().splitlines()):
route_dict: dict
route = Route(
id=route_dict['route_id'],
agency_id=route_dict['agency_id'],
short_name=route_dict['route_short_name'],
long_name=route_dict['route_long_name'],
desc=route_dict['route_desc'],
type=route_dict['route_type'],
url=route_dict['route_url'],
color=route_dict['route_color'],
text_color=route_dict['route_text_color'],
transport_type=transport_type,
)
routes.append(route)
if len(routes) >= bulk_size and not dry_run:
Route.objects.bulk_create(routes,
update_conflicts=True,
update_fields=['agency_id', 'short_name', 'long_name', 'desc',
'type', 'url', 'color', 'text_color',
'transport_type'],
unique_fields=['id'])
routes.clear()
if routes and not dry_run:
Route.objects.bulk_create(routes,
update_conflicts=True,
update_fields=['agency_id', 'short_name', 'long_name', 'desc',
'type', 'url', 'color', 'text_color'],
unique_fields=['id'])
routes.clear()
calendar_ids = []
if "calendar.txt" in zipfile.namelist():
calendars = []
for calendar_dict in csv.DictReader(zipfile.read("calendar.txt").decode().splitlines()):
calendar_dict: dict
calendar = Calendar(
id=f"{transport_type}-{calendar_dict['service_id']}",
monday=calendar_dict['monday'],
tuesday=calendar_dict['tuesday'],
wednesday=calendar_dict['wednesday'],
thursday=calendar_dict['thursday'],
friday=calendar_dict['friday'],
saturday=calendar_dict['saturday'],
sunday=calendar_dict['sunday'],
start_date=calendar_dict['start_date'],
end_date=calendar_dict['end_date'],
transport_type=transport_type,
)
calendars.append(calendar)
calendar_ids.append(calendar.id)
if len(calendars) >= bulk_size and not dry_run:
Calendar.objects.bulk_create(calendars,
update_conflicts=True,
update_fields=['monday', 'tuesday', 'wednesday', 'thursday',
'friday', 'saturday', 'sunday', 'start_date',
'end_date', 'transport_type'],
unique_fields=['id'])
calendars.clear()
if calendars and not dry_run:
Calendar.objects.bulk_create(calendars, update_conflicts=True,
update_fields=['monday', 'tuesday', 'wednesday', 'thursday',
'friday', 'saturday', 'sunday', 'start_date',
'end_date', 'transport_type'],
unique_fields=['id'])
calendars.clear()
calendars = []
calendar_dates = []
for calendar_date_dict in csv.DictReader(zipfile.read("calendar_dates.txt").decode().splitlines()):
calendar_date_dict: dict
calendar_date = CalendarDate(
id=f"{transport_type}-{calendar_date_dict['service_id']}-{calendar_date_dict['date']}",
service_id=f"{transport_type}-{calendar_date_dict['service_id']}",
date=calendar_date_dict['date'],
exception_type=calendar_date_dict['exception_type'],
)
calendar_dates.append(calendar_date)
if calendar_date.service_id not in calendar_ids:
calendar = Calendar(
id=f"{transport_type}-{calendar_date_dict['service_id']}",
monday=False,
tuesday=False,
wednesday=False,
thursday=False,
friday=False,
saturday=False,
sunday=False,
start_date=calendar_date_dict['date'],
end_date=calendar_date_dict['date'],
transport_type=transport_type,
)
calendars.append(calendar)
if len(calendar_dates) >= bulk_size and not dry_run:
Calendar.objects.bulk_create(calendars,
update_conflicts=True,
update_fields=['end_date'],
unique_fields=['id'])
CalendarDate.objects.bulk_create(calendar_dates,
update_conflicts=True,
update_fields=['service_id', 'date', 'exception_type'],
unique_fields=['id'])
calendars.clear()
calendar_dates.clear()
if calendar_dates and not dry_run:
Calendar.objects.bulk_create(calendars,
update_conflicts=True,
update_fields=['end_date'],
unique_fields=['id'])
CalendarDate.objects.bulk_create(calendar_dates,
update_conflicts=True,
update_fields=['service_id', 'date', 'exception_type'],
unique_fields=['id'])
calendars.clear()
calendar_dates.clear()
trips = []
for trip_dict in csv.DictReader(zipfile.read("trips.txt").decode().splitlines()):
trip_dict: dict
trip_id = trip_dict['trip_id']
if transport_type != "TN":
trip_id, last_update = trip_id.split(':', 1)
last_update = datetime.fromisoformat(last_update)
else:
last_update = None
trip = Trip(
id=trip_id,
route_id=trip_dict['route_id'],
service_id=f"{transport_type}-{trip_dict['service_id']}",
headsign=trip_dict['trip_headsign'],
short_name=trip_dict.get('trip_short_name', ""),
direction_id=trip_dict['direction_id'] or None,
block_id=trip_dict['block_id'],
shape_id=trip_dict['shape_id'],
wheelchair_accessible=trip_dict.get('wheelchair_accessible', None),
bikes_allowed=trip_dict.get('bikes_allowed', None),
last_update=last_update,
)
trips.append(trip)
if len(trips) >= bulk_size and not dry_run:
Trip.objects.bulk_create(trips,
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed'],
unique_fields=['id'])
trips.clear()
if trips and not dry_run:
Trip.objects.bulk_create(trips,
update_conflicts=True,
update_fields=['route_id', 'service_id', 'headsign', 'short_name',
'direction_id', 'block_id', 'shape_id',
'wheelchair_accessible', 'bikes_allowed'],
unique_fields=['id'])
trips.clear()
all_trips.extend(trips)
stop_times = []
for stop_time_dict in csv.DictReader(zipfile.read("stop_times.txt").decode().splitlines()):
stop_time_dict: dict
trip_id = stop_time_dict['trip_id']
if transport_type != "TN":
trip_id = trip_id.split(':', 1)[0]
arr_time = stop_time_dict['arrival_time']
arr_time = int(arr_time[:2]) * 3600 + int(arr_time[3:5]) * 60 + int(arr_time[6:])
dep_time = stop_time_dict['departure_time']
dep_time = int(dep_time[:2]) * 3600 + int(dep_time[3:5]) * 60 + int(dep_time[6:])
st = StopTime(
id=f"{stop_time_dict['trip_id']}-{stop_time_dict['stop_id']}",
trip_id=trip_id,
arrival_time=timedelta(seconds=arr_time),
departure_time=timedelta(seconds=dep_time),
stop_id=stop_time_dict['stop_id'],
stop_sequence=stop_time_dict['stop_sequence'],
stop_headsign=stop_time_dict['stop_headsign'],
pickup_type=stop_time_dict['pickup_type'],
drop_off_type=stop_time_dict['drop_off_type'],
timepoint=stop_time_dict.get('timepoint', None),
)
stop_times.append(st)
if len(stop_times) >= bulk_size and not dry_run:
StopTime.objects.bulk_create(stop_times,
update_conflicts=True,
update_fields=['stop_id', 'arrival_time', 'departure_time',
'stop_headsign', 'pickup_type',
'drop_off_type', 'timepoint'],
unique_fields=['id'])
stop_times.clear()
if stop_times and not dry_run:
StopTime.objects.bulk_create(stop_times,
update_conflicts=True,
update_fields=['stop_id', 'arrival_time', 'departure_time',
'stop_headsign', 'pickup_type',
'drop_off_type', 'timepoint'],
unique_fields=['id'])
stop_times.clear()
transfers = []
for transfer_dict in csv.DictReader(zipfile.read("transfers.txt").decode().splitlines()):
transfer_dict: dict
transfer = Transfer(
id=f"{transfer_dict['from_stop_id']}-{transfer_dict['to_stop_id']}",
from_stop_id=transfer_dict['from_stop_id'],
to_stop_id=transfer_dict['to_stop_id'],
transfer_type=transfer_dict['transfer_type'],
min_transfer_time=transfer_dict['min_transfer_time'],
)
transfers.append(transfer)
if len(transfers) >= bulk_size and not dry_run:
Transfer.objects.bulk_create(transfers,
update_conflicts=True,
update_fields=['transfer_type', 'min_transfer_time'],
unique_fields=['id'])
transfers.clear()
if transfers and not dry_run:
Transfer.objects.bulk_create(transfers,
update_conflicts=True,
update_fields=['transfer_type', 'min_transfer_time'],
unique_fields=['id'])
transfers.clear()
if "feed_info.txt" in zipfile.namelist() and not dry_run:
for feed_info_dict in csv.DictReader(zipfile.read("feed_info.txt").decode().splitlines()):
feed_info_dict: dict
FeedInfo.objects.update_or_create(
publisher_name=feed_info_dict['feed_publisher_name'],
defaults=dict(
publisher_url=feed_info_dict['feed_publisher_url'],
lang=feed_info_dict['feed_lang'],
start_date=feed_info_dict['feed_start_date'],
end_date=feed_info_dict['feed_end_date'],
version=feed_info_dict['feed_version'],
)
)