Make it an importable module

Abandon DuckDB in favor of sqlite3
This commit is contained in:
Kumi 2023-05-25 07:28:28 +00:00
parent 3a1d1ea86a
commit 55809a9a39
Signed by: kumi
GPG key ID: ECBCC9082395383F
56 changed files with 234 additions and 828 deletions

1
.gitignore vendored
View file

@ -4,3 +4,4 @@ config.ini
venv/ venv/
*.pyc *.pyc
__pycache__/ __pycache__/
*.bak

3
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"python.formatting.provider": "black"
}

View file

@ -26,12 +26,61 @@ probably add more in the future, so the name is a bit misleading.
## Installation ## Installation
Simply clone this repository and install the requirements. To run the bot, you will need Python 3.10 or newer.
### Requirements The bot has been tested with Python 3.11 on Arch, but should work with any
current version, and should not require any special dependencies or operating
system features.
- Python 3.10 or later ### Production
- Requirements from `requirements.txt` (install with `pip install -r requirements.txt` in a venv)
The easiest way to install the bot is to use pip to install it directly from
[its Git repository](https://kumig.it/kumitterer/matrix-gptbot/):
```shell
# If desired, activate a venv first
python -m venv venv
. venv/bin/activate
# Install the bot
pip install git+https://kumig.it/kumitterer/matrix-gptbot.git
```
This will install the bot from the main branch and all required dependencies.
A release to PyPI is planned, but not yet available.
### Development
Clone the repository and install the requirements to a virtual environment.
```shell
# Clone the repository
git clone https://kumig.it/kumitterer/matrix-gptbot.git
cd matrix-gptbot
# If desired, activate a venv first
python -m venv venv
. venv/bin/activate
# Install the requirements
pip install -Ur requirements.txt
# Install the bot in editable mode
pip install -e .
# Go to the bot directory and start working
cd src/gptbot
```
Of course, you can also fork the repository on [GitHub](https://github.com/kumitterer/matrix-gptbot/)
and work on your own copy.
### Configuration ### Configuration

View file

@ -1 +0,0 @@
from .store import DuckDBStore

View file

@ -1,637 +0,0 @@
import duckdb
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore, Session
from nio.crypto import OlmAccount, OlmDevice
from random import SystemRandom
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from .dict import AttrDict
import json
class DuckDBStore(MatrixStore):
@property
def account_id(self):
id = self._get_account()[0] if self._get_account() else None
if id is None:
id = SystemRandom().randint(0, 2**16)
return id
def __init__(self, user_id, device_id, duckdb_conn):
self.conn = duckdb_conn
self.user_id = user_id
self.device_id = device_id
def _get_account(self):
cursor = self.conn.cursor()
cursor.execute(
"SELECT * FROM accounts WHERE user_id = ? AND device_id = ?",
(self.user_id, self.device_id),
)
account = cursor.fetchone()
cursor.close()
return account
def _get_device(self, device):
acc = self._get_account()
if not acc:
return None
cursor = self.conn.cursor()
cursor.execute(
"SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?",
(device.user_id, device.id, acc[0]),
)
device_entry = cursor.fetchone()
cursor.close()
return device_entry
# Implementing methods with DuckDB equivalents
def verify_device(self, device):
if self.is_device_verified(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.verified),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.verified
return True
def unverify_device(self, device):
if not self.is_device_verified(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.unset),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.unset
return True
def is_device_verified(self, device):
d = self._get_device(device)
if not d:
return False
cursor = self.conn.cursor()
cursor.execute(
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
)
trust_state = cursor.fetchone()
cursor.close()
if not trust_state:
return False
return trust_state[0] == TrustState.verified
def blacklist_device(self, device):
if self.is_device_blacklisted(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.blacklisted),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.blacklisted
return True
def unblacklist_device(self, device):
if not self.is_device_blacklisted(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.unset),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.unset
return True
def is_device_blacklisted(self, device):
d = self._get_device(device)
if not d:
return False
cursor = self.conn.cursor()
cursor.execute(
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
)
trust_state = cursor.fetchone()
cursor.close()
if not trust_state:
return False
return trust_state[0] == TrustState.blacklisted
def ignore_device(self, device):
if self.is_device_ignored(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], int(TrustState.ignored.value)),
)
self.conn.commit()
cursor.close()
return True
def ignore_devices(self, devices):
for device in devices:
self.ignore_device(device)
def unignore_device(self, device):
if not self.is_device_ignored(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.unset),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.unset
return True
def is_device_ignored(self, device):
d = self._get_device(device)
if not d:
return False
cursor = self.conn.cursor()
cursor.execute(
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
)
trust_state = cursor.fetchone()
cursor.close()
if not trust_state:
return False
return trust_state[0] == TrustState.ignored
def load_device_keys(self):
"""Load all the device keys from the database.
Returns DeviceStore containing the OlmDevices with the device keys.
"""
store = DeviceStore()
account = self.account_id
if not account:
return store
with self.conn.cursor() as cur:
cur.execute(
"SELECT * FROM device_keys WHERE account_id = ?",
(account,)
)
device_keys = cur.fetchall()
for d in device_keys:
cur.execute(
"SELECT * FROM keys WHERE device_id = ?",
(d[0],)
)
keys = cur.fetchall()
key_dict = {k[0]: k[1] for k in keys}
store.add(
OlmDevice(
d[2],
d[0],
key_dict,
display_name=d[3],
deleted=d[4],
)
)
return store
def save_device_keys(self, device_keys):
"""Save the provided device keys to the database."""
account = self.account_id
assert account
rows = []
for user_id, devices_dict in device_keys.items():
for device_id, device in devices_dict.items():
rows.append(
{
"account_id": account,
"user_id": user_id,
"device_id": device_id,
"display_name": device.display_name,
"deleted": device.deleted,
}
)
if not rows:
return
with self.conn.cursor() as cur:
for idx in range(0, len(rows), 100):
data = rows[idx: idx + 100]
cur.executemany(
"INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)",
[(r["account_id"], r["user_id"], r["device_id"],
r["display_name"], r["deleted"]) for r in data]
)
for user_id, devices_dict in device_keys.items():
for device_id, device in devices_dict.items():
cur.execute(
"UPDATE device_keys SET deleted = ? WHERE device_id = ?",
(device.deleted, device_id)
)
for key_type, key in device.keys.items():
cur.execute("""
INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?)
ON CONFLICT (key_type, device_id) DO UPDATE SET key = ?
""",
(key_type, key, device_id, key)
)
self.conn.commit()
def save_group_sessions(self, sessions):
with self.conn.cursor() as cur:
for session in sessions:
cur.execute("""
INSERT OR REPLACE INTO inbound_group_sessions (
session_id, sender_key, signing_key, room_id, pickle, account_id
) VALUES (?, ?, ?, ?, ?, ?)
""", (
session.id,
session.sender_key,
session.signing_key,
session.room_id,
session.pickle,
self.account_id
))
self.conn.commit()
def save_olm_sessions(self, sessions):
with self.conn.cursor() as cur:
for session in sessions:
cur.execute("""
INSERT OR REPLACE INTO olm_sessions (
session_id, sender_key, pickle, account_id
) VALUES (?, ?, ?, ?)
""", (
session.id,
session.sender_key,
session.pickle,
self.account_id
))
self.conn.commit()
def save_outbound_group_sessions(self, sessions):
with self.conn.cursor() as cur:
for session in sessions:
cur.execute("""
INSERT OR REPLACE INTO outbound_group_sessions (
room_id, session_id, pickle, account_id
) VALUES (?, ?, ?, ?)
""", (
session.room_id,
session.id,
session.pickle,
self.account_id
))
self.conn.commit()
def save_account(self, account: OlmAccount):
with self.conn.cursor() as cur:
cur.execute("""
INSERT OR REPLACE INTO accounts (
id, user_id, device_id, shared_account, pickle
) VALUES (?, ?, ?, ?, ?)
""", (
self.account_id,
self.user_id,
self.device_id,
account.shared,
account.pickle(self.pickle_key),
))
self.conn.commit()
def load_sessions(self):
session_store = SessionStore()
with self.conn.cursor() as cur:
cur.execute("""
SELECT
os.sender_key, os.session, os.creation_time
FROM
olm_sessions os
INNER JOIN
accounts a ON os.account_id = a.id
WHERE
a.id = ?
""", (self.account_id,))
for row in cur.fetchall():
sender_key, session_pickle, creation_time = row
session = Session.from_pickle(
session_pickle, creation_time, self.pickle_key)
session_store.add(sender_key, session)
return session_store
def load_inbound_group_sessions(self):
# type: () -> GroupSessionStore
"""Load all Olm sessions from the database.
Returns:
``GroupSessionStore`` object, containing all the loaded sessions.
"""
store = GroupSessionStore()
account = self.account_id
if not account:
return store
with self.conn.cursor() as cursor:
cursor.execute(
"SELECT * FROM inbound_group_sessions WHERE account_id = ?", (
account,)
)
for row in cursor.fetchall():
cursor.execute(
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
(row[1],),
)
chains = cursor.fetchall()
session = InboundGroupSession.from_pickle(
row[2].encode(),
row[3],
row[4],
row[5],
self.pickle_key,
[
chain[0]
for chain in chains
],
)
store.add(session)
return store
def load_outgoing_key_requests(self):
# type: () -> dict
"""Load all outgoing key requests from the database.
Returns:
``OutgoingKeyRequestStore`` object, containing all the loaded key requests.
"""
account = self.account_id
if not account:
return store
with self.conn.cursor() as cur:
cur.execute(
"SELECT * FROM outgoing_key_requests WHERE account_id = ?",
(account,)
)
rows = cur.fetchall()
return {
row[1]: OutgoingKeyRequest.from_response(AttrDict({
"id": row[0],
"account_id": row[1],
"request_id": row[2],
"session_id": row[3],
"room_id": row[4],
"algorithm": row[5],
})) for row in rows
}
def load_encrypted_rooms(self):
"""Load the set of encrypted rooms for this account.
Returns:
``Set`` containing room ids of encrypted rooms.
"""
account = self.account_id
if not account:
return set()
with self.conn.cursor() as cur:
cur.execute(
"SELECT room_id FROM encrypted_rooms WHERE account_id = ?",
(account,)
)
rows = cur.fetchall()
return {row[0] for row in rows}
def save_sync_token(self, token):
"""Save the given token"""
account = self.account_id
assert account
with self.conn.cursor() as cur:
cur.execute(
"INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)",
(account, token)
)
self.conn.commit()
def save_encrypted_rooms(self, rooms):
"""Save the set of room ids for this account."""
account = self.account_id
assert account
data = [(room_id, account) for room_id in rooms]
with self.conn.cursor() as cur:
for idx in range(0, len(data), 400):
rows = data[idx: idx + 400]
cur.executemany(
"INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)",
rows
)
self.conn.commit()
def save_session(self, sender_key, session):
"""Save the provided Olm session to the database.
Args:
sender_key (str): The curve key that owns the Olm session.
session (Session): The Olm session that will be pickled and
saved in the database.
"""
account = self.account_id
assert account
pickled_session = session.pickle(self.pickle_key)
with self.conn.cursor() as cur:
cur.execute(
"INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)",
(account, sender_key, pickled_session, session.id,
session.creation_time, session.use_time)
)
self.conn.commit()
def save_inbound_group_session(self, session):
"""Save the provided Megolm inbound group session to the database.
Args:
session (InboundGroupSession): The session to save.
"""
account = self.account_id
assert account
with self.conn.cursor() as cur:
# Insert a new session or update the existing one
query = """
INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT (account_id, sender_key, fp_key, room_id)
DO UPDATE SET session = excluded.session
"""
cur.execute(query, (account, session.sender_key,
session.ed25519, session.room_id, session.pickle(self.pickle_key)))
# Delete existing forwarded chains for the session
delete_query = """
DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?)
"""
cur.execute(
delete_query, (account, session.sender_key, session.ed25519, session.room_id))
# Insert new forwarded chains for the session
insert_query = """
INSERT INTO forwarded_chains (session_id, sender_key)
VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?)
"""
for chain in session.forwarding_chain:
cur.execute(
insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain))
def add_outgoing_key_request(self, key_request):
"""Add a new outgoing key request to the database.
Args:
key_request (OutgoingKeyRequest): The key request to add.
"""
account_id = self.account_id
with self.conn.cursor() as cursor:
cursor.execute(
"""
SELECT MAX(id) FROM outgoing_key_requests
"""
)
row = cursor.fetchone()
request_id = row[0] + 1 if row[0] else 1
cursor.execute(
"""
INSERT INTO outgoing_key_requests (id, account_id, request_id, session_id, room_id, algorithm)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (account_id, request_id) DO NOTHING
""",
(
request_id,
account_id,
key_request.request_id,
key_request.session_id,
key_request.room_id,
key_request.algorithm,
)
)
def load_account(self):
# type: () -> Optional[OlmAccount]
"""Load the Olm account from the database.
Returns:
``OlmAccount`` object, or ``None`` if it wasn't found for the
current device_id.
"""
cursor = self.conn.cursor()
query = """
SELECT pickle, shared_account
FROM accounts
WHERE device_id = ?;
"""
cursor.execute(query, (self.device_id,))
result = cursor.fetchone()
if not result:
return None
account_pickle, shared = result
return OlmAccount.from_pickle(account_pickle.encode(), self.pickle_key, shared)

View file

@ -106,15 +106,35 @@ Operator = Contact details not set
[Database] [Database]
# Settings for the DuckDB database. # Path of the main database
# If not defined, the bot will not be able to remember anything, and will not support encryption # Used to "remember" settings, etc.
# N.B.: Encryption doesn't work as it is supposed to anyway. #
Path = database.db Path = database.db
# Path of the Crypto Store - required to support encrypted rooms
# (not tested/supported yet)
#
CryptoStore = store.db
[TrackingMore] [TrackingMore]
# API key for TrackingMore # API key for TrackingMore
# If not defined, the bot will not be able to provide parcel tracking # If not defined, the bot will not be able to provide parcel tracking
# #
# APIKey = abcde-fghij-klmnop # APIKey = abcde-fghij-klmnop
[Replicate]
# API key for replicate.com
# Can be used to run lots of different AI models
# If not defined, the features that depend on it are not available
#
# APIKey = r8_alotoflettersandnumbershere
[HuggingFace]
# API key for Hugging Face
# Can be used to run lots of different AI models
# If not defined, the features that depend on it are not available
#
# APIKey = __________________________

BIN
database.db-journal Normal file

Binary file not shown.

View file

@ -1,5 +1,5 @@
[Unit] [Unit]
Description=GPTbot - A GPT bot for Matrix Description=GPTbot - A multifunctional Chatbot for Matrix
Requires=network.target Requires=network.target
[Service] [Service]

View file

@ -1,138 +0,0 @@
# Migration for Matrix Store
from datetime import datetime
def migration(conn):
with conn.cursor() as cursor:
# Create accounts table
cursor.execute("""
CREATE TABLE IF NOT EXISTS accounts (
id INTEGER PRIMARY KEY,
user_id VARCHAR NOT NULL,
device_id VARCHAR NOT NULL,
shared_account INTEGER NOT NULL,
pickle VARCHAR NOT NULL
);
""")
# Create device_keys table
cursor.execute("""
CREATE TABLE IF NOT EXISTS device_keys (
device_id TEXT PRIMARY KEY,
account_id INTEGER NOT NULL,
user_id TEXT NOT NULL,
display_name TEXT,
deleted BOOLEAN NOT NULL DEFAULT 0,
UNIQUE (account_id, user_id, device_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS keys (
key_type TEXT NOT NULL,
key TEXT NOT NULL,
device_id VARCHAR NOT NULL,
UNIQUE (key_type, device_id),
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
);
""")
# Create device_trust_state table
cursor.execute("""
CREATE TABLE IF NOT EXISTS device_trust_state (
device_id VARCHAR PRIMARY KEY,
state INTEGER NOT NULL,
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
);
""")
# Create olm_sessions table
cursor.execute("""
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
CREATE TABLE IF NOT EXISTS olm_sessions (
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
account_id INTEGER NOT NULL,
sender_key TEXT NOT NULL,
session BLOB NOT NULL,
session_id VARCHAR NOT NULL,
creation_time TIMESTAMP NOT NULL,
last_usage_date TIMESTAMP NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
);
""")
# Create inbound_group_sessions table
cursor.execute("""
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
account_id INTEGER NOT NULL,
session TEXT NOT NULL,
fp_key TEXT NOT NULL,
sender_key TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (account_id, sender_key, fp_key, room_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS forwarded_chains (
id INTEGER PRIMARY KEY,
session_id INTEGER NOT NULL,
sender_key TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
);
""")
# Create outbound_group_sessions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL,
room_id VARCHAR NOT NULL,
session_id VARCHAR NOT NULL UNIQUE,
session BLOB NOT NULL,
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
# Create outgoing_key_requests table
cursor.execute("""
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL,
request_id TEXT NOT NULL,
session_id TEXT NOT NULL,
room_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
UNIQUE (account_id, request_id)
);
""")
# Create encrypted_rooms table
cursor.execute("""
CREATE TABLE IF NOT EXISTS encrypted_rooms (
room_id TEXT NOT NULL,
account_id INTEGER NOT NULL,
PRIMARY KEY (room_id, account_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
# Create sync_tokens table
cursor.execute("""
CREATE TABLE IF NOT EXISTS sync_tokens (
account_id INTEGER PRIMARY KEY,
token TEXT NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
cursor.execute(
"INSERT INTO migrations (id, timestamp) VALUES (2, ?)",
(datetime.now(),)
)
conn.commit()

69
pyproject.toml Normal file
View file

@ -0,0 +1,69 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[project]
name = "matrix-gptbot"
version = "0.1.0-alpha1"
authors = [
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
]
description = "Multifunctional Chatbot for Matrix"
readme = "README.md"
license = { file="LICENSE" }
requires-python = ">=3.10"
packages = [
"src/gptbot"
]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"matrix-nio[e2e]",
"markdown2[all]",
"tiktoken",
"python-magic",
"pillow",
]
[project.optional-dependencies]
openai = [
"openai",
]
wolframalpha = [
"wolframalpha",
]
trackingmore = [
"trackingmore @ git+https://kumig.it/kumitterer/trackingmore-api-tool.git",
]
all = [
"matrix-gptbot[openai,wolframalpha,trackingmore]",
]
dev = [
"matrix-gptbot[all]",
"black",
]
[project.urls]
"Homepage" = "https://kumig.it/kumitterer/matrix-gptbot"
"Bug Tracker" = "https://kumig.it/kumitterer/matrix-gptbot/issues"
[project.scripts]
gptbot = "gptbot:main"
[tool.hatch.build.targets.wheel]
packages = ["src/gptbot"]

View file

@ -1,4 +1,4 @@
from classes.bot import GPTBot from .classes.bot import GPTBot
from argparse import ArgumentParser from argparse import ArgumentParser
from configparser import ConfigParser from configparser import ConfigParser
@ -15,7 +15,10 @@ if __name__ == "__main__":
# Parse command line arguments # Parse command line arguments
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument( parser.add_argument(
"--config", help="Path to config file (default: config.ini in working directory)", default="config.ini") "--config",
help="Path to config file (default: config.ini in working directory)",
default="config.ini",
)
args = parser.parse_args() args = parser.parse_args()
# Read config file # Read config file

View file

Before

Width:  |  Height:  |  Size: 186 KiB

After

Width:  |  Height:  |  Size: 186 KiB

View file

@ -1,10 +1,12 @@
from contextlib import closing
async def join_callback(response, bot): async def join_callback(response, bot):
bot.logger.log( bot.logger.log(
f"Join response received for room {response.room_id}", "debug") f"Join response received for room {response.room_id}", "debug")
bot.matrix_client.joined_rooms() bot.matrix_client.joined_rooms()
with bot.database.cursor() as cursor: with closing(bot.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,)) "SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
space = cursor.fetchone() space = cursor.fetchone()

View file

View file

@ -1,5 +1,4 @@
import markdown2 import markdown2
import duckdb
import tiktoken import tiktoken
import asyncio import asyncio
import functools import functools
@ -30,30 +29,34 @@ from nio import (
RoomCreateError, RoomCreateError,
) )
from nio.crypto import Olm from nio.crypto import Olm
from nio.store import SqliteStore
from typing import Optional, List from typing import Optional, List
from configparser import ConfigParser from configparser import ConfigParser
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from contextlib import closing
import uuid import uuid
import traceback import traceback
import json import json
import importlib.util
import sys
import sqlite3
from .logging import Logger from .logging import Logger
from migrations import migrate from ..migrations import migrate
from callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS from ..callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
from commands import COMMANDS from ..commands import COMMANDS
from .store import DuckDBStore
from .openai import OpenAI from .openai import OpenAI
from .wolframalpha import WolframAlpha from .wolframalpha import WolframAlpha
from .trackingmore import TrackingMore from .trackingmore import TrackingMore
class GPTBot: class GPTBot:
# Default values # Default values
database: Optional[duckdb.DuckDBPyConnection] = None database: Optional[sqlite3.Connection] = None
crypto_store_path: Optional[str|Path] = None
# Default name of rooms created by the bot # Default name of rooms created by the bot
display_name = default_room_name = "GPTBot" display_name = default_room_name = "GPTBot"
default_system_message: str = "You are a helpful assistant." default_system_message: str = "You are a helpful assistant."
@ -90,9 +93,11 @@ class GPTBot:
bot = cls() bot = cls()
# Set the database connection # Set the database connection
bot.database = duckdb.connect( bot.database = sqlite3.connect(
config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
bot.crypto_store_path = config["Database"]["CryptoStore"] if "Database" in config and "CryptoStore" in config["Database"] else None
# Override default values # Override default values
if "GPTBot" in config: if "GPTBot" in config:
bot.operator = config["GPTBot"].get("Operator", bot.operator) bot.operator = config["GPTBot"].get("Operator", bot.operator)
@ -290,7 +295,7 @@ class GPTBot:
""" """
room_id = room.room_id if isinstance(room, MatrixRoom) else room room_id = room.room_id if isinstance(room, MatrixRoom) else room
with self.database.cursor() as cursor: with closing(self.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_classification")) "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_classification"))
result = cursor.fetchone() result = cursor.fetchone()
@ -362,7 +367,7 @@ class GPTBot:
""" """
room_id = room.room_id room_id = room.room_id
with self.database.cursor() as cursor: with closing(self.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_timing")) "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_timing"))
result = cursor.fetchone() result = cursor.fetchone()
@ -584,10 +589,17 @@ class GPTBot:
self.logger.log( self.logger.log(
"No database connection set up, using in-memory database. Data will be lost on bot shutdown.") "No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
IN_MEMORY = True IN_MEMORY = True
self.database = duckdb.DuckDBPyConnection(":memory:") self.database = sqlite3.connect(":memory:")
self.logger.log("Running migrations...") self.logger.log("Running migrations...")
try:
before, after = migrate(self.database) before, after = migrate(self.database)
except sqlite3.DatabaseError as e:
self.logger.log(f"Error migrating database: {e}", "fatal")
self.logger.log("If you have just updated the bot, the previous version of the database may be incompatible with this version. Please delete the database file and try again.", "fatal")
exit(1)
if before != after: if before != after:
self.logger.log(f"Migrated from version {before} to {after}.") self.logger.log(f"Migrated from version {before} to {after}.")
else: else:
@ -597,14 +609,14 @@ class GPTBot:
client_config = AsyncClientConfig( client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=False) store_sync_tokens=True, encryption_enabled=False)
else: else:
matrix_store = DuckDBStore matrix_store = SqliteStore
client_config = AsyncClientConfig( client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=True, store=matrix_store) store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
self.matrix_client.config = client_config self.matrix_client.config = client_config
self.matrix_client.store = matrix_store( self.matrix_client.store = matrix_store(
self.matrix_client.user_id, self.matrix_client.user_id,
self.matrix_client.device_id, self.matrix_client.device_id,
self.database self.crypto_store_path or ""
) )
self.matrix_client.olm = Olm( self.matrix_client.olm = Olm(
@ -722,7 +734,7 @@ class GPTBot:
if isinstance(room, MatrixRoom): if isinstance(room, MatrixRoom):
room = room.room_id room = room.room_id
with self.database.cursor() as cursor: with closing(self.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room, "always_reply")) "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room, "always_reply"))
result = cursor.fetchone() result = cursor.fetchone()
@ -830,7 +842,7 @@ class GPTBot:
else: else:
room_id = room.room_id room_id = room.room_id
with self.database.cursor() as cur: with closing(self.database.cursor()) as cur:
cur.execute( cur.execute(
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
(room_id, "system_message") (room_id, "system_message")

View file

@ -24,7 +24,7 @@ for command in [
"space", "space",
]: ]:
function = getattr(import_module( function = getattr(import_module(
"commands." + command), "command_" + command) "." + command, "gptbot.commands"), "command_" + command)
COMMANDS[command] = function COMMANDS[command] = function
COMMANDS[None] = command_unknown COMMANDS[None] = command_unknown

View file

@ -2,6 +2,7 @@ from nio.events.room_events import RoomMessageText
from nio import RoomCreateError, RoomInviteError from nio import RoomCreateError, RoomInviteError
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
from contextlib import closing
async def command_newroom(room: MatrixRoom, event: RoomMessageText, bot): async def command_newroom(room: MatrixRoom, event: RoomMessageText, bot):
room_name = " ".join(event.body.split()[ room_name = " ".join(event.body.split()[
@ -23,7 +24,7 @@ async def command_newroom(room: MatrixRoom, event: RoomMessageText, bot):
await bot.send_message(room, f"Sorry, I was unable to invite you to the new room. Please try again later, or create a room manually.", True) await bot.send_message(room, f"Sorry, I was unable to invite you to the new room. Please try again later, or create a room manually.", True)
return return
with bot.database.cursor() as cursor: with closing(bot.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,)) "SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
space = cursor.fetchone() space = cursor.fetchone()

View file

@ -1,6 +1,8 @@
from nio.events.room_events import RoomMessageText from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
from contextlib import closing
async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot): async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
setting = event.body.split()[2] if len(event.body.split()) > 2 else None setting = event.body.split()[2] if len(event.body.split()) > 2 else None
@ -16,7 +18,7 @@ async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
if value: if value:
bot.logger.log("Adding system message...") bot.logger.log("Adding system message...")
with bot.database.cursor() as cur: with closing(bot.database.cursor()) as cur:
cur.execute( cur.execute(
"""INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?) """INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)
ON CONFLICT (room_id, setting) DO UPDATE SET value = ?;""", ON CONFLICT (room_id, setting) DO UPDATE SET value = ?;""",
@ -40,7 +42,7 @@ async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
bot.logger.log(f"Setting {setting} status for {room.room_id} to {value}...") bot.logger.log(f"Setting {setting} status for {room.room_id} to {value}...")
with bot.database.cursor() as cur: with closing(bot.database.cursor()) as cur:
cur.execute( cur.execute(
"""INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?) """INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)
ON CONFLICT (room_id, setting) DO UPDATE SET value = ?;""", ON CONFLICT (room_id, setting) DO UPDATE SET value = ?;""",
@ -55,7 +57,7 @@ async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
bot.logger.log(f"Retrieving {setting} status for {room.room_id}...") bot.logger.log(f"Retrieving {setting} status for {room.room_id}...")
with bot.database.cursor() as cur: with closing(bot.database.cursor()) as cur:
cur.execute( cur.execute(
"""SELECT value FROM room_settings WHERE room_id = ? AND setting = ?;""", """SELECT value FROM room_settings WHERE room_id = ? AND setting = ?;""",
(room.room_id, setting) (room.room_id, setting)

View file

@ -2,6 +2,8 @@ from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
from nio.responses import RoomInviteError from nio.responses import RoomInviteError
from contextlib import closing
async def command_space(room: MatrixRoom, event: RoomMessageText, bot): async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
if len(event.body.split()) == 3: if len(event.body.split()) == 3:
@ -10,7 +12,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
if request.lower() == "enable": if request.lower() == "enable":
bot.logger.log("Enabling space...") bot.logger.log("Enabling space...")
with bot.database.cursor() as cursor: with closing(bot.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,)) "SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
space = cursor.fetchone() space = cursor.fetchone()
@ -25,7 +27,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
"url": bot.logo_uri "url": bot.logo_uri
}, "") }, "")
with bot.database.cursor() as cursor: with closing(bot.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"INSERT INTO user_spaces (space_id, user_id) VALUES (?, ?)", (space, event.sender)) "INSERT INTO user_spaces (space_id, user_id) VALUES (?, ?)", (space, event.sender))
@ -48,7 +50,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
elif request.lower() == "disable": elif request.lower() == "disable":
bot.logger.log("Disabling space...") bot.logger.log("Disabling space...")
with bot.database.cursor() as cursor: with closing(bot.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,)) "SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
space = cursor.fetchone()[0] space = cursor.fetchone()[0]
@ -58,7 +60,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
await bot.send_message(room, "You don't have a space enabled.", True) await bot.send_message(room, "You don't have a space enabled.", True)
return return
with bot.database.cursor() as cursor: with closing(bot.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"UPDATE user_spaces SET active = FALSE WHERE user_id = ?", (event.sender,)) "UPDATE user_spaces SET active = FALSE WHERE user_id = ?", (event.sender,))
@ -69,7 +71,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
if request.lower() == "update": if request.lower() == "update":
bot.logger.log("Updating space...") bot.logger.log("Updating space...")
with bot.database.cursor() as cursor: with closing(bot.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,)) "SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
space = cursor.fetchone()[0] space = cursor.fetchone()[0]
@ -103,7 +105,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
if request.lower() == "invite": if request.lower() == "invite":
bot.logger.log("Inviting user to space...") bot.logger.log("Inviting user to space...")
with bot.database.cursor() as cursor: with closing(bot.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT space_id FROM user_spaces WHERE user_id = ?", (event.sender,)) "SELECT space_id FROM user_spaces WHERE user_id = ?", (event.sender,))
space = cursor.fetchone()[0] space = cursor.fetchone()[0]
@ -126,7 +128,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
await bot.send_message(room, "Invited you to the space.", True) await bot.send_message(room, "Invited you to the space.", True)
return return
with bot.database.cursor() as cursor: with closing(bot.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT active FROM user_spaces WHERE user_id = ?", (event.sender,)) "SELECT active FROM user_spaces WHERE user_id = ?", (event.sender,))
status = cursor.fetchone() status = cursor.fetchone()

View file

@ -1,6 +1,8 @@
from nio.events.room_events import RoomMessageText from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
from contextlib import closing
async def command_stats(room: MatrixRoom, event: RoomMessageText, bot): async def command_stats(room: MatrixRoom, event: RoomMessageText, bot):
bot.logger.log("Showing stats...") bot.logger.log("Showing stats...")
@ -10,7 +12,7 @@ async def command_stats(room: MatrixRoom, event: RoomMessageText, bot):
bot.send_message(room, "Sorry, I'm not connected to a database, so I don't have any statistics on your usage.", True) bot.send_message(room, "Sorry, I'm not connected to a database, so I don't have any statistics on your usage.", True)
return return
with bot.database.cursor() as cursor: with closing(bot.database.cursor()) as cursor:
cursor.execute( cursor.execute(
"SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,)) "SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,))
total_tokens = cursor.fetchone()[0] or 0 total_tokens = cursor.fetchone()[0] or 0

View file

@ -1,6 +1,8 @@
from nio.events.room_events import RoomMessageText from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
from contextlib import closing
async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, bot): async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, bot):
system_message = " ".join(event.body.split()[2:]) system_message = " ".join(event.body.split()[2:])
@ -8,7 +10,7 @@ async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, bot):
if system_message: if system_message:
bot.logger.log("Adding system message...") bot.logger.log("Adding system message...")
with bot.database.cursor() as cur: with closing(bot.database.cursor()) as cur:
cur.execute( cur.execute(
""" """
INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?) INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)

View file

@ -1,9 +1,10 @@
# Initial migration, token usage logging # Initial migration, token usage logging
from datetime import datetime from datetime import datetime
from contextlib import closing
def migration(conn): def migration(conn):
with conn.cursor() as cursor: with closing(conn.cursor()) as cursor:
cursor.execute( cursor.execute(
""" """
CREATE TABLE IF NOT EXISTS token_usage ( CREATE TABLE IF NOT EXISTS token_usage (

View file

@ -0,0 +1,7 @@
# Migration for Matrix Store - No longer used
from datetime import datetime
from contextlib import closing
def migration(conn):
pass

View file

@ -1,9 +1,10 @@
# Migration for custom system messages # Migration for custom system messages
from datetime import datetime from datetime import datetime
from contextlib import closing
def migration(conn): def migration(conn):
with conn.cursor() as cursor: with closing(conn.cursor()) as cursor:
cursor.execute( cursor.execute(
""" """
CREATE TABLE IF NOT EXISTS system_messages ( CREATE TABLE IF NOT EXISTS system_messages (
@ -11,7 +12,7 @@ def migration(conn):
message_id TEXT NOT NULL, message_id TEXT NOT NULL,
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
body TEXT NOT NULL, body TEXT NOT NULL,
timestamp BIGINT NOT NULL, timestamp BIGINT NOT NULL
) )
""" """
) )

View file

@ -1,9 +1,10 @@
# Migration to add API column to token usage table # Migration to add API column to token usage table
from datetime import datetime from datetime import datetime
from contextlib import closing
def migration(conn): def migration(conn):
with conn.cursor() as cursor: with closing(conn.cursor()) as cursor:
cursor.execute( cursor.execute(
""" """
ALTER TABLE token_usage ADD COLUMN api TEXT DEFAULT 'openai' ALTER TABLE token_usage ADD COLUMN api TEXT DEFAULT 'openai'

View file

@ -1,9 +1,10 @@
# Migration to add room settings table # Migration to add room settings table
from datetime import datetime from datetime import datetime
from contextlib import closing
def migration(conn): def migration(conn):
with conn.cursor() as cursor: with closing(conn.cursor()) as cursor:
cursor.execute( cursor.execute(
""" """
CREATE TABLE IF NOT EXISTS room_settings ( CREATE TABLE IF NOT EXISTS room_settings (

View file

@ -1,9 +1,10 @@
# Migration to drop primary key constraint from token_usage table # Migration to drop primary key constraint from token_usage table
from datetime import datetime from datetime import datetime
from contextlib import closing
def migration(conn): def migration(conn):
with conn.cursor() as cursor: with closing(conn.cursor()) as cursor:
cursor.execute( cursor.execute(
""" """
CREATE TABLE token_usage_temp ( CREATE TABLE token_usage_temp (

View file

@ -1,9 +1,10 @@
# Migration to add user_spaces table # Migration to add user_spaces table
from datetime import datetime from datetime import datetime
from contextlib import closing
def migration(conn): def migration(conn):
with conn.cursor() as cursor: with closing(conn.cursor()) as cursor:
cursor.execute( cursor.execute(
""" """
CREATE TABLE user_spaces ( CREATE TABLE user_spaces (

View file

@ -1,9 +1,10 @@
# Migration to add settings table # Migration to add settings table
from datetime import datetime from datetime import datetime
from contextlib import closing
def migration(conn): def migration(conn):
with conn.cursor() as cursor: with closing(conn.cursor()) as cursor:
cursor.execute( cursor.execute(
""" """
CREATE TABLE IF NOT EXISTS settings ( CREATE TABLE IF NOT EXISTS settings (