Parse CSV to database
This commit is contained in:
parent
e22a039aa8
commit
39fdfce38f
90
app.py
90
app.py
|
@ -7,10 +7,12 @@ import json
|
|||
from pytz import timezone
|
||||
import requests
|
||||
|
||||
import click
|
||||
from flask import Flask
|
||||
from flask.cli import AppGroup
|
||||
from flask_migrate import Migrate
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from sqlalchemy import Boolean, Column, Date, Integer, String, Time
|
||||
from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String, Time
|
||||
from tqdm import tqdm
|
||||
|
||||
import config
|
||||
|
@ -18,6 +20,10 @@ import config
|
|||
|
||||
app = Flask(__name__)
|
||||
|
||||
cli = AppGroup('tgvmax', help="Manage the TGVMax dataset.")
|
||||
app.cli.add_command(cli)
|
||||
|
||||
|
||||
app.config |= config.FLASK_CONFIG
|
||||
|
||||
db = SQLAlchemy(app)
|
||||
|
@ -29,19 +35,25 @@ class Train(db.Model):
|
|||
id = Column(String, primary_key=True)
|
||||
day = Column(Date, index=True)
|
||||
number = Column(Integer, index=True)
|
||||
entity = Column(String(255))
|
||||
axe = Column(String(255), index=True)
|
||||
entity = Column(String(10))
|
||||
axe = Column(String(32), index=True)
|
||||
orig_iata = Column(String(5), index=True)
|
||||
dest_iata = Column(String(5), index=True)
|
||||
orig = Column(String(255))
|
||||
dest = Column(String(255))
|
||||
dep = Column(String(255))
|
||||
orig = Column(String(32))
|
||||
dest = Column(String(32))
|
||||
dep = Column(Time)
|
||||
arr = Column(Time)
|
||||
tgvmax = Column(Boolean, index=True)
|
||||
remaining_seats = Column(Integer)
|
||||
remaining_seats = Column(Integer, default=-1)
|
||||
last_modification = Column(DateTime)
|
||||
expiration_time = Column(DateTime)
|
||||
|
||||
|
||||
@cli.command("update-dataset")
|
||||
def update_dataset():
|
||||
"""
|
||||
Query the latest version of the SNCF OpenData dataset, as a CSV file.
|
||||
"""
|
||||
try:
|
||||
resp = requests.get('https://ressources.data.sncf.com/explore/dataset/tgvmax/information/')
|
||||
content = resp.content.decode().split('<script type="application/ld+json">')[1].split('</script>')[0].strip()
|
||||
|
@ -76,37 +88,61 @@ def update_dataset():
|
|||
print(e)
|
||||
|
||||
|
||||
def parse_trains(*, filter_day: date | None = None,
|
||||
filter_number: int | None = None,
|
||||
filter_tgvmax: bool | None = None):
|
||||
trains = []
|
||||
@cli.command("parse-csv")
|
||||
@click.option('-F', '--flush', type=bool, is_flag=True, help="Flush the database before filling it.")
|
||||
def parse_trains(flush: bool = False):
|
||||
"""
|
||||
Parse the CSV file and store it to the database.
|
||||
"""
|
||||
|
||||
if flush:
|
||||
print("Flush database…")
|
||||
db.session.query(Train).delete()
|
||||
|
||||
last_modification = datetime.utcfromtimestamp(os.path.getmtime('tgvmax.csv')).replace(tzinfo=timezone('UTC'))
|
||||
|
||||
with open('tgvmax.csv') as f:
|
||||
first_line = True
|
||||
for line in csv.reader(f, delimiter=';'):
|
||||
already_seen = set()
|
||||
for line in tqdm(csv.reader(f, delimiter=';')):
|
||||
if first_line:
|
||||
first_line = False
|
||||
continue
|
||||
|
||||
train = Train(*line)
|
||||
train.day = date.fromisoformat(train.day)
|
||||
train.number = int(train.number)
|
||||
train.dep = time.fromisoformat(train.dep)
|
||||
train.arr = time.fromisoformat(train.arr)
|
||||
train.tgvmax = train.tgvmax == 'OUI'
|
||||
|
||||
if filter_day is not None and train.day != filter_day:
|
||||
train_id = f"{line[1]}-{line[0]}-{line[4]}-{line[5]}"
|
||||
if train_id in already_seen:
|
||||
# Some trains are mysteriously duplicated, concerns only some « Intercités de nuit »
|
||||
# and the Brive-la-Gaillarde -- Paris
|
||||
# and, maybe, for Roubaix-Tourcoing
|
||||
if line[3] != "IC NUIT" and line[1] != '3614' and not (line[4] == 'FRADP' and line[5] == 'FRADM'):
|
||||
print("Duplicate:", train_id)
|
||||
continue
|
||||
|
||||
if filter_number is not None and train.number != filter_number:
|
||||
continue
|
||||
train = Train(
|
||||
id=train_id,
|
||||
day=date.fromisoformat(line[0]),
|
||||
number=int(line[1]),
|
||||
entity=line[2],
|
||||
axe=line[3],
|
||||
orig_iata=line[4],
|
||||
dest_iata=line[5],
|
||||
orig=line[6],
|
||||
dest=line[7],
|
||||
dep=time.fromisoformat(line[8]),
|
||||
arr=time.fromisoformat(line[9]),
|
||||
tgvmax=line[10] == 'OUI',
|
||||
last_modification=last_modification,
|
||||
expiration_time=last_modification,
|
||||
)
|
||||
if flush:
|
||||
db.session.add(train)
|
||||
else:
|
||||
db.session.merge(train)
|
||||
|
||||
if filter_tgvmax is not None and train.tgvmax != filter_tgvmax:
|
||||
continue
|
||||
if line[3] == "IC NUIT" or line[1] == '3614' or (line[4] == 'FRADP' and line[5] == 'FRADM'):
|
||||
already_seen.add(train_id)
|
||||
|
||||
trains.append(train)
|
||||
|
||||
return trains
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def find_routes(day, orig, dest):
|
||||
|
|
Loading…
Reference in New Issue