feat: Add shipment disabled flag to database

This commit adds a new column named "disabled" to the Shipment table in the database, allowing shipments to be disabled. The default value for this column is false. As part of this change, a new migration script has been added to update the database schema. Additionally, when retrieving shipments in the Database class, the get_shipments method now has an optional ignore_disabled parameter, which filters out disabled shipments if set to True. In the Tracker class, a database timeout handling has been added, and the database connection is disposed upon encountering a timeout.
This commit is contained in:
Kumi 2023-09-06 12:03:54 +02:00
parent c3633450bb
commit 74883d15bd
Signed by: kumi
GPG key ID: ECBCC9082395383F
3 changed files with 42 additions and 4 deletions

View file

@ -1,5 +1,4 @@
from sqlalchemy import Column, Integer, String from sqlalchemy import Column, Integer, String, Boolean, create_engine, ForeignKey, event
from sqlalchemy import create_engine, ForeignKey
from sqlalchemy.orm import sessionmaker, relationship, scoped_session from sqlalchemy.orm import sessionmaker, relationship, scoped_session
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
@ -8,6 +7,7 @@ from alembic import command
import json import json
import time import time
import logging
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
@ -37,6 +37,7 @@ class Shipment(Base):
tracking_number = Column(String) tracking_number = Column(String)
carrier = Column(String) carrier = Column(String)
description = Column(String) description = Column(String)
disabled = Column(Boolean, default=False)
events = relationship("Event") events = relationship("Event")
@ -53,8 +54,12 @@ class Event(Base):
class Database: class Database:
def __init__(self, database_uri): def __init__(self, database_uri):
self.engine = create_engine(database_uri) self.engine = create_engine(database_uri, pool_size=20, max_overflow=20)
self.session = scoped_session(sessionmaker(bind=self.engine)) self.session = scoped_session(sessionmaker(bind=self.engine))
event.listen(self.engine, "connect", lambda _, __: logging.debug("DB connected"))
event.listen(self.engine, "close", lambda _, __: logging.debug("DB connection closed"))
self.run_migrations() self.run_migrations()
@with_session @with_session
@ -94,8 +99,12 @@ class Database:
return shipment return shipment
@with_session @with_session
def get_shipments(self, session): def get_shipments(self, session, ignore_disabled=True):
shipments = session.query(Shipment).all() shipments = session.query(Shipment).all()
if ignore_disabled:
shipments = [s for s in shipments if not s.disabled]
return shipments return shipments
def create_event(self, shipment_id, event_time, event_description, raw_event): def create_event(self, shipment_id, event_time, event_description, raw_event):

View file

@ -169,6 +169,7 @@ class Tracker:
except sqlalchemy.exc.TimeoutError: except sqlalchemy.exc.TimeoutError:
logging.warning("Database timeout while processing shipments") logging.warning("Database timeout while processing shipments")
self.db.engine.dispose()
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info("Keyboard interrupt, exiting") logging.info("Keyboard interrupt, exiting")

View file

@ -0,0 +1,28 @@
"""Shipment.disabled
Revision ID: 91ca1665ca83
Revises: 770fdbef1f4e
Create Date: 2023-09-06 11:44:48.214655
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '91ca1665ca83'
down_revision: Union[str, None] = '770fdbef1f4e'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column('shipments', sa.Column('disabled', sa.Boolean(), nullable=True))
op.execute("UPDATE shipments SET disabled = true WHERE carrier = ''")
def downgrade() -> None:
op.execute("UPDATE shipments SET carrier = '' WHERE disabled = true")
op.drop_column('shipments', 'disabled')