tgvmax/app.py

290 lines
9.6 KiB
Python

#!/usr/bin/env python3
import csv
from datetime import date, datetime, time, timedelta
import os
import json
from pytz import timezone
import requests
import click
from flask import Flask
from flask.cli import AppGroup
from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time
from sqlalchemy.sql import func
from tqdm import tqdm
import config
app = Flask(__name__)
cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.")
app.cli.add_command(cli)
app.config |= config.FLASK_CONFIG
db = SQLAlchemy(app)
Migrate(app, db)
class Train(db.Model):
__tablename__ = 'train'
id = Column(String, primary_key=True)
day = Column(Date, index=True)
number = Column(Integer, index=True)
entity = Column(String(10))
axe = Column(String(32), index=True)
orig_iata = Column(String(5), index=True)
dest_iata = Column(String(5), index=True)
orig = Column(String(32))
dest = Column(String(32))
dep = Column(Time)
arr = Column(Time)
tgvmax = Column(Boolean, index=True)
remaining_seats = Column(Integer, default=-1)
last_modification = Column(DateTime)
expiration_time = Column(DateTime)
class RouteQueue(db.Model):
id = Column(Integer, autoincrement=True, primary_key=True)
queue_time = Column(DateTime(timezone=True), server_default=func.now())
day = Column(Date)
origin = Column(String(5))
destination = Column(String(5))
response_time = Column(DateTime(timezone=True), nullable=True, default=None)
expiration_time = Column(DateTime(timezone=True), nullable=True, default=None)
@cli.command("update-dataset")
def update_dataset():
"""
Query the latest version of the SNCF OpenData dataset, as a CSV file.
"""
try:
resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/')
content = resp.content.decode().split('<script type="application/ld+json">')[1].split('</script>')[0].strip()
content = content.replace('\r', '')
content = content.replace('" \n', '" \\n')
content = content.replace('.\n', '.\\n')
content = content.replace('\n\n \nLa', '\\n\\n \\nLa')
content = content.replace('\n"', '\\n"')
info = json.loads(content)
modified_date = datetime.fromisoformat(info['dateModified'])
utc = timezone('UTC')
last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile('tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc)
if last_modified < modified_date:
print("Updating tgvmax.csv…")
with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp:
resp.raise_for_status()
with open('tgvmax.csv', 'wb') as f:
with tqdm(unit='io', unit_scale=True) as t:
for chunk in resp.iter_content(chunk_size=512 * 1024):
if chunk:
f.write(chunk)
t.update(len(chunk))
os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp()))
print("Done")
print("Last modification:", modified_date)
except Exception as e:
print("An error occured while updating tgvmax.csv")
print(e)
@cli.command("parse-csv")
@click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.")
def parse_trains(flush: bool = False):
"""
Parse the CSV file and store it to the database.
"""
if flush:
print("Flush database…")
db.session.query(Train).delete()
last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC'))
with open('tgvmax.csv') as f:
first_line = True
already_seen = set()
for line in tqdm(csv.reader(f, delimiter=';')):
if first_line:
first_line = False
continue
train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}"
if train_id in already_seen:
# Some trains are mysteriously duplicated, concerns only some « Intercités de nuit »
# and the Brive-la-Gaillarde -- Paris
# and, maybe, for Roubaix-Tourcoing
if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'):
print("Duplicate:", train_id)
continue
train = Train(
id=train_id,
day=date.fromisoformat(line[0]),
number=int(line[1]),
entity=line[2],
axe=line[3],
orig_iata=line[4],
dest_iata=line[5],
orig=line[6],
dest=line[7],
dep=time.fromisoformat(line[8]),
arr=time.fromisoformat(line[9]),
tgvmax=line[10] == 'OUI',
last_modification=last_modification,
expiration_time=last_modification,
)
if flush:
db.session.add(train)
else:
db.session.merge(train)
if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'):
already_seen.add(train_id)
db.session.commit()
def find_routes(day, orig, dest):
trains = parse_trains(filter_day=date(2023, 2, 17),
filter_tgvmax=True)
trains.sort(key=lambda train: train.dep)
origin = "STRASBOURG"
dest = "LYON (intramuros)"
explore = []
per_arr_explore = {}
valid_routes = []
for train in tqdm(trains):
if train.orig == origin:
it = [train]
if train.dest == dest:
# We hope that we have a direct train
valid_routes.append(it)
else:
explore.append(it)
per_arr_explore.setdefault(train.dest, [])
per_arr_explore[train.dest].append(it)
continue
for it in list(per_arr_explore.get(train.orig, [])):
if any(train.dest == tr.dest or train.dest == origin for tr in it):
# Avoid loops
continue
last_train = it[-1]
if last_train.arr <= train.dep:
new_it = it + [train]
if train.dest == dest:
# Goal is achieved
valid_routes.append(new_it)
else:
explore.append(new_it)
per_arr_explore.setdefault(train.dest, [])
per_arr_explore[train.dest].append(new_it)
return valid_routes
def print_route(route: list[Train]):
s = f"{route[0].orig} "
for tr in route:
s += f"({tr.dep}) --> ({tr.arr}) {tr.dest}, "
print(s[:-2])
@cli.command('queue-route')
@click.argument('day', type=click.DateTime(formats=['%Y-%m-%d']))
@click.argument('origin', type=str)
@click.argument('destination', type=str)
def queue_route(day: date | datetime, origin: str, destination: str):
"""
Fetch the TGVMax simulator to refresh data.
DAY: The day to query, in format YYYY-MM-DD.
ORIGIN: The origin of the route.
DESTINATION: The destination of the route.
"""
if isinstance(day, datetime):
day = day.date()
query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination, response_time=None)
if query.count():
print("Already queued")
return
db.session.add(RouteQueue(day=day, origin=origin, destination=destination))
db.session.commit()
@cli.command('process-queue', help="Process the waiting list to refresh from the simulator.")
@click.argument('number', default=5, type=int)
def process_queue(number: int):
queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time)
if number > 0:
queue = queue[:number]
URL = "https://www.maxjeune-tgvinoui.sncf/api/public/refdata/search-freeplaces-proposals"
for req in queue:
req: RouteQueue
resp = requests.post(URL, json={
'departureDateTime': req.day.isoformat(),
'origin': req.origin,
'destination': req.destination,
})
if resp.status_code == 404:
# No travel found
req.response_time = datetime.now()
req.expiration_time = datetime.now() + timedelta(hours=1)
db.session.add(req)
continue
resp.raise_for_status()
data = resp.json()
req.response_time = datetime.utcfromtimestamp(data['updatedAt'] // 1000).replace(tzinfo=timezone('UTC'))
req.expiration_time = datetime.utcfromtimestamp(data['expiresAt'] // 1000).replace(tzinfo=timezone('UTC'))
db.session.add(req)
db.session.query(Train).filter_by(day=req.day, orig_iata=req.origin, dest_iata=req.destination)\
.update(dict(tgvmax=False, remaining_seats=-1))
for proposal in data['proposals']:
train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']),
orig_iata=req.origin, dest_iata=req.destination).first()
train.tgvmax = True
train.remaining_seats = proposal['freePlaces']
train.last_modification = req.response_time
train.expiration_time = req.expiration_time
db.session.add(train)
db.session.commit()
@app.get('/')
def index():
return "Hello world!"
if __name__ == '__main__':
app.run(debug=True)