Moving migrations to subdirectory

Add option for custom system messages per room
Fixing some methods in store
This commit is contained in:
Kumi 2023-04-24 08:48:59 +00:00
parent f20b762558
commit 2bbc6a33ca
Signed by: kumi
GPG key ID: ECBCC9082395383F
10 changed files with 358 additions and 207 deletions

View file

@ -6,6 +6,9 @@ to generate responses to messages in a Matrix room.
It will also save a log of the spent tokens to a DuckDB database
(database.db in the working directory, by default).
Note that this bot does not yet support encryption - this is still work in
progress.
## Installation
Simply clone this repository and install the requirements.

View file

@ -1,6 +1,6 @@
import duckdb
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore, Session
from nio.crypto import OlmAccount, OlmDevice
from random import SystemRandom
@ -24,150 +24,6 @@ class DuckDBStore(MatrixStore):
self.conn = duckdb_conn
self.user_id = user_id
self.device_id = device_id
self._create_tables()
def _create_tables(self):
with self.conn.cursor() as cursor:
cursor.execute("""
DROP TABLE IF EXISTS sync_tokens CASCADE;
DROP TABLE IF EXISTS encrypted_rooms CASCADE;
DROP TABLE IF EXISTS outgoing_key_requests CASCADE;
DROP TABLE IF EXISTS forwarded_chains CASCADE;
DROP TABLE IF EXISTS outbound_group_sessions CASCADE;
DROP TABLE IF EXISTS inbound_group_sessions CASCADE;
DROP TABLE IF EXISTS olm_sessions CASCADE;
DROP TABLE IF EXISTS device_trust_state CASCADE;
DROP TABLE IF EXISTS keys CASCADE;
DROP TABLE IF EXISTS device_keys_key CASCADE;
DROP TABLE IF EXISTS device_keys CASCADE;
DROP TABLE IF EXISTS accounts CASCADE;
""")
# Create accounts table
cursor.execute("""
CREATE TABLE IF NOT EXISTS accounts (
id INTEGER PRIMARY KEY,
user_id VARCHAR NOT NULL,
device_id VARCHAR NOT NULL,
shared_account INTEGER NOT NULL,
pickle VARCHAR NOT NULL
);
""")
# Create device_keys table
cursor.execute("""
CREATE TABLE IF NOT EXISTS device_keys (
device_id TEXT PRIMARY KEY,
account_id INTEGER NOT NULL,
user_id TEXT NOT NULL,
display_name TEXT,
deleted BOOLEAN NOT NULL DEFAULT 0,
UNIQUE (account_id, user_id, device_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS keys (
key_type TEXT NOT NULL,
key TEXT NOT NULL,
device_id VARCHAR NOT NULL,
UNIQUE (key_type, device_id),
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
);
""")
# Create device_trust_state table
cursor.execute("""
CREATE TABLE IF NOT EXISTS device_trust_state (
device_id VARCHAR PRIMARY KEY,
state INTEGER NOT NULL,
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
);
""")
# Create olm_sessions table
cursor.execute("""
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
CREATE TABLE IF NOT EXISTS olm_sessions (
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
account_id INTEGER NOT NULL,
sender_key TEXT NOT NULL,
session BLOB NOT NULL,
session_id VARCHAR NOT NULL,
creation_time TIMESTAMP NOT NULL,
last_usage_date TIMESTAMP NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
);
""")
# Create inbound_group_sessions table
cursor.execute("""
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
account_id INTEGER NOT NULL,
session TEXT NOT NULL,
fp_key TEXT NOT NULL,
sender_key TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (account_id, sender_key, fp_key, room_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS forwarded_chains (
id INTEGER PRIMARY KEY,
session_id INTEGER NOT NULL,
sender_key TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
);
""")
# Create outbound_group_sessions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL,
room_id VARCHAR NOT NULL,
session_id VARCHAR NOT NULL UNIQUE,
session BLOB NOT NULL,
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
# Create outgoing_key_requests table
cursor.execute("""
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL,
request_id TEXT NOT NULL,
session_id TEXT NOT NULL,
room_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
UNIQUE (account_id, request_id)
);
""")
# Create encrypted_rooms table
cursor.execute("""
CREATE TABLE IF NOT EXISTS encrypted_rooms (
room_id TEXT NOT NULL,
account_id INTEGER NOT NULL,
PRIMARY KEY (room_id, account_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
# Create sync_tokens table
cursor.execute("""
CREATE TABLE IF NOT EXISTS sync_tokens (
account_id INTEGER PRIMARY KEY,
token TEXT NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
def _get_account(self):
cursor = self.conn.cursor()
@ -387,18 +243,18 @@ class DuckDBStore(MatrixStore):
for d in device_keys:
cur.execute(
"SELECT * FROM keys WHERE device_id = ?",
(d["id"],)
(d[0],)
)
keys = cur.fetchall()
key_dict = {k["key_type"]: k["key"] for k in keys}
key_dict = {k[0]: k[1] for k in keys}
store.add(
OlmDevice(
d["user_id"],
d["device_id"],
d[2],
d[0],
key_dict,
display_name=d["display_name"],
deleted=d["deleted"],
display_name=d[3],
deleted=d[4],
)
)
@ -561,18 +417,21 @@ class DuckDBStore(MatrixStore):
)
for row in cursor.fetchall():
cursor.execute(
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
(row[1],),
)
chains = cursor.fetchall()
session = InboundGroupSession.from_pickle(
row["session"],
row["fp_key"],
row["sender_key"],
row["room_id"],
row[2].encode(),
row[3],
row[4],
row[5],
self.pickle_key,
[
chain["sender_key"]
for chain in cursor.execute(
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
(row["id"],),
)
chain[0]
for chain in chains
],
)
store.add(session)
@ -621,7 +480,7 @@ class DuckDBStore(MatrixStore):
)
rows = cur.fetchall()
return {row["room_id"] for row in rows}
return {row[0] for row in rows}
def save_sync_token(self, token):
"""Save the given token"""
@ -727,4 +586,29 @@ class DuckDBStore(MatrixStore):
key_request.room_id,
key_request.algorithm,
)
)
)
def load_account(self):
# type: () -> Optional[OlmAccount]
"""Load the Olm account from the database.
Returns:
``OlmAccount`` object, or ``None`` if it wasn't found for the
current device_id.
"""
cursor = self.conn.cursor()
query = """
SELECT pickle, shared_account
FROM accounts
WHERE device_id = ?;
"""
cursor.execute(query, (self.device_id,))
result = cursor.fetchone()
if not result:
return None
account_pickle, shared = result
return OlmAccount.from_pickle(account_pickle.encode(), self.pickle_key, shared)

View file

@ -5,6 +5,7 @@ from .botinfo import command_botinfo
from .unknown import command_unknown
from .coin import command_coin
from .ignoreolder import command_ignoreolder
from .systemmessage import command_systemmessage
COMMANDS = {
"help": command_help,
@ -13,5 +14,6 @@ COMMANDS = {
"botinfo": command_botinfo,
"coin": command_coin,
"ignoreolder": command_ignoreolder,
"systemmessage": command_systemmessage,
None: command_unknown,
}

34
commands/systemmessage.py Normal file
View file

@ -0,0 +1,34 @@
from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom
async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, context: dict):
system_message = " ".join(event.body.split()[2:])
if system_message:
context["logger"]("Adding system message...")
with context["database"].cursor() as cur:
cur.execute(
"INSERT INTO system_messages (room_id, message_id, user_id, body, timestamp) VALUES (?, ?, ?, ?, ?)",
(room.room_id, event.event_id, event.sender,
system_message, event.server_timestamp)
)
return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message stored: {system_message}"}
context["logger"]("Retrieving system message...")
with context["database"].cursor() as cur:
cur.execute(
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
(room.room_id,)
)
system_message = cur.fetchone()
if system_message is None:
system_message = context.get("system_message", "No system message set")
elif context.get("force_system_message") and context.get("system_message"):
system_message = system_message + "\n\n" + context["system_message"]
return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message: {system_message}"}

View file

@ -69,10 +69,15 @@ AccessToken = syt_yoursynapsetoken
#
# SystemMessage = You are a helpful bot.
# Force inclusion of the SystemMessage defined above if one is defined on per-room level
# If no custom message is defined for the room, SystemMessage is always included
#
# ForceSystemMessage = 0
[Database]
# Settings for the DuckDB database.
# Currently only used to store details on spent tokens per room.
# If not defined, the bot will not store this data.
# If not defined, the bot will not be able to remember anything, and will not support encryption
# N.B.: Encryption doesn't work as it is supposed to anyway.
Path = database.db

104
gptbot.py
View file

@ -23,6 +23,7 @@ from typing import List, Dict, Union, Optional
from commands import COMMANDS
from classes import DuckDBStore
from migrations import MIGRATIONS
def logging(message: str, log_level: str = "info"):
@ -35,6 +36,7 @@ CONTEXT = {
"database": False,
"default_room_name": "GPTBot",
"system_message": "You are a helpful assistant.",
"force_system_message": False,
"max_tokens": 3000,
"max_messages": 20,
"model": "gpt-3.5-turbo",
@ -48,6 +50,8 @@ async def gpt_query(messages: list, model: Optional[str] = None):
model = model or CONTEXT["model"]
logging(f"Querying GPT with {len(messages)} messages")
logging(messages, "debug")
try:
response = openai.ChatCompletion.create(
model=model,
@ -143,8 +147,9 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
client = kwargs.get("client") or CONTEXT["client"]
database = kwargs.get("database") or CONTEXT["database"]
system_message = kwargs.get("system_message") or CONTEXT["system_message"]
max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
system_message = kwargs.get("system_message") or CONTEXT["system_message"]
force_system_message = kwargs.get("force_system_message") or CONTEXT["force_system_message"]
await client.room_typing(room.room_id, True)
@ -152,6 +157,12 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
last_messages = await fetch_last_n_messages(room.room_id, 20)
system_message = get_system_message(room, {
"database": database,
"system_message": system_message,
"force_system_message": force_system_message,
})
chat_messages = [{"role": "system", "content": system_message}]
for message in last_messages:
@ -163,8 +174,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
# Truncate messages to fit within the token limit
truncated_messages = truncate_messages_to_fit_tokens(
chat_messages, max_tokens - 1)
chat_messages, max_tokens - 1, system_message=system_message)
response, tokens_used = await gpt_query(truncated_messages)
if response:
@ -204,10 +214,29 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt
if message:
room_id, event, content = message
rooms = await context["client"].joined_rooms()
await send_message(context["client"].rooms[room_id], content["body"],
True if content["msgtype"] == "m.notice" else False, context["client"])
def get_system_message(room: MatrixRoom, context: Optional[dict]) -> str:
context = context or CONTEXT
default = context.get("system_message")
with context["database"].cursor() as cur:
cur.execute(
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
(room.room_id,)
)
system_message = cur.fetchone()
complete = ((default if ((not system_message) or context["force_system_message"]) else "") + (
"\n\n" + system_message[0] if system_message else "")).strip()
return complete
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
context = kwargs.get("context") or CONTEXT
@ -286,7 +315,8 @@ async def send_message(room: MatrixRoom, message: str, notice: bool = False, cli
)
if msgtype != "m.reaction":
response = client.encrypt(room.room_id, "m.room.message", msgcontent)
response = client.encrypt(
room.room_id, "m.room.message", msgcontent)
msgtype, content = response
except Exception as e:
@ -318,7 +348,7 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None):
if isinstance(response, JoinResponse):
logging(response, "debug")
rooms = await client.joined_rooms()
await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
await send_message(client.rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
else:
logging(f"Error joining room {room_id}: {response}", "error")
@ -408,13 +438,16 @@ async def init(config: ConfigParser):
if "MaxMessages" in config["OpenAI"]:
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
# Listen for SIGTERM
# Override defaults with config
def sigterm_handler(_signo, _stack_frame):
logging("Received SIGTERM - exiting...")
exit()
signal.signal(signal.SIGTERM, sigterm_handler)
if "GPTBot" in config:
if "SystemMessage" in config["GPTBot"]:
CONTEXT["system_message"] = config["GPTBot"]["SystemMessage"]
if "DefaultRoomName" in config["GPTBot"]:
CONTEXT["default_room_name"] = config["GPTBot"]["DefaultRoomName"]
if "ForceSystemMessage" in config["GPTBot"]:
CONTEXT["force_system_message"] = config["GPTBot"].getboolean(
"ForceSystemMessage")
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
@ -431,6 +464,14 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
await client.close()
return
# Listen for SIGTERM
def sigterm_handler(_signo, _stack_frame):
logging("Received SIGTERM - exiting...")
exit()
signal.signal(signal.SIGTERM, sigterm_handler)
logging("Starting bot...")
client.add_response_callback(sync_cb, SyncResponse)
@ -460,9 +501,9 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
def initialize_database(path: os.PathLike):
logging("Initializing database...")
database = duckdb.connect(path)
conn = duckdb.connect(path)
with database.cursor() as cursor:
with conn.cursor() as cursor:
# Get the latest migration ID if the migrations table exists
try:
cursor.execute(
@ -472,40 +513,17 @@ def initialize_database(path: os.PathLike):
)
latest_migration = int(cursor.fetchone()[0])
except:
latest_migration = 0
# Version 1
for migration, function in MIGRATIONS.items():
if latest_migration < migration:
logging(f"Running migration {migration}...")
function(conn)
latest_migration = migration
if latest_migration < 1:
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS token_usage (
message_id TEXT PRIMARY KEY,
room_id TEXT NOT NULL,
tokens INTEGER NOT NULL,
timestamp TIMESTAMP NOT NULL
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER NOT NULL,
timestamp TIMESTAMP NOT NULL
)
"""
)
cursor.execute(
"INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
(datetime.now(),)
)
database.commit()
return database
return conn
async def get_device_id(access_token, homeserver):

11
migrations/__init__.py Normal file
View file

@ -0,0 +1,11 @@
from collections import OrderedDict
from .migration_1 import migration as migration_1
from .migration_2 import migration as migration_2
from .migration_3 import migration as migration_3
MIGRATIONS = OrderedDict()
MIGRATIONS[1] = migration_1
MIGRATIONS[2] = migration_2
MIGRATIONS[3] = migration_3

32
migrations/migration_1.py Normal file
View file

@ -0,0 +1,32 @@
# Initial migration, token usage logging
from datetime import datetime
def migration(conn):
with conn.cursor() as cursor:
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS token_usage (
message_id TEXT PRIMARY KEY,
room_id TEXT NOT NULL,
tokens INTEGER NOT NULL,
timestamp TIMESTAMP NOT NULL
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER NOT NULL,
timestamp TIMESTAMP NOT NULL
)
"""
)
cursor.execute(
"INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
(datetime.now(),)
)
conn.commit()

138
migrations/migration_2.py Normal file
View file

@ -0,0 +1,138 @@
# Migration for Matrix Store
from datetime import datetime
def migration(conn):
with conn.cursor() as cursor:
# Create accounts table
cursor.execute("""
CREATE TABLE IF NOT EXISTS accounts (
id INTEGER PRIMARY KEY,
user_id VARCHAR NOT NULL,
device_id VARCHAR NOT NULL,
shared_account INTEGER NOT NULL,
pickle VARCHAR NOT NULL
);
""")
# Create device_keys table
cursor.execute("""
CREATE TABLE IF NOT EXISTS device_keys (
device_id TEXT PRIMARY KEY,
account_id INTEGER NOT NULL,
user_id TEXT NOT NULL,
display_name TEXT,
deleted BOOLEAN NOT NULL DEFAULT 0,
UNIQUE (account_id, user_id, device_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS keys (
key_type TEXT NOT NULL,
key TEXT NOT NULL,
device_id VARCHAR NOT NULL,
UNIQUE (key_type, device_id),
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
);
""")
# Create device_trust_state table
cursor.execute("""
CREATE TABLE IF NOT EXISTS device_trust_state (
device_id VARCHAR PRIMARY KEY,
state INTEGER NOT NULL,
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
);
""")
# Create olm_sessions table
cursor.execute("""
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
CREATE TABLE IF NOT EXISTS olm_sessions (
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
account_id INTEGER NOT NULL,
sender_key TEXT NOT NULL,
session BLOB NOT NULL,
session_id VARCHAR NOT NULL,
creation_time TIMESTAMP NOT NULL,
last_usage_date TIMESTAMP NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
);
""")
# Create inbound_group_sessions table
cursor.execute("""
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
account_id INTEGER NOT NULL,
session TEXT NOT NULL,
fp_key TEXT NOT NULL,
sender_key TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (account_id, sender_key, fp_key, room_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS forwarded_chains (
id INTEGER PRIMARY KEY,
session_id INTEGER NOT NULL,
sender_key TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
);
""")
# Create outbound_group_sessions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL,
room_id VARCHAR NOT NULL,
session_id VARCHAR NOT NULL UNIQUE,
session BLOB NOT NULL,
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
# Create outgoing_key_requests table
cursor.execute("""
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL,
request_id TEXT NOT NULL,
session_id TEXT NOT NULL,
room_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
UNIQUE (account_id, request_id)
);
""")
# Create encrypted_rooms table
cursor.execute("""
CREATE TABLE IF NOT EXISTS encrypted_rooms (
room_id TEXT NOT NULL,
account_id INTEGER NOT NULL,
PRIMARY KEY (room_id, account_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
# Create sync_tokens table
cursor.execute("""
CREATE TABLE IF NOT EXISTS sync_tokens (
account_id INTEGER PRIMARY KEY,
token TEXT NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
cursor.execute(
"INSERT INTO migrations (id, timestamp) VALUES (2, ?)",
(datetime.now(),)
)
conn.commit()

24
migrations/migration_3.py Normal file
View file

@ -0,0 +1,24 @@
# Migration for custom system messages
from datetime import datetime
def migration(conn):
with conn.cursor() as cursor:
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS system_messages (
room_id TEXT NOT NULL,
message_id TEXT NOT NULL,
user_id TEXT NOT NULL,
body TEXT NOT NULL,
timestamp BIGINT NOT NULL,
)
"""
)
cursor.execute(
"INSERT INTO migrations (id, timestamp) VALUES (3, ?)",
(datetime.now(),)
)
conn.commit()