#!/usr/bin/env python3 import csv from datetime import date, datetime, time, timedelta import os import json from pytz import timezone import requests import click from flask import Flask, render_template from flask.cli import AppGroup from flask_migrate import Migrate from flask_sqlalchemy import SQLAlchemy from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time from sqlalchemy.sql import func from tqdm import tqdm import config app = Flask(__name__) cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.") app.cli.add_command(cli) app.config |= config.FLASK_CONFIG db = SQLAlchemy(app) Migrate(app, db) class Train(db.Model): __tablename__ = 'train' id = Column(String, primary_key=True) day = Column(Date, index=True) number = Column(Integer, index=True) entity = Column(String(16)) axe = Column(String(32), index=True) orig_iata = Column(String(5), index=True) dest_iata = Column(String(5), index=True) orig = Column(String(32)) dest = Column(String(32)) dep = Column(Time) arr = Column(Time) tgvmax = Column(Boolean, index=True) remaining_seats = Column(Integer, default=-1) last_modification = Column(DateTime) expiration_time = Column(DateTime) class RouteQueue(db.Model): id = Column(Integer, autoincrement=True, primary_key=True) queue_time = Column(DateTime(timezone=True), server_default=func.now()) day = Column(Date) origin = Column(String(5)) destination = Column(String(5)) response_time = Column(DateTime(timezone=True), nullable=True, default=None) expiration_time = Column(DateTime(timezone=True), nullable=True, default=None) @cli.command("update-dataset") def update_dataset(): """ Query the latest version of the SNCF OpenData dataset, as a CSV file. """ try: resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/') content = resp.content.decode().split('')[0].strip() content = content.replace('\r', '') content = content.replace('" \n', '" \\n') content = content.replace('.\n', '.\\n') content = content.replace('\n\n \nLa', '\\n\\n \\nLa') content = content.replace('\n"', '\\n"') info = json.loads(content) modified_date = datetime.fromisoformat(info['dateModified']) utc = timezone('UTC') last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile('tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc) if last_modified < modified_date: print("Updating tgvmax.csv…") with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp: resp.raise_for_status() with open('tgvmax.csv', 'wb') as f: with tqdm(unit='io', unit_scale=True) as t: for chunk in resp.iter_content(chunk_size=512 * 1024): if chunk: f.write(chunk) t.update(len(chunk)) os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp())) print("Done") print("Last modification:", modified_date) except Exception as e: print("An error occured while updating tgvmax.csv") print(e) @cli.command("parse-csv") @click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.") def parse_trains(flush: bool = False): """ Parse the CSV file and store it to the database. """ if flush: print("Flush database…") db.session.query(Train).delete() last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC')) with open('tgvmax.csv') as f: first_line = True already_seen = set() for line in tqdm(csv.reader(f, delimiter=';')): if first_line: first_line = False continue train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}" if train_id in already_seen: # Some trains are mysteriously duplicated, concerns only some « Intercités de nuit » # and the Brive-la-Gaillarde -- Paris # and, maybe, for Roubaix-Tourcoing if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'): print("Duplicate:", train_id) continue train = Train( id=train_id, day=date.fromisoformat(line[0]), number=int(line[1]), entity=line[2], axe=line[3], orig_iata=line[4], dest_iata=line[5], orig=line[6], dest=line[7], dep=time.fromisoformat(line[8]), arr=time.fromisoformat(line[9]), tgvmax=line[10] == 'OUI', last_modification=last_modification, expiration_time=last_modification, ) if flush: db.session.add(train) else: db.session.merge(train) if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'): already_seen.add(train_id) db.session.commit() def find_routes(day, origin, destination): trains = db.session.query(Train).filter_by(day=day, tgvmax=True).all() trains.sort(key=lambda train: train.dep) explore = [] per_arr_explore = {} valid_routes = [] for train in tqdm(trains): if train.orig == origin: it = [train] if train.dest == destination: # We hope that we have a direct train valid_routes.append(it) else: explore.append(it) per_arr_explore.setdefault(train.dest, []) per_arr_explore[train.dest].append(it) continue for it in list(per_arr_explore.get(train.orig, [])): if any(train.dest == tr.dest or train.dest == origin for tr in it): # Avoid loops continue last_train = it[-1] if last_train.arr <= train.dep: new_it = it + [train] if train.dest == destination: # Goal is achieved valid_routes.append(new_it) else: explore.append(new_it) per_arr_explore.setdefault(train.dest, []) per_arr_explore[train.dest].append(new_it) return valid_routes def print_route(route: list[Train]): s = f"{route[0].orig} " for tr in route: s += f"({tr.dep}) --> ({tr.arr}) {tr.dest}, " print(s[:-2]) @cli.command('queue-route') @click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) @click.argument('origin', type=str) @click.argument('destination', type=str) def queue_route(day: date | datetime, origin: str, destination: str): """ Fetch the TGVMax simulator to refresh data. DAY: The day to query, in format YYYY-MM-DD. ORIGIN: The origin of the route. DESTINATION: The destination of the route. """ if isinstance(day, datetime): day = day.date() query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination, response_time=None) if query.count(): print("Already queued") return query = db.session.query(RouteQueue).filter(RouteQueue.day == day, RouteQueue.origin == origin, RouteQueue.destination == destination, RouteQueue.expiration_time >= datetime.now(timezone('UTC'))) if query.count(): print("Using recent value") return db.session.add(RouteQueue(day=day, origin=origin, destination=destination)) db.session.commit() @cli.command('process-queue', help="Process the waiting list to refresh from the simulator.") @click.argument('number', default=5, type=int) def process_queue(number: int): queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time) if number > 0: queue = queue[:number] URL = "https://www.maxjeune-tgvinoui.sncf/api/public/refdata/search-freeplaces-proposals" for req in queue: req: RouteQueue resp = requests.post(URL, json={ 'departureDateTime': req.day.isoformat(), 'origin': req.origin, 'destination': req.destination, }) if resp.status_code == 404: # No travel found req.response_time = datetime.now() req.expiration_time = datetime.now() + timedelta(hours=1) db.session.add(req) continue resp.raise_for_status() data = resp.json() req.response_time = datetime.utcfromtimestamp(data['updatedAt'] // 1000).replace(tzinfo=timezone('UTC')) req.expiration_time = datetime.utcfromtimestamp(data['expiresAt'] // 1000).replace(tzinfo=timezone('UTC')) req.expiration_time += timedelta(hours=3) # By default 5 minutes, extend it to 3 hours to be safe db.session.add(req) db.session.query(Train).filter_by(day=req.day, orig_iata=req.origin, dest_iata=req.destination)\ .update(dict(tgvmax=False, remaining_seats=-1)) for proposal in data['proposals']: train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']), orig_iata=req.origin, dest_iata=req.destination).first() train.tgvmax = True train.remaining_seats = proposal['freePlaces'] train.last_modification = req.response_time train.expiration_time = req.expiration_time db.session.add(train) db.session.commit() @app.get('/') def index(): return render_template('index.html', today=date.today(), max_day=date.today() + timedelta(days=30)) @app.get('/api/iata-codes/') def iata_codes(): query = db.session.query(Train).with_entities(Train.orig_iata, Train.orig).distinct() return { 'iata2name': { k: v for (k, v) in query.all() }, 'name2iata': { v: k for (k, v) in query.all() } } @app.get('/api/routes////') def get_routes(day: date, origin: str, destination: str): routes = find_routes(day, origin, destination) return [ [{ 'origin': tr.orig, 'origin_iata': tr.orig_iata, 'destination': tr.dest, 'destination_iata': tr.dest_iata, 'departure': tr.dep.isoformat(), 'arrival': tr.arr.isoformat(), 'number': tr.number, 'free_seats': tr.remaining_seats, } for tr in route] for route in routes ] if __name__ == '__main__': app.run(debug=True)