#!/usr/bin/env python3 from contextlib import nullcontext import csv from datetime import date, datetime, time, timedelta import os import json from pytz import timezone import requests from time import sleep import click from flask import Flask, render_template from flask.cli import AppGroup from flask_migrate import Migrate from flask_sqlalchemy import SQLAlchemy from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time from sqlalchemy.sql import func from tqdm import tqdm import config app = Flask(__name__) cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.") app.cli.add_command(cli) app.config |= config.FLASK_CONFIG db = SQLAlchemy(app) Migrate(app, db) class Train(db.Model): __tablename__ = 'train' id = Column(String, primary_key=True) day = Column(Date, index=True) number = Column(Integer, index=True) entity = Column(String(16)) axe = Column(String(32), index=True) orig_iata = Column(String(5), index=True) dest_iata = Column(String(5), index=True) orig = Column(String(32)) dest = Column(String(32)) dep = Column(Time) arr = Column(Time) tgvmax = Column(Boolean, index=True) remaining_seats = Column(Integer, default=-1) last_modification = Column(DateTime) expiration_time = Column(DateTime) class RouteQueue(db.Model): id = Column(Integer, autoincrement=True, primary_key=True) queue_time = Column(DateTime(timezone=True), server_default=func.now()) day = Column(Date) origin = Column(String(5)) destination = Column(String(5)) response_time = Column(DateTime(timezone=True), nullable=True, default=None) expiration_time = Column(DateTime(timezone=True), nullable=True, default=None) @cli.command("update-dataset") @click.option('--verbose', '-v', is_flag=True, help="Display errors.") def update_dataset(verbose: bool = False): """ Query the latest version of the SNCF OpenData dataset, as a CSV file. """ resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/') content = resp.content.decode().split('')[0].strip() content = content.replace('\r', '') content = content.replace('" \n', '" \\n') content = content.replace('.\n', '.\\n') content = content.replace('\n\n \nLa', '\\n\\n \\nLa') content = content.replace('\n"', '\\n"') info = json.loads(content) modified_date = datetime.fromisoformat(info['dateModified']) utc = timezone('UTC') last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile( 'tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc) if last_modified < modified_date: if verbose: print("Updating tgvmax.csv…") with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp: resp.raise_for_status() with open('tgvmax.csv', 'wb') as f: with tqdm(unit='io', unit_scale=True) if verbose else nullcontext() as t: for chunk in resp.iter_content(chunk_size=512 * 1024): if chunk: f.write(chunk) if verbose: t.update(len(chunk)) os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp())) if verbose: print("Done") else: if verbose: print("Last modification:", modified_date) exit(2) @cli.command("parse-csv") @click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.") @click.option('--verbose', '-v', is_flag=True, help="Display errors.") def parse_trains(flush: bool = False, verbose: bool = False): """ Parse the CSV file and store it to the database. """ if flush: if verbose: print("Flush database…") db.session.query(Train).delete() last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC')) with open('tgvmax.csv') as f: first_line = True already_seen = set() already_updated = set(x[0] for x in db.session.query(Train).filter(Train.last_modification > last_modification)\ .values(Train.id)) for line in (tqdm if verbose else lambda x: x)(csv.reader(f, delimiter=';')): if first_line: # Skip first line first_line = False continue train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}" if train_id in already_seen: # Some trains are mysteriously duplicated, concerns only some « Intercités de nuit » # and the Brive-la-Gaillarde -- Paris # and, maybe, for Roubaix-Tourcoing if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'): print("Duplicate:", train_id) continue if train_id in already_updated: # Already updated by the simulator continue train = Train( id=train_id, day=date.fromisoformat(line[0]), number=int(line[1]), entity=line[2], axe=line[3], orig_iata=line[4], dest_iata=line[5], orig=line[6], dest=line[7], dep=time.fromisoformat(line[8]), arr=time.fromisoformat(line[9]), tgvmax=line[10] == 'OUI', last_modification=last_modification, expiration_time=last_modification, ) if flush: db.session.add(train) else: db.session.merge(train) if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'): already_seen.add(train_id) db.session.commit() def find_routes(day: date | datetime, origin: str, destination: str | None, verbose: bool = False, min_dep: time = time(0, 0), explored: dict | None = None): if isinstance(day, datetime): day = day.date() if explored is None: explored = {} if origin not in explored: explored[origin] = (min_dep, None) valid_routes = [] max_dep = time(23, 59) else: max_dep, valid_routes = explored[origin] if max_dep < min_dep: # Already parsed these trains return {destination: valid_routes} explored[origin] = min_dep, None trains = db.session.query(Train).filter_by(day=day, tgvmax=True, orig=origin)\ .filter(Train.dep >= min_dep, Train.dep < max_dep).all() if not trains: # No train in the requested interval explored[origin] = (min_dep, valid_routes) return {destination: valid_routes} trains.sort(key=lambda train: train.dep) db.session.commit() for train in (t := tqdm(trains, desc=origin) if verbose else trains): if train.dest == destination: # We hope that we have a direct train valid_routes.append([train]) else: if train.dest in explored and explored[train.dest][1] is None: # This is a loop continue elif train.arr < min_dep: # The train is not direct and arrives on the next day, we avoid that continue find_routes(day, train.dest, destination, verbose, train.arr, explored) # Filter unusuable trains valid_routes += [[train] + it for it in explored[train.dest][1] if it[0].dep >= train.arr] explored[origin] = (min_dep, valid_routes) # Send queued trains to the database db.session.commit() return {destination: valid_routes} if destination else {} # Don't use the decorator to keep the function callable cli.command('find-routes')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) (click.argument('origin', type=str) (click.argument('destination', type=str, default=None) (click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") (find_routes))))) def queue_route(day: date | datetime, origin: str, destination: str, verbose: bool = False, autocommit: bool = True): """ Fetch the TGVMax simulator to refresh data. DAY: The day to query, in format YYYY-MM-DD. ORIGIN: The origin of the route. DESTINATION: The destination of the route. """ if isinstance(day, datetime): day = day.date() query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination)\ .filter((RouteQueue.response_time == None) | (RouteQueue.expiration_time >= datetime.now(timezone('UTC')))) if query.count(): return db.session.add(RouteQueue(day=day, origin=origin, destination=destination)) if autocommit: db.session.commit() # Don't use the decorator to keep the function callable cli.command('queue-route')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) (click.argument('origin', type=str) (click.argument('destination', type=str) (click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") (queue_route))))) def queue_routes(day: date | datetime, origin: str | None = None, destination: str | None = None, verbose: bool = False, autocommit: bool = True): if isinstance(day, datetime): day = day.date() valid_routes = set(db.session.query(RouteQueue).filter_by(day=day)\ .filter((RouteQueue.response_time == None) | (RouteQueue.expiration_time >= datetime.now(timezone('UTC'))))\ .values(RouteQueue.origin, RouteQueue.destination)) query = db.session.query(Train).filter((Train.day == day)) if origin: query = query.filter((Train.orig_iata == origin) | (Train.orig == origin)) if destination: query = query.filter((Train.dest_iata == destination) | (Train.dest == destination)) query = query.all() for train in (t := tqdm(query) if verbose else query): if verbose: t.set_description(f"{day}: {train.orig} --> {train.dest}") if (train.orig_iata, train.dest_iata) not in valid_routes: queue_route(day, train.orig_iata, train.dest_iata, verbose, autocommit) valid_routes.add((train.orig_iata, train.dest_iata)) # Same as above cli.command('queue-routes')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) (click.option('--origin', '-o', default=None) (click.option('--destination', '-d', default=None) (click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") (queue_routes))))) @cli.command('process-queue', help="Process the waiting list to refresh from the simulator.") @click.argument('number', default=30, type=int) @click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") def process_queue(number: int, verbose: bool = False): queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time).all() if number > 0: queue = queue[:number] URL = "https://www.maxjeune-tgvinoui.sncf/api/public/refdata/search-freeplaces-proposals" if verbose: query = db.session.query(Train).with_entities(Train.orig_iata, Train.orig).distinct() iata_to_names = {k: v for (k, v) in query.all()} for i, req in enumerate(t := tqdm(queue) if verbose else queue): req: RouteQueue if verbose: t.set_description(f"{req.day:%d/%m/%Y}: {iata_to_names[req.origin]} --> {iata_to_names[req.destination]}") resp = None while resp is None or resp.status_code == 429: resp = requests.post(URL, json={ 'departureDateTime': req.day.isoformat(), 'origin': req.origin, 'destination': req.destination, }) if resp.status_code == 429: sleep(1) if resp.status_code == 404: # No travel found req.response_time = datetime.now() req.expiration_time = datetime.now() + timedelta(hours=3) db.session.add(req) continue resp.raise_for_status() data = resp.json() req.response_time = datetime.utcfromtimestamp(data['updatedAt'] // 1000).replace(tzinfo=timezone('UTC')) req.expiration_time = datetime.utcfromtimestamp(data['expiresAt'] // 1000).replace(tzinfo=timezone('UTC')) req.expiration_time += timedelta(hours=3) # By default 5 minutes, extend it to 3 hours to be safe db.session.add(req) db.session.query(Train).filter_by(day=req.day, orig_iata=req.origin, dest_iata=req.destination) \ .update(dict(tgvmax=False, remaining_seats=-1)) for proposal in data['proposals']: train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']), orig_iata=proposal['origin']['rrCode'], dest_iata=proposal['destination']['rrCode']).first() if train is None: # In a city with multiple stations print("ERROR") print(proposal) continue train.tgvmax = True train.remaining_seats = proposal['freePlaces'] train.last_modification = req.response_time train.expiration_time = req.expiration_time db.session.add(train) if i % 50 == 0: db.session.commit() db.session.commit() @app.get('/') def index(): return render_template('index.html', today=date.today(), max_day=date.today() + timedelta(days=30)) @app.get('/api/iata-codes/') def iata_codes(): query = db.session.query(Train).with_entities(Train.orig_iata, Train.orig).distinct() return { 'iata2name': { k: v for (k, v) in query.all() }, 'name2iata': { v: k for (k, v) in query.all() } } @app.get('/api/routes////') def get_routes(day: date | str, origin: str, destination: str): if isinstance(day, str): day = date.fromisoformat(day) if destination == 'undefined': destination = None routes = find_routes(day, origin, destination) return { city: [ [{ 'origin': tr.orig, 'origin_iata': tr.orig_iata, 'destination': tr.dest, 'destination_iata': tr.dest_iata, 'departure': tr.dep.isoformat(), 'arrival': tr.arr.isoformat(), 'number': tr.number, 'free_seats': tr.remaining_seats, } for tr in route] for route in city_routes ] for city, city_routes in routes.items() } if __name__ == '__main__': app.run(debug=True)