diff --git a/classes/__init__.py b/classes/__init__.py new file mode 100644 index 0000000..0e22a56 --- /dev/null +++ b/classes/__init__.py @@ -0,0 +1 @@ +from .store import DuckDBStore \ No newline at end of file diff --git a/classes/store.py b/classes/store.py new file mode 100644 index 0000000..8de4431 --- /dev/null +++ b/classes/store.py @@ -0,0 +1,730 @@ +import duckdb + +from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore +from nio.crypto import OlmAccount, OlmDevice + +from random import SystemRandom +from collections import defaultdict +from typing import Dict, List, Optional, Tuple + +import json + + +class DuckDBStore(MatrixStore): + @property + def account_id(self): + id = self._get_account()[0] if self._get_account() else None + + if id is None: + id = SystemRandom().randint(0, 2**16) + + return id + + def __init__(self, user_id, device_id, duckdb_conn): + self.conn = duckdb_conn + self.user_id = user_id + self.device_id = device_id + self._create_tables() + + def _create_tables(self): + with self.conn.cursor() as cursor: + cursor.execute(""" + DROP TABLE IF EXISTS sync_tokens CASCADE; + DROP TABLE IF EXISTS encrypted_rooms CASCADE; + DROP TABLE IF EXISTS outgoing_key_requests CASCADE; + DROP TABLE IF EXISTS forwarded_chains CASCADE; + DROP TABLE IF EXISTS outbound_group_sessions CASCADE; + DROP TABLE IF EXISTS inbound_group_sessions CASCADE; + DROP TABLE IF EXISTS olm_sessions CASCADE; + DROP TABLE IF EXISTS device_trust_state CASCADE; + DROP TABLE IF EXISTS keys CASCADE; + DROP TABLE IF EXISTS device_keys_key CASCADE; + DROP TABLE IF EXISTS device_keys CASCADE; + DROP TABLE IF EXISTS accounts CASCADE; + """) + + # Create accounts table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS accounts ( + id INTEGER PRIMARY KEY, + user_id VARCHAR NOT NULL, + device_id VARCHAR NOT NULL, + shared_account INTEGER NOT NULL, + pickle VARCHAR NOT NULL + ); + """) + + # Create device_keys table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS device_keys ( + device_id TEXT PRIMARY KEY, + account_id INTEGER NOT NULL, + user_id TEXT NOT NULL, + display_name TEXT, + deleted BOOLEAN NOT NULL DEFAULT 0, + UNIQUE (account_id, user_id, device_id), + FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS keys ( + key_type TEXT NOT NULL, + key TEXT NOT NULL, + device_id VARCHAR NOT NULL, + UNIQUE (key_type, device_id), + FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE + ); + """) + + # Create device_trust_state table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS device_trust_state ( + device_id VARCHAR PRIMARY KEY, + state INTEGER NOT NULL, + FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE + ); + """) + + # Create olm_sessions table + cursor.execute(""" + CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1; + + CREATE TABLE IF NOT EXISTS olm_sessions ( + id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'), + account_id INTEGER NOT NULL, + sender_key TEXT NOT NULL, + session BLOB NOT NULL, + session_id VARCHAR NOT NULL, + creation_time TIMESTAMP NOT NULL, + last_usage_date TIMESTAMP NOT NULL, + FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE + ); + """) + + # Create inbound_group_sessions table + cursor.execute(""" + CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1; + + CREATE TABLE IF NOT EXISTS inbound_group_sessions ( + id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'), + account_id INTEGER NOT NULL, + session TEXT NOT NULL, + fp_key TEXT NOT NULL, + sender_key TEXT NOT NULL, + room_id TEXT NOT NULL, + UNIQUE (account_id, sender_key, fp_key, room_id), + FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS forwarded_chains ( + id INTEGER PRIMARY KEY, + session_id INTEGER NOT NULL, + sender_key TEXT NOT NULL, + FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE + ); + """) + + # Create outbound_group_sessions table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS outbound_group_sessions ( + id INTEGER PRIMARY KEY, + account_id INTEGER NOT NULL, + room_id VARCHAR NOT NULL, + session_id VARCHAR NOT NULL UNIQUE, + session BLOB NOT NULL, + FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE + ); + """) + + # Create outgoing_key_requests table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS outgoing_key_requests ( + id INTEGER PRIMARY KEY, + account_id INTEGER NOT NULL, + request_id TEXT NOT NULL, + session_id TEXT NOT NULL, + room_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE, + UNIQUE (account_id, request_id) + ); + + """) + + # Create encrypted_rooms table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS encrypted_rooms ( + room_id TEXT NOT NULL, + account_id INTEGER NOT NULL, + PRIMARY KEY (room_id, account_id), + FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE + ); + """) + + # Create sync_tokens table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS sync_tokens ( + account_id INTEGER PRIMARY KEY, + token TEXT NOT NULL, + FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE + ); + """) + + def _get_account(self): + cursor = self.conn.cursor() + cursor.execute( + "SELECT * FROM accounts WHERE user_id = ? AND device_id = ?", + (self.user_id, self.device_id), + ) + account = cursor.fetchone() + cursor.close() + return account + + def _get_device(self, device): + acc = self._get_account() + + if not acc: + return None + + cursor = self.conn.cursor() + cursor.execute( + "SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?", + (device.user_id, device.id, acc[0]), + ) + device_entry = cursor.fetchone() + cursor.close() + + return device_entry + + # Implementing methods with DuckDB equivalents + def verify_device(self, device): + if self.is_device_verified(device): + return False + + d = self._get_device(device) + assert d + + cursor = self.conn.cursor() + cursor.execute( + "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)", + (d[0], TrustState.verified), + ) + self.conn.commit() + cursor.close() + + device.trust_state = TrustState.verified + + return True + + def unverify_device(self, device): + if not self.is_device_verified(device): + return False + + d = self._get_device(device) + assert d + + cursor = self.conn.cursor() + cursor.execute( + "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)", + (d[0], TrustState.unset), + ) + self.conn.commit() + cursor.close() + + device.trust_state = TrustState.unset + + return True + + def is_device_verified(self, device): + d = self._get_device(device) + + if not d: + return False + + cursor = self.conn.cursor() + cursor.execute( + "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],) + ) + trust_state = cursor.fetchone() + cursor.close() + + if not trust_state: + return False + + return trust_state[0] == TrustState.verified + + def blacklist_device(self, device): + if self.is_device_blacklisted(device): + return False + + d = self._get_device(device) + assert d + + cursor = self.conn.cursor() + cursor.execute( + "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)", + (d[0], TrustState.blacklisted), + ) + self.conn.commit() + cursor.close() + + device.trust_state = TrustState.blacklisted + + return True + + def unblacklist_device(self, device): + if not self.is_device_blacklisted(device): + return False + + d = self._get_device(device) + assert d + + cursor = self.conn.cursor() + cursor.execute( + "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)", + (d[0], TrustState.unset), + ) + self.conn.commit() + cursor.close() + + device.trust_state = TrustState.unset + + return True + + def is_device_blacklisted(self, device): + d = self._get_device(device) + + if not d: + return False + + cursor = self.conn.cursor() + cursor.execute( + "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],) + ) + trust_state = cursor.fetchone() + cursor.close() + + if not trust_state: + return False + + return trust_state[0] == TrustState.blacklisted + + def ignore_device(self, device): + if self.is_device_ignored(device): + return False + + d = self._get_device(device) + assert d + + cursor = self.conn.cursor() + cursor.execute( + "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)", + (d[0], int(TrustState.ignored.value)), + ) + self.conn.commit() + cursor.close() + + return True + + def ignore_devices(self, devices): + for device in devices: + self.ignore_device(device) + + def unignore_device(self, device): + if not self.is_device_ignored(device): + return False + + d = self._get_device(device) + assert d + + cursor = self.conn.cursor() + cursor.execute( + "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)", + (d[0], TrustState.unset), + ) + self.conn.commit() + cursor.close() + + device.trust_state = TrustState.unset + + return True + + def is_device_ignored(self, device): + d = self._get_device(device) + + if not d: + return False + + cursor = self.conn.cursor() + cursor.execute( + "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],) + ) + trust_state = cursor.fetchone() + cursor.close() + + if not trust_state: + return False + + return trust_state[0] == TrustState.ignored + + def load_device_keys(self): + """Load all the device keys from the database. + + Returns DeviceStore containing the OlmDevices with the device keys. + """ + store = DeviceStore() + account = self.account_id + + if not account: + return store + + with self.conn.cursor() as cur: + cur.execute( + "SELECT * FROM device_keys WHERE account_id = ?", + (account,) + ) + device_keys = cur.fetchall() + + for d in device_keys: + cur.execute( + "SELECT * FROM keys WHERE device_id = ?", + (d["id"],) + ) + keys = cur.fetchall() + key_dict = {k["key_type"]: k["key"] for k in keys} + + store.add( + OlmDevice( + d["user_id"], + d["device_id"], + key_dict, + display_name=d["display_name"], + deleted=d["deleted"], + ) + ) + + return store + + def save_device_keys(self, device_keys): + """Save the provided device keys to the database.""" + account = self.account_id + assert account + rows = [] + + for user_id, devices_dict in device_keys.items(): + for device_id, device in devices_dict.items(): + rows.append( + { + "account_id": account, + "user_id": user_id, + "device_id": device_id, + "display_name": device.display_name, + "deleted": device.deleted, + } + ) + + if not rows: + return + + with self.conn.cursor() as cur: + for idx in range(0, len(rows), 100): + data = rows[idx: idx + 100] + cur.executemany( + "INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)", + [(r["account_id"], r["user_id"], r["device_id"], + r["display_name"], r["deleted"]) for r in data] + ) + + for user_id, devices_dict in device_keys.items(): + for device_id, device in devices_dict.items(): + cur.execute( + "UPDATE device_keys SET deleted = ? WHERE device_id = ?", + (device.deleted, device_id) + ) + + for key_type, key in device.keys.items(): + cur.execute(""" + INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?) + ON CONFLICT (key_type, device_id) DO UPDATE SET key = ? + """, + (key_type, key, device_id, key) + ) + self.conn.commit() + + def save_group_sessions(self, sessions): + with self.conn.cursor() as cur: + for session in sessions: + cur.execute(""" + INSERT OR REPLACE INTO inbound_group_sessions ( + session_id, sender_key, signing_key, room_id, pickle, account_id + ) VALUES (?, ?, ?, ?, ?, ?) + """, ( + session.id, + session.sender_key, + session.signing_key, + session.room_id, + session.pickle, + self.account_id + )) + + self.conn.commit() + + def save_olm_sessions(self, sessions): + with self.conn.cursor() as cur: + for session in sessions: + cur.execute(""" + INSERT OR REPLACE INTO olm_sessions ( + session_id, sender_key, pickle, account_id + ) VALUES (?, ?, ?, ?) + """, ( + session.id, + session.sender_key, + session.pickle, + self.account_id + )) + + self.conn.commit() + + def save_outbound_group_sessions(self, sessions): + with self.conn.cursor() as cur: + for session in sessions: + cur.execute(""" + INSERT OR REPLACE INTO outbound_group_sessions ( + room_id, session_id, pickle, account_id + ) VALUES (?, ?, ?, ?) + """, ( + session.room_id, + session.id, + session.pickle, + self.account_id + )) + + self.conn.commit() + + def save_account(self, account: OlmAccount): + with self.conn.cursor() as cur: + cur.execute(""" + INSERT OR REPLACE INTO accounts ( + id, user_id, device_id, shared_account, pickle + ) VALUES (?, ?, ?, ?, ?) + """, ( + self.account_id, + self.user_id, + self.device_id, + account.shared, + account.pickle(self.pickle_key), + )) + + self.conn.commit() + + def load_sessions(self): + session_store = SessionStore() + + with self.conn.cursor() as cur: + cur.execute(""" + SELECT + os.sender_key, os.session, os.creation_time + FROM + olm_sessions os + INNER JOIN + accounts a ON os.account_id = a.id + WHERE + a.id = ? + """, (self.account_id,)) + + for row in cur.fetchall(): + sender_key, session_pickle, creation_time = row + session = Session.from_pickle( + session_pickle, creation_time, self.pickle_key) + session_store.add(sender_key, session) + + return session_store + + def load_inbound_group_sessions(self): + # type: () -> GroupSessionStore + """Load all Olm sessions from the database. + + Returns: + ``GroupSessionStore`` object, containing all the loaded sessions. + + """ + store = GroupSessionStore() + + account = self.account_id + + if not account: + return store + + with self.conn.cursor() as cursor: + cursor.execute( + "SELECT * FROM inbound_group_sessions WHERE account_id = ?", ( + account,) + ) + + for row in cursor.fetchall(): + session = InboundGroupSession.from_pickle( + row["session"], + row["fp_key"], + row["sender_key"], + row["room_id"], + self.pickle_key, + [ + chain["sender_key"] + for chain in cursor.execute( + "SELECT sender_key FROM forwarded_chains WHERE session_id = ?", + (row["id"],), + ) + ], + ) + store.add(session) + + return store + + def load_outgoing_key_requests(self): + # type: () -> dict + """Load all outgoing key requests from the database. + + Returns: + ``OutgoingKeyRequestStore`` object, containing all the loaded key requests. + """ + account = self.account_id + + if not account: + return store + + with self.conn.cursor() as cur: + cur.execute( + "SELECT * FROM outgoing_key_requests WHERE account_id = ?", + (account,) + ) + rows = cur.fetchall() + + return { + request.request_id: OutgoingKeyRequest.from_database(request) + for request in rows + } + + def load_encrypted_rooms(self): + """Load the set of encrypted rooms for this account. + + Returns: + ``Set`` containing room ids of encrypted rooms. + """ + account = self.account_id + + if not account: + return set() + + with self.conn.cursor() as cur: + cur.execute( + "SELECT room_id FROM encrypted_rooms WHERE account_id = ?", + (account,) + ) + rows = cur.fetchall() + + return {row["room_id"] for row in rows} + + def save_sync_token(self, token): + """Save the given token""" + account = self.account_id + assert account + + with self.conn.cursor() as cur: + cur.execute( + "INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)", + (account, token) + ) + self.conn.commit() + + def save_encrypted_rooms(self, rooms): + """Save the set of room ids for this account.""" + account = self.account_id + assert account + + data = [(room_id, account) for room_id in rooms] + + with self.conn.cursor() as cur: + for idx in range(0, len(data), 400): + rows = data[idx: idx + 400] + cur.executemany( + "INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)", + rows + ) + self.conn.commit() + + def save_session(self, sender_key, session): + """Save the provided Olm session to the database. + + Args: + sender_key (str): The curve key that owns the Olm session. + session (Session): The Olm session that will be pickled and + saved in the database. + """ + account = self.account_id + assert account + + pickled_session = session.pickle(self.pickle_key) + + with self.conn.cursor() as cur: + + cur.execute( + "INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)", + (account, sender_key, pickled_session, session.id, + session.creation_time, session.use_time) + ) + self.conn.commit() + + def save_inbound_group_session(self, session): + """Save the provided Megolm inbound group session to the database. + + Args: + session (InboundGroupSession): The session to save. + """ + account = self.account_id + assert account + + with self.conn.cursor() as cur: + + # Insert a new session or update the existing one + query = """ + INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (account_id, sender_key, fp_key, room_id) + DO UPDATE SET session = excluded.session + """ + cur.execute(query, (account, session.sender_key, + session.ed25519, session.room_id, session.pickle(self.pickle_key))) + + # Delete existing forwarded chains for the session + delete_query = """ + DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?) + """ + cur.execute( + delete_query, (account, session.sender_key, session.ed25519, session.room_id)) + + # Insert new forwarded chains for the session + insert_query = """ + INSERT INTO forwarded_chains (session_id, sender_key) + VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?) + """ + + for chain in session.forwarding_chain: + cur.execute( + insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain)) + + def add_outgoing_key_request(self, key_request): + account_id = self.account_id + with self.conn.cursor() as cursor: + cursor.execute( + """ + INSERT INTO outgoing_key_requests (account_id, request_id, session_id, room_id, algorithm) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (account_id, request_id) DO NOTHING + """, + ( + account_id, + key_request.request_id, + key_request.session_id, + key_request.room_id, + key_request.algorithm, + ) + ) \ No newline at end of file diff --git a/commands/__init__.py b/commands/__init__.py index df29381..52fefab 100644 --- a/commands/__init__.py +++ b/commands/__init__.py @@ -14,4 +14,4 @@ COMMANDS = { "coin": command_coin, "ignoreolder": command_ignoreolder, None: command_unknown, -} \ No newline at end of file +} diff --git a/commands/botinfo.py b/commands/botinfo.py index b7e9bb2..430912a 100644 --- a/commands/botinfo.py +++ b/commands/botinfo.py @@ -1,12 +1,12 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom + async def command_botinfo(room: MatrixRoom, event: RoomMessageText, context: dict): logging("Showing bot info...") - await context["client"].room_send( - room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": f"""GPT Info: + return room.room_id, "m.room.message", {"msgtype": "m.notice", + "body": f"""GPT Info: Model: {context["model"]} Maximum context tokens: {context["max_tokens"]} @@ -19,4 +19,4 @@ Bot user ID: {context["client"].user_id} Current room ID: {room.room_id} For usage statistics, run !gptbot stats -"""}) \ No newline at end of file +"""} diff --git a/commands/coin.py b/commands/coin.py index f792718..82a3741 100644 --- a/commands/coin.py +++ b/commands/coin.py @@ -3,12 +3,11 @@ from nio.rooms import MatrixRoom from random import SystemRandom + async def command_coin(room: MatrixRoom, event: RoomMessageText, context: dict): context["logger"]("Flipping a coin...") heads = SystemRandom().choice([True, False]) - await context["client"].room_send( - room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": "Heads!" if heads else "Tails!"} - ) \ No newline at end of file + return room.room_id, "m.room.message", {"msgtype": "m.notice", + "body": "Heads!" if heads else "Tails!"} diff --git a/commands/help.py b/commands/help.py index 6554d23..d5fb235 100644 --- a/commands/help.py +++ b/commands/help.py @@ -1,10 +1,10 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom + async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict): - await context["client"].room_send( - room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": """Available commands: + return room.guest_accessroom_id, "m.room.message", {"msgtype": "m.notice", + "body": """Available commands: !gptbot help - Show this message !gptbot newroom - Create a new room and invite yourself to it @@ -13,4 +13,3 @@ async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict): !gptbot coin - Flip a coin (heads or tails) !gptbot ignoreolder - Ignore messages before this point as context """} - ) \ No newline at end of file diff --git a/commands/ignoreolder.py b/commands/ignoreolder.py index 8da3d7c..348abe8 100644 --- a/commands/ignoreolder.py +++ b/commands/ignoreolder.py @@ -2,9 +2,7 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, context: dict): - await context["client"].room_send( - room.room_id, "m.room.message", {"msgtype": "m.notice", + return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": """Alright, messages before this point will not be processed as context anymore. -If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""} - ) \ No newline at end of file +If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""} \ No newline at end of file diff --git a/commands/newroom.py b/commands/newroom.py index 1395545..823fb7a 100644 --- a/commands/newroom.py +++ b/commands/newroom.py @@ -1,8 +1,10 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom + async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dict): - room_name = " ".join(event.body.split()[2:]) or context["default_room_name"] + room_name = " ".join(event.body.split()[ + 2:]) or context["default_room_name"] context["logger"]("Creating new room...") new_room = await context["client"].room_create(name=room_name) @@ -12,5 +14,4 @@ async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dic await context["client"].room_put_state( new_room.room_id, "m.room.power_levels", {"users": {event.sender: 100}}) - await context["client"].room_send( - new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"}) + return new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"} diff --git a/commands/stats.py b/commands/stats.py index e7ae307..cffb3e4 100644 --- a/commands/stats.py +++ b/commands/stats.py @@ -1,23 +1,19 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom + async def command_stats(room: MatrixRoom, event: RoomMessageText, context: dict): context["logger"]("Showing stats...") if not (database := context.get("database")): context["logger"]("No database connection - cannot show stats") - context["client"].room_send( - room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."} - ) - return + return room.room_id, "m.room.message", {"msgtype": "m.notice", + "body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."} with database.cursor() as cursor: cursor.execute( "SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,)) total_tokens = cursor.fetchone()[0] or 0 - await context["client"].room_send( - room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": f"Total tokens used: {total_tokens}"} - ) \ No newline at end of file + return room.room_id, "m.room.message", {"msgtype": "m.notice", + "body": f"Total tokens used: {total_tokens}"} diff --git a/commands/unknown.py b/commands/unknown.py index 4b97c95..ce01eb9 100644 --- a/commands/unknown.py +++ b/commands/unknown.py @@ -1,10 +1,9 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom + async def command_unknown(room: MatrixRoom, event: RoomMessageText, context: dict): context["logger"]("Unknown command") - await context["client"].room_send( - room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": "Unknown command - try !gptbot help"} - ) \ No newline at end of file + return room.room_id, "m.room.message", {"msgtype": "m.notice", + "body": "Unknown command - try !gptbot help"} diff --git a/gptbot.py b/gptbot.py index ae310cf..bc8624e 100644 --- a/gptbot.py +++ b/gptbot.py @@ -3,6 +3,7 @@ import inspect import logging import signal import random +import uuid import openai import asyncio @@ -10,9 +11,10 @@ import markdown2 import tiktoken import duckdb -from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent +from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent, AsyncClientConfig, MegolmEvent, GroupEncryptionError, EncryptionError, HttpClient, Api from nio.api import MessageDirection -from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError +from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError, WhoamiResponse, JoinResponse, RoomSendResponse +from nio.crypto import Olm from configparser import ConfigParser from datetime import datetime @@ -20,6 +22,7 @@ from argparse import ArgumentParser from typing import List, Dict, Union, Optional from commands import COMMANDS +from classes import DuckDBStore def logging(message: str, log_level: str = "info"): @@ -85,6 +88,13 @@ async def fetch_last_n_messages(room_id: str, n: Optional[int] = None, for event in response.chunk: if len(messages) >= n: break + if isinstance(event, MegolmEvent): + try: + event = await client.decrypt_event(event) + except (GroupEncryptionError, EncryptionError): + logging( + f"Could not decrypt message {event.event_id} in room {room_id}", "error") + continue if isinstance(event, RoomMessageText): if event.body.startswith("!gptbot ignoreolder"): break @@ -162,14 +172,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs): # Convert markdown to HTML - markdowner = markdown2.Markdown(extras=["fenced-code-blocks"]) - formatted_body = markdowner.convert(response) - - message = await client.room_send( - room.room_id, "m.room.message", - {"msgtype": "m.text", "body": response, - "format": "org.matrix.custom.html", "formatted_body": formatted_body} - ) + message = await send_message(room, response) if database: logging("Logging tokens used...") @@ -183,11 +186,8 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs): # Send a notice to the room if there was an error logging("Error during GPT API call - sending notice to room") - - await client.room_send( - room.room_id, "m.room.message", { - "msgtype": "m.notice", "body": "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later."} - ) + send_message( + room, "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later.", True) print("No response from GPT API") await client.room_typing(room.room_id, False) @@ -199,14 +199,34 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt logging( f"Received command {event.body} from {event.sender} in room {room.room_id}") command = event.body.split()[1] if event.body.split()[1:] else None - await COMMANDS.get(command, COMMANDS[None])(room, event, context) + + message = await COMMANDS.get(command, COMMANDS[None])(room, event, context) + + if message: + room_id, event, content = message + await send_message(context["client"].rooms[room_id], content["body"], + True if content["msgtype"] == "m.notice" else False, context["client"]) -async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs): +async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs): context = kwargs.get("context") or CONTEXT - + logging(f"Received message from {event.sender} in room {room.room_id}") + if isinstance(event, MegolmEvent): + try: + event = await context["client"].decrypt_event(event) + except Exception as e: + try: + logging("Requesting new encryption keys...") + await context["client"].request_room_key(event) + except: + pass + + logging(f"Error decrypting message: {e}", "error") + await send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True, context["client"]) + return + if event.sender == context["client"].user_id: logging("Message is from bot itself - ignoring") @@ -221,18 +241,69 @@ async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs): async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs): - client = kwargs.get("client") or CONTEXT["client"] + client: AsyncClient = kwargs.get("client") or CONTEXT["client"] + + if room.room_id in client.rooms: + logging(f"Already in room {room.room_id} - ignoring invite") + return logging(f"Received invite to room {room.room_id} - joining...") - await client.join(room.room_id) - await client.room_send( - room.room_id, - "m.room.message", - {"msgtype": "m.text", - "body": "Hello! I'm a helpful assistant. How can I help you today?"} + response = await client.join(room.room_id) + if isinstance(response, JoinResponse): + await send_message(room, "Hello! I'm a helpful assistant. How can I help you today?", client) + else: + logging(f"Error joining room {room.room_id}: {response}", "error") + + +async def send_message(room: MatrixRoom, message: str, notice: bool = False, client: Optional[AsyncClient] = None): + client = client or CONTEXT["client"] + + markdowner = markdown2.Markdown(extras=["fenced-code-blocks"]) + formatted_body = markdowner.convert(message) + + msgtype = "m.notice" if notice else "m.text" + + msgcontent = {"msgtype": msgtype, "body": message, + "format": "org.matrix.custom.html", "formatted_body": formatted_body} + + content = None + + if client.olm and room.encrypted: + try: + if not room.members_synced: + responses = [] + responses.append(await client.joined_members(room.room_id)) + + if client.olm.should_share_group_session(room.room_id): + try: + event = client.sharing_session[room.room_id] + await event.wait() + except KeyError: + await client.share_group_session( + room.room_id, + ignore_unverified_devices=True, + ) + + if msgtype != "m.reaction": + response = client.encrypt(room.room_id, "m.room.message", msgcontent) + msgtype, content = response + + except Exception as e: + logging( + f"Error encrypting message: {e} - sending unencrypted", "error") + raise + + if not content: + msgtype = "m.room.message" + content = msgcontent + + method, path, data = Api.room_send( + client.access_token, room.room_id, msgtype, content, uuid.uuid4() ) + return await client._send(RoomSendResponse, method, path, data, (room.room_id,)) + async def accept_pending_invites(client: Optional[AsyncClient] = None): client = client or CONTEXT["client"] @@ -242,13 +313,14 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None): for room_id in list(client.invited_rooms.keys()): logging(f"Joining room {room_id}...") - await client.join(room_id) - await client.room_send( - room_id, - "m.room.message", - {"msgtype": "m.text", - "body": "Hello! I'm a helpful assistant. How can I help you today?"} - ) + response = await client.join(room_id) + + if isinstance(response, JoinResponse): + logging(response, "debug") + rooms = await client.joined_rooms() + await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client) + else: + logging(f"Error joining room {room_id}: {response}", "error") async def sync_cb(response, write_global: bool = True): @@ -261,12 +333,95 @@ async def sync_cb(response, write_global: bool = True): CONTEXT["sync_token"] = SYNC_TOKEN -async def main(client: Optional[AsyncClient] = None): - client = client or CONTEXT["client"] +async def test_callback(room: MatrixRoom, event: Event, **kwargs): + logging( + f"Received event {event.__class__.__name__} in room {room.room_id}", "debug") - if not client.user_id: - whoami = await client.whoami() - client.user_id = whoami.user_id + +async def init(config: ConfigParser): + # Set up Matrix client + try: + assert "Matrix" in config + assert "Homeserver" in config["Matrix"] + assert "AccessToken" in config["Matrix"] + except: + logging("Matrix config not found or incomplete", "critical") + exit(1) + + homeserver = config["Matrix"]["Homeserver"] + access_token = config["Matrix"]["AccessToken"] + + device_id, user_id = await get_device_id(access_token, homeserver) + + device_id = config["Matrix"].get("DeviceID", device_id) + user_id = config["Matrix"].get("UserID", user_id) + + # Set up database + if "Database" in config and config["Database"].get("Path"): + database = CONTEXT["database"] = initialize_database( + config["Database"]["Path"]) + matrix_store = DuckDBStore + + client_config = AsyncClientConfig( + store_sync_tokens=True, encryption_enabled=True, store=matrix_store) + + else: + client_config = AsyncClientConfig( + store_sync_tokens=True, encryption_enabled=False) + + client = AsyncClient( + config["Matrix"]["Homeserver"], config=client_config) + + if client.config.encryption_enabled: + client.store = client.config.store( + user_id, + device_id, + database + ) + assert client.store + + client.olm = Olm(client.user_id, client.device_id, client.store) + client.encrypted_rooms = client.store.load_encrypted_rooms() + + CONTEXT["client"] = client + + CONTEXT["client"].access_token = config["Matrix"]["AccessToken"] + CONTEXT["client"].user_id = user_id + CONTEXT["client"].device_id = device_id + + # Set up GPT API + try: + assert "OpenAI" in config + assert "APIKey" in config["OpenAI"] + except: + logging("OpenAI config not found or incomplete", "critical") + exit(1) + + openai.api_key = config["OpenAI"]["APIKey"] + + if "Model" in config["OpenAI"]: + CONTEXT["model"] = config["OpenAI"]["Model"] + + if "MaxTokens" in config["OpenAI"]: + CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"]) + + if "MaxMessages" in config["OpenAI"]: + CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"]) + + # Listen for SIGTERM + + def sigterm_handler(_signo, _stack_frame): + logging("Received SIGTERM - exiting...") + exit() + + signal.signal(signal.SIGTERM, sigterm_handler) + + +async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None): + if not client and not CONTEXT.get("client"): + await init(config) + + client = client or CONTEXT["client"] try: assert client.user_id @@ -285,7 +440,9 @@ async def main(client: Optional[AsyncClient] = None): await client.sync(timeout=30000) client.add_event_callback(message_callback, RoomMessageText) + client.add_event_callback(message_callback, MegolmEvent) client.add_event_callback(room_invite_callback, InviteEvent) + client.add_event_callback(test_callback, Event) await accept_pending_invites() # Accept pending invites @@ -351,6 +508,31 @@ def initialize_database(path: os.PathLike): return database +async def get_device_id(access_token, homeserver): + client = AsyncClient(homeserver) + client.access_token = access_token + + logging(f"Obtaining device ID for access token {access_token}...", "debug") + response = await client.whoami() + if isinstance(response, WhoamiResponse): + logging( + f"Authenticated as {response.user_id}.") + user_id = response.user_id + devices = await client.devices() + device_id = devices.devices[0].id + + await client.close() + + return device_id, user_id + + else: + logging(f"Failed to obtain device ID: {response}", "error") + + await client.close() + + return None, None + + if __name__ == "__main__": # Parse command line arguments parser = ArgumentParser() @@ -362,54 +544,9 @@ if __name__ == "__main__": config = ConfigParser() config.read(args.config) - # Set up Matrix client - try: - assert "Matrix" in config - assert "Homeserver" in config["Matrix"] - assert "AccessToken" in config["Matrix"] - except: - logging("Matrix config not found or incomplete", "critical") - exit(1) - - CONTEXT["client"] = AsyncClient(config["Matrix"]["Homeserver"]) - - CONTEXT["client"].access_token = config["Matrix"]["AccessToken"] - CONTEXT["client"].user_id = config["Matrix"].get("UserID") - - # Set up GPT API - try: - assert "OpenAI" in config - assert "APIKey" in config["OpenAI"] - except: - logging("OpenAI config not found or incomplete", "critical") - exit(1) - - openai.api_key = config["OpenAI"]["APIKey"] - - if "Model" in config["OpenAI"]: - CONTEXT["model"] = config["OpenAI"]["Model"] - - if "MaxTokens" in config["OpenAI"]: - CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"]) - - if "MaxMessages" in config["OpenAI"]: - CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"]) - - # Set up database - if "Database" in config and config["Database"].get("Path"): - CONTEXT["database"] = initialize_database(config["Database"]["Path"]) - - # Listen for SIGTERM - - def sigterm_handler(_signo, _stack_frame): - logging("Received SIGTERM - exiting...") - exit() - - signal.signal(signal.SIGTERM, sigterm_handler) - # Start bot loop try: - asyncio.run(main()) + asyncio.run(main(config)) except KeyboardInterrupt: logging("Received KeyboardInterrupt - exiting...") except SystemExit: