diff --git a/src/trackbert/classes/database.py b/src/trackbert/classes/database.py index f49c1a1..a42ab38 100644 --- a/src/trackbert/classes/database.py +++ b/src/trackbert/classes/database.py @@ -1,5 +1,4 @@ -from sqlalchemy import Column, Integer, String -from sqlalchemy import create_engine, ForeignKey +from sqlalchemy import Column, Integer, String, Boolean, create_engine, ForeignKey, event from sqlalchemy.orm import sessionmaker, relationship, scoped_session from sqlalchemy.ext.declarative import declarative_base @@ -8,6 +7,7 @@ from alembic import command import json import time +import logging from functools import wraps from pathlib import Path @@ -37,6 +37,7 @@ class Shipment(Base): tracking_number = Column(String) carrier = Column(String) description = Column(String) + disabled = Column(Boolean, default=False) events = relationship("Event") @@ -53,8 +54,12 @@ class Event(Base): class Database: def __init__(self, database_uri): - self.engine = create_engine(database_uri) + self.engine = create_engine(database_uri, pool_size=20, max_overflow=20) self.session = scoped_session(sessionmaker(bind=self.engine)) + + event.listen(self.engine, "connect", lambda _, __: logging.debug("DB connected")) + event.listen(self.engine, "close", lambda _, __: logging.debug("DB connection closed")) + self.run_migrations() @with_session @@ -94,8 +99,12 @@ class Database: return shipment @with_session - def get_shipments(self, session): + def get_shipments(self, session, ignore_disabled=True): shipments = session.query(Shipment).all() + + if ignore_disabled: + shipments = [s for s in shipments if not s.disabled] + return shipments def create_event(self, shipment_id, event_time, event_description, raw_event): diff --git a/src/trackbert/classes/tracker.py b/src/trackbert/classes/tracker.py index 6617231..3d9e4eb 100644 --- a/src/trackbert/classes/tracker.py +++ b/src/trackbert/classes/tracker.py @@ -169,6 +169,7 @@ class Tracker: except sqlalchemy.exc.TimeoutError: logging.warning("Database timeout while processing shipments") + self.db.engine.dispose() except KeyboardInterrupt: logging.info("Keyboard interrupt, exiting") diff --git a/src/trackbert/migrations/versions/91ca1665ca83_shipment_disabled.py b/src/trackbert/migrations/versions/91ca1665ca83_shipment_disabled.py new file mode 100644 index 0000000..72a9544 --- /dev/null +++ b/src/trackbert/migrations/versions/91ca1665ca83_shipment_disabled.py @@ -0,0 +1,28 @@ +"""Shipment.disabled + +Revision ID: 91ca1665ca83 +Revises: 770fdbef1f4e +Create Date: 2023-09-06 11:44:48.214655 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '91ca1665ca83' +down_revision: Union[str, None] = '770fdbef1f4e' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('shipments', sa.Column('disabled', sa.Boolean(), nullable=True)) + op.execute("UPDATE shipments SET disabled = true WHERE carrier = ''") + + +def downgrade() -> None: + op.execute("UPDATE shipments SET carrier = '' WHERE disabled = true") + op.drop_column('shipments', 'disabled')