#!/usr/bin/env python3 import csv from datetime import date, datetime, time import os import json from pytz import timezone import requests import click from flask import Flask from flask.cli import AppGroup from flask_migrate import Migrate from flask_sqlalchemy import SQLAlchemy from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time from tqdm import tqdm import config app = Flask(__name__) cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.") app.cli.add_command(cli) app.config |= config.FLASK_CONFIG db = SQLAlchemy(app) Migrate(app, db) class Train(db.Model): __tablename__ = 'train' id = Column(String, primary_key=True) day = Column(Date, index=True) number = Column(Integer, index=True) entity = Column(String(10)) axe = Column(String(32), index=True) orig_iata = Column(String(5), index=True) dest_iata = Column(String(5), index=True) orig = Column(String(32)) dest = Column(String(32)) dep = Column(Time) arr = Column(Time) tgvmax = Column(Boolean, index=True) remaining_seats = Column(Integer, default=-1) last_modification = Column(DateTime) expiration_time = Column(DateTime) @cli.command("update-dataset") def update_dataset(): """ Query the latest version of the SNCF OpenData dataset, as a CSV file. """ try: resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/') content = resp.content.decode().split('')[0].strip() content = content.replace('\r', '') content = content.replace('" \n', '" \\n') content = content.replace('.\n', '.\\n') content = content.replace('\n\n \nLa', '\\n\\n \\nLa') content = content.replace('\n"', '\\n"') info = json.loads(content) modified_date = datetime.fromisoformat(info['dateModified']) utc = timezone('UTC') last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile('tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc) if last_modified < modified_date: print("Updating tgvmax.csv…") with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp: resp.raise_for_status() with open('tgvmax.csv', 'wb') as f: with tqdm(unit='io', unit_scale=True) as t: for chunk in resp.iter_content(chunk_size=512 * 1024): if chunk: f.write(chunk) t.update(len(chunk)) os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp())) print("Done") print("Last modification:", modified_date) except Exception as e: print("An error occured while updating tgvmax.csv") print(e) @cli.command("parse-csv") @click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.") def parse_trains(flush: bool = False): """ Parse the CSV file and store it to the database. """ if flush: print("Flush database…") db.session.query(Train).delete() last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC')) with open('tgvmax.csv') as f: first_line = True already_seen = set() for line in tqdm(csv.reader(f, delimiter=';')): if first_line: first_line = False continue train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}" if train_id in already_seen: # Some trains are mysteriously duplicated, concerns only some « Intercités de nuit » # and the Brive-la-Gaillarde -- Paris # and, maybe, for Roubaix-Tourcoing if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'): print("Duplicate:", train_id) continue train = Train( id=train_id, day=date.fromisoformat(line[0]), number=int(line[1]), entity=line[2], axe=line[3], orig_iata=line[4], dest_iata=line[5], orig=line[6], dest=line[7], dep=time.fromisoformat(line[8]), arr=time.fromisoformat(line[9]), tgvmax=line[10] == 'OUI', last_modification=last_modification, expiration_time=last_modification, ) if flush: db.session.add(train) else: db.session.merge(train) if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'): already_seen.add(train_id) db.session.commit() def find_routes(day, orig, dest): trains = parse_trains(filter_day=date(2023, 2, 17), filter_tgvmax=True) trains.sort(key=lambda train: train.dep) origin = "STRASBOURG" dest = "LYON (intramuros)" explore = [] per_arr_explore = {} valid_routes = [] for train in tqdm(trains): if train.orig == origin: it = [train] if train.dest == dest: # We hope that we have a direct train valid_routes.append(it) else: explore.append(it) per_arr_explore.setdefault(train.dest, []) per_arr_explore[train.dest].append(it) continue for it in list(per_arr_explore.get(train.orig, [])): if any(train.dest == tr.dest or train.dest == origin for tr in it): # Avoid loops continue last_train = it[-1] if last_train.arr <= train.dep: new_it = it + [train] if train.dest == dest: # Goal is achieved valid_routes.append(new_it) else: explore.append(new_it) per_arr_explore.setdefault(train.dest, []) per_arr_explore[train.dest].append(new_it) return valid_routes def print_route(route: list[Train]): s = f"{route[0].orig} " for tr in route: s += f"({tr.dep}) --> ({tr.arr}) {tr.dest}, " print(s[:-2]) @app.get('/') def index(): return "Hello world!" if __name__ == '__main__': app.run(debug=True)