Version bump to 0.2.5 and added a new function for handling database sessions

This commit is contained in:
Kumi 2023-08-30 12:30:21 +02:00
parent b12a7f71c5
commit ca16a01dda
Signed by: kumi
GPG key ID: ECBCC9082395383F
3 changed files with 55 additions and 31 deletions

View file

@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "trackbert"
version = "0.2.4"
version = "0.2.5"
authors = [
{ name="Kumi Mitterer", email="trackbert@kumi.email" },
]

View file

@ -1,6 +1,6 @@
from sqlalchemy import Column, Integer, String
from sqlalchemy import create_engine, ForeignKey
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.orm import sessionmaker, relationship, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from alembic.config import Config
@ -9,11 +9,27 @@ from alembic import command
import json
import time
from functools import wraps
from pathlib import Path
Base = declarative_base()
def with_session(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
session = self.session()
try:
result = func(self, session, *args, **kwargs)
session.commit()
return result
except:
session.rollback()
raise
return wrapper
class Shipment(Base):
__tablename__ = "shipments"
@ -38,45 +54,48 @@ class Event(Base):
class Database:
def __init__(self, database_uri):
self.engine = create_engine(database_uri)
Session = sessionmaker(bind=self.engine)
self.session = Session()
self.session = scoped_session(sessionmaker(bind=self.engine))
self.run_migrations()
def create_shipment(self, tracking_number, carrier, description=""):
@with_session
def create_shipment(self, session, tracking_number, carrier, description=""):
new_shipment = Shipment(
tracking_number=tracking_number, carrier=carrier, description=description
)
self.session.add(new_shipment)
self.session.commit()
session.add(new_shipment)
session.commit()
def update_shipment(self, tracking_number, carrier, description=""):
@with_session
def update_shipment(self, session, tracking_number, carrier, description=""):
shipment = self.get_shipment(tracking_number)
if shipment:
shipment.carrier = carrier
shipment.description = description
self.session.commit()
session.commit()
else:
raise ValueError(f"Shipment {tracking_number} does not exist")
def disable_shipment(self, tracking_number):
@with_session
def disable_shipment(self, session, tracking_number):
shipment = self.get_shipment(tracking_number)
if shipment:
shipment.carrier = ""
self.session.commit()
session.commit()
else:
raise ValueError(f"Shipment {tracking_number} does not exist")
def get_shipment(self, tracking_number):
@with_session
def get_shipment(self, session, tracking_number):
shipment = (
self.session.query(Shipment)
session.query(Shipment)
.filter(Shipment.tracking_number == tracking_number)
.first()
)
return shipment
def get_shipments(self):
shipments = self.session.query(Shipment).all()
@with_session
def get_shipments(self, session):
shipments = session.query(Shipment).all()
return shipments
def create_event(self, shipment_id, event_time, event_description, raw_event):
@ -91,19 +110,20 @@ class Database:
)
self.write_event(new_event)
def write_event(self, event):
self.session.add(event)
self.session.commit()
@with_session
def write_event(self, session, event):
session.add(event)
session.commit()
def get_shipment_events(self, shipment_id):
shipment = (
self.session.query(Shipment).filter(Shipment.id == shipment_id).first()
)
@with_session
def get_shipment_events(self, session, shipment_id):
shipment = session.query(Shipment).filter(Shipment.id == shipment_id).first()
return shipment.events if shipment else None
def get_latest_event(self, shipment_id):
@with_session
def get_latest_event(self, session, shipment_id):
event = (
self.session.query(Event)
session.query(Event)
.filter(Event.shipment_id == shipment_id)
.order_by(Event.event_time.desc())
.first()
@ -112,14 +132,18 @@ class Database:
def make_migration(self, message):
alembic_cfg = Config(Path(__file__).parent.parent / "alembic.ini")
alembic_cfg.set_main_option("sqlalchemy.url", self.engine.url.__to_string__(hide_password=False))
migrations_dir = Path(__file__).parent.parent / 'migrations'
alembic_cfg.set_main_option(
"sqlalchemy.url", self.engine.url.__to_string__(hide_password=False)
)
migrations_dir = Path(__file__).parent.parent / "migrations"
alembic_cfg.set_main_option("script_location", str(migrations_dir))
command.revision(alembic_cfg, message=message, autogenerate=True)
def run_migrations(self):
alembic_cfg = Config(Path(__file__).parent.parent / "alembic.ini")
alembic_cfg.set_main_option("sqlalchemy.url", self.engine.url.__to_string__(hide_password=False))
migrations_dir = Path(__file__).parent.parent / 'migrations'
alembic_cfg.set_main_option(
"sqlalchemy.url", self.engine.url.__to_string__(hide_password=False)
)
migrations_dir = Path(__file__).parent.parent / "migrations"
alembic_cfg.set_main_option("script_location", str(migrations_dir))
command.upgrade(alembic_cfg, "head")
command.upgrade(alembic_cfg, "head")

View file

@ -144,7 +144,7 @@ class Tracker:
logging.debug(f"No known events for {shipment.tracking_number}")
logging.debug(
f"Latest upstream event for {shipment.tracking_number}: {events[0].event_description} - {events[0].event_time}"
f"Latest upstream event for {shipment.tracking_number}: {events[-1].event_description} - {events[-1].event_time}"
)
for event in events: