tgvmax/app.py

410 lines
15 KiB
Python

#!/usr/bin/env python3
from contextlib import nullcontext
import csv
from datetime import date, datetime, time, timedelta
import os
import json
from pytz import timezone
import requests
from time import sleep
import click
from flask import Flask, render_template
from flask.cli import AppGroup
from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time
from sqlalchemy.sql import func
from tqdm import tqdm
import config
app = Flask(__name__)
cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.")
app.cli.add_command(cli)
app.config |= config.FLASK_CONFIG
db = SQLAlchemy(app)
Migrate(app, db)
class Train(db.Model):
__tablename__ = 'train'
id = Column(String, primary_key=True)
day = Column(Date, index=True)
number = Column(Integer, index=True)
entity = Column(String(16))
axe = Column(String(32), index=True)
orig_iata = Column(String(5), index=True)
dest_iata = Column(String(5), index=True)
orig = Column(String(32))
dest = Column(String(32))
dep = Column(Time)
arr = Column(Time)
tgvmax = Column(Boolean, index=True)
remaining_seats = Column(Integer, default=-1)
last_modification = Column(DateTime)
expiration_time = Column(DateTime)
class RouteQueue(db.Model):
id = Column(Integer, autoincrement=True, primary_key=True)
queue_time = Column(DateTime(timezone=True), server_default=func.now())
day = Column(Date)
origin = Column(String(5))
destination = Column(String(5))
response_time = Column(DateTime(timezone=True), nullable=True, default=None)
expiration_time = Column(DateTime(timezone=True), nullable=True, default=None)
@cli.command("update-dataset")
@click.option('--verbose', '-v', is_flag=True, help="Display errors.")
def update_dataset(verbose: bool = False):
"""
Query the latest version of the SNCF OpenData dataset, as a CSV file.
"""
resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/')
content = resp.content.decode().split('<script type="application/ld+json">')[1].split('</script>')[0].strip()
content = content.replace('\r', '')
content = content.replace('" \n', '" \\n')
content = content.replace('.\n', '.\\n')
content = content.replace('\n\n \nLa', '\\n\\n \\nLa')
content = content.replace('\n"', '\\n"')
info = json.loads(content)
modified_date = datetime.fromisoformat(info['dateModified'])
utc = timezone('UTC')
last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile(
'tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc)
if last_modified < modified_date:
if verbose:
print("Updating tgvmax.csv…")
with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp:
resp.raise_for_status()
with open('tgvmax.csv', 'wb') as f:
with tqdm(unit='io', unit_scale=True) if verbose else nullcontext() as t:
for chunk in resp.iter_content(chunk_size=512 * 1024):
if chunk:
f.write(chunk)
if verbose:
t.update(len(chunk))
os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp()))
if verbose:
print("Done")
else:
if verbose:
print("Last modification:", modified_date)
exit(2)
@cli.command("parse-csv")
@click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.")
@click.option('--verbose', '-v', is_flag=True, help="Display errors.")
def parse_trains(flush: bool = False, verbose: bool = False):
"""
Parse the CSV file and store it to the database.
"""
if flush:
if verbose:
print("Flush database…")
db.session.query(Train).delete()
last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC'))
with open('tgvmax.csv') as f:
first_line = True
already_seen = set()
already_updated = set(x[0] for x in db.session.query(Train).filter(Train.last_modification > last_modification)\
.values(Train.id))
for line in (tqdm if verbose else lambda x: x)(csv.reader(f, delimiter=';')):
if first_line:
# Skip first line
first_line = False
continue
train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}"
if train_id in already_seen:
# Some trains are mysteriously duplicated, concerns only some « Intercités de nuit »
# and the Brive-la-Gaillarde -- Paris
# and, maybe, for Roubaix-Tourcoing
if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'):
print("Duplicate:", train_id)
continue
if train_id in already_updated:
# Already updated by the simulator
continue
train = Train(
id=train_id,
day=date.fromisoformat(line[0]),
number=int(line[1]),
entity=line[2],
axe=line[3],
orig_iata=line[4],
dest_iata=line[5],
orig=line[6],
dest=line[7],
dep=time.fromisoformat(line[8]),
arr=time.fromisoformat(line[9]),
tgvmax=line[10] == 'OUI',
last_modification=last_modification,
expiration_time=last_modification,
)
if flush:
db.session.add(train)
else:
db.session.merge(train)
if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'):
already_seen.add(train_id)
db.session.commit()
def find_routes(day: date | datetime, origin: str, destination: str | None,
verbose: bool = False, min_dep: time = time(0, 0),
explored: dict | None = None):
if isinstance(day, datetime):
day = day.date()
if explored is None:
explored = {}
if origin not in explored:
explored[origin] = (min_dep, None)
valid_routes = []
max_dep = time(23, 59)
else:
max_dep, valid_routes = explored[origin]
if max_dep < min_dep:
# Already parsed these trains
return {destination: valid_routes}
explored[origin] = min_dep, None
trains = db.session.query(Train).filter_by(day=day, tgvmax=True, orig=origin)\
.filter(Train.dep >= min_dep, Train.dep < max_dep).all()
if not trains:
# No train in the requested interval
explored[origin] = (min_dep, valid_routes)
return {destination: valid_routes}
trains.sort(key=lambda train: train.dep)
db.session.commit()
for train in (t := tqdm(trains, desc=origin) if verbose else trains):
if train.dest == destination:
# We hope that we have a direct train
valid_routes.append([train])
else:
if train.dest in explored and explored[train.dest][1] is None:
# This is a loop
continue
elif train.arr < min_dep:
# The train is not direct and arrives on the next day, we avoid that
continue
find_routes(day, train.dest, destination, verbose, train.arr, explored)
# Filter unusuable trains
valid_routes += [[train] + it for it in explored[train.dest][1] if it[0].dep >= train.arr]
explored[origin] = (min_dep, valid_routes)
# Send queued trains to the database
db.session.commit()
return {destination: valid_routes} if destination else {}
# Don't use the decorator to keep the function callable
cli.command('find-routes')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d']))
(click.argument('origin', type=str)
(click.argument('destination', type=str, default=None)
(click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.")
(find_routes)))))
def queue_route(day: date | datetime, origin: str, destination: str, verbose: bool = False, autocommit: bool = True):
"""
Fetch the TGVMax simulator to refresh data.
DAY: The day to query, in format YYYY-MM-DD.
ORIGIN: The origin of the route.
DESTINATION: The destination of the route.
"""
if isinstance(day, datetime):
day = day.date()
query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination)\
.filter((RouteQueue.response_time == None) | (RouteQueue.expiration_time >= datetime.now(timezone('UTC'))))
if query.count():
return
db.session.add(RouteQueue(day=day, origin=origin, destination=destination))
if autocommit:
db.session.commit()
# Don't use the decorator to keep the function callable
cli.command('queue-route')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d']))
(click.argument('origin', type=str)
(click.argument('destination', type=str)
(click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.")
(queue_route)))))
def queue_routes(day: date | datetime, origin: str | None = None,
destination: str | None = None, verbose: bool = False,
autocommit: bool = True):
if isinstance(day, datetime):
day = day.date()
valid_routes = set(db.session.query(RouteQueue).filter_by(day=day)\
.filter((RouteQueue.response_time == None) | (RouteQueue.expiration_time >= datetime.now(timezone('UTC'))))\
.values(RouteQueue.origin, RouteQueue.destination))
query = db.session.query(Train).filter((Train.day == day))
if origin:
query = query.filter((Train.orig_iata == origin) | (Train.orig == origin))
if destination:
query = query.filter((Train.dest_iata == destination) | (Train.dest == destination))
query = query.all()
for train in (t := tqdm(query) if verbose else query):
if verbose:
t.set_description(f"{day}: {train.orig} --> {train.dest}")
if (train.orig_iata, train.dest_iata) not in valid_routes:
queue_route(day, train.orig_iata, train.dest_iata, verbose, autocommit)
valid_routes.add((train.orig_iata, train.dest_iata))
# Same as above
cli.command('queue-routes')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d']))
(click.option('--origin', '-o', default=None)
(click.option('--destination', '-d', default=None)
(click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.")
(queue_routes)))))
@cli.command('process-queue', help="Process the waiting list to refresh from the simulator.")
@click.argument('number', default=30, type=int)
@click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.")
def process_queue(number: int, verbose: bool = False):
queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time).all()
if number > 0:
queue = queue[:number]
URL = "https://www.maxjeune-tgvinoui.sncf/api/public/refdata/search-freeplaces-proposals"
if verbose:
query = db.session.query(Train).with_entities(Train.orig_iata, Train.orig).distinct()
iata_to_names = {k: v for (k, v) in query.all()}
for i, req in enumerate(t := tqdm(queue) if verbose else queue):
req: RouteQueue
if verbose:
t.set_description(f"{req.day:%d/%m/%Y}: {iata_to_names[req.origin]} --> {iata_to_names[req.destination]}")
resp = None
while resp is None or resp.status_code == 429:
resp = requests.post(URL, json={
'departureDateTime': req.day.isoformat(),
'origin': req.origin,
'destination': req.destination,
})
if resp.status_code == 429:
sleep(1)
if resp.status_code == 404:
# No travel found
req.response_time = datetime.now()
req.expiration_time = datetime.now() + timedelta(hours=3)
db.session.add(req)
continue
resp.raise_for_status()
data = resp.json()
req.response_time = datetime.utcfromtimestamp(data['updatedAt'] // 1000).replace(tzinfo=timezone('UTC'))
req.expiration_time = datetime.utcfromtimestamp(data['expiresAt'] // 1000).replace(tzinfo=timezone('UTC'))
req.expiration_time += timedelta(hours=3) # By default 5 minutes, extend it to 3 hours to be safe
db.session.add(req)
db.session.query(Train).filter_by(day=req.day, orig_iata=req.origin, dest_iata=req.destination) \
.update(dict(tgvmax=False, remaining_seats=-1))
for proposal in data['proposals']:
train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']),
orig_iata=proposal['origin']['rrCode'],
dest_iata=proposal['destination']['rrCode']).first()
if train is None:
# In a city with multiple stations
print("ERROR")
print(proposal)
continue
train.tgvmax = True
train.remaining_seats = proposal['freePlaces']
train.last_modification = req.response_time
train.expiration_time = req.expiration_time
db.session.add(train)
if i % 50 == 0:
db.session.commit()
db.session.commit()
@app.get('/')
def index():
return render_template('index.html', today=date.today(), max_day=date.today() + timedelta(days=30))
@app.get('/api/iata-codes/')
def iata_codes():
query = db.session.query(Train).with_entities(Train.orig_iata, Train.orig).distinct()
return {
'iata2name': {
k: v for (k, v) in query.all()
},
'name2iata': {
v: k for (k, v) in query.all()
}
}
@app.get('/api/routes/<day>/<origin>/<destination>/')
def get_routes(day: date | str, origin: str, destination: str):
if isinstance(day, str):
day = date.fromisoformat(day)
if destination == 'undefined':
destination = None
routes = find_routes(day, origin, destination)
return {
city: [
[{
'origin': tr.orig,
'origin_iata': tr.orig_iata,
'destination': tr.dest,
'destination_iata': tr.dest_iata,
'departure': tr.dep.isoformat(),
'arrival': tr.arr.isoformat(),
'number': tr.number,
'free_seats': tr.remaining_seats,
} for tr in route] for route in city_routes
] for city, city_routes in routes.items()
}
if __name__ == '__main__':
app.run(debug=True)