From ca16a01dda5389ac3785deedbcfe1948341804dd Mon Sep 17 00:00:00 2001 From: Kumi Date: Wed, 30 Aug 2023 12:30:21 +0200 Subject: [PATCH] Version bump to 0.2.5 and added a new function for handling database sessions --- pyproject.toml | 2 +- src/trackbert/classes/database.py | 82 ++++++++++++++++++++----------- src/trackbert/classes/tracker.py | 2 +- 3 files changed, 55 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 04738ad..65759bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "trackbert" -version = "0.2.4" +version = "0.2.5" authors = [ { name="Kumi Mitterer", email="trackbert@kumi.email" }, ] diff --git a/src/trackbert/classes/database.py b/src/trackbert/classes/database.py index d9ca303..f49c1a1 100644 --- a/src/trackbert/classes/database.py +++ b/src/trackbert/classes/database.py @@ -1,6 +1,6 @@ from sqlalchemy import Column, Integer, String from sqlalchemy import create_engine, ForeignKey -from sqlalchemy.orm import sessionmaker, relationship +from sqlalchemy.orm import sessionmaker, relationship, scoped_session from sqlalchemy.ext.declarative import declarative_base from alembic.config import Config @@ -9,11 +9,27 @@ from alembic import command import json import time +from functools import wraps from pathlib import Path Base = declarative_base() +def with_session(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + session = self.session() + try: + result = func(self, session, *args, **kwargs) + session.commit() + return result + except: + session.rollback() + raise + + return wrapper + + class Shipment(Base): __tablename__ = "shipments" @@ -38,45 +54,48 @@ class Event(Base): class Database: def __init__(self, database_uri): self.engine = create_engine(database_uri) - Session = sessionmaker(bind=self.engine) - self.session = Session() - + self.session = scoped_session(sessionmaker(bind=self.engine)) self.run_migrations() - def create_shipment(self, tracking_number, carrier, description=""): + @with_session + def create_shipment(self, session, tracking_number, carrier, description=""): new_shipment = Shipment( tracking_number=tracking_number, carrier=carrier, description=description ) - self.session.add(new_shipment) - self.session.commit() + session.add(new_shipment) + session.commit() - def update_shipment(self, tracking_number, carrier, description=""): + @with_session + def update_shipment(self, session, tracking_number, carrier, description=""): shipment = self.get_shipment(tracking_number) if shipment: shipment.carrier = carrier shipment.description = description - self.session.commit() + session.commit() else: raise ValueError(f"Shipment {tracking_number} does not exist") - def disable_shipment(self, tracking_number): + @with_session + def disable_shipment(self, session, tracking_number): shipment = self.get_shipment(tracking_number) if shipment: shipment.carrier = "" - self.session.commit() + session.commit() else: raise ValueError(f"Shipment {tracking_number} does not exist") - def get_shipment(self, tracking_number): + @with_session + def get_shipment(self, session, tracking_number): shipment = ( - self.session.query(Shipment) + session.query(Shipment) .filter(Shipment.tracking_number == tracking_number) .first() ) return shipment - def get_shipments(self): - shipments = self.session.query(Shipment).all() + @with_session + def get_shipments(self, session): + shipments = session.query(Shipment).all() return shipments def create_event(self, shipment_id, event_time, event_description, raw_event): @@ -91,19 +110,20 @@ class Database: ) self.write_event(new_event) - def write_event(self, event): - self.session.add(event) - self.session.commit() + @with_session + def write_event(self, session, event): + session.add(event) + session.commit() - def get_shipment_events(self, shipment_id): - shipment = ( - self.session.query(Shipment).filter(Shipment.id == shipment_id).first() - ) + @with_session + def get_shipment_events(self, session, shipment_id): + shipment = session.query(Shipment).filter(Shipment.id == shipment_id).first() return shipment.events if shipment else None - def get_latest_event(self, shipment_id): + @with_session + def get_latest_event(self, session, shipment_id): event = ( - self.session.query(Event) + session.query(Event) .filter(Event.shipment_id == shipment_id) .order_by(Event.event_time.desc()) .first() @@ -112,14 +132,18 @@ class Database: def make_migration(self, message): alembic_cfg = Config(Path(__file__).parent.parent / "alembic.ini") - alembic_cfg.set_main_option("sqlalchemy.url", self.engine.url.__to_string__(hide_password=False)) - migrations_dir = Path(__file__).parent.parent / 'migrations' + alembic_cfg.set_main_option( + "sqlalchemy.url", self.engine.url.__to_string__(hide_password=False) + ) + migrations_dir = Path(__file__).parent.parent / "migrations" alembic_cfg.set_main_option("script_location", str(migrations_dir)) command.revision(alembic_cfg, message=message, autogenerate=True) def run_migrations(self): alembic_cfg = Config(Path(__file__).parent.parent / "alembic.ini") - alembic_cfg.set_main_option("sqlalchemy.url", self.engine.url.__to_string__(hide_password=False)) - migrations_dir = Path(__file__).parent.parent / 'migrations' + alembic_cfg.set_main_option( + "sqlalchemy.url", self.engine.url.__to_string__(hide_password=False) + ) + migrations_dir = Path(__file__).parent.parent / "migrations" alembic_cfg.set_main_option("script_location", str(migrations_dir)) - command.upgrade(alembic_cfg, "head") \ No newline at end of file + command.upgrade(alembic_cfg, "head") diff --git a/src/trackbert/classes/tracker.py b/src/trackbert/classes/tracker.py index c430bd4..42df6b5 100644 --- a/src/trackbert/classes/tracker.py +++ b/src/trackbert/classes/tracker.py @@ -144,7 +144,7 @@ class Tracker: logging.debug(f"No known events for {shipment.tracking_number}") logging.debug( - f"Latest upstream event for {shipment.tracking_number}: {events[0].event_description} - {events[0].event_time}" + f"Latest upstream event for {shipment.tracking_number}: {events[-1].event_description} - {events[-1].event_time}" ) for event in events: