365 lines
13 KiB
Python
365 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import csv
|
|
from datetime import date, datetime, time, timedelta
|
|
import os
|
|
import json
|
|
from pytz import timezone
|
|
import requests
|
|
|
|
import click
|
|
from flask import Flask, render_template
|
|
from flask.cli import AppGroup
|
|
from flask_migrate import Migrate
|
|
from flask_sqlalchemy import SQLAlchemy
|
|
from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time
|
|
from sqlalchemy.sql import func
|
|
from tqdm import tqdm
|
|
|
|
import config
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.")
|
|
app.cli.add_command(cli)
|
|
|
|
|
|
app.config |= config.FLASK_CONFIG
|
|
|
|
db = SQLAlchemy(app)
|
|
Migrate(app, db)
|
|
|
|
|
|
class Train(db.Model):
|
|
__tablename__ = 'train'
|
|
id = Column(String, primary_key=True)
|
|
day = Column(Date, index=True)
|
|
number = Column(Integer, index=True)
|
|
entity = Column(String(16))
|
|
axe = Column(String(32), index=True)
|
|
orig_iata = Column(String(5), index=True)
|
|
dest_iata = Column(String(5), index=True)
|
|
orig = Column(String(32))
|
|
dest = Column(String(32))
|
|
dep = Column(Time)
|
|
arr = Column(Time)
|
|
tgvmax = Column(Boolean, index=True)
|
|
remaining_seats = Column(Integer, default=-1)
|
|
last_modification = Column(DateTime)
|
|
expiration_time = Column(DateTime)
|
|
|
|
|
|
class RouteQueue(db.Model):
|
|
id = Column(Integer, autoincrement=True, primary_key=True)
|
|
queue_time = Column(DateTime(timezone=True), server_default=func.now())
|
|
day = Column(Date)
|
|
origin = Column(String(5))
|
|
destination = Column(String(5))
|
|
response_time = Column(DateTime(timezone=True), nullable=True, default=None)
|
|
expiration_time = Column(DateTime(timezone=True), nullable=True, default=None)
|
|
|
|
|
|
@cli.command("update-dataset")
|
|
def update_dataset():
|
|
"""
|
|
Query the latest version of the SNCF OpenData dataset, as a CSV file.
|
|
"""
|
|
try:
|
|
resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/')
|
|
content = resp.content.decode().split('<script type="application/ld+json">')[1].split('</script>')[0].strip()
|
|
content = content.replace('\r', '')
|
|
content = content.replace('" \n', '" \\n')
|
|
content = content.replace('.\n', '.\\n')
|
|
content = content.replace('\n\n \nLa', '\\n\\n \\nLa')
|
|
content = content.replace('\n"', '\\n"')
|
|
|
|
info = json.loads(content)
|
|
modified_date = datetime.fromisoformat(info['dateModified'])
|
|
|
|
utc = timezone('UTC')
|
|
last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile('tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc)
|
|
|
|
if last_modified < modified_date:
|
|
print("Updating tgvmax.csv…")
|
|
with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp:
|
|
resp.raise_for_status()
|
|
with open('tgvmax.csv', 'wb') as f:
|
|
with tqdm(unit='io', unit_scale=True) as t:
|
|
for chunk in resp.iter_content(chunk_size=512 * 1024):
|
|
if chunk:
|
|
f.write(chunk)
|
|
t.update(len(chunk))
|
|
os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp()))
|
|
print("Done")
|
|
|
|
print("Last modification:", modified_date)
|
|
except Exception as e:
|
|
print("An error occured while updating tgvmax.csv")
|
|
print(e)
|
|
|
|
|
|
@cli.command("parse-csv")
|
|
@click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.")
|
|
def parse_trains(flush: bool = False):
|
|
"""
|
|
Parse the CSV file and store it to the database.
|
|
"""
|
|
|
|
if flush:
|
|
print("Flush database…")
|
|
db.session.query(Train).delete()
|
|
|
|
last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC'))
|
|
|
|
with open('tgvmax.csv') as f:
|
|
first_line = True
|
|
already_seen = set()
|
|
for line in tqdm(csv.reader(f, delimiter=';')):
|
|
if first_line:
|
|
first_line = False
|
|
continue
|
|
|
|
train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}"
|
|
if train_id in already_seen:
|
|
# Some trains are mysteriously duplicated, concerns only some « Intercités de nuit »
|
|
# and the Brive-la-Gaillarde -- Paris
|
|
# and, maybe, for Roubaix-Tourcoing
|
|
if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'):
|
|
print("Duplicate:", train_id)
|
|
continue
|
|
|
|
train = Train(
|
|
id=train_id,
|
|
day=date.fromisoformat(line[0]),
|
|
number=int(line[1]),
|
|
entity=line[2],
|
|
axe=line[3],
|
|
orig_iata=line[4],
|
|
dest_iata=line[5],
|
|
orig=line[6],
|
|
dest=line[7],
|
|
dep=time.fromisoformat(line[8]),
|
|
arr=time.fromisoformat(line[9]),
|
|
tgvmax=line[10] == 'OUI',
|
|
last_modification=last_modification,
|
|
expiration_time=last_modification,
|
|
)
|
|
if flush:
|
|
db.session.add(train)
|
|
else:
|
|
db.session.merge(train)
|
|
|
|
if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'):
|
|
already_seen.add(train_id)
|
|
|
|
db.session.commit()
|
|
|
|
|
|
def find_routes(day: date, origin: str, destination: str):
|
|
trains = db.session.query(Train).filter_by(day=day, tgvmax=True).all()
|
|
|
|
trains.sort(key=lambda train: train.dep)
|
|
|
|
# For better results later, fetch all trains from the origin or to the destination
|
|
# This is not exhaustive, but can be a good approximation
|
|
queue_routes(day, origin=origin)
|
|
queue_routes(day, destination=destination)
|
|
|
|
explore = []
|
|
per_arr_explore = {}
|
|
valid_routes = []
|
|
|
|
for train in tqdm(trains):
|
|
if train.orig == origin:
|
|
# Update from the TGVMax simulator
|
|
queue_route(day, train.orig_iata, train.dest_iata)
|
|
|
|
it = [train]
|
|
if train.dest == destination:
|
|
# We hope that we have a direct train
|
|
valid_routes.append(it)
|
|
else:
|
|
explore.append(it)
|
|
per_arr_explore.setdefault(train.dest, [])
|
|
per_arr_explore[train.dest].append(it)
|
|
continue
|
|
|
|
for it in list(per_arr_explore.get(train.orig, [])):
|
|
if any(train.dest == tr.dest or train.dest == origin for tr in it):
|
|
# Avoid loops
|
|
continue
|
|
|
|
last_train = it[-1]
|
|
|
|
if last_train.arr <= train.dep:
|
|
# Update from the TGVMax simulator, this line can be useful later
|
|
queue_route(day, train.orig_iata, train.dest_iata)
|
|
|
|
new_it = it + [train]
|
|
if train.dest == destination:
|
|
# Goal is achieved
|
|
valid_routes.append(new_it)
|
|
else:
|
|
explore.append(new_it)
|
|
per_arr_explore.setdefault(train.dest, [])
|
|
per_arr_explore[train.dest].append(new_it)
|
|
|
|
return valid_routes
|
|
|
|
|
|
def queue_route(day: date | datetime, origin: str, destination: str, verbose: bool = False):
|
|
"""
|
|
Fetch the TGVMax simulator to refresh data.
|
|
|
|
DAY: The day to query, in format YYYY-MM-DD.
|
|
|
|
ORIGIN: The origin of the route.
|
|
|
|
DESTINATION: The destination of the route.
|
|
"""
|
|
if isinstance(day, datetime):
|
|
day = day.date()
|
|
|
|
query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination, response_time=None)
|
|
if query.count():
|
|
if verbose:
|
|
print("Already queued")
|
|
return
|
|
|
|
query = db.session.query(RouteQueue).filter(RouteQueue.day == day,
|
|
RouteQueue.origin == origin,
|
|
RouteQueue.destination == destination,
|
|
RouteQueue.expiration_time >= datetime.now(timezone('UTC')))
|
|
if query.count():
|
|
if verbose:
|
|
print("Using recent value")
|
|
return
|
|
|
|
db.session.add(RouteQueue(day=day, origin=origin, destination=destination))
|
|
db.session.commit()
|
|
|
|
|
|
# Don't use the decorator to keep the function callable
|
|
cli.command('queue-route')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d']))
|
|
(click.argument('origin', type=str)
|
|
(click.argument('destination', type=str)
|
|
(click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.")
|
|
(queue_route)))))
|
|
|
|
|
|
def queue_routes(day: date | datetime, origin: str | None = None,
|
|
destination: str | None = None, verbose: bool = False):
|
|
if isinstance(day, datetime):
|
|
day = day.date()
|
|
|
|
query = db.session.query(Train).filter((Train.day == day))
|
|
if origin:
|
|
query = query.filter((Train.orig_iata == origin) | (Train.orig == origin))
|
|
if destination:
|
|
query = query.filter((Train.dest_iata == destination) | (Train.dest == destination))
|
|
for train in query.all():
|
|
queue_route(day, train.orig_iata, train.dest_iata, verbose)
|
|
|
|
|
|
# Same as above
|
|
cli.command('queue-routes')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d']))
|
|
(click.option('--origin', '-o', default=None)
|
|
(click.option('--destination', '-d', default=None)
|
|
(click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.")
|
|
(queue_routes)))))
|
|
|
|
|
|
@cli.command('process-queue', help="Process the waiting list to refresh from the simulator.")
|
|
@click.argument('number', default=30, type=int)
|
|
def process_queue(number: int):
|
|
queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time)
|
|
if number > 0:
|
|
queue = queue[:number]
|
|
|
|
URL = "https://www.maxjeune-tgvinoui.sncf/api/public/refdata/search-freeplaces-proposals"
|
|
|
|
for req in queue:
|
|
req: RouteQueue
|
|
resp = requests.post(URL, json={
|
|
'departureDateTime': req.day.isoformat(),
|
|
'origin': req.origin,
|
|
'destination': req.destination,
|
|
})
|
|
|
|
if resp.status_code == 404:
|
|
# No travel found
|
|
req.response_time = datetime.now()
|
|
req.expiration_time = datetime.now() + timedelta(hours=1)
|
|
db.session.add(req)
|
|
db.session.commit()
|
|
continue
|
|
|
|
resp.raise_for_status()
|
|
|
|
data = resp.json()
|
|
req.response_time = datetime.utcfromtimestamp(data['updatedAt'] // 1000).replace(tzinfo=timezone('UTC'))
|
|
req.expiration_time = datetime.utcfromtimestamp(data['expiresAt'] // 1000).replace(tzinfo=timezone('UTC'))
|
|
req.expiration_time += timedelta(hours=3) # By default 5 minutes, extend it to 3 hours to be safe
|
|
db.session.add(req)
|
|
|
|
db.session.query(Train).filter_by(day=req.day, orig_iata=req.origin, dest_iata=req.destination)\
|
|
.update(dict(tgvmax=False, remaining_seats=-1))
|
|
|
|
for proposal in data['proposals']:
|
|
train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']),
|
|
orig_iata=req.origin, dest_iata=req.destination).first()
|
|
if train is None:
|
|
# In a city with multiple stations
|
|
print(proposal)
|
|
continue
|
|
train.tgvmax = True
|
|
train.remaining_seats = proposal['freePlaces']
|
|
train.last_modification = req.response_time
|
|
train.expiration_time = req.expiration_time
|
|
db.session.add(train)
|
|
|
|
db.session.commit()
|
|
|
|
|
|
@app.get('/')
|
|
def index():
|
|
return render_template('index.html', today=date.today(), max_day=date.today() + timedelta(days=30))
|
|
|
|
|
|
@app.get('/api/iata-codes/')
|
|
def iata_codes():
|
|
query = db.session.query(Train).with_entities(Train.orig_iata, Train.orig).distinct()
|
|
return {
|
|
'iata2name': {
|
|
k: v for (k, v) in query.all()
|
|
},
|
|
'name2iata': {
|
|
v: k for (k, v) in query.all()
|
|
}
|
|
}
|
|
|
|
|
|
@app.get('/api/routes/<day>/<origin>/<destination>/')
|
|
def get_routes(day: date | str, origin: str, destination: str):
|
|
if isinstance(day, str):
|
|
day = date.fromisoformat(day)
|
|
|
|
routes = find_routes(day, origin, destination)
|
|
return [
|
|
[{
|
|
'origin': tr.orig,
|
|
'origin_iata': tr.orig_iata,
|
|
'destination': tr.dest,
|
|
'destination_iata': tr.dest_iata,
|
|
'departure': tr.dep.isoformat(),
|
|
'arrival': tr.arr.isoformat(),
|
|
'number': tr.number,
|
|
'free_seats': tr.remaining_seats,
|
|
} for tr in route] for route in routes
|
|
]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app.run(debug=True)
|