#!/usr/bin/env python3 import csv from datetime import date, datetime, time, timedelta import os import json from pytz import timezone import requests import click from flask import Flask from flask.cli import AppGroup from flask_migrate import Migrate from flask_sqlalchemy import SQLAlchemy from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time from sqlalchemy.sql import func from tqdm import tqdm import config app = Flask(__name__) cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.") app.cli.add_command(cli) app.config |= config.FLASK_CONFIG db = SQLAlchemy(app) Migrate(app, db) class Train(db.Model): __tablename__ = 'train' id = Column(String, primary_key=True) day = Column(Date, index=True) number = Column(Integer, index=True) entity = Column(String(10)) axe = Column(String(32), index=True) orig_iata = Column(String(5), index=True) dest_iata = Column(String(5), index=True) orig = Column(String(32)) dest = Column(String(32)) dep = Column(Time) arr = Column(Time) tgvmax = Column(Boolean, index=True) remaining_seats = Column(Integer, default=-1) last_modification = Column(DateTime) expiration_time = Column(DateTime) class RouteQueue(db.Model): id = Column(Integer, autoincrement=True, primary_key=True) queue_time = Column(DateTime(timezone=True), server_default=func.now()) day = Column(Date) origin = Column(String(5)) destination = Column(String(5)) response_time = Column(DateTime(timezone=True), nullable=True, default=None) expiration_time = Column(DateTime(timezone=True), nullable=True, default=None) @cli.command("update-dataset") def update_dataset(): """ Query the latest version of the SNCF OpenData dataset, as a CSV file. """ try: resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/') content = resp.content.decode().split('')[0].strip() content = content.replace('\r', '') content = content.replace('" \n', '" \\n') content = content.replace('.\n', '.\\n') content = content.replace('\n\n \nLa', '\\n\\n \\nLa') content = content.replace('\n"', '\\n"') info = json.loads(content) modified_date = datetime.fromisoformat(info['dateModified']) utc = timezone('UTC') last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile('tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc) if last_modified < modified_date: print("Updating tgvmax.csv…") with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp: resp.raise_for_status() with open('tgvmax.csv', 'wb') as f: with tqdm(unit='io', unit_scale=True) as t: for chunk in resp.iter_content(chunk_size=512 * 1024): if chunk: f.write(chunk) t.update(len(chunk)) os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp())) print("Done") print("Last modification:", modified_date) except Exception as e: print("An error occured while updating tgvmax.csv") print(e) @cli.command("parse-csv") @click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.") def parse_trains(flush: bool = False): """ Parse the CSV file and store it to the database. """ if flush: print("Flush database…") db.session.query(Train).delete() last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC')) with open('tgvmax.csv') as f: first_line = True already_seen = set() for line in tqdm(csv.reader(f, delimiter=';')): if first_line: first_line = False continue train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}" if train_id in already_seen: # Some trains are mysteriously duplicated, concerns only some « Intercités de nuit » # and the Brive-la-Gaillarde -- Paris # and, maybe, for Roubaix-Tourcoing if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'): print("Duplicate:", train_id) continue train = Train( id=train_id, day=date.fromisoformat(line[0]), number=int(line[1]), entity=line[2], axe=line[3], orig_iata=line[4], dest_iata=line[5], orig=line[6], dest=line[7], dep=time.fromisoformat(line[8]), arr=time.fromisoformat(line[9]), tgvmax=line[10] == 'OUI', last_modification=last_modification, expiration_time=last_modification, ) if flush: db.session.add(train) else: db.session.merge(train) if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'): already_seen.add(train_id) db.session.commit() def find_routes(day, orig, dest): trains = parse_trains(filter_day=date(2023, 2, 17), filter_tgvmax=True) trains.sort(key=lambda train: train.dep) origin = "STRASBOURG" dest = "LYON (intramuros)" explore = [] per_arr_explore = {} valid_routes = [] for train in tqdm(trains): if train.orig == origin: it = [train] if train.dest == dest: # We hope that we have a direct train valid_routes.append(it) else: explore.append(it) per_arr_explore.setdefault(train.dest, []) per_arr_explore[train.dest].append(it) continue for it in list(per_arr_explore.get(train.orig, [])): if any(train.dest == tr.dest or train.dest == origin for tr in it): # Avoid loops continue last_train = it[-1] if last_train.arr <= train.dep: new_it = it + [train] if train.dest == dest: # Goal is achieved valid_routes.append(new_it) else: explore.append(new_it) per_arr_explore.setdefault(train.dest, []) per_arr_explore[train.dest].append(new_it) return valid_routes def print_route(route: list[Train]): s = f"{route[0].orig} " for tr in route: s += f"({tr.dep}) --> ({tr.arr}) {tr.dest}, " print(s[:-2]) @cli.command('queue-route') @click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) @click.argument('origin', type=str) @click.argument('destination', type=str) def queue_route(day: date | datetime, origin: str, destination: str): """ Fetch the TGVMax simulator to refresh data. DAY: The day to query, in format YYYY-MM-DD. ORIGIN: The origin of the route. DESTINATION: The destination of the route. """ if isinstance(day, datetime): day = day.date() query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination, response_time=None) if query.count(): print("Already queued") return db.session.add(RouteQueue(day=day, origin=origin, destination=destination)) db.session.commit() @cli.command('process-queue', help="Process the waiting list to refresh from the simulator.") @click.argument('number', default=5, type=int) def process_queue(number: int): queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time) if number > 0: queue = queue[:number] URL = "https://www.maxjeune-tgvinoui.sncf/api/public/refdata/search-freeplaces-proposals" for req in queue: req: RouteQueue resp = requests.post(URL, json={ 'departureDateTime': req.day.isoformat(), 'origin': req.origin, 'destination': req.destination, }) if resp.status_code == 404: # No travel found req.response_time = datetime.now() req.expiration_time = datetime.now() + timedelta(hours=1) db.session.add(req) continue resp.raise_for_status() data = resp.json() req.response_time = datetime.utcfromtimestamp(data['updatedAt'] // 1000).replace(tzinfo=timezone('UTC')) req.expiration_time = datetime.utcfromtimestamp(data['expiresAt'] // 1000).replace(tzinfo=timezone('UTC')) db.session.add(req) db.session.query(Train).filter_by(day=req.day, orig_iata=req.origin, dest_iata=req.destination)\ .update(dict(tgvmax=False, remaining_seats=-1)) for proposal in data['proposals']: train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']), orig_iata=req.origin, dest_iata=req.destination).first() train.tgvmax = True train.remaining_seats = proposal['freePlaces'] train.last_modification = req.response_time train.expiration_time = req.expiration_time db.session.add(train) db.session.commit() @app.get('/') def index(): return "Hello world!" if __name__ == '__main__': app.run(debug=True)