From 4332862419008d39a77e14e0a4b4cac61c895b05 Mon Sep 17 00:00:00 2001 From: Emmy D'Anello Date: Tue, 25 Apr 2023 15:05:23 +0200 Subject: [PATCH] Optimize computation Signed-off-by: Emmy D'Anello --- app.py | 107 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/app.py b/app.py index e27814a..c6d2dfe 100644 --- a/app.py +++ b/app.py @@ -120,8 +120,11 @@ def parse_trains(flush: bool = False, verbose: bool = False): with open('tgvmax.csv') as f: first_line = True already_seen = set() + already_updated = set(x[0] for x in db.session.query(Train).filter(Train.last_modification > last_modification)\ + .values(Train.id)) for line in (tqdm if verbose else lambda x: x)(csv.reader(f, delimiter=';')): if first_line: + # Skip first line first_line = False continue @@ -134,6 +137,10 @@ def parse_trains(flush: bool = False, verbose: bool = False): print("Duplicate:", train_id) continue + if train_id in already_updated: + # Already updated by the simulator + continue + train = Train( id=train_id, day=date.fromisoformat(line[0]), @@ -162,62 +169,57 @@ def parse_trains(flush: bool = False, verbose: bool = False): def find_routes(day: date | datetime, origin: str, destination: str | None, - verbose: bool = False): + verbose: bool = False, min_dep: time = time(0, 0), + explored: dict | None = None): if isinstance(day, datetime): day = day.date() - trains = db.session.query(Train).filter_by(day=day, tgvmax=True).all() + if explored is None: + explored = {} + + if origin not in explored: + explored[origin] = (min_dep, None) + valid_routes = [] + max_dep = time(23, 59) + else: + max_dep, valid_routes = explored[origin] + if max_dep < min_dep: + # Already parsed these trains + return {destination: valid_routes} + explored[origin] = min_dep, None + + trains = db.session.query(Train).filter_by(day=day, tgvmax=True, orig=origin)\ + .filter(Train.dep >= min_dep, Train.dep < max_dep).all() + if not trains: + # No train in the requested interval + explored[origin] = (min_dep, valid_routes) + return {destination: valid_routes} trains.sort(key=lambda train: train.dep) - - # For better results later, fetch all trains from the origin or to the destination - # This is not exhaustive, but can be a good approximation - queue_routes(day, origin=origin, verbose=verbose, autocommit=False) - if destination: - queue_routes(day, destination=destination, verbose=verbose, autocommit=False) - db.session.commit() - per_arr_explore = {} - valid_routes = [] - - for train in (t := tqdm(trains) if verbose else trains): - if train.orig == origin: - # Update from the TGVMax simulator - queue_route(day, train.orig_iata, train.dest_iata, verbose, False) - - it = [train] - if train.dest == destination: - # We hope that we have a direct train - valid_routes.append(it) - else: - per_arr_explore.setdefault(train.dest, []) - per_arr_explore[train.dest].append(it) - continue - - for it in list(per_arr_explore.get(train.orig, [])): - if any(train.dest == tr.dest or train.dest == origin for tr in it): - # Avoid loops + for train in (t := tqdm(trains, desc=origin) if verbose else trains): + if train.dest == destination: + # We hope that we have a direct train + valid_routes.append([train]) + else: + if train.dest in explored and explored[train.dest][1] is None: + # This is a loop continue + elif train.arr < min_dep: + # The train is not direct and arrives on the next day, we avoid that + continue + find_routes(day, train.dest, destination, verbose, train.arr, explored) - last_train = it[-1] + # Filter unusuable trains + valid_routes += [[train] + it for it in explored[train.dest][1] if it[0].dep >= train.arr] - if last_train.arr <= train.dep: - # Update from the TGVMax simulator, this line can be useful later - queue_route(day, train.orig_iata, train.dest_iata, verbose, False) - - new_it = it + [train] - if train.dest == destination: - # Goal is achieved - valid_routes.append(new_it) - else: - per_arr_explore.setdefault(train.dest, []) - per_arr_explore[train.dest].append(new_it) + explored[origin] = (min_dep, valid_routes) # Send queued trains to the database db.session.commit() - return {destination: valid_routes} if destination else per_arr_explore + return {destination: valid_routes} if destination else {} # Don't use the decorator to keep the function callable @@ -241,19 +243,14 @@ def queue_route(day: date | datetime, origin: str, destination: str, verbose: bo if isinstance(day, datetime): day = day.date() - query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination, response_time=None) - if query.count(): - return - - query = db.session.query(RouteQueue).filter(RouteQueue.day == day, - RouteQueue.origin == origin, - RouteQueue.destination == destination, - RouteQueue.expiration_time >= datetime.now(timezone('UTC'))) + query = db.session.query(RouteQueue).filter_by(day=day, origin=origin, destination=destination)\ + .filter((RouteQueue.response_time == None) | (RouteQueue.expiration_time >= datetime.now(timezone('UTC')))) if query.count(): return db.session.add(RouteQueue(day=day, origin=origin, destination=destination)) - db.session.commit() + if autocommit: + db.session.commit() # Don't use the decorator to keep the function callable @@ -270,6 +267,10 @@ def queue_routes(day: date | datetime, origin: str | None = None, if isinstance(day, datetime): day = day.date() + valid_routes = set(db.session.query(RouteQueue).filter_by(day=day)\ + .filter((RouteQueue.response_time == None) | (RouteQueue.expiration_time >= datetime.now(timezone('UTC'))))\ + .values(RouteQueue.origin, RouteQueue.destination)) + query = db.session.query(Train).filter((Train.day == day)) if origin: query = query.filter((Train.orig_iata == origin) | (Train.orig == origin)) @@ -279,7 +280,9 @@ def queue_routes(day: date | datetime, origin: str | None = None, for train in (t := tqdm(query) if verbose else query): if verbose: t.set_description(f"{day}: {train.orig} --> {train.dest}") - queue_route(day, train.orig_iata, train.dest_iata, verbose, autocommit) + if (train.orig_iata, train.dest_iata) not in valid_routes: + queue_route(day, train.orig_iata, train.dest_iata, verbose, autocommit) + valid_routes.add((train.orig_iata, train.dest_iata)) # Same as above