diff --git a/app.py b/app.py index 21079e5..4bbc4eb 100644 --- a/app.py +++ b/app.py @@ -156,17 +156,25 @@ def parse_trains(flush: bool = False): db.session.commit() -def find_routes(day, origin, destination): +def find_routes(day: date, origin: str, destination: str): trains = db.session.query(Train).filter_by(day=day, tgvmax=True).all() trains.sort(key=lambda train: train.dep) + # For better results later, fetch all trains from the origin or to the destination + # This is not exhaustive, but can be a good approximation + queue_routes(day, origin=origin) + queue_routes(day, destination=destination) + explore = [] per_arr_explore = {} valid_routes = [] for train in tqdm(trains): if train.orig == origin: + # Update from the TGVMax simulator + queue_route(day, train.orig_iata, train.dest_iata) + it = [train] if train.dest == destination: # We hope that we have a direct train @@ -185,6 +193,9 @@ def find_routes(day, origin, destination): last_train = it[-1] if last_train.arr <= train.dep: + # Update from the TGVMax simulator, this line can be useful later + queue_route(day, train.orig_iata, train.dest_iata) + new_it = it + [train] if train.dest == destination: # Goal is achieved @@ -197,18 +208,7 @@ def find_routes(day, origin, destination): return valid_routes -def print_route(route: list[Train]): - s = f"{route[0].orig} " - for tr in route: - s += f"({tr.dep}) --> ({tr.arr}) {tr.dest}, " - print(s[:-2]) - - -@cli.command('queue-route') -@click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) -@click.argument('origin', type=str) -@click.argument('destination', type=str) -def queue_route(day: date | datetime, origin: str, destination: str): +def queue_route(day: date | datetime, origin: str, destination: str, verbose: bool = False): """ Fetch the TGVMax simulator to refresh data. @@ -223,7 +223,8 @@ def queue_route(day: date | datetime, origin: str, destination: str): query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination, response_time=None) if query.count(): - print("Already queued") + if verbose: + print("Already queued") return query = db.session.query(RouteQueue).filter(RouteQueue.day == day, @@ -231,15 +232,46 @@ def queue_route(day: date | datetime, origin: str, destination: str): RouteQueue.destination == destination, RouteQueue.expiration_time >= datetime.now(timezone('UTC'))) if query.count(): - print("Using recent value") + if verbose: + print("Using recent value") return db.session.add(RouteQueue(day=day, origin=origin, destination=destination)) db.session.commit() +# Don't use the decorator to keep the function callable +cli.command('queue-route')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) + (click.argument('origin', type=str) + (click.argument('destination', type=str) + (click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") + (queue_route))))) + + +def queue_routes(day: date | datetime, origin: str | None = None, + destination: str | None = None, verbose: bool = False): + if isinstance(day, datetime): + day = day.date() + + query = db.session.query(Train).filter((Train.day == day)) + if origin: + query = query.filter((Train.orig_iata == origin) | (Train.orig == origin)) + if destination: + query = query.filter((Train.dest_iata == destination) | (Train.dest == destination)) + for train in query.all(): + queue_route(day, train.orig_iata, train.dest_iata, verbose) + + +# Same as above +cli.command('queue-routes')(click.argument('day', type=click.DateTime(formats=['%Y-%m-%d'])) + (click.option('--origin', '-o', default=None) + (click.option('--destination', '-d', default=None) + (click.option('--verbose', '-v', type=bool, is_flag=True, help="Display errors.") + (queue_routes))))) + + @cli.command('process-queue', help="Process the waiting list to refresh from the simulator.") -@click.argument('number', default=5, type=int) +@click.argument('number', default=30, type=int) def process_queue(number: int): queue = db.session.query(RouteQueue).filter_by(response_time=None).order_by(RouteQueue.queue_time) if number > 0: @@ -260,6 +292,7 @@ def process_queue(number: int): req.response_time = datetime.now() req.expiration_time = datetime.now() + timedelta(hours=1) db.session.add(req) + db.session.commit() continue resp.raise_for_status() @@ -276,13 +309,17 @@ def process_queue(number: int): for proposal in data['proposals']: train = db.session.query(Train).filter_by(day=req.day, number=int(proposal['trainNumber']), orig_iata=req.origin, dest_iata=req.destination).first() + if train is None: + # In a city with multiple stations + print(proposal) + continue train.tgvmax = True train.remaining_seats = proposal['freePlaces'] train.last_modification = req.response_time train.expiration_time = req.expiration_time db.session.add(train) - db.session.commit() + db.session.commit() @app.get('/') @@ -304,7 +341,10 @@ def iata_codes(): @app.get('/api/routes////') -def get_routes(day: date, origin: str, destination: str): +def get_routes(day: date | str, origin: str, destination: str): + if isinstance(day, str): + day = date.fromisoformat(day) + routes = find_routes(day, origin, destination) return [ [{