Queue routes when necessary

This commit is contained in:
Emmy D'Anello 2023-02-13 12:08:24 +01:00
parent 995a320208
commit ba3bef3d27
Signed by: ynerant
GPG Key ID: 3A75C55819C8CF85
1 changed files with 58 additions and 18 deletions

70
app.py
View File

@ -156,17 +156,25 @@ def parse_trains(flush: bool = False):
db.session.commit() db.session.commit()
def find_routes(day, origin, destination): def find_routes(day: date, origin: str, destination: str):
trains = db.session.query(Train).filter_by(day=day, tgvmax=True).all() trains = db.session.query(Train).filter_by(day=day, tgvmax=True).all()
trains.sort(key=lambda train: train.dep) trains.sort(key=lambda train: train.dep)
# For better results later, fetch all trains from the origin or to the destination
# This is not exhaustive, but can be a good approximation
queue_routes(day, origin=origin)
queue_routes(day, destination=destination)
explore = [] explore = []
per_arr_explore = {} per_arr_explore = {}
valid_routes = [] valid_routes = []
for train in tqdm(trains): for train in tqdm(trains):
if train.orig == origin: if train.orig == origin:
# Update from the TGVMax simulator
queue_route(day, train.orig_iata, train.dest_iata)
it = [train] it = [train]
if train.dest == destination: if train.dest == destination:
# We hope that we have a direct train # We hope that we have a direct train
@ -185,6 +193,9 @@ def find_routes(day, origin, destination):
last_train = it[-1] last_train = it[-1]
if last_train.arr <= train.dep: if last_train.arr <= train.dep:
# Update from the TGVMax simulator, this line can be useful later
queue_route(day, train.orig_iata, train.dest_iata)
new_it = it + [train] new_it = it + [train]
if train.dest == destination: if train.dest == destination:
# Goal is achieved # Goal is achieved
@ -197,18 +208,7 @@ def find_routes(day, origin, destination):
return valid_routes return valid_routes
def print_route(route: list[Train]): def queue_route(day: date | datetime, origin: str, destination: str, verbose: bool = False):
s = f"{route[0].orig} "
for tr in route:
s += f"({tr.dep}) --> ({tr.arr}) {tr.dest}, "
print(s[:-2])
@cli.command('queue-route')
@click.argument('day', type=click.DateTime(formats=['%Y-%m-%d']))
@click.argument('origin', type=str)
@click.argument('destination', type=str)
def queue_route(day: date | datetime, origin: str, destination: str):
""" """
Fetch the TGVMax simulator to refresh data. Fetch the TGVMax simulator to refresh data.
@ -223,6 +223,7 @@ def queue_route(day: date | datetime, origin: str, destination: str):
query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination, response_time=None) query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination, response_time=None)
if query.count(): if query.count():
if verbose:
print("Already queued") print("Already queued")
return return
@ -231,6 +232,7 @@ def queue_route(day: date | datetime, origin: str, destination: str):
RouteQueue.destination == destination, RouteQueue.destination == destination,
RouteQueue.expiration_time >= datetime.now(timezone('UTC'))) RouteQueue.expiration_time >= datetime.now(timezone('UTC')))
if query.count(): if query.count():
if verbose:
print("Using recent value") print("Using recent value")
return return
@ -238,8 +240,38 @@ def queue_route(day: date | datetime, origin: str, destination: str):
db.session.commit() db.session.commit()
# Don't use the decorator to keep the function callable
cli.command('queue-route')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d']))
(click.argument('origin', type=str)
(click.argument('destination', type=str)
(click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.")
(queue_route)))))
def queue_routes(day: date | datetime, origin: str | None = None,
destination: str | None = None, verbose: bool = False):
if isinstance(day, datetime):
day = day.date()
query = db.session.query(Train).filter((Train.day == day))
if origin:
query = query.filter((Train.orig_iata == origin) | (Train.orig == origin))
if destination:
query = query.filter((Train.dest_iata == destination) | (Train.dest == destination))
for train in query.all():
queue_route(day, train.orig_iata, train.dest_iata, verbose)
# Same as above
cli.command('queue-routes')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d']))
(click.option('--origin', '-o', default=None)
(click.option('--destination', '-d', default=None)
(click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.")
(queue_routes)))))
@cli.command('process-queue', help="Process the waiting list to refresh from the simulator.") @cli.command('process-queue', help="Process the waiting list to refresh from the simulator.")
@click.argument('number', default=5, type=int) @click.argument('number', default=30, type=int)
def process_queue(number: int): def process_queue(number: int):
queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time) queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time)
if number > 0: if number > 0:
@ -260,6 +292,7 @@ def process_queue(number: int):
req.response_time = datetime.now() req.response_time = datetime.now()
req.expiration_time = datetime.now() + timedelta(hours=1) req.expiration_time = datetime.now() + timedelta(hours=1)
db.session.add(req) db.session.add(req)
db.session.commit()
continue continue
resp.raise_for_status() resp.raise_for_status()
@ -276,6 +309,10 @@ def process_queue(number: int):
for proposal in data['proposals']: for proposal in data['proposals']:
train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']), train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']),
orig_iata=req.origin, dest_iata=req.destination).first() orig_iata=req.origin, dest_iata=req.destination).first()
if train is None:
# In a city with multiple stations
print(proposal)
continue
train.tgvmax = True train.tgvmax = True
train.remaining_seats = proposal['freePlaces'] train.remaining_seats = proposal['freePlaces']
train.last_modification = req.response_time train.last_modification = req.response_time
@ -304,7 +341,10 @@ def iata_codes():
@app.get('/api/routes/<day>/<origin>/<destination>/') @app.get('/api/routes/<day>/<origin>/<destination>/')
def get_routes(day: date, origin: str, destination: str): def get_routes(day: date | str, origin: str, destination: str):
if isinstance(day, str):
day = date.fromisoformat(day)
routes = find_routes(day, origin, destination) routes = find_routes(day, origin, destination)
return [ return [
[{ [{