#!/usr/bin/env python3 from contextlib import nullcontext import csv from datetime import date, datetime, time, timedelta import os import json from pytz import timezone import requests from time import sleep import click from flask import Flask, render_template from flask.cli import AppGroup from flask_migrate import Migrate from flask_sqlalchemy import SQLAlchemy from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time from sqlalchemy.sql import func from tqdm import tqdm import config app = Flask(__name__) cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.") app.cli.add_command(cli) app.config |= config.FLASK_CONFIG db = SQLAlchemy(app) Migrate(app, db) class Train(db.Model): __tablename__ = 'train' id = Column(String, primary_key=True) day = Column(Date, index=True) number = Column(Integer, index=True) entity = Column(String(16)) axe = Column(String(32), index=True) orig_iata = Column(String(5), index=True) dest_iata = Column(String(5), index=True) orig = Column(String(32)) dest = Column(String(32)) dep = Column(Time) arr = Column(Time) tgvmax = Column(Boolean, index=True) remaining_seats = Column(Integer, default=-1) last_modification = Column(DateTime) expiration_time = Column(DateTime) class RouteQueue(db.Model): id = Column(Integer, autoincrement=True, primary_key=True) queue_time = Column(DateTime(timezone=True), server_default=func.now()) day = Column(Date) origin = Column(String(5)) destination = Column(String(5)) response_time = Column(DateTime(timezone=True), nullable=True, default=None) expiration_time = Column(DateTime(timezone=True), nullable=True, default=None) @cli.command("update-dataset") @click.option('--verbose', '-v', is_flag=True, help="Display errors.") def update_dataset(verbose: bool = False): """ Query the latest version of the SNCF OpenData dataset, as a CSV file. """ resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/') content = resp.content.decode().split('')[0].strip() content = content.replace('\r', '') content = content.replace('" \n', '" \\n') content = content.replace('.\n', '.\\n') content = content.replace('\n\n \nLa', '\\n\\n \\nLa') content = content.replace('\n"', '\\n"') info = json.loads(content) modified_date = datetime.fromisoformat(info['dateModified']) utc = timezone('UTC') last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile( 'tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc) if last_modified < modified_date: if verbose: print("Updating tgvmax.csv…") with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp: resp.raise_for_status() with open('tgvmax.csv', 'wb') as f: with tqdm(unit='io', unit_scale=True) if verbose else nullcontext() as t: for chunk in resp.iter_content(chunk_size=512 * 1024): if chunk: f.write(chunk) if verbose: t.update(len(chunk)) os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp())) if verbose: print("Done") else: if verbose: print("Last modification:", modified_date) exit(2) @cli.command("parse-csv") @click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.") @click.option('--verbose', '-v', is_flag=True, help="Display errors.") def parse_trains(flush: bool = False, verbose: bool = False): """ Parse the CSV file and store it to the database. """ if flush: if verbose: print("Flush database…") db.session.query(Train).delete() last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC')) with open('tgvmax.csv') as f: first_line = True already_seen = set() for line in (tqdm if verbose else lambda x: x)(csv.reader(f, delimiter=';')): if first_line: first_line = False continue train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}" if train_id in already_seen: # Some trains are mysteriously duplicated, concerns only some « Intercités de nuit » # and the Brive-la-Gaillarde -- Paris # and, maybe, for Roubaix-Tourcoing if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'): print("Duplicate:", train_id) continue train = Train( id=train_id, day=date.fromisoformat(line[0]), number=int(line[1]), entity=line[2], axe=line[3], orig_iata=line[4], dest_iata=line[5], orig=line[6], dest=line[7], dep=time.fromisoformat(line[8]), arr=time.fromisoformat(line[9]), tgvmax=line[10] == 'OUI', last_modification=last_modification, expiration_time=last_modification, ) if flush: db.session.add(train) else: db.session.merge(train) if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'): already_seen.add(train_id) db.session.commit() def find_routes(day: date | datetime, origin: str, destination: str | None, verbose: bool = False): if isinstance(day, datetime): day = day.date() trains = db.session.query(Train).filter_by(day=day, tgvmax=True).all() trains.sort(key=lambda train: train.dep) # For better results later, fetch all trains from the origin or to the destination # This is not exhaustive, but can be a good approximation queue_routes(day, origin=origin, verbose=verbose, autocommit=False) if destination: queue_routes(day, destination=destination, verbose=verbose, autocommit=False) db.session.commit() per_arr_explore = {} valid_routes = [] for train in (t := tqdm(trains) if verbose else trains): if train.orig == origin: # Update from the TGVMax simulator queue_route(day, train.orig_iata, train.dest_iata, verbose, False) it = [train] if train.dest == destination: # We hope that we have a direct train valid_routes.append(it) else: per_arr_explore.setdefault(train.dest, []) per_arr_explore[train.dest].append(it) continue for it in list(per_arr_explore.get(train.orig, [])): if any(train.dest == tr.dest or train.dest == origin for tr in it): # Avoid loops continue last_train = it[-1] if last_train.arr <= train.dep: # Update from the TGVMax simulator, this line can be useful later queue_route(day, train.orig_iata, train.dest_iata, verbose, False) new_it = it + [train] if train.dest == destination: # Goal is achieved valid_routes.append(new_it) else: per_arr_explore.setdefault(train.dest, []) per_arr_explore[train.dest].append(new_it) # Send queued trains to the database db.session.commit() return {destination: valid_routes} if destination else per_arr_explore # Don't use the decorator to keep the function callable cli.command('find-routes')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) (click.argument('origin', type=str) (click.argument('destination', type=str, default=None) (click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") (find_routes))))) def queue_route(day: date | datetime, origin: str, destination: str, verbose: bool = False, autocommit: bool = True): """ Fetch the TGVMax simulator to refresh data. DAY: The day to query, in format YYYY-MM-DD. ORIGIN: The origin of the route. DESTINATION: The destination of the route. """ if isinstance(day, datetime): day = day.date() query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination, response_time=None) if query.count(): return query = db.session.query(RouteQueue).filter(RouteQueue.day == day, RouteQueue.origin == origin, RouteQueue.destination == destination, RouteQueue.expiration_time >= datetime.now(timezone('UTC'))) if query.count(): return db.session.add(RouteQueue(day=day, origin=origin, destination=destination)) db.session.commit() # Don't use the decorator to keep the function callable cli.command('queue-route')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) (click.argument('origin', type=str) (click.argument('destination', type=str) (click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") (queue_route))))) def queue_routes(day: date | datetime, origin: str | None = None, destination: str | None = None, verbose: bool = False, autocommit: bool = True): if isinstance(day, datetime): day = day.date() query = db.session.query(Train).filter((Train.day == day)) if origin: query = query.filter((Train.orig_iata == origin) | (Train.orig == origin)) if destination: query = query.filter((Train.dest_iata == destination) | (Train.dest == destination)) query = query.all() for train in (t := tqdm(query) if verbose else query): if verbose: t.set_description(f"{day}: {train.orig} --> {train.dest}") queue_route(day, train.orig_iata, train.dest_iata, verbose, autocommit) # Same as above cli.command('queue-routes')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) (click.option('--origin', '-o', default=None) (click.option('--destination', '-d', default=None) (click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") (queue_routes))))) @cli.command('process-queue', help="Process the waiting list to refresh from the simulator.") @click.argument('number', default=30, type=int) @click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") def process_queue(number: int, verbose: bool = False): queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time).all() if number > 0: queue = queue[:number] URL = "https://www.maxjeune-tgvinoui.sncf/api/public/refdata/search-freeplaces-proposals" if verbose: query = db.session.query(Train).with_entities(Train.orig_iata, Train.orig).distinct() iata_to_names = {k: v for (k, v) in query.all()} for i, req in enumerate(t := tqdm(queue) if verbose else queue): req: RouteQueue if verbose: t.set_description(f"{req.day:%d/%m/%Y}: {iata_to_names[req.origin]} --> {iata_to_names[req.destination]}") resp = None while resp is None or resp.status_code == 429: resp = requests.post(URL, json={ 'departureDateTime': req.day.isoformat(), 'origin': req.origin, 'destination': req.destination, }) if resp.status_code == 429: sleep(1) if resp.status_code == 404: # No travel found req.response_time = datetime.now() req.expiration_time = datetime.now() + timedelta(hours=3) db.session.add(req) continue resp.raise_for_status() data = resp.json() req.response_time = datetime.utcfromtimestamp(data['updatedAt'] // 1000).replace(tzinfo=timezone('UTC')) req.expiration_time = datetime.utcfromtimestamp(data['expiresAt'] // 1000).replace(tzinfo=timezone('UTC')) req.expiration_time += timedelta(hours=3) # By default 5 minutes, extend it to 3 hours to be safe db.session.add(req) db.session.query(Train).filter_by(day=req.day, orig_iata=req.origin, dest_iata=req.destination) \ .update(dict(tgvmax=False, remaining_seats=-1)) for proposal in data['proposals']: train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']), orig_iata=proposal['origin']['rrCode'], dest_iata=proposal['destination']['rrCode']).first() if train is None: # In a city with multiple stations print("ERROR") print(proposal) continue train.tgvmax = True train.remaining_seats = proposal['freePlaces'] train.last_modification = req.response_time train.expiration_time = req.expiration_time db.session.add(train) if i % 50 == 0: db.session.commit() db.session.commit() @app.get('/') def index(): return render_template('index.html', today=date.today(), max_day=date.today() + timedelta(days=30)) @app.get('/api/iata-codes/') def iata_codes(): query = db.session.query(Train).with_entities(Train.orig_iata, Train.orig).distinct() return { 'iata2name': { k: v for (k, v) in query.all() }, 'name2iata': { v: k for (k, v) in query.all() } } @app.get('/api/routes////') def get_routes(day: date | str, origin: str, destination: str): if isinstance(day, str): day = date.fromisoformat(day) if destination == 'undefined': destination = None routes = find_routes(day, origin, destination) return { city: [ [{ 'origin': tr.orig, 'origin_iata': tr.orig_iata, 'destination': tr.dest, 'destination_iata': tr.dest_iata, 'departure': tr.dep.isoformat(), 'arrival': tr.arr.isoformat(), 'number': tr.number, 'free_seats': tr.remaining_seats, } for tr in route] for route in city_routes ] for city, city_routes in routes.items() } if __name__ == '__main__': app.run(debug=True)