tgvmax/app.py

171 lines
5.1 KiB
Python

#!/usr/bin/env python3
import csv
from datetime import date, datetime, time
import os
import json
from pytz import timezone
import requests
from flask import Flask
from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import Boolean, Column, Date, Integer, String, Time
from tqdm import tqdm
import config
app = Flask(__name__)
app.config |= config.FLASK_CONFIG
db = SQLAlchemy(app)
Migrate(app, db)
class Train(db.Model):
__tablename__ = 'train'
id = Column(String, primary_key=True)
day = Column(Date, index=True)
number = Column(Integer, index=True)
entity = Column(String(255))
axe = Column(String(255), index=True)
orig_iata = Column(String(5), index=True)
dest_iata = Column(String(5), index=True)
orig = Column(String(255))
dest = Column(String(255))
dep = Column(String(255))
arr = Column(Time)
tgvmax = Column(Boolean, index=True)
remaining_seats = Column(Integer)
def update_dataset():
try:
resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/')
content = resp.content.decode().split('<script type="application/ld+json">')[1].split('</script>')[0].strip()
content = content.replace('\r', '')
content = content.replace('" \n', '" \\n')
content = content.replace('.\n', '.\\n')
content = content.replace('\n\n \nLa', '\\n\\n \\nLa')
content = content.replace('\n"', '\\n"')
info = json.loads(content)
modified_date = datetime.fromisoformat(info['dateModified'])
utc = timezone('UTC')
last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile('tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc)
if last_modified < modified_date:
print("Updating tgvmax.csv…")
with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp:
resp.raise_for_status()
with open('tgvmax.csv', 'wb') as f:
with tqdm(unit='io', unit_scale=True) as t:
for chunk in resp.iter_content(chunk_size=512 * 1024):
if chunk:
f.write(chunk)
t.update(len(chunk))
os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp()))
print("Done")
print("Last modification:", modified_date)
except Exception as e:
print("An error occured while updating tgvmax.csv")
print(e)
def parse_trains(*, filter_day: date | None = None,
filter_number: int | None = None,
filter_tgvmax: bool | None = None):
trains = []
with open('tgvmax.csv') as f:
first_line = True
for line in csv.reader(f, delimiter=';'):
if first_line:
first_line = False
continue
train = Train(*line)
train.day = date.fromisoformat(train.day)
train.number = int(train.number)
train.dep = time.fromisoformat(train.dep)
train.arr = time.fromisoformat(train.arr)
train.tgvmax = train.tgvmax == 'OUI'
if filter_day is not None and train.day != filter_day:
continue
if filter_number is not None and train.number != filter_number:
continue
if filter_tgvmax is not None and train.tgvmax != filter_tgvmax:
continue
trains.append(train)
return trains
def find_routes(day, orig, dest):
trains = parse_trains(filter_day=date(2023, 2, 17),
filter_tgvmax=True)
trains.sort(key=lambda train: train.dep)
origin = "STRASBOURG"
dest = "LYON (intramuros)"
explore = []
per_arr_explore = {}
valid_routes = []
for train in tqdm(trains):
if train.orig == origin:
it = [train]
if train.dest == dest:
# We hope that we have a direct train
valid_routes.append(it)
else:
explore.append(it)
per_arr_explore.setdefault(train.dest, [])
per_arr_explore[train.dest].append(it)
continue
for it in list(per_arr_explore.get(train.orig, [])):
if any(train.dest == tr.dest or train.dest == origin for tr in it):
# Avoid loops
continue
last_train = it[-1]
if last_train.arr <= train.dep:
new_it = it + [train]
if train.dest == dest:
# Goal is achieved
valid_routes.append(new_it)
else:
explore.append(new_it)
per_arr_explore.setdefault(train.dest, [])
per_arr_explore[train.dest].append(new_it)
return valid_routes
def print_route(route: list[Train]):
s = f"{route[0].orig} "
for tr in route:
s += f"({tr.dep}) --> ({tr.arr}) {tr.dest}, "
print(s[:-2])
@app.get('/')
def index():
return "Hello world!"
if __name__ == '__main__':
app.run(debug=True)