Version bump to 0.2.5 and added a new function for handling database sessions

This commit is contained in:
Kumi 2023-08-30 12:30:21 +02:00
parent b12a7f71c5
commit ca16a01dda
Signed by: kumi
GPG key ID: ECBCC9082395383F
3 changed files with 55 additions and 31 deletions

View file

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "trackbert" name = "trackbert"
version = "0.2.4" version = "0.2.5"
authors = [ authors = [
{ name="Kumi Mitterer", email="trackbert@kumi.email" }, { name="Kumi Mitterer", email="trackbert@kumi.email" },
] ]

View file

@ -1,6 +1,6 @@
from sqlalchemy import Column, Integer, String from sqlalchemy import Column, Integer, String
from sqlalchemy import create_engine, ForeignKey from sqlalchemy import create_engine, ForeignKey
from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.orm import sessionmaker, relationship, scoped_session
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from alembic.config import Config from alembic.config import Config
@ -9,11 +9,27 @@ from alembic import command
import json import json
import time import time
from functools import wraps
from pathlib import Path from pathlib import Path
Base = declarative_base() Base = declarative_base()
def with_session(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
session = self.session()
try:
result = func(self, session, *args, **kwargs)
session.commit()
return result
except:
session.rollback()
raise
return wrapper
class Shipment(Base): class Shipment(Base):
__tablename__ = "shipments" __tablename__ = "shipments"
@ -38,45 +54,48 @@ class Event(Base):
class Database: class Database:
def __init__(self, database_uri): def __init__(self, database_uri):
self.engine = create_engine(database_uri) self.engine = create_engine(database_uri)
Session = sessionmaker(bind=self.engine) self.session = scoped_session(sessionmaker(bind=self.engine))
self.session = Session()
self.run_migrations() self.run_migrations()
def create_shipment(self, tracking_number, carrier, description=""): @with_session
def create_shipment(self, session, tracking_number, carrier, description=""):
new_shipment = Shipment( new_shipment = Shipment(
tracking_number=tracking_number, carrier=carrier, description=description tracking_number=tracking_number, carrier=carrier, description=description
) )
self.session.add(new_shipment) session.add(new_shipment)
self.session.commit() session.commit()
def update_shipment(self, tracking_number, carrier, description=""): @with_session
def update_shipment(self, session, tracking_number, carrier, description=""):
shipment = self.get_shipment(tracking_number) shipment = self.get_shipment(tracking_number)
if shipment: if shipment:
shipment.carrier = carrier shipment.carrier = carrier
shipment.description = description shipment.description = description
self.session.commit() session.commit()
else: else:
raise ValueError(f"Shipment {tracking_number} does not exist") raise ValueError(f"Shipment {tracking_number} does not exist")
def disable_shipment(self, tracking_number): @with_session
def disable_shipment(self, session, tracking_number):
shipment = self.get_shipment(tracking_number) shipment = self.get_shipment(tracking_number)
if shipment: if shipment:
shipment.carrier = "" shipment.carrier = ""
self.session.commit() session.commit()
else: else:
raise ValueError(f"Shipment {tracking_number} does not exist") raise ValueError(f"Shipment {tracking_number} does not exist")
def get_shipment(self, tracking_number): @with_session
def get_shipment(self, session, tracking_number):
shipment = ( shipment = (
self.session.query(Shipment) session.query(Shipment)
.filter(Shipment.tracking_number == tracking_number) .filter(Shipment.tracking_number == tracking_number)
.first() .first()
) )
return shipment return shipment
def get_shipments(self): @with_session
shipments = self.session.query(Shipment).all() def get_shipments(self, session):
shipments = session.query(Shipment).all()
return shipments return shipments
def create_event(self, shipment_id, event_time, event_description, raw_event): def create_event(self, shipment_id, event_time, event_description, raw_event):
@ -91,19 +110,20 @@ class Database:
) )
self.write_event(new_event) self.write_event(new_event)
def write_event(self, event): @with_session
self.session.add(event) def write_event(self, session, event):
self.session.commit() session.add(event)
session.commit()
def get_shipment_events(self, shipment_id): @with_session
shipment = ( def get_shipment_events(self, session, shipment_id):
self.session.query(Shipment).filter(Shipment.id == shipment_id).first() shipment = session.query(Shipment).filter(Shipment.id == shipment_id).first()
)
return shipment.events if shipment else None return shipment.events if shipment else None
def get_latest_event(self, shipment_id): @with_session
def get_latest_event(self, session, shipment_id):
event = ( event = (
self.session.query(Event) session.query(Event)
.filter(Event.shipment_id == shipment_id) .filter(Event.shipment_id == shipment_id)
.order_by(Event.event_time.desc()) .order_by(Event.event_time.desc())
.first() .first()
@ -112,14 +132,18 @@ class Database:
def make_migration(self, message): def make_migration(self, message):
alembic_cfg = Config(Path(__file__).parent.parent / "alembic.ini") alembic_cfg = Config(Path(__file__).parent.parent / "alembic.ini")
alembic_cfg.set_main_option("sqlalchemy.url", self.engine.url.__to_string__(hide_password=False)) alembic_cfg.set_main_option(
migrations_dir = Path(__file__).parent.parent / 'migrations' "sqlalchemy.url", self.engine.url.__to_string__(hide_password=False)
)
migrations_dir = Path(__file__).parent.parent / "migrations"
alembic_cfg.set_main_option("script_location", str(migrations_dir)) alembic_cfg.set_main_option("script_location", str(migrations_dir))
command.revision(alembic_cfg, message=message, autogenerate=True) command.revision(alembic_cfg, message=message, autogenerate=True)
def run_migrations(self): def run_migrations(self):
alembic_cfg = Config(Path(__file__).parent.parent / "alembic.ini") alembic_cfg = Config(Path(__file__).parent.parent / "alembic.ini")
alembic_cfg.set_main_option("sqlalchemy.url", self.engine.url.__to_string__(hide_password=False)) alembic_cfg.set_main_option(
migrations_dir = Path(__file__).parent.parent / 'migrations' "sqlalchemy.url", self.engine.url.__to_string__(hide_password=False)
)
migrations_dir = Path(__file__).parent.parent / "migrations"
alembic_cfg.set_main_option("script_location", str(migrations_dir)) alembic_cfg.set_main_option("script_location", str(migrations_dir))
command.upgrade(alembic_cfg, "head") command.upgrade(alembic_cfg, "head")

View file

@ -144,7 +144,7 @@ class Tracker:
logging.debug(f"No known events for {shipment.tracking_number}") logging.debug(f"No known events for {shipment.tracking_number}")
logging.debug( logging.debug(
f"Latest upstream event for {shipment.tracking_number}: {events[0].event_description} - {events[0].event_time}" f"Latest upstream event for {shipment.tracking_number}: {events[-1].event_description} - {events[-1].event_time}"
) )
for event in events: for event in events: