Version bump to 0.2.5 and added a new function for handling database sessions
This commit is contained in:
parent
b12a7f71c5
commit
ca16a01dda
3 changed files with 55 additions and 31 deletions
|
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||
|
||||
[project]
|
||||
name = "trackbert"
|
||||
version = "0.2.4"
|
||||
version = "0.2.5"
|
||||
authors = [
|
||||
{ name="Kumi Mitterer", email="trackbert@kumi.email" },
|
||||
]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from sqlalchemy import Column, Integer, String
|
||||
from sqlalchemy import create_engine, ForeignKey
|
||||
from sqlalchemy.orm import sessionmaker, relationship
|
||||
from sqlalchemy.orm import sessionmaker, relationship, scoped_session
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from alembic.config import Config
|
||||
|
@ -9,11 +9,27 @@ from alembic import command
|
|||
import json
|
||||
import time
|
||||
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def with_session(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
session = self.session()
|
||||
try:
|
||||
result = func(self, session, *args, **kwargs)
|
||||
session.commit()
|
||||
return result
|
||||
except:
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Shipment(Base):
|
||||
__tablename__ = "shipments"
|
||||
|
||||
|
@ -38,45 +54,48 @@ class Event(Base):
|
|||
class Database:
|
||||
def __init__(self, database_uri):
|
||||
self.engine = create_engine(database_uri)
|
||||
Session = sessionmaker(bind=self.engine)
|
||||
self.session = Session()
|
||||
|
||||
self.session = scoped_session(sessionmaker(bind=self.engine))
|
||||
self.run_migrations()
|
||||
|
||||
def create_shipment(self, tracking_number, carrier, description=""):
|
||||
@with_session
|
||||
def create_shipment(self, session, tracking_number, carrier, description=""):
|
||||
new_shipment = Shipment(
|
||||
tracking_number=tracking_number, carrier=carrier, description=description
|
||||
)
|
||||
self.session.add(new_shipment)
|
||||
self.session.commit()
|
||||
session.add(new_shipment)
|
||||
session.commit()
|
||||
|
||||
def update_shipment(self, tracking_number, carrier, description=""):
|
||||
@with_session
|
||||
def update_shipment(self, session, tracking_number, carrier, description=""):
|
||||
shipment = self.get_shipment(tracking_number)
|
||||
if shipment:
|
||||
shipment.carrier = carrier
|
||||
shipment.description = description
|
||||
self.session.commit()
|
||||
session.commit()
|
||||
else:
|
||||
raise ValueError(f"Shipment {tracking_number} does not exist")
|
||||
|
||||
def disable_shipment(self, tracking_number):
|
||||
@with_session
|
||||
def disable_shipment(self, session, tracking_number):
|
||||
shipment = self.get_shipment(tracking_number)
|
||||
if shipment:
|
||||
shipment.carrier = ""
|
||||
self.session.commit()
|
||||
session.commit()
|
||||
else:
|
||||
raise ValueError(f"Shipment {tracking_number} does not exist")
|
||||
|
||||
def get_shipment(self, tracking_number):
|
||||
@with_session
|
||||
def get_shipment(self, session, tracking_number):
|
||||
shipment = (
|
||||
self.session.query(Shipment)
|
||||
session.query(Shipment)
|
||||
.filter(Shipment.tracking_number == tracking_number)
|
||||
.first()
|
||||
)
|
||||
return shipment
|
||||
|
||||
def get_shipments(self):
|
||||
shipments = self.session.query(Shipment).all()
|
||||
@with_session
|
||||
def get_shipments(self, session):
|
||||
shipments = session.query(Shipment).all()
|
||||
return shipments
|
||||
|
||||
def create_event(self, shipment_id, event_time, event_description, raw_event):
|
||||
|
@ -91,19 +110,20 @@ class Database:
|
|||
)
|
||||
self.write_event(new_event)
|
||||
|
||||
def write_event(self, event):
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
@with_session
|
||||
def write_event(self, session, event):
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
def get_shipment_events(self, shipment_id):
|
||||
shipment = (
|
||||
self.session.query(Shipment).filter(Shipment.id == shipment_id).first()
|
||||
)
|
||||
@with_session
|
||||
def get_shipment_events(self, session, shipment_id):
|
||||
shipment = session.query(Shipment).filter(Shipment.id == shipment_id).first()
|
||||
return shipment.events if shipment else None
|
||||
|
||||
def get_latest_event(self, shipment_id):
|
||||
@with_session
|
||||
def get_latest_event(self, session, shipment_id):
|
||||
event = (
|
||||
self.session.query(Event)
|
||||
session.query(Event)
|
||||
.filter(Event.shipment_id == shipment_id)
|
||||
.order_by(Event.event_time.desc())
|
||||
.first()
|
||||
|
@ -112,14 +132,18 @@ class Database:
|
|||
|
||||
def make_migration(self, message):
|
||||
alembic_cfg = Config(Path(__file__).parent.parent / "alembic.ini")
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", self.engine.url.__to_string__(hide_password=False))
|
||||
migrations_dir = Path(__file__).parent.parent / 'migrations'
|
||||
alembic_cfg.set_main_option(
|
||||
"sqlalchemy.url", self.engine.url.__to_string__(hide_password=False)
|
||||
)
|
||||
migrations_dir = Path(__file__).parent.parent / "migrations"
|
||||
alembic_cfg.set_main_option("script_location", str(migrations_dir))
|
||||
command.revision(alembic_cfg, message=message, autogenerate=True)
|
||||
|
||||
def run_migrations(self):
|
||||
alembic_cfg = Config(Path(__file__).parent.parent / "alembic.ini")
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", self.engine.url.__to_string__(hide_password=False))
|
||||
migrations_dir = Path(__file__).parent.parent / 'migrations'
|
||||
alembic_cfg.set_main_option(
|
||||
"sqlalchemy.url", self.engine.url.__to_string__(hide_password=False)
|
||||
)
|
||||
migrations_dir = Path(__file__).parent.parent / "migrations"
|
||||
alembic_cfg.set_main_option("script_location", str(migrations_dir))
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
|
|
@ -144,7 +144,7 @@ class Tracker:
|
|||
logging.debug(f"No known events for {shipment.tracking_number}")
|
||||
|
||||
logging.debug(
|
||||
f"Latest upstream event for {shipment.tracking_number}: {events[0].event_description} - {events[0].event_time}"
|
||||
f"Latest upstream event for {shipment.tracking_number}: {events[-1].event_description} - {events[-1].event_time}"
|
||||
)
|
||||
|
||||
for event in events:
|
||||
|
|
Loading…
Reference in a new issue