Some refactoring, starting implementation of encryption

This commit is contained in:
Kumi 2023-04-23 13:26:46 +00:00
parent 202bed25c6
commit f20b762558
Signed by: kumi
GPG key ID: ECBCC9082395383F
11 changed files with 976 additions and 116 deletions

1
classes/__init__.py Normal file
View file

@ -0,0 +1 @@
from .store import DuckDBStore

730
classes/store.py Normal file
View file

@ -0,0 +1,730 @@
import duckdb
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore
from nio.crypto import OlmAccount, OlmDevice
from random import SystemRandom
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import json
class DuckDBStore(MatrixStore):
@property
def account_id(self):
id = self._get_account()[0] if self._get_account() else None
if id is None:
id = SystemRandom().randint(0, 2**16)
return id
def __init__(self, user_id, device_id, duckdb_conn):
self.conn = duckdb_conn
self.user_id = user_id
self.device_id = device_id
self._create_tables()
def _create_tables(self):
with self.conn.cursor() as cursor:
cursor.execute("""
DROP TABLE IF EXISTS sync_tokens CASCADE;
DROP TABLE IF EXISTS encrypted_rooms CASCADE;
DROP TABLE IF EXISTS outgoing_key_requests CASCADE;
DROP TABLE IF EXISTS forwarded_chains CASCADE;
DROP TABLE IF EXISTS outbound_group_sessions CASCADE;
DROP TABLE IF EXISTS inbound_group_sessions CASCADE;
DROP TABLE IF EXISTS olm_sessions CASCADE;
DROP TABLE IF EXISTS device_trust_state CASCADE;
DROP TABLE IF EXISTS keys CASCADE;
DROP TABLE IF EXISTS device_keys_key CASCADE;
DROP TABLE IF EXISTS device_keys CASCADE;
DROP TABLE IF EXISTS accounts CASCADE;
""")
# Create accounts table
cursor.execute("""
CREATE TABLE IF NOT EXISTS accounts (
id INTEGER PRIMARY KEY,
user_id VARCHAR NOT NULL,
device_id VARCHAR NOT NULL,
shared_account INTEGER NOT NULL,
pickle VARCHAR NOT NULL
);
""")
# Create device_keys table
cursor.execute("""
CREATE TABLE IF NOT EXISTS device_keys (
device_id TEXT PRIMARY KEY,
account_id INTEGER NOT NULL,
user_id TEXT NOT NULL,
display_name TEXT,
deleted BOOLEAN NOT NULL DEFAULT 0,
UNIQUE (account_id, user_id, device_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS keys (
key_type TEXT NOT NULL,
key TEXT NOT NULL,
device_id VARCHAR NOT NULL,
UNIQUE (key_type, device_id),
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
);
""")
# Create device_trust_state table
cursor.execute("""
CREATE TABLE IF NOT EXISTS device_trust_state (
device_id VARCHAR PRIMARY KEY,
state INTEGER NOT NULL,
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
);
""")
# Create olm_sessions table
cursor.execute("""
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
CREATE TABLE IF NOT EXISTS olm_sessions (
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
account_id INTEGER NOT NULL,
sender_key TEXT NOT NULL,
session BLOB NOT NULL,
session_id VARCHAR NOT NULL,
creation_time TIMESTAMP NOT NULL,
last_usage_date TIMESTAMP NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
);
""")
# Create inbound_group_sessions table
cursor.execute("""
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
account_id INTEGER NOT NULL,
session TEXT NOT NULL,
fp_key TEXT NOT NULL,
sender_key TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (account_id, sender_key, fp_key, room_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS forwarded_chains (
id INTEGER PRIMARY KEY,
session_id INTEGER NOT NULL,
sender_key TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
);
""")
# Create outbound_group_sessions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL,
room_id VARCHAR NOT NULL,
session_id VARCHAR NOT NULL UNIQUE,
session BLOB NOT NULL,
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
# Create outgoing_key_requests table
cursor.execute("""
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL,
request_id TEXT NOT NULL,
session_id TEXT NOT NULL,
room_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
UNIQUE (account_id, request_id)
);
""")
# Create encrypted_rooms table
cursor.execute("""
CREATE TABLE IF NOT EXISTS encrypted_rooms (
room_id TEXT NOT NULL,
account_id INTEGER NOT NULL,
PRIMARY KEY (room_id, account_id),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
# Create sync_tokens table
cursor.execute("""
CREATE TABLE IF NOT EXISTS sync_tokens (
account_id INTEGER PRIMARY KEY,
token TEXT NOT NULL,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
""")
def _get_account(self):
cursor = self.conn.cursor()
cursor.execute(
"SELECT * FROM accounts WHERE user_id = ? AND device_id = ?",
(self.user_id, self.device_id),
)
account = cursor.fetchone()
cursor.close()
return account
def _get_device(self, device):
acc = self._get_account()
if not acc:
return None
cursor = self.conn.cursor()
cursor.execute(
"SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?",
(device.user_id, device.id, acc[0]),
)
device_entry = cursor.fetchone()
cursor.close()
return device_entry
# Implementing methods with DuckDB equivalents
def verify_device(self, device):
if self.is_device_verified(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.verified),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.verified
return True
def unverify_device(self, device):
if not self.is_device_verified(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.unset),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.unset
return True
def is_device_verified(self, device):
d = self._get_device(device)
if not d:
return False
cursor = self.conn.cursor()
cursor.execute(
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
)
trust_state = cursor.fetchone()
cursor.close()
if not trust_state:
return False
return trust_state[0] == TrustState.verified
def blacklist_device(self, device):
if self.is_device_blacklisted(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.blacklisted),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.blacklisted
return True
def unblacklist_device(self, device):
if not self.is_device_blacklisted(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.unset),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.unset
return True
def is_device_blacklisted(self, device):
d = self._get_device(device)
if not d:
return False
cursor = self.conn.cursor()
cursor.execute(
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
)
trust_state = cursor.fetchone()
cursor.close()
if not trust_state:
return False
return trust_state[0] == TrustState.blacklisted
def ignore_device(self, device):
if self.is_device_ignored(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], int(TrustState.ignored.value)),
)
self.conn.commit()
cursor.close()
return True
def ignore_devices(self, devices):
for device in devices:
self.ignore_device(device)
def unignore_device(self, device):
if not self.is_device_ignored(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.unset),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.unset
return True
def is_device_ignored(self, device):
d = self._get_device(device)
if not d:
return False
cursor = self.conn.cursor()
cursor.execute(
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
)
trust_state = cursor.fetchone()
cursor.close()
if not trust_state:
return False
return trust_state[0] == TrustState.ignored
def load_device_keys(self):
"""Load all the device keys from the database.
Returns DeviceStore containing the OlmDevices with the device keys.
"""
store = DeviceStore()
account = self.account_id
if not account:
return store
with self.conn.cursor() as cur:
cur.execute(
"SELECT * FROM device_keys WHERE account_id = ?",
(account,)
)
device_keys = cur.fetchall()
for d in device_keys:
cur.execute(
"SELECT * FROM keys WHERE device_id = ?",
(d["id"],)
)
keys = cur.fetchall()
key_dict = {k["key_type"]: k["key"] for k in keys}
store.add(
OlmDevice(
d["user_id"],
d["device_id"],
key_dict,
display_name=d["display_name"],
deleted=d["deleted"],
)
)
return store
def save_device_keys(self, device_keys):
"""Save the provided device keys to the database."""
account = self.account_id
assert account
rows = []
for user_id, devices_dict in device_keys.items():
for device_id, device in devices_dict.items():
rows.append(
{
"account_id": account,
"user_id": user_id,
"device_id": device_id,
"display_name": device.display_name,
"deleted": device.deleted,
}
)
if not rows:
return
with self.conn.cursor() as cur:
for idx in range(0, len(rows), 100):
data = rows[idx: idx + 100]
cur.executemany(
"INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)",
[(r["account_id"], r["user_id"], r["device_id"],
r["display_name"], r["deleted"]) for r in data]
)
for user_id, devices_dict in device_keys.items():
for device_id, device in devices_dict.items():
cur.execute(
"UPDATE device_keys SET deleted = ? WHERE device_id = ?",
(device.deleted, device_id)
)
for key_type, key in device.keys.items():
cur.execute("""
INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?)
ON CONFLICT (key_type, device_id) DO UPDATE SET key = ?
""",
(key_type, key, device_id, key)
)
self.conn.commit()
def save_group_sessions(self, sessions):
with self.conn.cursor() as cur:
for session in sessions:
cur.execute("""
INSERT OR REPLACE INTO inbound_group_sessions (
session_id, sender_key, signing_key, room_id, pickle, account_id
) VALUES (?, ?, ?, ?, ?, ?)
""", (
session.id,
session.sender_key,
session.signing_key,
session.room_id,
session.pickle,
self.account_id
))
self.conn.commit()
def save_olm_sessions(self, sessions):
with self.conn.cursor() as cur:
for session in sessions:
cur.execute("""
INSERT OR REPLACE INTO olm_sessions (
session_id, sender_key, pickle, account_id
) VALUES (?, ?, ?, ?)
""", (
session.id,
session.sender_key,
session.pickle,
self.account_id
))
self.conn.commit()
def save_outbound_group_sessions(self, sessions):
with self.conn.cursor() as cur:
for session in sessions:
cur.execute("""
INSERT OR REPLACE INTO outbound_group_sessions (
room_id, session_id, pickle, account_id
) VALUES (?, ?, ?, ?)
""", (
session.room_id,
session.id,
session.pickle,
self.account_id
))
self.conn.commit()
def save_account(self, account: OlmAccount):
with self.conn.cursor() as cur:
cur.execute("""
INSERT OR REPLACE INTO accounts (
id, user_id, device_id, shared_account, pickle
) VALUES (?, ?, ?, ?, ?)
""", (
self.account_id,
self.user_id,
self.device_id,
account.shared,
account.pickle(self.pickle_key),
))
self.conn.commit()
def load_sessions(self):
session_store = SessionStore()
with self.conn.cursor() as cur:
cur.execute("""
SELECT
os.sender_key, os.session, os.creation_time
FROM
olm_sessions os
INNER JOIN
accounts a ON os.account_id = a.id
WHERE
a.id = ?
""", (self.account_id,))
for row in cur.fetchall():
sender_key, session_pickle, creation_time = row
session = Session.from_pickle(
session_pickle, creation_time, self.pickle_key)
session_store.add(sender_key, session)
return session_store
def load_inbound_group_sessions(self):
# type: () -> GroupSessionStore
"""Load all Olm sessions from the database.
Returns:
``GroupSessionStore`` object, containing all the loaded sessions.
"""
store = GroupSessionStore()
account = self.account_id
if not account:
return store
with self.conn.cursor() as cursor:
cursor.execute(
"SELECT * FROM inbound_group_sessions WHERE account_id = ?", (
account,)
)
for row in cursor.fetchall():
session = InboundGroupSession.from_pickle(
row["session"],
row["fp_key"],
row["sender_key"],
row["room_id"],
self.pickle_key,
[
chain["sender_key"]
for chain in cursor.execute(
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
(row["id"],),
)
],
)
store.add(session)
return store
def load_outgoing_key_requests(self):
# type: () -> dict
"""Load all outgoing key requests from the database.
Returns:
``OutgoingKeyRequestStore`` object, containing all the loaded key requests.
"""
account = self.account_id
if not account:
return store
with self.conn.cursor() as cur:
cur.execute(
"SELECT * FROM outgoing_key_requests WHERE account_id = ?",
(account,)
)
rows = cur.fetchall()
return {
request.request_id: OutgoingKeyRequest.from_database(request)
for request in rows
}
def load_encrypted_rooms(self):
"""Load the set of encrypted rooms for this account.
Returns:
``Set`` containing room ids of encrypted rooms.
"""
account = self.account_id
if not account:
return set()
with self.conn.cursor() as cur:
cur.execute(
"SELECT room_id FROM encrypted_rooms WHERE account_id = ?",
(account,)
)
rows = cur.fetchall()
return {row["room_id"] for row in rows}
def save_sync_token(self, token):
"""Save the given token"""
account = self.account_id
assert account
with self.conn.cursor() as cur:
cur.execute(
"INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)",
(account, token)
)
self.conn.commit()
def save_encrypted_rooms(self, rooms):
"""Save the set of room ids for this account."""
account = self.account_id
assert account
data = [(room_id, account) for room_id in rooms]
with self.conn.cursor() as cur:
for idx in range(0, len(data), 400):
rows = data[idx: idx + 400]
cur.executemany(
"INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)",
rows
)
self.conn.commit()
def save_session(self, sender_key, session):
"""Save the provided Olm session to the database.
Args:
sender_key (str): The curve key that owns the Olm session.
session (Session): The Olm session that will be pickled and
saved in the database.
"""
account = self.account_id
assert account
pickled_session = session.pickle(self.pickle_key)
with self.conn.cursor() as cur:
cur.execute(
"INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)",
(account, sender_key, pickled_session, session.id,
session.creation_time, session.use_time)
)
self.conn.commit()
def save_inbound_group_session(self, session):
"""Save the provided Megolm inbound group session to the database.
Args:
session (InboundGroupSession): The session to save.
"""
account = self.account_id
assert account
with self.conn.cursor() as cur:
# Insert a new session or update the existing one
query = """
INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT (account_id, sender_key, fp_key, room_id)
DO UPDATE SET session = excluded.session
"""
cur.execute(query, (account, session.sender_key,
session.ed25519, session.room_id, session.pickle(self.pickle_key)))
# Delete existing forwarded chains for the session
delete_query = """
DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?)
"""
cur.execute(
delete_query, (account, session.sender_key, session.ed25519, session.room_id))
# Insert new forwarded chains for the session
insert_query = """
INSERT INTO forwarded_chains (session_id, sender_key)
VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?)
"""
for chain in session.forwarding_chain:
cur.execute(
insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain))
def add_outgoing_key_request(self, key_request):
account_id = self.account_id
with self.conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO outgoing_key_requests (account_id, request_id, session_id, room_id, algorithm)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT (account_id, request_id) DO NOTHING
""",
(
account_id,
key_request.request_id,
key_request.session_id,
key_request.room_id,
key_request.algorithm,
)
)

View file

@ -14,4 +14,4 @@ COMMANDS = {
"coin": command_coin, "coin": command_coin,
"ignoreolder": command_ignoreolder, "ignoreolder": command_ignoreolder,
None: command_unknown, None: command_unknown,
} }

View file

@ -1,12 +1,12 @@
from nio.events.room_events import RoomMessageText from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
async def command_botinfo(room: MatrixRoom, event: RoomMessageText, context: dict): async def command_botinfo(room: MatrixRoom, event: RoomMessageText, context: dict):
logging("Showing bot info...") logging("Showing bot info...")
await context["client"].room_send( return room.room_id, "m.room.message", {"msgtype": "m.notice",
room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"""GPT Info:
"body": f"""GPT Info:
Model: {context["model"]} Model: {context["model"]}
Maximum context tokens: {context["max_tokens"]} Maximum context tokens: {context["max_tokens"]}
@ -19,4 +19,4 @@ Bot user ID: {context["client"].user_id}
Current room ID: {room.room_id} Current room ID: {room.room_id}
For usage statistics, run !gptbot stats For usage statistics, run !gptbot stats
"""}) """}

View file

@ -3,12 +3,11 @@ from nio.rooms import MatrixRoom
from random import SystemRandom from random import SystemRandom
async def command_coin(room: MatrixRoom, event: RoomMessageText, context: dict): async def command_coin(room: MatrixRoom, event: RoomMessageText, context: dict):
context["logger"]("Flipping a coin...") context["logger"]("Flipping a coin...")
heads = SystemRandom().choice([True, False]) heads = SystemRandom().choice([True, False])
await context["client"].room_send( return room.room_id, "m.room.message", {"msgtype": "m.notice",
room.room_id, "m.room.message", {"msgtype": "m.notice", "body": "Heads!" if heads else "Tails!"}
"body": "Heads!" if heads else "Tails!"}
)

View file

@ -1,10 +1,10 @@
from nio.events.room_events import RoomMessageText from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict): async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
await context["client"].room_send( return room.guest_accessroom_id, "m.room.message", {"msgtype": "m.notice",
room.room_id, "m.room.message", {"msgtype": "m.notice", "body": """Available commands:
"body": """Available commands:
!gptbot help - Show this message !gptbot help - Show this message
!gptbot newroom <room name> - Create a new room and invite yourself to it !gptbot newroom <room name> - Create a new room and invite yourself to it
@ -13,4 +13,3 @@ async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
!gptbot coin - Flip a coin (heads or tails) !gptbot coin - Flip a coin (heads or tails)
!gptbot ignoreolder - Ignore messages before this point as context !gptbot ignoreolder - Ignore messages before this point as context
"""} """}
)

View file

@ -2,9 +2,7 @@ from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, context: dict): async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, context: dict):
await context["client"].room_send( return room.room_id, "m.room.message", {"msgtype": "m.notice",
room.room_id, "m.room.message", {"msgtype": "m.notice",
"body": """Alright, messages before this point will not be processed as context anymore. "body": """Alright, messages before this point will not be processed as context anymore.
If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""} If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""}
)

View file

@ -1,8 +1,10 @@
from nio.events.room_events import RoomMessageText from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dict): async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dict):
room_name = " ".join(event.body.split()[2:]) or context["default_room_name"] room_name = " ".join(event.body.split()[
2:]) or context["default_room_name"]
context["logger"]("Creating new room...") context["logger"]("Creating new room...")
new_room = await context["client"].room_create(name=room_name) new_room = await context["client"].room_create(name=room_name)
@ -12,5 +14,4 @@ async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dic
await context["client"].room_put_state( await context["client"].room_put_state(
new_room.room_id, "m.room.power_levels", {"users": {event.sender: 100}}) new_room.room_id, "m.room.power_levels", {"users": {event.sender: 100}})
await context["client"].room_send( return new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"}
new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"})

View file

@ -1,23 +1,19 @@
from nio.events.room_events import RoomMessageText from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
async def command_stats(room: MatrixRoom, event: RoomMessageText, context: dict): async def command_stats(room: MatrixRoom, event: RoomMessageText, context: dict):
context["logger"]("Showing stats...") context["logger"]("Showing stats...")
if not (database := context.get("database")): if not (database := context.get("database")):
context["logger"]("No database connection - cannot show stats") context["logger"]("No database connection - cannot show stats")
context["client"].room_send( return room.room_id, "m.room.message", {"msgtype": "m.notice",
room.room_id, "m.room.message", {"msgtype": "m.notice", "body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."}
"body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."}
)
return
with database.cursor() as cursor: with database.cursor() as cursor:
cursor.execute( cursor.execute(
"SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,)) "SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,))
total_tokens = cursor.fetchone()[0] or 0 total_tokens = cursor.fetchone()[0] or 0
await context["client"].room_send( return room.room_id, "m.room.message", {"msgtype": "m.notice",
room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"Total tokens used: {total_tokens}"}
"body": f"Total tokens used: {total_tokens}"}
)

View file

@ -1,10 +1,9 @@
from nio.events.room_events import RoomMessageText from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom from nio.rooms import MatrixRoom
async def command_unknown(room: MatrixRoom, event: RoomMessageText, context: dict): async def command_unknown(room: MatrixRoom, event: RoomMessageText, context: dict):
context["logger"]("Unknown command") context["logger"]("Unknown command")
await context["client"].room_send( return room.room_id, "m.room.message", {"msgtype": "m.notice",
room.room_id, "m.room.message", {"msgtype": "m.notice", "body": "Unknown command - try !gptbot help"}
"body": "Unknown command - try !gptbot help"}
)

303
gptbot.py
View file

@ -3,6 +3,7 @@ import inspect
import logging import logging
import signal import signal
import random import random
import uuid
import openai import openai
import asyncio import asyncio
@ -10,9 +11,10 @@ import markdown2
import tiktoken import tiktoken
import duckdb import duckdb
from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent, AsyncClientConfig, MegolmEvent, GroupEncryptionError, EncryptionError, HttpClient, Api
from nio.api import MessageDirection from nio.api import MessageDirection
from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError, WhoamiResponse, JoinResponse, RoomSendResponse
from nio.crypto import Olm
from configparser import ConfigParser from configparser import ConfigParser
from datetime import datetime from datetime import datetime
@ -20,6 +22,7 @@ from argparse import ArgumentParser
from typing import List, Dict, Union, Optional from typing import List, Dict, Union, Optional
from commands import COMMANDS from commands import COMMANDS
from classes import DuckDBStore
def logging(message: str, log_level: str = "info"): def logging(message: str, log_level: str = "info"):
@ -85,6 +88,13 @@ async def fetch_last_n_messages(room_id: str, n: Optional[int] = None,
for event in response.chunk: for event in response.chunk:
if len(messages) >= n: if len(messages) >= n:
break break
if isinstance(event, MegolmEvent):
try:
event = await client.decrypt_event(event)
except (GroupEncryptionError, EncryptionError):
logging(
f"Could not decrypt message {event.event_id} in room {room_id}", "error")
continue
if isinstance(event, RoomMessageText): if isinstance(event, RoomMessageText):
if event.body.startswith("!gptbot ignoreolder"): if event.body.startswith("!gptbot ignoreolder"):
break break
@ -162,14 +172,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
# Convert markdown to HTML # Convert markdown to HTML
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"]) message = await send_message(room, response)
formatted_body = markdowner.convert(response)
message = await client.room_send(
room.room_id, "m.room.message",
{"msgtype": "m.text", "body": response,
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
)
if database: if database:
logging("Logging tokens used...") logging("Logging tokens used...")
@ -183,11 +186,8 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
# Send a notice to the room if there was an error # Send a notice to the room if there was an error
logging("Error during GPT API call - sending notice to room") logging("Error during GPT API call - sending notice to room")
send_message(
await client.room_send( room, "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later.", True)
room.room_id, "m.room.message", {
"msgtype": "m.notice", "body": "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later."}
)
print("No response from GPT API") print("No response from GPT API")
await client.room_typing(room.room_id, False) await client.room_typing(room.room_id, False)
@ -199,14 +199,34 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt
logging( logging(
f"Received command {event.body} from {event.sender} in room {room.room_id}") f"Received command {event.body} from {event.sender} in room {room.room_id}")
command = event.body.split()[1] if event.body.split()[1:] else None command = event.body.split()[1] if event.body.split()[1:] else None
await COMMANDS.get(command, COMMANDS[None])(room, event, context)
message = await COMMANDS.get(command, COMMANDS[None])(room, event, context)
if message:
room_id, event, content = message
await send_message(context["client"].rooms[room_id], content["body"],
True if content["msgtype"] == "m.notice" else False, context["client"])
async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs): async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
context = kwargs.get("context") or CONTEXT context = kwargs.get("context") or CONTEXT
logging(f"Received message from {event.sender} in room {room.room_id}") logging(f"Received message from {event.sender} in room {room.room_id}")
if isinstance(event, MegolmEvent):
try:
event = await context["client"].decrypt_event(event)
except Exception as e:
try:
logging("Requesting new encryption keys...")
await context["client"].request_room_key(event)
except:
pass
logging(f"Error decrypting message: {e}", "error")
await send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True, context["client"])
return
if event.sender == context["client"].user_id: if event.sender == context["client"].user_id:
logging("Message is from bot itself - ignoring") logging("Message is from bot itself - ignoring")
@ -221,18 +241,69 @@ async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs):
async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs): async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs):
client = kwargs.get("client") or CONTEXT["client"] client: AsyncClient = kwargs.get("client") or CONTEXT["client"]
if room.room_id in client.rooms:
logging(f"Already in room {room.room_id} - ignoring invite")
return
logging(f"Received invite to room {room.room_id} - joining...") logging(f"Received invite to room {room.room_id} - joining...")
await client.join(room.room_id) response = await client.join(room.room_id)
await client.room_send( if isinstance(response, JoinResponse):
room.room_id, await send_message(room, "Hello! I'm a helpful assistant. How can I help you today?", client)
"m.room.message", else:
{"msgtype": "m.text", logging(f"Error joining room {room.room_id}: {response}", "error")
"body": "Hello! I'm a helpful assistant. How can I help you today?"}
async def send_message(room: MatrixRoom, message: str, notice: bool = False, client: Optional[AsyncClient] = None):
client = client or CONTEXT["client"]
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
formatted_body = markdowner.convert(message)
msgtype = "m.notice" if notice else "m.text"
msgcontent = {"msgtype": msgtype, "body": message,
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
content = None
if client.olm and room.encrypted:
try:
if not room.members_synced:
responses = []
responses.append(await client.joined_members(room.room_id))
if client.olm.should_share_group_session(room.room_id):
try:
event = client.sharing_session[room.room_id]
await event.wait()
except KeyError:
await client.share_group_session(
room.room_id,
ignore_unverified_devices=True,
)
if msgtype != "m.reaction":
response = client.encrypt(room.room_id, "m.room.message", msgcontent)
msgtype, content = response
except Exception as e:
logging(
f"Error encrypting message: {e} - sending unencrypted", "error")
raise
if not content:
msgtype = "m.room.message"
content = msgcontent
method, path, data = Api.room_send(
client.access_token, room.room_id, msgtype, content, uuid.uuid4()
) )
return await client._send(RoomSendResponse, method, path, data, (room.room_id,))
async def accept_pending_invites(client: Optional[AsyncClient] = None): async def accept_pending_invites(client: Optional[AsyncClient] = None):
client = client or CONTEXT["client"] client = client or CONTEXT["client"]
@ -242,13 +313,14 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None):
for room_id in list(client.invited_rooms.keys()): for room_id in list(client.invited_rooms.keys()):
logging(f"Joining room {room_id}...") logging(f"Joining room {room_id}...")
await client.join(room_id) response = await client.join(room_id)
await client.room_send(
room_id, if isinstance(response, JoinResponse):
"m.room.message", logging(response, "debug")
{"msgtype": "m.text", rooms = await client.joined_rooms()
"body": "Hello! I'm a helpful assistant. How can I help you today?"} await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
) else:
logging(f"Error joining room {room_id}: {response}", "error")
async def sync_cb(response, write_global: bool = True): async def sync_cb(response, write_global: bool = True):
@ -261,12 +333,95 @@ async def sync_cb(response, write_global: bool = True):
CONTEXT["sync_token"] = SYNC_TOKEN CONTEXT["sync_token"] = SYNC_TOKEN
async def main(client: Optional[AsyncClient] = None): async def test_callback(room: MatrixRoom, event: Event, **kwargs):
client = client or CONTEXT["client"] logging(
f"Received event {event.__class__.__name__} in room {room.room_id}", "debug")
if not client.user_id:
whoami = await client.whoami() async def init(config: ConfigParser):
client.user_id = whoami.user_id # Set up Matrix client
try:
assert "Matrix" in config
assert "Homeserver" in config["Matrix"]
assert "AccessToken" in config["Matrix"]
except:
logging("Matrix config not found or incomplete", "critical")
exit(1)
homeserver = config["Matrix"]["Homeserver"]
access_token = config["Matrix"]["AccessToken"]
device_id, user_id = await get_device_id(access_token, homeserver)
device_id = config["Matrix"].get("DeviceID", device_id)
user_id = config["Matrix"].get("UserID", user_id)
# Set up database
if "Database" in config and config["Database"].get("Path"):
database = CONTEXT["database"] = initialize_database(
config["Database"]["Path"])
matrix_store = DuckDBStore
client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
else:
client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=False)
client = AsyncClient(
config["Matrix"]["Homeserver"], config=client_config)
if client.config.encryption_enabled:
client.store = client.config.store(
user_id,
device_id,
database
)
assert client.store
client.olm = Olm(client.user_id, client.device_id, client.store)
client.encrypted_rooms = client.store.load_encrypted_rooms()
CONTEXT["client"] = client
CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
CONTEXT["client"].user_id = user_id
CONTEXT["client"].device_id = device_id
# Set up GPT API
try:
assert "OpenAI" in config
assert "APIKey" in config["OpenAI"]
except:
logging("OpenAI config not found or incomplete", "critical")
exit(1)
openai.api_key = config["OpenAI"]["APIKey"]
if "Model" in config["OpenAI"]:
CONTEXT["model"] = config["OpenAI"]["Model"]
if "MaxTokens" in config["OpenAI"]:
CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
if "MaxMessages" in config["OpenAI"]:
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
# Listen for SIGTERM
def sigterm_handler(_signo, _stack_frame):
logging("Received SIGTERM - exiting...")
exit()
signal.signal(signal.SIGTERM, sigterm_handler)
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
if not client and not CONTEXT.get("client"):
await init(config)
client = client or CONTEXT["client"]
try: try:
assert client.user_id assert client.user_id
@ -285,7 +440,9 @@ async def main(client: Optional[AsyncClient] = None):
await client.sync(timeout=30000) await client.sync(timeout=30000)
client.add_event_callback(message_callback, RoomMessageText) client.add_event_callback(message_callback, RoomMessageText)
client.add_event_callback(message_callback, MegolmEvent)
client.add_event_callback(room_invite_callback, InviteEvent) client.add_event_callback(room_invite_callback, InviteEvent)
client.add_event_callback(test_callback, Event)
await accept_pending_invites() # Accept pending invites await accept_pending_invites() # Accept pending invites
@ -351,6 +508,31 @@ def initialize_database(path: os.PathLike):
return database return database
async def get_device_id(access_token, homeserver):
client = AsyncClient(homeserver)
client.access_token = access_token
logging(f"Obtaining device ID for access token {access_token}...", "debug")
response = await client.whoami()
if isinstance(response, WhoamiResponse):
logging(
f"Authenticated as {response.user_id}.")
user_id = response.user_id
devices = await client.devices()
device_id = devices.devices[0].id
await client.close()
return device_id, user_id
else:
logging(f"Failed to obtain device ID: {response}", "error")
await client.close()
return None, None
if __name__ == "__main__": if __name__ == "__main__":
# Parse command line arguments # Parse command line arguments
parser = ArgumentParser() parser = ArgumentParser()
@ -362,54 +544,9 @@ if __name__ == "__main__":
config = ConfigParser() config = ConfigParser()
config.read(args.config) config.read(args.config)
# Set up Matrix client
try:
assert "Matrix" in config
assert "Homeserver" in config["Matrix"]
assert "AccessToken" in config["Matrix"]
except:
logging("Matrix config not found or incomplete", "critical")
exit(1)
CONTEXT["client"] = AsyncClient(config["Matrix"]["Homeserver"])
CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
CONTEXT["client"].user_id = config["Matrix"].get("UserID")
# Set up GPT API
try:
assert "OpenAI" in config
assert "APIKey" in config["OpenAI"]
except:
logging("OpenAI config not found or incomplete", "critical")
exit(1)
openai.api_key = config["OpenAI"]["APIKey"]
if "Model" in config["OpenAI"]:
CONTEXT["model"] = config["OpenAI"]["Model"]
if "MaxTokens" in config["OpenAI"]:
CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
if "MaxMessages" in config["OpenAI"]:
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
# Set up database
if "Database" in config and config["Database"].get("Path"):
CONTEXT["database"] = initialize_database(config["Database"]["Path"])
# Listen for SIGTERM
def sigterm_handler(_signo, _stack_frame):
logging("Received SIGTERM - exiting...")
exit()
signal.signal(signal.SIGTERM, sigterm_handler)
# Start bot loop # Start bot loop
try: try:
asyncio.run(main()) asyncio.run(main(config))
except KeyboardInterrupt: except KeyboardInterrupt:
logging("Received KeyboardInterrupt - exiting...") logging("Received KeyboardInterrupt - exiting...")
except SystemExit: except SystemExit: