Some refactoring, starting implementation of encryption
This commit is contained in:
parent
202bed25c6
commit
f20b762558
11 changed files with 976 additions and 116 deletions
1
classes/__init__.py
Normal file
1
classes/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .store import DuckDBStore
|
730
classes/store.py
Normal file
730
classes/store.py
Normal file
|
@ -0,0 +1,730 @@
|
||||||
|
import duckdb
|
||||||
|
|
||||||
|
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore
|
||||||
|
from nio.crypto import OlmAccount, OlmDevice
|
||||||
|
|
||||||
|
from random import SystemRandom
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class DuckDBStore(MatrixStore):
|
||||||
|
@property
|
||||||
|
def account_id(self):
|
||||||
|
id = self._get_account()[0] if self._get_account() else None
|
||||||
|
|
||||||
|
if id is None:
|
||||||
|
id = SystemRandom().randint(0, 2**16)
|
||||||
|
|
||||||
|
return id
|
||||||
|
|
||||||
|
def __init__(self, user_id, device_id, duckdb_conn):
|
||||||
|
self.conn = duckdb_conn
|
||||||
|
self.user_id = user_id
|
||||||
|
self.device_id = device_id
|
||||||
|
self._create_tables()
|
||||||
|
|
||||||
|
def _create_tables(self):
|
||||||
|
with self.conn.cursor() as cursor:
|
||||||
|
cursor.execute("""
|
||||||
|
DROP TABLE IF EXISTS sync_tokens CASCADE;
|
||||||
|
DROP TABLE IF EXISTS encrypted_rooms CASCADE;
|
||||||
|
DROP TABLE IF EXISTS outgoing_key_requests CASCADE;
|
||||||
|
DROP TABLE IF EXISTS forwarded_chains CASCADE;
|
||||||
|
DROP TABLE IF EXISTS outbound_group_sessions CASCADE;
|
||||||
|
DROP TABLE IF EXISTS inbound_group_sessions CASCADE;
|
||||||
|
DROP TABLE IF EXISTS olm_sessions CASCADE;
|
||||||
|
DROP TABLE IF EXISTS device_trust_state CASCADE;
|
||||||
|
DROP TABLE IF EXISTS keys CASCADE;
|
||||||
|
DROP TABLE IF EXISTS device_keys_key CASCADE;
|
||||||
|
DROP TABLE IF EXISTS device_keys CASCADE;
|
||||||
|
DROP TABLE IF EXISTS accounts CASCADE;
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create accounts table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS accounts (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
user_id VARCHAR NOT NULL,
|
||||||
|
device_id VARCHAR NOT NULL,
|
||||||
|
shared_account INTEGER NOT NULL,
|
||||||
|
pickle VARCHAR NOT NULL
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create device_keys table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS device_keys (
|
||||||
|
device_id TEXT PRIMARY KEY,
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
|
deleted BOOLEAN NOT NULL DEFAULT 0,
|
||||||
|
UNIQUE (account_id, user_id, device_id),
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS keys (
|
||||||
|
key_type TEXT NOT NULL,
|
||||||
|
key TEXT NOT NULL,
|
||||||
|
device_id VARCHAR NOT NULL,
|
||||||
|
UNIQUE (key_type, device_id),
|
||||||
|
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create device_trust_state table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS device_trust_state (
|
||||||
|
device_id VARCHAR PRIMARY KEY,
|
||||||
|
state INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create olm_sessions table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS olm_sessions (
|
||||||
|
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
sender_key TEXT NOT NULL,
|
||||||
|
session BLOB NOT NULL,
|
||||||
|
session_id VARCHAR NOT NULL,
|
||||||
|
creation_time TIMESTAMP NOT NULL,
|
||||||
|
last_usage_date TIMESTAMP NOT NULL,
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create inbound_group_sessions table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
|
||||||
|
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
session TEXT NOT NULL,
|
||||||
|
fp_key TEXT NOT NULL,
|
||||||
|
sender_key TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
UNIQUE (account_id, sender_key, fp_key, room_id),
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS forwarded_chains (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
session_id INTEGER NOT NULL,
|
||||||
|
sender_key TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create outbound_group_sessions table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
room_id VARCHAR NOT NULL,
|
||||||
|
session_id VARCHAR NOT NULL UNIQUE,
|
||||||
|
session BLOB NOT NULL,
|
||||||
|
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create outgoing_key_requests table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
request_id TEXT NOT NULL,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
algorithm TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
||||||
|
UNIQUE (account_id, request_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create encrypted_rooms table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS encrypted_rooms (
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (room_id, account_id),
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create sync_tokens table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS sync_tokens (
|
||||||
|
account_id INTEGER PRIMARY KEY,
|
||||||
|
token TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
def _get_account(self):
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT * FROM accounts WHERE user_id = ? AND device_id = ?",
|
||||||
|
(self.user_id, self.device_id),
|
||||||
|
)
|
||||||
|
account = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
return account
|
||||||
|
|
||||||
|
def _get_device(self, device):
|
||||||
|
acc = self._get_account()
|
||||||
|
|
||||||
|
if not acc:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?",
|
||||||
|
(device.user_id, device.id, acc[0]),
|
||||||
|
)
|
||||||
|
device_entry = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return device_entry
|
||||||
|
|
||||||
|
# Implementing methods with DuckDB equivalents
|
||||||
|
def verify_device(self, device):
|
||||||
|
if self.is_device_verified(device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
d = self._get_device(device)
|
||||||
|
assert d
|
||||||
|
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||||
|
(d[0], TrustState.verified),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
device.trust_state = TrustState.verified
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def unverify_device(self, device):
|
||||||
|
if not self.is_device_verified(device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
d = self._get_device(device)
|
||||||
|
assert d
|
||||||
|
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||||
|
(d[0], TrustState.unset),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
device.trust_state = TrustState.unset
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_device_verified(self, device):
|
||||||
|
d = self._get_device(device)
|
||||||
|
|
||||||
|
if not d:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
||||||
|
)
|
||||||
|
trust_state = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
if not trust_state:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return trust_state[0] == TrustState.verified
|
||||||
|
|
||||||
|
def blacklist_device(self, device):
|
||||||
|
if self.is_device_blacklisted(device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
d = self._get_device(device)
|
||||||
|
assert d
|
||||||
|
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||||
|
(d[0], TrustState.blacklisted),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
device.trust_state = TrustState.blacklisted
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def unblacklist_device(self, device):
|
||||||
|
if not self.is_device_blacklisted(device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
d = self._get_device(device)
|
||||||
|
assert d
|
||||||
|
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||||
|
(d[0], TrustState.unset),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
device.trust_state = TrustState.unset
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_device_blacklisted(self, device):
|
||||||
|
d = self._get_device(device)
|
||||||
|
|
||||||
|
if not d:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
||||||
|
)
|
||||||
|
trust_state = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
if not trust_state:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return trust_state[0] == TrustState.blacklisted
|
||||||
|
|
||||||
|
def ignore_device(self, device):
|
||||||
|
if self.is_device_ignored(device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
d = self._get_device(device)
|
||||||
|
assert d
|
||||||
|
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||||
|
(d[0], int(TrustState.ignored.value)),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def ignore_devices(self, devices):
|
||||||
|
for device in devices:
|
||||||
|
self.ignore_device(device)
|
||||||
|
|
||||||
|
def unignore_device(self, device):
|
||||||
|
if not self.is_device_ignored(device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
d = self._get_device(device)
|
||||||
|
assert d
|
||||||
|
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||||
|
(d[0], TrustState.unset),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
device.trust_state = TrustState.unset
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_device_ignored(self, device):
|
||||||
|
d = self._get_device(device)
|
||||||
|
|
||||||
|
if not d:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
||||||
|
)
|
||||||
|
trust_state = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
if not trust_state:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return trust_state[0] == TrustState.ignored
|
||||||
|
|
||||||
|
def load_device_keys(self):
|
||||||
|
"""Load all the device keys from the database.
|
||||||
|
|
||||||
|
Returns DeviceStore containing the OlmDevices with the device keys.
|
||||||
|
"""
|
||||||
|
store = DeviceStore()
|
||||||
|
account = self.account_id
|
||||||
|
|
||||||
|
if not account:
|
||||||
|
return store
|
||||||
|
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT * FROM device_keys WHERE account_id = ?",
|
||||||
|
(account,)
|
||||||
|
)
|
||||||
|
device_keys = cur.fetchall()
|
||||||
|
|
||||||
|
for d in device_keys:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT * FROM keys WHERE device_id = ?",
|
||||||
|
(d["id"],)
|
||||||
|
)
|
||||||
|
keys = cur.fetchall()
|
||||||
|
key_dict = {k["key_type"]: k["key"] for k in keys}
|
||||||
|
|
||||||
|
store.add(
|
||||||
|
OlmDevice(
|
||||||
|
d["user_id"],
|
||||||
|
d["device_id"],
|
||||||
|
key_dict,
|
||||||
|
display_name=d["display_name"],
|
||||||
|
deleted=d["deleted"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return store
|
||||||
|
|
||||||
|
def save_device_keys(self, device_keys):
|
||||||
|
"""Save the provided device keys to the database."""
|
||||||
|
account = self.account_id
|
||||||
|
assert account
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
for user_id, devices_dict in device_keys.items():
|
||||||
|
for device_id, device in devices_dict.items():
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"account_id": account,
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
"display_name": device.display_name,
|
||||||
|
"deleted": device.deleted,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
for idx in range(0, len(rows), 100):
|
||||||
|
data = rows[idx: idx + 100]
|
||||||
|
cur.executemany(
|
||||||
|
"INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
[(r["account_id"], r["user_id"], r["device_id"],
|
||||||
|
r["display_name"], r["deleted"]) for r in data]
|
||||||
|
)
|
||||||
|
|
||||||
|
for user_id, devices_dict in device_keys.items():
|
||||||
|
for device_id, device in devices_dict.items():
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE device_keys SET deleted = ? WHERE device_id = ?",
|
||||||
|
(device.deleted, device_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
for key_type, key in device.keys.items():
|
||||||
|
cur.execute("""
|
||||||
|
INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?)
|
||||||
|
ON CONFLICT (key_type, device_id) DO UPDATE SET key = ?
|
||||||
|
""",
|
||||||
|
(key_type, key, device_id, key)
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def save_group_sessions(self, sessions):
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
for session in sessions:
|
||||||
|
cur.execute("""
|
||||||
|
INSERT OR REPLACE INTO inbound_group_sessions (
|
||||||
|
session_id, sender_key, signing_key, room_id, pickle, account_id
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
session.id,
|
||||||
|
session.sender_key,
|
||||||
|
session.signing_key,
|
||||||
|
session.room_id,
|
||||||
|
session.pickle,
|
||||||
|
self.account_id
|
||||||
|
))
|
||||||
|
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def save_olm_sessions(self, sessions):
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
for session in sessions:
|
||||||
|
cur.execute("""
|
||||||
|
INSERT OR REPLACE INTO olm_sessions (
|
||||||
|
session_id, sender_key, pickle, account_id
|
||||||
|
) VALUES (?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
session.id,
|
||||||
|
session.sender_key,
|
||||||
|
session.pickle,
|
||||||
|
self.account_id
|
||||||
|
))
|
||||||
|
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def save_outbound_group_sessions(self, sessions):
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
for session in sessions:
|
||||||
|
cur.execute("""
|
||||||
|
INSERT OR REPLACE INTO outbound_group_sessions (
|
||||||
|
room_id, session_id, pickle, account_id
|
||||||
|
) VALUES (?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
session.room_id,
|
||||||
|
session.id,
|
||||||
|
session.pickle,
|
||||||
|
self.account_id
|
||||||
|
))
|
||||||
|
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def save_account(self, account: OlmAccount):
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
cur.execute("""
|
||||||
|
INSERT OR REPLACE INTO accounts (
|
||||||
|
id, user_id, device_id, shared_account, pickle
|
||||||
|
) VALUES (?, ?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
self.account_id,
|
||||||
|
self.user_id,
|
||||||
|
self.device_id,
|
||||||
|
account.shared,
|
||||||
|
account.pickle(self.pickle_key),
|
||||||
|
))
|
||||||
|
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def load_sessions(self):
|
||||||
|
session_store = SessionStore()
|
||||||
|
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
cur.execute("""
|
||||||
|
SELECT
|
||||||
|
os.sender_key, os.session, os.creation_time
|
||||||
|
FROM
|
||||||
|
olm_sessions os
|
||||||
|
INNER JOIN
|
||||||
|
accounts a ON os.account_id = a.id
|
||||||
|
WHERE
|
||||||
|
a.id = ?
|
||||||
|
""", (self.account_id,))
|
||||||
|
|
||||||
|
for row in cur.fetchall():
|
||||||
|
sender_key, session_pickle, creation_time = row
|
||||||
|
session = Session.from_pickle(
|
||||||
|
session_pickle, creation_time, self.pickle_key)
|
||||||
|
session_store.add(sender_key, session)
|
||||||
|
|
||||||
|
return session_store
|
||||||
|
|
||||||
|
def load_inbound_group_sessions(self):
|
||||||
|
# type: () -> GroupSessionStore
|
||||||
|
"""Load all Olm sessions from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``GroupSessionStore`` object, containing all the loaded sessions.
|
||||||
|
|
||||||
|
"""
|
||||||
|
store = GroupSessionStore()
|
||||||
|
|
||||||
|
account = self.account_id
|
||||||
|
|
||||||
|
if not account:
|
||||||
|
return store
|
||||||
|
|
||||||
|
with self.conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT * FROM inbound_group_sessions WHERE account_id = ?", (
|
||||||
|
account,)
|
||||||
|
)
|
||||||
|
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
session = InboundGroupSession.from_pickle(
|
||||||
|
row["session"],
|
||||||
|
row["fp_key"],
|
||||||
|
row["sender_key"],
|
||||||
|
row["room_id"],
|
||||||
|
self.pickle_key,
|
||||||
|
[
|
||||||
|
chain["sender_key"]
|
||||||
|
for chain in cursor.execute(
|
||||||
|
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
|
||||||
|
(row["id"],),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
store.add(session)
|
||||||
|
|
||||||
|
return store
|
||||||
|
|
||||||
|
def load_outgoing_key_requests(self):
|
||||||
|
# type: () -> dict
|
||||||
|
"""Load all outgoing key requests from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``OutgoingKeyRequestStore`` object, containing all the loaded key requests.
|
||||||
|
"""
|
||||||
|
account = self.account_id
|
||||||
|
|
||||||
|
if not account:
|
||||||
|
return store
|
||||||
|
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT * FROM outgoing_key_requests WHERE account_id = ?",
|
||||||
|
(account,)
|
||||||
|
)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
|
||||||
|
return {
|
||||||
|
request.request_id: OutgoingKeyRequest.from_database(request)
|
||||||
|
for request in rows
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_encrypted_rooms(self):
|
||||||
|
"""Load the set of encrypted rooms for this account.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``Set`` containing room ids of encrypted rooms.
|
||||||
|
"""
|
||||||
|
account = self.account_id
|
||||||
|
|
||||||
|
if not account:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT room_id FROM encrypted_rooms WHERE account_id = ?",
|
||||||
|
(account,)
|
||||||
|
)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
|
||||||
|
return {row["room_id"] for row in rows}
|
||||||
|
|
||||||
|
def save_sync_token(self, token):
|
||||||
|
"""Save the given token"""
|
||||||
|
account = self.account_id
|
||||||
|
assert account
|
||||||
|
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)",
|
||||||
|
(account, token)
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def save_encrypted_rooms(self, rooms):
|
||||||
|
"""Save the set of room ids for this account."""
|
||||||
|
account = self.account_id
|
||||||
|
assert account
|
||||||
|
|
||||||
|
data = [(room_id, account) for room_id in rooms]
|
||||||
|
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
for idx in range(0, len(data), 400):
|
||||||
|
rows = data[idx: idx + 400]
|
||||||
|
cur.executemany(
|
||||||
|
"INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)",
|
||||||
|
rows
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def save_session(self, sender_key, session):
|
||||||
|
"""Save the provided Olm session to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sender_key (str): The curve key that owns the Olm session.
|
||||||
|
session (Session): The Olm session that will be pickled and
|
||||||
|
saved in the database.
|
||||||
|
"""
|
||||||
|
account = self.account_id
|
||||||
|
assert account
|
||||||
|
|
||||||
|
pickled_session = session.pickle(self.pickle_key)
|
||||||
|
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(account, sender_key, pickled_session, session.id,
|
||||||
|
session.creation_time, session.use_time)
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def save_inbound_group_session(self, session):
|
||||||
|
"""Save the provided Megolm inbound group session to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session (InboundGroupSession): The session to save.
|
||||||
|
"""
|
||||||
|
account = self.account_id
|
||||||
|
assert account
|
||||||
|
|
||||||
|
with self.conn.cursor() as cur:
|
||||||
|
|
||||||
|
# Insert a new session or update the existing one
|
||||||
|
query = """
|
||||||
|
INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT (account_id, sender_key, fp_key, room_id)
|
||||||
|
DO UPDATE SET session = excluded.session
|
||||||
|
"""
|
||||||
|
cur.execute(query, (account, session.sender_key,
|
||||||
|
session.ed25519, session.room_id, session.pickle(self.pickle_key)))
|
||||||
|
|
||||||
|
# Delete existing forwarded chains for the session
|
||||||
|
delete_query = """
|
||||||
|
DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?)
|
||||||
|
"""
|
||||||
|
cur.execute(
|
||||||
|
delete_query, (account, session.sender_key, session.ed25519, session.room_id))
|
||||||
|
|
||||||
|
# Insert new forwarded chains for the session
|
||||||
|
insert_query = """
|
||||||
|
INSERT INTO forwarded_chains (session_id, sender_key)
|
||||||
|
VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?)
|
||||||
|
"""
|
||||||
|
|
||||||
|
for chain in session.forwarding_chain:
|
||||||
|
cur.execute(
|
||||||
|
insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain))
|
||||||
|
|
||||||
|
def add_outgoing_key_request(self, key_request):
|
||||||
|
account_id = self.account_id
|
||||||
|
with self.conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO outgoing_key_requests (account_id, request_id, session_id, room_id, algorithm)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT (account_id, request_id) DO NOTHING
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
account_id,
|
||||||
|
key_request.request_id,
|
||||||
|
key_request.session_id,
|
||||||
|
key_request.room_id,
|
||||||
|
key_request.algorithm,
|
||||||
|
)
|
||||||
|
)
|
|
@ -1,11 +1,11 @@
|
||||||
from nio.events.room_events import RoomMessageText
|
from nio.events.room_events import RoomMessageText
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
|
|
||||||
async def command_botinfo(room: MatrixRoom, event: RoomMessageText, context: dict):
|
async def command_botinfo(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||||
logging("Showing bot info...")
|
logging("Showing bot info...")
|
||||||
|
|
||||||
await context["client"].room_send(
|
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
|
||||||
"body": f"""GPT Info:
|
"body": f"""GPT Info:
|
||||||
|
|
||||||
Model: {context["model"]}
|
Model: {context["model"]}
|
||||||
|
@ -19,4 +19,4 @@ Bot user ID: {context["client"].user_id}
|
||||||
Current room ID: {room.room_id}
|
Current room ID: {room.room_id}
|
||||||
|
|
||||||
For usage statistics, run !gptbot stats
|
For usage statistics, run !gptbot stats
|
||||||
"""})
|
"""}
|
||||||
|
|
|
@ -3,12 +3,11 @@ from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
from random import SystemRandom
|
from random import SystemRandom
|
||||||
|
|
||||||
|
|
||||||
async def command_coin(room: MatrixRoom, event: RoomMessageText, context: dict):
|
async def command_coin(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||||
context["logger"]("Flipping a coin...")
|
context["logger"]("Flipping a coin...")
|
||||||
|
|
||||||
heads = SystemRandom().choice([True, False])
|
heads = SystemRandom().choice([True, False])
|
||||||
|
|
||||||
await context["client"].room_send(
|
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
|
||||||
"body": "Heads!" if heads else "Tails!"}
|
"body": "Heads!" if heads else "Tails!"}
|
||||||
)
|
|
|
@ -1,9 +1,9 @@
|
||||||
from nio.events.room_events import RoomMessageText
|
from nio.events.room_events import RoomMessageText
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
|
|
||||||
async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
|
async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||||
await context["client"].room_send(
|
return room.guest_accessroom_id, "m.room.message", {"msgtype": "m.notice",
|
||||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
|
||||||
"body": """Available commands:
|
"body": """Available commands:
|
||||||
|
|
||||||
!gptbot help - Show this message
|
!gptbot help - Show this message
|
||||||
|
@ -13,4 +13,3 @@ async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||||
!gptbot coin - Flip a coin (heads or tails)
|
!gptbot coin - Flip a coin (heads or tails)
|
||||||
!gptbot ignoreolder - Ignore messages before this point as context
|
!gptbot ignoreolder - Ignore messages before this point as context
|
||||||
"""}
|
"""}
|
||||||
)
|
|
|
@ -2,9 +2,7 @@ from nio.events.room_events import RoomMessageText
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, context: dict):
|
async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||||
await context["client"].room_send(
|
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
|
||||||
"body": """Alright, messages before this point will not be processed as context anymore.
|
"body": """Alright, messages before this point will not be processed as context anymore.
|
||||||
|
|
||||||
If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""}
|
If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""}
|
||||||
)
|
|
|
@ -1,8 +1,10 @@
|
||||||
from nio.events.room_events import RoomMessageText
|
from nio.events.room_events import RoomMessageText
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
|
|
||||||
async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dict):
|
async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||||
room_name = " ".join(event.body.split()[2:]) or context["default_room_name"]
|
room_name = " ".join(event.body.split()[
|
||||||
|
2:]) or context["default_room_name"]
|
||||||
|
|
||||||
context["logger"]("Creating new room...")
|
context["logger"]("Creating new room...")
|
||||||
new_room = await context["client"].room_create(name=room_name)
|
new_room = await context["client"].room_create(name=room_name)
|
||||||
|
@ -12,5 +14,4 @@ async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dic
|
||||||
await context["client"].room_put_state(
|
await context["client"].room_put_state(
|
||||||
new_room.room_id, "m.room.power_levels", {"users": {event.sender: 100}})
|
new_room.room_id, "m.room.power_levels", {"users": {event.sender: 100}})
|
||||||
|
|
||||||
await context["client"].room_send(
|
return new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"}
|
||||||
new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"})
|
|
||||||
|
|
|
@ -1,23 +1,19 @@
|
||||||
from nio.events.room_events import RoomMessageText
|
from nio.events.room_events import RoomMessageText
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
|
|
||||||
async def command_stats(room: MatrixRoom, event: RoomMessageText, context: dict):
|
async def command_stats(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||||
context["logger"]("Showing stats...")
|
context["logger"]("Showing stats...")
|
||||||
|
|
||||||
if not (database := context.get("database")):
|
if not (database := context.get("database")):
|
||||||
context["logger"]("No database connection - cannot show stats")
|
context["logger"]("No database connection - cannot show stats")
|
||||||
context["client"].room_send(
|
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
|
||||||
"body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."}
|
"body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."}
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
with database.cursor() as cursor:
|
with database.cursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,))
|
"SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,))
|
||||||
total_tokens = cursor.fetchone()[0] or 0
|
total_tokens = cursor.fetchone()[0] or 0
|
||||||
|
|
||||||
await context["client"].room_send(
|
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
|
||||||
"body": f"Total tokens used: {total_tokens}"}
|
"body": f"Total tokens used: {total_tokens}"}
|
||||||
)
|
|
|
@ -1,10 +1,9 @@
|
||||||
from nio.events.room_events import RoomMessageText
|
from nio.events.room_events import RoomMessageText
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
|
|
||||||
async def command_unknown(room: MatrixRoom, event: RoomMessageText, context: dict):
|
async def command_unknown(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||||
context["logger"]("Unknown command")
|
context["logger"]("Unknown command")
|
||||||
|
|
||||||
await context["client"].room_send(
|
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
|
||||||
"body": "Unknown command - try !gptbot help"}
|
"body": "Unknown command - try !gptbot help"}
|
||||||
)
|
|
299
gptbot.py
299
gptbot.py
|
@ -3,6 +3,7 @@ import inspect
|
||||||
import logging
|
import logging
|
||||||
import signal
|
import signal
|
||||||
import random
|
import random
|
||||||
|
import uuid
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -10,9 +11,10 @@ import markdown2
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import duckdb
|
import duckdb
|
||||||
|
|
||||||
from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent
|
from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent, AsyncClientConfig, MegolmEvent, GroupEncryptionError, EncryptionError, HttpClient, Api
|
||||||
from nio.api import MessageDirection
|
from nio.api import MessageDirection
|
||||||
from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError
|
from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError, WhoamiResponse, JoinResponse, RoomSendResponse
|
||||||
|
from nio.crypto import Olm
|
||||||
|
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -20,6 +22,7 @@ from argparse import ArgumentParser
|
||||||
from typing import List, Dict, Union, Optional
|
from typing import List, Dict, Union, Optional
|
||||||
|
|
||||||
from commands import COMMANDS
|
from commands import COMMANDS
|
||||||
|
from classes import DuckDBStore
|
||||||
|
|
||||||
|
|
||||||
def logging(message: str, log_level: str = "info"):
|
def logging(message: str, log_level: str = "info"):
|
||||||
|
@ -85,6 +88,13 @@ async def fetch_last_n_messages(room_id: str, n: Optional[int] = None,
|
||||||
for event in response.chunk:
|
for event in response.chunk:
|
||||||
if len(messages) >= n:
|
if len(messages) >= n:
|
||||||
break
|
break
|
||||||
|
if isinstance(event, MegolmEvent):
|
||||||
|
try:
|
||||||
|
event = await client.decrypt_event(event)
|
||||||
|
except (GroupEncryptionError, EncryptionError):
|
||||||
|
logging(
|
||||||
|
f"Could not decrypt message {event.event_id} in room {room_id}", "error")
|
||||||
|
continue
|
||||||
if isinstance(event, RoomMessageText):
|
if isinstance(event, RoomMessageText):
|
||||||
if event.body.startswith("!gptbot ignoreolder"):
|
if event.body.startswith("!gptbot ignoreolder"):
|
||||||
break
|
break
|
||||||
|
@ -162,14 +172,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
||||||
|
|
||||||
# Convert markdown to HTML
|
# Convert markdown to HTML
|
||||||
|
|
||||||
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
|
message = await send_message(room, response)
|
||||||
formatted_body = markdowner.convert(response)
|
|
||||||
|
|
||||||
message = await client.room_send(
|
|
||||||
room.room_id, "m.room.message",
|
|
||||||
{"msgtype": "m.text", "body": response,
|
|
||||||
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
|
|
||||||
)
|
|
||||||
|
|
||||||
if database:
|
if database:
|
||||||
logging("Logging tokens used...")
|
logging("Logging tokens used...")
|
||||||
|
@ -183,11 +186,8 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
||||||
# Send a notice to the room if there was an error
|
# Send a notice to the room if there was an error
|
||||||
|
|
||||||
logging("Error during GPT API call - sending notice to room")
|
logging("Error during GPT API call - sending notice to room")
|
||||||
|
send_message(
|
||||||
await client.room_send(
|
room, "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later.", True)
|
||||||
room.room_id, "m.room.message", {
|
|
||||||
"msgtype": "m.notice", "body": "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later."}
|
|
||||||
)
|
|
||||||
print("No response from GPT API")
|
print("No response from GPT API")
|
||||||
|
|
||||||
await client.room_typing(room.room_id, False)
|
await client.room_typing(room.room_id, False)
|
||||||
|
@ -199,14 +199,34 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt
|
||||||
logging(
|
logging(
|
||||||
f"Received command {event.body} from {event.sender} in room {room.room_id}")
|
f"Received command {event.body} from {event.sender} in room {room.room_id}")
|
||||||
command = event.body.split()[1] if event.body.split()[1:] else None
|
command = event.body.split()[1] if event.body.split()[1:] else None
|
||||||
await COMMANDS.get(command, COMMANDS[None])(room, event, context)
|
|
||||||
|
message = await COMMANDS.get(command, COMMANDS[None])(room, event, context)
|
||||||
|
|
||||||
|
if message:
|
||||||
|
room_id, event, content = message
|
||||||
|
await send_message(context["client"].rooms[room_id], content["body"],
|
||||||
|
True if content["msgtype"] == "m.notice" else False, context["client"])
|
||||||
|
|
||||||
|
|
||||||
async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
|
||||||
context = kwargs.get("context") or CONTEXT
|
context = kwargs.get("context") or CONTEXT
|
||||||
|
|
||||||
logging(f"Received message from {event.sender} in room {room.room_id}")
|
logging(f"Received message from {event.sender} in room {room.room_id}")
|
||||||
|
|
||||||
|
if isinstance(event, MegolmEvent):
|
||||||
|
try:
|
||||||
|
event = await context["client"].decrypt_event(event)
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
logging("Requesting new encryption keys...")
|
||||||
|
await context["client"].request_room_key(event)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logging(f"Error decrypting message: {e}", "error")
|
||||||
|
await send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True, context["client"])
|
||||||
|
return
|
||||||
|
|
||||||
if event.sender == context["client"].user_id:
|
if event.sender == context["client"].user_id:
|
||||||
logging("Message is from bot itself - ignoring")
|
logging("Message is from bot itself - ignoring")
|
||||||
|
|
||||||
|
@ -221,18 +241,69 @@ async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs):
|
async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs):
|
||||||
client = kwargs.get("client") or CONTEXT["client"]
|
client: AsyncClient = kwargs.get("client") or CONTEXT["client"]
|
||||||
|
|
||||||
|
if room.room_id in client.rooms:
|
||||||
|
logging(f"Already in room {room.room_id} - ignoring invite")
|
||||||
|
return
|
||||||
|
|
||||||
logging(f"Received invite to room {room.room_id} - joining...")
|
logging(f"Received invite to room {room.room_id} - joining...")
|
||||||
|
|
||||||
await client.join(room.room_id)
|
response = await client.join(room.room_id)
|
||||||
await client.room_send(
|
if isinstance(response, JoinResponse):
|
||||||
|
await send_message(room, "Hello! I'm a helpful assistant. How can I help you today?", client)
|
||||||
|
else:
|
||||||
|
logging(f"Error joining room {room.room_id}: {response}", "error")
|
||||||
|
|
||||||
|
|
||||||
|
async def send_message(room: MatrixRoom, message: str, notice: bool = False, client: Optional[AsyncClient] = None):
|
||||||
|
client = client or CONTEXT["client"]
|
||||||
|
|
||||||
|
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
|
||||||
|
formatted_body = markdowner.convert(message)
|
||||||
|
|
||||||
|
msgtype = "m.notice" if notice else "m.text"
|
||||||
|
|
||||||
|
msgcontent = {"msgtype": msgtype, "body": message,
|
||||||
|
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
|
||||||
|
|
||||||
|
content = None
|
||||||
|
|
||||||
|
if client.olm and room.encrypted:
|
||||||
|
try:
|
||||||
|
if not room.members_synced:
|
||||||
|
responses = []
|
||||||
|
responses.append(await client.joined_members(room.room_id))
|
||||||
|
|
||||||
|
if client.olm.should_share_group_session(room.room_id):
|
||||||
|
try:
|
||||||
|
event = client.sharing_session[room.room_id]
|
||||||
|
await event.wait()
|
||||||
|
except KeyError:
|
||||||
|
await client.share_group_session(
|
||||||
room.room_id,
|
room.room_id,
|
||||||
"m.room.message",
|
ignore_unverified_devices=True,
|
||||||
{"msgtype": "m.text",
|
|
||||||
"body": "Hello! I'm a helpful assistant. How can I help you today?"}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if msgtype != "m.reaction":
|
||||||
|
response = client.encrypt(room.room_id, "m.room.message", msgcontent)
|
||||||
|
msgtype, content = response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging(
|
||||||
|
f"Error encrypting message: {e} - sending unencrypted", "error")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
msgtype = "m.room.message"
|
||||||
|
content = msgcontent
|
||||||
|
|
||||||
|
method, path, data = Api.room_send(
|
||||||
|
client.access_token, room.room_id, msgtype, content, uuid.uuid4()
|
||||||
|
)
|
||||||
|
|
||||||
|
return await client._send(RoomSendResponse, method, path, data, (room.room_id,))
|
||||||
|
|
||||||
|
|
||||||
async def accept_pending_invites(client: Optional[AsyncClient] = None):
|
async def accept_pending_invites(client: Optional[AsyncClient] = None):
|
||||||
client = client or CONTEXT["client"]
|
client = client or CONTEXT["client"]
|
||||||
|
@ -242,13 +313,14 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None):
|
||||||
for room_id in list(client.invited_rooms.keys()):
|
for room_id in list(client.invited_rooms.keys()):
|
||||||
logging(f"Joining room {room_id}...")
|
logging(f"Joining room {room_id}...")
|
||||||
|
|
||||||
await client.join(room_id)
|
response = await client.join(room_id)
|
||||||
await client.room_send(
|
|
||||||
room_id,
|
if isinstance(response, JoinResponse):
|
||||||
"m.room.message",
|
logging(response, "debug")
|
||||||
{"msgtype": "m.text",
|
rooms = await client.joined_rooms()
|
||||||
"body": "Hello! I'm a helpful assistant. How can I help you today?"}
|
await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
|
||||||
)
|
else:
|
||||||
|
logging(f"Error joining room {room_id}: {response}", "error")
|
||||||
|
|
||||||
|
|
||||||
async def sync_cb(response, write_global: bool = True):
|
async def sync_cb(response, write_global: bool = True):
|
||||||
|
@ -261,12 +333,95 @@ async def sync_cb(response, write_global: bool = True):
|
||||||
CONTEXT["sync_token"] = SYNC_TOKEN
|
CONTEXT["sync_token"] = SYNC_TOKEN
|
||||||
|
|
||||||
|
|
||||||
async def main(client: Optional[AsyncClient] = None):
|
async def test_callback(room: MatrixRoom, event: Event, **kwargs):
|
||||||
client = client or CONTEXT["client"]
|
logging(
|
||||||
|
f"Received event {event.__class__.__name__} in room {room.room_id}", "debug")
|
||||||
|
|
||||||
if not client.user_id:
|
|
||||||
whoami = await client.whoami()
|
async def init(config: ConfigParser):
|
||||||
client.user_id = whoami.user_id
|
# Set up Matrix client
|
||||||
|
try:
|
||||||
|
assert "Matrix" in config
|
||||||
|
assert "Homeserver" in config["Matrix"]
|
||||||
|
assert "AccessToken" in config["Matrix"]
|
||||||
|
except:
|
||||||
|
logging("Matrix config not found or incomplete", "critical")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
homeserver = config["Matrix"]["Homeserver"]
|
||||||
|
access_token = config["Matrix"]["AccessToken"]
|
||||||
|
|
||||||
|
device_id, user_id = await get_device_id(access_token, homeserver)
|
||||||
|
|
||||||
|
device_id = config["Matrix"].get("DeviceID", device_id)
|
||||||
|
user_id = config["Matrix"].get("UserID", user_id)
|
||||||
|
|
||||||
|
# Set up database
|
||||||
|
if "Database" in config and config["Database"].get("Path"):
|
||||||
|
database = CONTEXT["database"] = initialize_database(
|
||||||
|
config["Database"]["Path"])
|
||||||
|
matrix_store = DuckDBStore
|
||||||
|
|
||||||
|
client_config = AsyncClientConfig(
|
||||||
|
store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
|
||||||
|
|
||||||
|
else:
|
||||||
|
client_config = AsyncClientConfig(
|
||||||
|
store_sync_tokens=True, encryption_enabled=False)
|
||||||
|
|
||||||
|
client = AsyncClient(
|
||||||
|
config["Matrix"]["Homeserver"], config=client_config)
|
||||||
|
|
||||||
|
if client.config.encryption_enabled:
|
||||||
|
client.store = client.config.store(
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
database
|
||||||
|
)
|
||||||
|
assert client.store
|
||||||
|
|
||||||
|
client.olm = Olm(client.user_id, client.device_id, client.store)
|
||||||
|
client.encrypted_rooms = client.store.load_encrypted_rooms()
|
||||||
|
|
||||||
|
CONTEXT["client"] = client
|
||||||
|
|
||||||
|
CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
|
||||||
|
CONTEXT["client"].user_id = user_id
|
||||||
|
CONTEXT["client"].device_id = device_id
|
||||||
|
|
||||||
|
# Set up GPT API
|
||||||
|
try:
|
||||||
|
assert "OpenAI" in config
|
||||||
|
assert "APIKey" in config["OpenAI"]
|
||||||
|
except:
|
||||||
|
logging("OpenAI config not found or incomplete", "critical")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
openai.api_key = config["OpenAI"]["APIKey"]
|
||||||
|
|
||||||
|
if "Model" in config["OpenAI"]:
|
||||||
|
CONTEXT["model"] = config["OpenAI"]["Model"]
|
||||||
|
|
||||||
|
if "MaxTokens" in config["OpenAI"]:
|
||||||
|
CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
|
||||||
|
|
||||||
|
if "MaxMessages" in config["OpenAI"]:
|
||||||
|
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
|
||||||
|
|
||||||
|
# Listen for SIGTERM
|
||||||
|
|
||||||
|
def sigterm_handler(_signo, _stack_frame):
|
||||||
|
logging("Received SIGTERM - exiting...")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
|
||||||
|
if not client and not CONTEXT.get("client"):
|
||||||
|
await init(config)
|
||||||
|
|
||||||
|
client = client or CONTEXT["client"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert client.user_id
|
assert client.user_id
|
||||||
|
@ -285,7 +440,9 @@ async def main(client: Optional[AsyncClient] = None):
|
||||||
await client.sync(timeout=30000)
|
await client.sync(timeout=30000)
|
||||||
|
|
||||||
client.add_event_callback(message_callback, RoomMessageText)
|
client.add_event_callback(message_callback, RoomMessageText)
|
||||||
|
client.add_event_callback(message_callback, MegolmEvent)
|
||||||
client.add_event_callback(room_invite_callback, InviteEvent)
|
client.add_event_callback(room_invite_callback, InviteEvent)
|
||||||
|
client.add_event_callback(test_callback, Event)
|
||||||
|
|
||||||
await accept_pending_invites() # Accept pending invites
|
await accept_pending_invites() # Accept pending invites
|
||||||
|
|
||||||
|
@ -351,6 +508,31 @@ def initialize_database(path: os.PathLike):
|
||||||
return database
|
return database
|
||||||
|
|
||||||
|
|
||||||
|
async def get_device_id(access_token, homeserver):
|
||||||
|
client = AsyncClient(homeserver)
|
||||||
|
client.access_token = access_token
|
||||||
|
|
||||||
|
logging(f"Obtaining device ID for access token {access_token}...", "debug")
|
||||||
|
response = await client.whoami()
|
||||||
|
if isinstance(response, WhoamiResponse):
|
||||||
|
logging(
|
||||||
|
f"Authenticated as {response.user_id}.")
|
||||||
|
user_id = response.user_id
|
||||||
|
devices = await client.devices()
|
||||||
|
device_id = devices.devices[0].id
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
return device_id, user_id
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging(f"Failed to obtain device ID: {response}", "error")
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
@ -362,54 +544,9 @@ if __name__ == "__main__":
|
||||||
config = ConfigParser()
|
config = ConfigParser()
|
||||||
config.read(args.config)
|
config.read(args.config)
|
||||||
|
|
||||||
# Set up Matrix client
|
|
||||||
try:
|
|
||||||
assert "Matrix" in config
|
|
||||||
assert "Homeserver" in config["Matrix"]
|
|
||||||
assert "AccessToken" in config["Matrix"]
|
|
||||||
except:
|
|
||||||
logging("Matrix config not found or incomplete", "critical")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
CONTEXT["client"] = AsyncClient(config["Matrix"]["Homeserver"])
|
|
||||||
|
|
||||||
CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
|
|
||||||
CONTEXT["client"].user_id = config["Matrix"].get("UserID")
|
|
||||||
|
|
||||||
# Set up GPT API
|
|
||||||
try:
|
|
||||||
assert "OpenAI" in config
|
|
||||||
assert "APIKey" in config["OpenAI"]
|
|
||||||
except:
|
|
||||||
logging("OpenAI config not found or incomplete", "critical")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
openai.api_key = config["OpenAI"]["APIKey"]
|
|
||||||
|
|
||||||
if "Model" in config["OpenAI"]:
|
|
||||||
CONTEXT["model"] = config["OpenAI"]["Model"]
|
|
||||||
|
|
||||||
if "MaxTokens" in config["OpenAI"]:
|
|
||||||
CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
|
|
||||||
|
|
||||||
if "MaxMessages" in config["OpenAI"]:
|
|
||||||
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
|
|
||||||
|
|
||||||
# Set up database
|
|
||||||
if "Database" in config and config["Database"].get("Path"):
|
|
||||||
CONTEXT["database"] = initialize_database(config["Database"]["Path"])
|
|
||||||
|
|
||||||
# Listen for SIGTERM
|
|
||||||
|
|
||||||
def sigterm_handler(_signo, _stack_frame):
|
|
||||||
logging("Received SIGTERM - exiting...")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
|
||||||
|
|
||||||
# Start bot loop
|
# Start bot loop
|
||||||
try:
|
try:
|
||||||
asyncio.run(main())
|
asyncio.run(main(config))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging("Received KeyboardInterrupt - exiting...")
|
logging("Received KeyboardInterrupt - exiting...")
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
|
|
Loading…
Reference in a new issue