diff --git a/app.py b/app.py index 14f7317..e27814a 100644 --- a/app.py +++ b/app.py @@ -7,6 +7,7 @@ import os import json from pytz import timezone import requests +from time import sleep import click from flask import Flask, render_template @@ -160,24 +161,30 @@ def parse_trains(flush: bool = False, verbose: bool = False): db.session.commit() -def find_routes(day: date, origin: str, destination: str | None): +def find_routes(day: date | datetime, origin: str, destination: str | None, + verbose: bool = False): + if isinstance(day, datetime): + day = day.date() + trains = db.session.query(Train).filter_by(day=day, tgvmax=True).all() trains.sort(key=lambda train: train.dep) # For better results later, fetch all trains from the origin or to the destination # This is not exhaustive, but can be a good approximation - queue_routes(day, origin=origin) + queue_routes(day, origin=origin, verbose=verbose, autocommit=False) if destination: - queue_routes(day, destination=destination) + queue_routes(day, destination=destination, verbose=verbose, autocommit=False) + + db.session.commit() per_arr_explore = {} valid_routes = [] - for train in trains: + for train in (t := tqdm(trains) if verbose else trains): if train.orig == origin: # Update from the TGVMax simulator - queue_route(day, train.orig_iata, train.dest_iata) + queue_route(day, train.orig_iata, train.dest_iata, verbose, False) it = [train] if train.dest == destination: @@ -197,7 +204,7 @@ def find_routes(day: date, origin: str, destination: str | None): if last_train.arr <= train.dep: # Update from the TGVMax simulator, this line can be useful later - queue_route(day, train.orig_iata, train.dest_iata) + queue_route(day, train.orig_iata, train.dest_iata, verbose, False) new_it = it + [train] if train.dest == destination: @@ -207,10 +214,21 @@ def find_routes(day: date, origin: str, destination: str | None): per_arr_explore.setdefault(train.dest, []) per_arr_explore[train.dest].append(new_it) + # Send queued trains to the database + db.session.commit() + return {destination: valid_routes} if destination else per_arr_explore -def queue_route(day: date | datetime, origin: str, destination: str, verbose: bool = False): +# Don't use the decorator to keep the function callable +cli.command('find-routes')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) + (click.argument('origin', type=str) + (click.argument('destination', type=str, default=None) + (click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") + (find_routes))))) + + +def queue_route(day: date | datetime, origin: str, destination: str, verbose: bool = False, autocommit: bool = True): """ Fetch the TGVMax simulator to refresh data. @@ -225,8 +243,6 @@ def queue_route(day: date | datetime, origin: str, destination: str, verbose: bo query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination, response_time=None) if query.count(): - if verbose: - print("Already queued") return query = db.session.query(RouteQueue).filter(RouteQueue.day == day, @@ -234,8 +250,6 @@ def queue_route(day: date | datetime, origin: str, destination: str, verbose: bo RouteQueue.destination == destination, RouteQueue.expiration_time >= datetime.now(timezone('UTC'))) if query.count(): - if verbose: - print("Using recent value") return db.session.add(RouteQueue(day=day, origin=origin, destination=destination)) @@ -251,7 +265,8 @@ cli.command('queue-route')(click.argument('day', type=click.DateTime(formats=['% def queue_routes(day: date | datetime, origin: str | None = None, - destination: str | None = None, verbose: bool = False): + destination: str | None = None, verbose: bool = False, + autocommit: bool = True): if isinstance(day, datetime): day = day.date() @@ -260,8 +275,11 @@ def queue_routes(day: date | datetime, origin: str | None = None, query = query.filter((Train.orig_iata == origin) | (Train.orig == origin)) if destination: query = query.filter((Train.dest_iata == destination) | (Train.dest == destination)) - for train in query.all(): - queue_route(day, train.orig_iata, train.dest_iata, verbose) + query = query.all() + for train in (t := tqdm(query) if verbose else query): + if verbose: + t.set_description(f"{day}: {train.orig} --> {train.dest}") + queue_route(day, train.orig_iata, train.dest_iata, verbose, autocommit) # Same as above @@ -274,27 +292,38 @@ cli.command('queue-routes')(click.argument('day', type=click.DateTime(formats=[' @cli.command('process-queue', help="Process the waiting list to refresh from the simulator.") @click.argument('number', default=30, type=int) -def process_queue(number: int): - queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time) +@click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") +def process_queue(number: int, verbose: bool = False): + queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time).all() if number > 0: queue = queue[:number] URL = "https://www.maxjeune-tgvinoui.sncf/api/public/refdata/search-freeplaces-proposals" - for req in queue: + if verbose: + query = db.session.query(Train).with_entities(Train.orig_iata, Train.orig).distinct() + iata_to_names = {k: v for (k, v) in query.all()} + + for i, req in enumerate(t := tqdm(queue) if verbose else queue): req: RouteQueue - resp = requests.post(URL, json={ - 'departureDateTime': req.day.isoformat(), - 'origin': req.origin, - 'destination': req.destination, - }) + if verbose: + t.set_description(f"{req.day:%d/%m/%Y}: {iata_to_names[req.origin]} --> {iata_to_names[req.destination]}") + + resp = None + while resp is None or resp.status_code == 429: + resp = requests.post(URL, json={ + 'departureDateTime': req.day.isoformat(), + 'origin': req.origin, + 'destination': req.destination, + }) + if resp.status_code == 429: + sleep(1) if resp.status_code == 404: # No travel found req.response_time = datetime.now() - req.expiration_time = datetime.now() + timedelta(hours=1) + req.expiration_time = datetime.now() + timedelta(hours=3) db.session.add(req) - db.session.commit() continue resp.raise_for_status() @@ -310,9 +339,11 @@ def process_queue(number: int): for proposal in data['proposals']: train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']), - orig_iata=req.origin, dest_iata=req.destination).first() + orig_iata=proposal['origin']['rrCode'], + dest_iata=proposal['destination']['rrCode']).first() if train is None: # In a city with multiple stations + print("ERROR") print(proposal) continue train.tgvmax = True @@ -321,7 +352,10 @@ def process_queue(number: int): train.expiration_time = req.expiration_time db.session.add(train) - db.session.commit() + if i % 50 == 0: + db.session.commit() + + db.session.commit() @app.get('/') diff --git a/templates/index.html b/templates/index.html index 0adb271..bec254c 100644 --- a/templates/index.html +++ b/templates/index.html @@ -42,8 +42,8 @@ fetch('/api/iata-codes/').then(res => res.json()).then(out => { let datalist = document.getElementById('iataCodes') datalist.innerHTML = '' - for (let iata in out.iata2name) { - let name = out.iata2name[iata] + for (let name in out.name2iata) { + let iata = out.name2iata[name] let elem = document.createElement('option') elem.value = name elem.setAttribute('data-iata', iata)