Import test code, add small Flask server
Signed-off-by: Emmy D'Anello <ynerant@emy.lu>
This commit is contained in:
		
							
								
								
									
										9
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
			
		||||
__pycache__
 | 
			
		||||
.idea
 | 
			
		||||
env
 | 
			
		||||
venv
 | 
			
		||||
instance/
 | 
			
		||||
 | 
			
		||||
config.py
 | 
			
		||||
tgvmax.csv
 | 
			
		||||
migrations/versions
 | 
			
		||||
							
								
								
									
										170
									
								
								app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										170
									
								
								app.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,170 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
 | 
			
		||||
import csv
 | 
			
		||||
from datetime import date, datetime, time
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
from pytz import timezone
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
from flask import Flask
 | 
			
		||||
from flask_migrate import Migrate
 | 
			
		||||
from flask_sqlalchemy import SQLAlchemy
 | 
			
		||||
from sqlalchemy import Boolean, Column, Date, Integer, String, Time
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
import config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app = Flask(__name__)
 | 
			
		||||
 | 
			
		||||
app.config |= config.FLASK_CONFIG
 | 
			
		||||
 | 
			
		||||
db = SQLAlchemy(app)
 | 
			
		||||
Migrate(app, db)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Train(db.Model):
 | 
			
		||||
    __tablename__ = 'train'
 | 
			
		||||
    id = Column(String, primary_key=True)
 | 
			
		||||
    day = Column(Date, index=True)
 | 
			
		||||
    number = Column(Integer, index=True)
 | 
			
		||||
    entity = Column(String(255))
 | 
			
		||||
    axe = Column(String(255), index=True)
 | 
			
		||||
    orig_iata = Column(String(5), index=True)
 | 
			
		||||
    dest_iata = Column(String(5), index=True)
 | 
			
		||||
    orig = Column(String(255))
 | 
			
		||||
    dest = Column(String(255))
 | 
			
		||||
    dep = Column(String(255))
 | 
			
		||||
    arr = Column(Time)
 | 
			
		||||
    tgvmax = Column(Boolean, index=True)
 | 
			
		||||
    remaining_seats = Column(Integer)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_dataset():
 | 
			
		||||
    try:
 | 
			
		||||
        resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/')
 | 
			
		||||
        content = resp.content.decode().split('<script type="application/ld+json">')[1].split('</script>')[0].strip()
 | 
			
		||||
        content = content.replace('\r', '')
 | 
			
		||||
        content = content.replace('" \n', '" \\n')
 | 
			
		||||
        content = content.replace('.\n', '.\\n')
 | 
			
		||||
        content = content.replace('\n\n \nLa', '\\n\\n \\nLa')
 | 
			
		||||
        content = content.replace('\n"', '\\n"')
 | 
			
		||||
 | 
			
		||||
        info = json.loads(content)
 | 
			
		||||
        modified_date = datetime.fromisoformat(info['dateModified'])
 | 
			
		||||
 | 
			
		||||
        utc = timezone('UTC')
 | 
			
		||||
        last_modified = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=utc) if os.path.isfile('tgvmax.csv') else datetime(1, 1, 1, tzinfo=utc)
 | 
			
		||||
 | 
			
		||||
        if last_modified < modified_date:
 | 
			
		||||
            print("Updating tgvmax.csv…")
 | 
			
		||||
            with requests.get(info['distribution'][0]['contentUrl'], stream=True) as resp:
 | 
			
		||||
                resp.raise_for_status()
 | 
			
		||||
                with open('tgvmax.csv', 'wb') as f:
 | 
			
		||||
                    with tqdm(unit='io', unit_scale=True) as t:
 | 
			
		||||
                        for chunk in resp.iter_content(chunk_size=512 * 1024):
 | 
			
		||||
                            if chunk:
 | 
			
		||||
                                f.write(chunk)
 | 
			
		||||
                                t.update(len(chunk))
 | 
			
		||||
            os.utime('tgvmax.csv', (modified_date.timestamp(), modified_date.timestamp()))
 | 
			
		||||
            print("Done")
 | 
			
		||||
 | 
			
		||||
        print("Last modification:", modified_date)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print("An error occured while updating tgvmax.csv")
 | 
			
		||||
        print(e)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_trains(*, filter_day: date | None = None,
 | 
			
		||||
                 filter_number: int | None = None,
 | 
			
		||||
                 filter_tgvmax: bool | None = None):
 | 
			
		||||
    trains = []
 | 
			
		||||
 | 
			
		||||
    with open('tgvmax.csv') as f:
 | 
			
		||||
        first_line = True
 | 
			
		||||
        for line in csv.reader(f, delimiter=';'):
 | 
			
		||||
            if first_line:
 | 
			
		||||
                first_line = False
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            train = Train(*line)
 | 
			
		||||
            train.day = date.fromisoformat(train.day)
 | 
			
		||||
            train.number = int(train.number)
 | 
			
		||||
            train.dep = time.fromisoformat(train.dep)
 | 
			
		||||
            train.arr = time.fromisoformat(train.arr)
 | 
			
		||||
            train.tgvmax = train.tgvmax == 'OUI'
 | 
			
		||||
 | 
			
		||||
            if filter_day is not None and train.day != filter_day:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if filter_number is not None and train.number != filter_number:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if filter_tgvmax is not None and train.tgvmax != filter_tgvmax:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            trains.append(train)
 | 
			
		||||
 | 
			
		||||
    return trains
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_routes(day, orig, dest):
 | 
			
		||||
    trains = parse_trains(filter_day=date(2023, 2, 17),
 | 
			
		||||
                          filter_tgvmax=True)
 | 
			
		||||
 | 
			
		||||
    trains.sort(key=lambda train: train.dep)
 | 
			
		||||
 | 
			
		||||
    origin = "STRASBOURG"
 | 
			
		||||
    dest = "LYON (intramuros)"
 | 
			
		||||
 | 
			
		||||
    explore = []
 | 
			
		||||
    per_arr_explore = {}
 | 
			
		||||
    valid_routes = []
 | 
			
		||||
 | 
			
		||||
    for train in tqdm(trains):
 | 
			
		||||
        if train.orig == origin:
 | 
			
		||||
            it = [train]
 | 
			
		||||
            if train.dest == dest:
 | 
			
		||||
                # We hope that we have a direct train
 | 
			
		||||
                valid_routes.append(it)
 | 
			
		||||
            else:
 | 
			
		||||
                explore.append(it)
 | 
			
		||||
                per_arr_explore.setdefault(train.dest, [])
 | 
			
		||||
                per_arr_explore[train.dest].append(it)
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        for it in list(per_arr_explore.get(train.orig, [])):
 | 
			
		||||
            if any(train.dest == tr.dest or train.dest == origin for tr in it):
 | 
			
		||||
                # Avoid loops
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            last_train = it[-1]
 | 
			
		||||
 | 
			
		||||
            if last_train.arr <= train.dep:
 | 
			
		||||
                new_it = it + [train]
 | 
			
		||||
                if train.dest == dest:
 | 
			
		||||
                    # Goal is achieved
 | 
			
		||||
                    valid_routes.append(new_it)
 | 
			
		||||
                else:
 | 
			
		||||
                    explore.append(new_it)
 | 
			
		||||
                    per_arr_explore.setdefault(train.dest, [])
 | 
			
		||||
                    per_arr_explore[train.dest].append(new_it)
 | 
			
		||||
 | 
			
		||||
    return valid_routes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_route(route: list[Train]):
 | 
			
		||||
    s = f"{route[0].orig} "
 | 
			
		||||
    for tr in route:
 | 
			
		||||
        s += f"({tr.dep}) --> ({tr.arr}) {tr.dest}, "
 | 
			
		||||
    print(s[:-2])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get('/')
 | 
			
		||||
def index():
 | 
			
		||||
    return "Hello world!"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    app.run(debug=True)
 | 
			
		||||
							
								
								
									
										5
									
								
								config.example.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								config.example.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
FLASK_CONFIG = {
 | 
			
		||||
    "SQLALCHEMY_DATABASE_URI": "postgresql://user:password@host:5432/dbname",
 | 
			
		||||
    'SQLALCHEMY_TRACK_MODIFICATIONS': True,
 | 
			
		||||
    'SECRET_KEY': "random string",
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										1
									
								
								migrations/README
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								migrations/README
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
Single-database configuration for Flask.
 | 
			
		||||
							
								
								
									
										50
									
								
								migrations/alembic.ini
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								migrations/alembic.ini
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,50 @@
 | 
			
		||||
# A generic, single database configuration.
 | 
			
		||||
 | 
			
		||||
[alembic]
 | 
			
		||||
# template used to generate migration files
 | 
			
		||||
# file_template = %%(rev)s_%%(slug)s
 | 
			
		||||
 | 
			
		||||
# set to 'true' to run the environment during
 | 
			
		||||
# the 'revision' command, regardless of autogenerate
 | 
			
		||||
# revision_environment = false
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Logging configuration
 | 
			
		||||
[loggers]
 | 
			
		||||
keys = root,sqlalchemy,alembic,flask_migrate
 | 
			
		||||
 | 
			
		||||
[handlers]
 | 
			
		||||
keys = console
 | 
			
		||||
 | 
			
		||||
[formatters]
 | 
			
		||||
keys = generic
 | 
			
		||||
 | 
			
		||||
[logger_root]
 | 
			
		||||
level = WARN
 | 
			
		||||
handlers = console
 | 
			
		||||
qualname =
 | 
			
		||||
 | 
			
		||||
[logger_sqlalchemy]
 | 
			
		||||
level = WARN
 | 
			
		||||
handlers =
 | 
			
		||||
qualname = sqlalchemy.engine
 | 
			
		||||
 | 
			
		||||
[logger_alembic]
 | 
			
		||||
level = INFO
 | 
			
		||||
handlers =
 | 
			
		||||
qualname = alembic
 | 
			
		||||
 | 
			
		||||
[logger_flask_migrate]
 | 
			
		||||
level = INFO
 | 
			
		||||
handlers =
 | 
			
		||||
qualname = flask_migrate
 | 
			
		||||
 | 
			
		||||
[handler_console]
 | 
			
		||||
class = StreamHandler
 | 
			
		||||
args = (sys.stderr,)
 | 
			
		||||
level = NOTSET
 | 
			
		||||
formatter = generic
 | 
			
		||||
 | 
			
		||||
[formatter_generic]
 | 
			
		||||
format = %(levelname)-5.5s [%(name)s] %(message)s
 | 
			
		||||
datefmt = %H:%M:%S
 | 
			
		||||
							
								
								
									
										110
									
								
								migrations/env.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								migrations/env.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,110 @@
 | 
			
		||||
import logging
 | 
			
		||||
from logging.config import fileConfig
 | 
			
		||||
 | 
			
		||||
from flask import current_app
 | 
			
		||||
 | 
			
		||||
from alembic import context
 | 
			
		||||
 | 
			
		||||
# this is the Alembic Config object, which provides
 | 
			
		||||
# access to the values within the .ini file in use.
 | 
			
		||||
config = context.config
 | 
			
		||||
 | 
			
		||||
# Interpret the config file for Python logging.
 | 
			
		||||
# This line sets up loggers basically.
 | 
			
		||||
fileConfig(config.config_file_name)
 | 
			
		||||
logger = logging.getLogger('alembic.env')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_engine():
 | 
			
		||||
    try:
 | 
			
		||||
        # this works with Flask-SQLAlchemy<3 and Alchemical
 | 
			
		||||
        return current_app.extensions['migrate'].db.get_engine()
 | 
			
		||||
    except TypeError:
 | 
			
		||||
        # this works with Flask-SQLAlchemy>=3
 | 
			
		||||
        return current_app.extensions['migrate'].db.engine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_engine_url():
 | 
			
		||||
    try:
 | 
			
		||||
        return get_engine().url.render_as_string(hide_password=False).replace(
 | 
			
		||||
            '%', '%%')
 | 
			
		||||
    except AttributeError:
 | 
			
		||||
        return str(get_engine().url).replace('%', '%%')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# add your model's MetaData object here
 | 
			
		||||
# for 'autogenerate' support
 | 
			
		||||
# from myapp import mymodel
 | 
			
		||||
# target_metadata = mymodel.Base.metadata
 | 
			
		||||
config.set_main_option('sqlalchemy.url', get_engine_url())
 | 
			
		||||
target_db = current_app.extensions['migrate'].db
 | 
			
		||||
 | 
			
		||||
# other values from the config, defined by the needs of env.py,
 | 
			
		||||
# can be acquired:
 | 
			
		||||
# my_important_option = config.get_main_option("my_important_option")
 | 
			
		||||
# ... etc.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_metadata():
 | 
			
		||||
    if hasattr(target_db, 'metadatas'):
 | 
			
		||||
        return target_db.metadatas[None]
 | 
			
		||||
    return target_db.metadata
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_migrations_offline():
 | 
			
		||||
    """Run migrations in 'offline' mode.
 | 
			
		||||
 | 
			
		||||
    This configures the context with just a URL
 | 
			
		||||
    and not an Engine, though an Engine is acceptable
 | 
			
		||||
    here as well.  By skipping the Engine creation
 | 
			
		||||
    we don't even need a DBAPI to be available.
 | 
			
		||||
 | 
			
		||||
    Calls to context.execute() here emit the given string to the
 | 
			
		||||
    script output.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    url = config.get_main_option("sqlalchemy.url")
 | 
			
		||||
    context.configure(
 | 
			
		||||
        url=url, target_metadata=get_metadata(), literal_binds=True
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with context.begin_transaction():
 | 
			
		||||
        context.run_migrations()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_migrations_online():
 | 
			
		||||
    """Run migrations in 'online' mode.
 | 
			
		||||
 | 
			
		||||
    In this scenario we need to create an Engine
 | 
			
		||||
    and associate a connection with the context.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # this callback is used to prevent an auto-migration from being generated
 | 
			
		||||
    # when there are no changes to the schema
 | 
			
		||||
    # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html
 | 
			
		||||
    def process_revision_directives(context, revision, directives):
 | 
			
		||||
        if getattr(config.cmd_opts, 'autogenerate', False):
 | 
			
		||||
            script = directives[0]
 | 
			
		||||
            if script.upgrade_ops.is_empty():
 | 
			
		||||
                directives[:] = []
 | 
			
		||||
                logger.info('No changes in schema detected.')
 | 
			
		||||
 | 
			
		||||
    connectable = get_engine()
 | 
			
		||||
 | 
			
		||||
    with connectable.connect() as connection:
 | 
			
		||||
        context.configure(
 | 
			
		||||
            connection=connection,
 | 
			
		||||
            target_metadata=get_metadata(),
 | 
			
		||||
            process_revision_directives=process_revision_directives,
 | 
			
		||||
            **current_app.extensions['migrate'].configure_args
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        with context.begin_transaction():
 | 
			
		||||
            context.run_migrations()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if context.is_offline_mode():
 | 
			
		||||
    run_migrations_offline()
 | 
			
		||||
else:
 | 
			
		||||
    run_migrations_online()
 | 
			
		||||
							
								
								
									
										24
									
								
								migrations/script.py.mako
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								migrations/script.py.mako
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,24 @@
 | 
			
		||||
"""${message}
 | 
			
		||||
 | 
			
		||||
Revision ID: ${up_revision}
 | 
			
		||||
Revises: ${down_revision | comma,n}
 | 
			
		||||
Create Date: ${create_date}
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
from alembic import op
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
${imports if imports else ""}
 | 
			
		||||
 | 
			
		||||
# revision identifiers, used by Alembic.
 | 
			
		||||
revision = ${repr(up_revision)}
 | 
			
		||||
down_revision = ${repr(down_revision)}
 | 
			
		||||
branch_labels = ${repr(branch_labels)}
 | 
			
		||||
depends_on = ${repr(depends_on)}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upgrade():
 | 
			
		||||
    ${upgrades if upgrades else "pass"}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def downgrade():
 | 
			
		||||
    ${downgrades if downgrades else "pass"}
 | 
			
		||||
		Reference in New Issue
	
	Block a user