tgvmax/app.py

207 lines
6.5 KiB
Python

#!/usr/bin/env python3
import csv
from datetime import date, datetime, time
import os
import json
from pytz import timezone
import requests
import click
from flask import Flask
from flask.cli import AppGroup
from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time
from tqdm import tqdm
import config
app = Flask(__name__)
cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.")
app.cli.add_command(cli)
app.config |= config.FLASK_CONFIG
db = SQLAlchemy(app)
Migrate(app, db)
class Train(db.Model):
__tablename__ = 'train'
id = Column(String, primary_key=True)
day = Column(Date, index=True)
number = Column(Integer, index=True)
entity = Column(String(10))
axe = Column(String(32), index=True)
orig_iata = Column(String(5), index=True)
dest_iata = Column(String(5), index=True)
orig = Column(String(32))
dest = Column(String(32))
dep = Column(Time)
arr = Column(Time)
tgvmax = Column(Boolean, index=True)
remaining_seats = Column(Integer, default=-1)
last_modification = Column(DateTime)
expiration_time = Column(DateTime)
@cli.command("update-dataset")
def update_dataset():
"""
Query the latest version of the SNCF OpenData dataset, as a CSV file.
"""
try:
resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/')
content = resp.content.decode().split('<script type="application/ld+json">')[1].split('</script>')[0].strip()
content = content.replace('\r', '')
content = content.replace('" \n', '" \\n')
content = content.replace('.\n', '.\\n')
content = content.replace('\n\n \nLa', '\\n\\n \\nLa')
content = content.replace('\n"', '\\n"')
info = json.loads(content)
modified_date = datetime.fromisoformat(info['dateModified'])
utc = timezone('UTC')
last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile('tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc)
if last_modified < modified_date:
print("Updating tgvmax.csv…")
with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp:
resp.raise_for_status()
with open('tgvmax.csv', 'wb') as f:
with tqdm(unit='io', unit_scale=True) as t:
for chunk in resp.iter_content(chunk_size=512 * 1024):
if chunk:
f.write(chunk)
t.update(len(chunk))
os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp()))
print("Done")
print("Last modification:", modified_date)
except Exception as e:
print("An error occured while updating tgvmax.csv")
print(e)
@cli.command("parse-csv")
@click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.")
def parse_trains(flush: bool = False):
"""
Parse the CSV file and store it to the database.
"""
if flush:
print("Flush database…")
db.session.query(Train).delete()
last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC'))
with open('tgvmax.csv') as f:
first_line = True
already_seen = set()
for line in tqdm(csv.reader(f, delimiter=';')):
if first_line:
first_line = False
continue
train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}"
if train_id in already_seen:
# Some trains are mysteriously duplicated, concerns only some « Intercités de nuit »
# and the Brive-la-Gaillarde -- Paris
# and, maybe, for Roubaix-Tourcoing
if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'):
print("Duplicate:", train_id)
continue
train = Train(
id=train_id,
day=date.fromisoformat(line[0]),
number=int(line[1]),
entity=line[2],
axe=line[3],
orig_iata=line[4],
dest_iata=line[5],
orig=line[6],
dest=line[7],
dep=time.fromisoformat(line[8]),
arr=time.fromisoformat(line[9]),
tgvmax=line[10] == 'OUI',
last_modification=last_modification,
expiration_time=last_modification,
)
if flush:
db.session.add(train)
else:
db.session.merge(train)
if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'):
already_seen.add(train_id)
db.session.commit()
def find_routes(day, orig, dest):
trains = parse_trains(filter_day=date(2023, 2, 17),
filter_tgvmax=True)
trains.sort(key=lambda train: train.dep)
origin = "STRASBOURG"
dest = "LYON (intramuros)"
explore = []
per_arr_explore = {}
valid_routes = []
for train in tqdm(trains):
if train.orig == origin:
it = [train]
if train.dest == dest:
# We hope that we have a direct train
valid_routes.append(it)
else:
explore.append(it)
per_arr_explore.setdefault(train.dest, [])
per_arr_explore[train.dest].append(it)
continue
for it in list(per_arr_explore.get(train.orig, [])):
if any(train.dest == tr.dest or train.dest == origin for tr in it):
# Avoid loops
continue
last_train = it[-1]
if last_train.arr <= train.dep:
new_it = it + [train]
if train.dest == dest:
# Goal is achieved
valid_routes.append(new_it)
else:
explore.append(new_it)
per_arr_explore.setdefault(train.dest, [])
per_arr_explore[train.dest].append(new_it)
return valid_routes
def print_route(route: list[Train]):
s = f"{route[0].orig} "
for tr in route:
s += f"({tr.dep}) --> ({tr.arr}) {tr.dest}, "
print(s[:-2])
@app.get('/')
def index():
return "Hello world!"
if __name__ == '__main__':
app.run(debug=True)