diff --git a/app.py b/app.py index 39e7a7d..24f0193 100644 --- a/app.py +++ b/app.py @@ -7,10 +7,12 @@ import json from pytz import timezone import requests +import click from flask import Flask +from flask.cli import AppGroup from flask_migrate import Migrate from flask_sqlalchemy import SQLAlchemy -from sqlalchemy import Boolean, Column, Date, Integer, String, Time +from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time from tqdm import tqdm import config @@ -18,6 +20,10 @@ import config app = Flask(__name__) +cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.") +app.cli.add_command(cli) + + app.config |= config.FLASK_CONFIG db = SQLAlchemy(app) @@ -29,19 +35,25 @@ class Train(db.Model): id = Column(String, primary_key=True) day = Column(Date, index=True) number = Column(Integer, index=True) - entity = Column(String(255)) - axe = Column(String(255), index=True) + entity = Column(String(10)) + axe = Column(String(32), index=True) orig_iata = Column(String(5), index=True) dest_iata = Column(String(5), index=True) - orig = Column(String(255)) - dest = Column(String(255)) - dep = Column(String(255)) + orig = Column(String(32)) + dest = Column(String(32)) + dep = Column(Time) arr = Column(Time) tgvmax = Column(Boolean, index=True) - remaining_seats = Column(Integer) + remaining_seats = Column(Integer, default=-1) + last_modification = Column(DateTime) + expiration_time = Column(DateTime) +@cli.command("update-dataset") def update_dataset(): + """ + Query the latest version of the SNCF OpenData dataset, as a CSV file. + """ try: resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/') content = resp.content.decode().split('')[0].strip() @@ -76,37 +88,61 @@ def update_dataset(): print(e) -def parse_trains(*, filter_day: date | None = None, - filter_number: int | None = None, - filter_tgvmax: bool | None = None): - trains = [] +@cli.command("parse-csv") +@click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.") +def parse_trains(flush: bool = False): + """ + Parse the CSV file and store it to the database. + """ + + if flush: + print("Flush database…") + db.session.query(Train).delete() + + last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC')) with open('tgvmax.csv') as f: first_line = True - for line in csv.reader(f, delimiter=';'): + already_seen = set() + for line in tqdm(csv.reader(f, delimiter=';')): if first_line: first_line = False continue - train = Train(*line) - train.day = date.fromisoformat(train.day) - train.number = int(train.number) - train.dep = time.fromisoformat(train.dep) - train.arr = time.fromisoformat(train.arr) - train.tgvmax = train.tgvmax == 'OUI' - - if filter_day is not None and train.day != filter_day: + train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}" + if train_id in already_seen: + # Some trains are mysteriously duplicated, concerns only some « Intercités de nuit » + # and the Brive-la-Gaillarde -- Paris + # and, maybe, for Roubaix-Tourcoing + if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'): + print("Duplicate:", train_id) continue - if filter_number is not None and train.number != filter_number: - continue + train = Train( + id=train_id, + day=date.fromisoformat(line[0]), + number=int(line[1]), + entity=line[2], + axe=line[3], + orig_iata=line[4], + dest_iata=line[5], + orig=line[6], + dest=line[7], + dep=time.fromisoformat(line[8]), + arr=time.fromisoformat(line[9]), + tgvmax=line[10] == 'OUI', + last_modification=last_modification, + expiration_time=last_modification, + ) + if flush: + db.session.add(train) + else: + db.session.merge(train) - if filter_tgvmax is not None and train.tgvmax != filter_tgvmax: - continue + if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'): + already_seen.add(train_id) - trains.append(train) - - return trains + db.session.commit() def find_routes(day, orig, dest):