Some refactoring, starting implementation of encryption
This commit is contained in:
parent
202bed25c6
commit
f20b762558
11 changed files with 976 additions and 116 deletions
1
classes/__init__.py
Normal file
1
classes/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .store import DuckDBStore
|
730
classes/store.py
Normal file
730
classes/store.py
Normal file
|
@ -0,0 +1,730 @@
|
|||
import duckdb
|
||||
|
||||
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore
|
||||
from nio.crypto import OlmAccount, OlmDevice
|
||||
|
||||
from random import SystemRandom
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class DuckDBStore(MatrixStore):
|
||||
@property
|
||||
def account_id(self):
|
||||
id = self._get_account()[0] if self._get_account() else None
|
||||
|
||||
if id is None:
|
||||
id = SystemRandom().randint(0, 2**16)
|
||||
|
||||
return id
|
||||
|
||||
def __init__(self, user_id, device_id, duckdb_conn):
|
||||
self.conn = duckdb_conn
|
||||
self.user_id = user_id
|
||||
self.device_id = device_id
|
||||
self._create_tables()
|
||||
|
||||
def _create_tables(self):
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
DROP TABLE IF EXISTS sync_tokens CASCADE;
|
||||
DROP TABLE IF EXISTS encrypted_rooms CASCADE;
|
||||
DROP TABLE IF EXISTS outgoing_key_requests CASCADE;
|
||||
DROP TABLE IF EXISTS forwarded_chains CASCADE;
|
||||
DROP TABLE IF EXISTS outbound_group_sessions CASCADE;
|
||||
DROP TABLE IF EXISTS inbound_group_sessions CASCADE;
|
||||
DROP TABLE IF EXISTS olm_sessions CASCADE;
|
||||
DROP TABLE IF EXISTS device_trust_state CASCADE;
|
||||
DROP TABLE IF EXISTS keys CASCADE;
|
||||
DROP TABLE IF EXISTS device_keys_key CASCADE;
|
||||
DROP TABLE IF EXISTS device_keys CASCADE;
|
||||
DROP TABLE IF EXISTS accounts CASCADE;
|
||||
""")
|
||||
|
||||
# Create accounts table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id VARCHAR NOT NULL,
|
||||
device_id VARCHAR NOT NULL,
|
||||
shared_account INTEGER NOT NULL,
|
||||
pickle VARCHAR NOT NULL
|
||||
);
|
||||
""")
|
||||
|
||||
# Create device_keys table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS device_keys (
|
||||
device_id TEXT PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
deleted BOOLEAN NOT NULL DEFAULT 0,
|
||||
UNIQUE (account_id, user_id, device_id),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS keys (
|
||||
key_type TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
device_id VARCHAR NOT NULL,
|
||||
UNIQUE (key_type, device_id),
|
||||
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create device_trust_state table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS device_trust_state (
|
||||
device_id VARCHAR PRIMARY KEY,
|
||||
state INTEGER NOT NULL,
|
||||
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create olm_sessions table
|
||||
cursor.execute("""
|
||||
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS olm_sessions (
|
||||
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
|
||||
account_id INTEGER NOT NULL,
|
||||
sender_key TEXT NOT NULL,
|
||||
session BLOB NOT NULL,
|
||||
session_id VARCHAR NOT NULL,
|
||||
creation_time TIMESTAMP NOT NULL,
|
||||
last_usage_date TIMESTAMP NOT NULL,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create inbound_group_sessions table
|
||||
cursor.execute("""
|
||||
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
|
||||
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
|
||||
account_id INTEGER NOT NULL,
|
||||
session TEXT NOT NULL,
|
||||
fp_key TEXT NOT NULL,
|
||||
sender_key TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
UNIQUE (account_id, sender_key, fp_key, room_id),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS forwarded_chains (
|
||||
id INTEGER PRIMARY KEY,
|
||||
session_id INTEGER NOT NULL,
|
||||
sender_key TEXT NOT NULL,
|
||||
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create outbound_group_sessions table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
|
||||
id INTEGER PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL,
|
||||
room_id VARCHAR NOT NULL,
|
||||
session_id VARCHAR NOT NULL UNIQUE,
|
||||
session BLOB NOT NULL,
|
||||
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create outgoing_key_requests table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
|
||||
id INTEGER PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL,
|
||||
request_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
UNIQUE (account_id, request_id)
|
||||
);
|
||||
|
||||
""")
|
||||
|
||||
# Create encrypted_rooms table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS encrypted_rooms (
|
||||
room_id TEXT NOT NULL,
|
||||
account_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (room_id, account_id),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create sync_tokens table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sync_tokens (
|
||||
account_id INTEGER PRIMARY KEY,
|
||||
token TEXT NOT NULL,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
def _get_account(self):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM accounts WHERE user_id = ? AND device_id = ?",
|
||||
(self.user_id, self.device_id),
|
||||
)
|
||||
account = cursor.fetchone()
|
||||
cursor.close()
|
||||
return account
|
||||
|
||||
def _get_device(self, device):
|
||||
acc = self._get_account()
|
||||
|
||||
if not acc:
|
||||
return None
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?",
|
||||
(device.user_id, device.id, acc[0]),
|
||||
)
|
||||
device_entry = cursor.fetchone()
|
||||
cursor.close()
|
||||
|
||||
return device_entry
|
||||
|
||||
# Implementing methods with DuckDB equivalents
|
||||
def verify_device(self, device):
|
||||
if self.is_device_verified(device):
|
||||
return False
|
||||
|
||||
d = self._get_device(device)
|
||||
assert d
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||
(d[0], TrustState.verified),
|
||||
)
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
|
||||
device.trust_state = TrustState.verified
|
||||
|
||||
return True
|
||||
|
||||
def unverify_device(self, device):
|
||||
if not self.is_device_verified(device):
|
||||
return False
|
||||
|
||||
d = self._get_device(device)
|
||||
assert d
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||
(d[0], TrustState.unset),
|
||||
)
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
|
||||
device.trust_state = TrustState.unset
|
||||
|
||||
return True
|
||||
|
||||
def is_device_verified(self, device):
|
||||
d = self._get_device(device)
|
||||
|
||||
if not d:
|
||||
return False
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
||||
)
|
||||
trust_state = cursor.fetchone()
|
||||
cursor.close()
|
||||
|
||||
if not trust_state:
|
||||
return False
|
||||
|
||||
return trust_state[0] == TrustState.verified
|
||||
|
||||
def blacklist_device(self, device):
|
||||
if self.is_device_blacklisted(device):
|
||||
return False
|
||||
|
||||
d = self._get_device(device)
|
||||
assert d
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||
(d[0], TrustState.blacklisted),
|
||||
)
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
|
||||
device.trust_state = TrustState.blacklisted
|
||||
|
||||
return True
|
||||
|
||||
def unblacklist_device(self, device):
|
||||
if not self.is_device_blacklisted(device):
|
||||
return False
|
||||
|
||||
d = self._get_device(device)
|
||||
assert d
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||
(d[0], TrustState.unset),
|
||||
)
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
|
||||
device.trust_state = TrustState.unset
|
||||
|
||||
return True
|
||||
|
||||
def is_device_blacklisted(self, device):
|
||||
d = self._get_device(device)
|
||||
|
||||
if not d:
|
||||
return False
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
||||
)
|
||||
trust_state = cursor.fetchone()
|
||||
cursor.close()
|
||||
|
||||
if not trust_state:
|
||||
return False
|
||||
|
||||
return trust_state[0] == TrustState.blacklisted
|
||||
|
||||
def ignore_device(self, device):
|
||||
if self.is_device_ignored(device):
|
||||
return False
|
||||
|
||||
d = self._get_device(device)
|
||||
assert d
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||
(d[0], int(TrustState.ignored.value)),
|
||||
)
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
|
||||
return True
|
||||
|
||||
def ignore_devices(self, devices):
|
||||
for device in devices:
|
||||
self.ignore_device(device)
|
||||
|
||||
def unignore_device(self, device):
|
||||
if not self.is_device_ignored(device):
|
||||
return False
|
||||
|
||||
d = self._get_device(device)
|
||||
assert d
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||||
(d[0], TrustState.unset),
|
||||
)
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
|
||||
device.trust_state = TrustState.unset
|
||||
|
||||
return True
|
||||
|
||||
def is_device_ignored(self, device):
|
||||
d = self._get_device(device)
|
||||
|
||||
if not d:
|
||||
return False
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
||||
)
|
||||
trust_state = cursor.fetchone()
|
||||
cursor.close()
|
||||
|
||||
if not trust_state:
|
||||
return False
|
||||
|
||||
return trust_state[0] == TrustState.ignored
|
||||
|
||||
def load_device_keys(self):
|
||||
"""Load all the device keys from the database.
|
||||
|
||||
Returns DeviceStore containing the OlmDevices with the device keys.
|
||||
"""
|
||||
store = DeviceStore()
|
||||
account = self.account_id
|
||||
|
||||
if not account:
|
||||
return store
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM device_keys WHERE account_id = ?",
|
||||
(account,)
|
||||
)
|
||||
device_keys = cur.fetchall()
|
||||
|
||||
for d in device_keys:
|
||||
cur.execute(
|
||||
"SELECT * FROM keys WHERE device_id = ?",
|
||||
(d["id"],)
|
||||
)
|
||||
keys = cur.fetchall()
|
||||
key_dict = {k["key_type"]: k["key"] for k in keys}
|
||||
|
||||
store.add(
|
||||
OlmDevice(
|
||||
d["user_id"],
|
||||
d["device_id"],
|
||||
key_dict,
|
||||
display_name=d["display_name"],
|
||||
deleted=d["deleted"],
|
||||
)
|
||||
)
|
||||
|
||||
return store
|
||||
|
||||
def save_device_keys(self, device_keys):
|
||||
"""Save the provided device keys to the database."""
|
||||
account = self.account_id
|
||||
assert account
|
||||
rows = []
|
||||
|
||||
for user_id, devices_dict in device_keys.items():
|
||||
for device_id, device in devices_dict.items():
|
||||
rows.append(
|
||||
{
|
||||
"account_id": account,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"display_name": device.display_name,
|
||||
"deleted": device.deleted,
|
||||
}
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
for idx in range(0, len(rows), 100):
|
||||
data = rows[idx: idx + 100]
|
||||
cur.executemany(
|
||||
"INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)",
|
||||
[(r["account_id"], r["user_id"], r["device_id"],
|
||||
r["display_name"], r["deleted"]) for r in data]
|
||||
)
|
||||
|
||||
for user_id, devices_dict in device_keys.items():
|
||||
for device_id, device in devices_dict.items():
|
||||
cur.execute(
|
||||
"UPDATE device_keys SET deleted = ? WHERE device_id = ?",
|
||||
(device.deleted, device_id)
|
||||
)
|
||||
|
||||
for key_type, key in device.keys.items():
|
||||
cur.execute("""
|
||||
INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?)
|
||||
ON CONFLICT (key_type, device_id) DO UPDATE SET key = ?
|
||||
""",
|
||||
(key_type, key, device_id, key)
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def save_group_sessions(self, sessions):
|
||||
with self.conn.cursor() as cur:
|
||||
for session in sessions:
|
||||
cur.execute("""
|
||||
INSERT OR REPLACE INTO inbound_group_sessions (
|
||||
session_id, sender_key, signing_key, room_id, pickle, account_id
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
session.id,
|
||||
session.sender_key,
|
||||
session.signing_key,
|
||||
session.room_id,
|
||||
session.pickle,
|
||||
self.account_id
|
||||
))
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def save_olm_sessions(self, sessions):
|
||||
with self.conn.cursor() as cur:
|
||||
for session in sessions:
|
||||
cur.execute("""
|
||||
INSERT OR REPLACE INTO olm_sessions (
|
||||
session_id, sender_key, pickle, account_id
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""", (
|
||||
session.id,
|
||||
session.sender_key,
|
||||
session.pickle,
|
||||
self.account_id
|
||||
))
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def save_outbound_group_sessions(self, sessions):
|
||||
with self.conn.cursor() as cur:
|
||||
for session in sessions:
|
||||
cur.execute("""
|
||||
INSERT OR REPLACE INTO outbound_group_sessions (
|
||||
room_id, session_id, pickle, account_id
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""", (
|
||||
session.room_id,
|
||||
session.id,
|
||||
session.pickle,
|
||||
self.account_id
|
||||
))
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def save_account(self, account: OlmAccount):
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
INSERT OR REPLACE INTO accounts (
|
||||
id, user_id, device_id, shared_account, pickle
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
""", (
|
||||
self.account_id,
|
||||
self.user_id,
|
||||
self.device_id,
|
||||
account.shared,
|
||||
account.pickle(self.pickle_key),
|
||||
))
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def load_sessions(self):
|
||||
session_store = SessionStore()
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
SELECT
|
||||
os.sender_key, os.session, os.creation_time
|
||||
FROM
|
||||
olm_sessions os
|
||||
INNER JOIN
|
||||
accounts a ON os.account_id = a.id
|
||||
WHERE
|
||||
a.id = ?
|
||||
""", (self.account_id,))
|
||||
|
||||
for row in cur.fetchall():
|
||||
sender_key, session_pickle, creation_time = row
|
||||
session = Session.from_pickle(
|
||||
session_pickle, creation_time, self.pickle_key)
|
||||
session_store.add(sender_key, session)
|
||||
|
||||
return session_store
|
||||
|
||||
def load_inbound_group_sessions(self):
|
||||
# type: () -> GroupSessionStore
|
||||
"""Load all Olm sessions from the database.
|
||||
|
||||
Returns:
|
||||
``GroupSessionStore`` object, containing all the loaded sessions.
|
||||
|
||||
"""
|
||||
store = GroupSessionStore()
|
||||
|
||||
account = self.account_id
|
||||
|
||||
if not account:
|
||||
return store
|
||||
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT * FROM inbound_group_sessions WHERE account_id = ?", (
|
||||
account,)
|
||||
)
|
||||
|
||||
for row in cursor.fetchall():
|
||||
session = InboundGroupSession.from_pickle(
|
||||
row["session"],
|
||||
row["fp_key"],
|
||||
row["sender_key"],
|
||||
row["room_id"],
|
||||
self.pickle_key,
|
||||
[
|
||||
chain["sender_key"]
|
||||
for chain in cursor.execute(
|
||||
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
|
||||
(row["id"],),
|
||||
)
|
||||
],
|
||||
)
|
||||
store.add(session)
|
||||
|
||||
return store
|
||||
|
||||
def load_outgoing_key_requests(self):
|
||||
# type: () -> dict
|
||||
"""Load all outgoing key requests from the database.
|
||||
|
||||
Returns:
|
||||
``OutgoingKeyRequestStore`` object, containing all the loaded key requests.
|
||||
"""
|
||||
account = self.account_id
|
||||
|
||||
if not account:
|
||||
return store
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM outgoing_key_requests WHERE account_id = ?",
|
||||
(account,)
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
return {
|
||||
request.request_id: OutgoingKeyRequest.from_database(request)
|
||||
for request in rows
|
||||
}
|
||||
|
||||
def load_encrypted_rooms(self):
|
||||
"""Load the set of encrypted rooms for this account.
|
||||
|
||||
Returns:
|
||||
``Set`` containing room ids of encrypted rooms.
|
||||
"""
|
||||
account = self.account_id
|
||||
|
||||
if not account:
|
||||
return set()
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT room_id FROM encrypted_rooms WHERE account_id = ?",
|
||||
(account,)
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
return {row["room_id"] for row in rows}
|
||||
|
||||
def save_sync_token(self, token):
|
||||
"""Save the given token"""
|
||||
account = self.account_id
|
||||
assert account
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)",
|
||||
(account, token)
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def save_encrypted_rooms(self, rooms):
|
||||
"""Save the set of room ids for this account."""
|
||||
account = self.account_id
|
||||
assert account
|
||||
|
||||
data = [(room_id, account) for room_id in rooms]
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
for idx in range(0, len(data), 400):
|
||||
rows = data[idx: idx + 400]
|
||||
cur.executemany(
|
||||
"INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)",
|
||||
rows
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def save_session(self, sender_key, session):
|
||||
"""Save the provided Olm session to the database.
|
||||
|
||||
Args:
|
||||
sender_key (str): The curve key that owns the Olm session.
|
||||
session (Session): The Olm session that will be pickled and
|
||||
saved in the database.
|
||||
"""
|
||||
account = self.account_id
|
||||
assert account
|
||||
|
||||
pickled_session = session.pickle(self.pickle_key)
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(account, sender_key, pickled_session, session.id,
|
||||
session.creation_time, session.use_time)
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def save_inbound_group_session(self, session):
|
||||
"""Save the provided Megolm inbound group session to the database.
|
||||
|
||||
Args:
|
||||
session (InboundGroupSession): The session to save.
|
||||
"""
|
||||
account = self.account_id
|
||||
assert account
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
|
||||
# Insert a new session or update the existing one
|
||||
query = """
|
||||
INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT (account_id, sender_key, fp_key, room_id)
|
||||
DO UPDATE SET session = excluded.session
|
||||
"""
|
||||
cur.execute(query, (account, session.sender_key,
|
||||
session.ed25519, session.room_id, session.pickle(self.pickle_key)))
|
||||
|
||||
# Delete existing forwarded chains for the session
|
||||
delete_query = """
|
||||
DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?)
|
||||
"""
|
||||
cur.execute(
|
||||
delete_query, (account, session.sender_key, session.ed25519, session.room_id))
|
||||
|
||||
# Insert new forwarded chains for the session
|
||||
insert_query = """
|
||||
INSERT INTO forwarded_chains (session_id, sender_key)
|
||||
VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?)
|
||||
"""
|
||||
|
||||
for chain in session.forwarding_chain:
|
||||
cur.execute(
|
||||
insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain))
|
||||
|
||||
def add_outgoing_key_request(self, key_request):
|
||||
account_id = self.account_id
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO outgoing_key_requests (account_id, request_id, session_id, room_id, algorithm)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT (account_id, request_id) DO NOTHING
|
||||
""",
|
||||
(
|
||||
account_id,
|
||||
key_request.request_id,
|
||||
key_request.session_id,
|
||||
key_request.room_id,
|
||||
key_request.algorithm,
|
||||
)
|
||||
)
|
|
@ -1,12 +1,12 @@
|
|||
from nio.events.room_events import RoomMessageText
|
||||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_botinfo(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
logging("Showing bot info...")
|
||||
|
||||
await context["client"].room_send(
|
||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": f"""GPT Info:
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": f"""GPT Info:
|
||||
|
||||
Model: {context["model"]}
|
||||
Maximum context tokens: {context["max_tokens"]}
|
||||
|
@ -19,4 +19,4 @@ Bot user ID: {context["client"].user_id}
|
|||
Current room ID: {room.room_id}
|
||||
|
||||
For usage statistics, run !gptbot stats
|
||||
"""})
|
||||
"""}
|
||||
|
|
|
@ -3,12 +3,11 @@ from nio.rooms import MatrixRoom
|
|||
|
||||
from random import SystemRandom
|
||||
|
||||
|
||||
async def command_coin(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
context["logger"]("Flipping a coin...")
|
||||
|
||||
heads = SystemRandom().choice([True, False])
|
||||
|
||||
await context["client"].room_send(
|
||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": "Heads!" if heads else "Tails!"}
|
||||
)
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": "Heads!" if heads else "Tails!"}
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from nio.events.room_events import RoomMessageText
|
||||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
await context["client"].room_send(
|
||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": """Available commands:
|
||||
return room.guest_accessroom_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": """Available commands:
|
||||
|
||||
!gptbot help - Show this message
|
||||
!gptbot newroom <room name> - Create a new room and invite yourself to it
|
||||
|
@ -13,4 +13,3 @@ async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
|
|||
!gptbot coin - Flip a coin (heads or tails)
|
||||
!gptbot ignoreolder - Ignore messages before this point as context
|
||||
"""}
|
||||
)
|
|
@ -2,9 +2,7 @@ from nio.events.room_events import RoomMessageText
|
|||
from nio.rooms import MatrixRoom
|
||||
|
||||
async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
await context["client"].room_send(
|
||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": """Alright, messages before this point will not be processed as context anymore.
|
||||
|
||||
If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""}
|
||||
)
|
|
@ -1,8 +1,10 @@
|
|||
from nio.events.room_events import RoomMessageText
|
||||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
room_name = " ".join(event.body.split()[2:]) or context["default_room_name"]
|
||||
room_name = " ".join(event.body.split()[
|
||||
2:]) or context["default_room_name"]
|
||||
|
||||
context["logger"]("Creating new room...")
|
||||
new_room = await context["client"].room_create(name=room_name)
|
||||
|
@ -12,5 +14,4 @@ async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dic
|
|||
await context["client"].room_put_state(
|
||||
new_room.room_id, "m.room.power_levels", {"users": {event.sender: 100}})
|
||||
|
||||
await context["client"].room_send(
|
||||
new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"})
|
||||
return new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"}
|
||||
|
|
|
@ -1,23 +1,19 @@
|
|||
from nio.events.room_events import RoomMessageText
|
||||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_stats(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
context["logger"]("Showing stats...")
|
||||
|
||||
if not (database := context.get("database")):
|
||||
context["logger"]("No database connection - cannot show stats")
|
||||
context["client"].room_send(
|
||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."}
|
||||
)
|
||||
return
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."}
|
||||
|
||||
with database.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,))
|
||||
total_tokens = cursor.fetchone()[0] or 0
|
||||
|
||||
await context["client"].room_send(
|
||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": f"Total tokens used: {total_tokens}"}
|
||||
)
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": f"Total tokens used: {total_tokens}"}
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from nio.events.room_events import RoomMessageText
|
||||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_unknown(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
context["logger"]("Unknown command")
|
||||
|
||||
await context["client"].room_send(
|
||||
room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": "Unknown command - try !gptbot help"}
|
||||
)
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": "Unknown command - try !gptbot help"}
|
||||
|
|
301
gptbot.py
301
gptbot.py
|
@ -3,6 +3,7 @@ import inspect
|
|||
import logging
|
||||
import signal
|
||||
import random
|
||||
import uuid
|
||||
|
||||
import openai
|
||||
import asyncio
|
||||
|
@ -10,9 +11,10 @@ import markdown2
|
|||
import tiktoken
|
||||
import duckdb
|
||||
|
||||
from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent
|
||||
from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent, AsyncClientConfig, MegolmEvent, GroupEncryptionError, EncryptionError, HttpClient, Api
|
||||
from nio.api import MessageDirection
|
||||
from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError
|
||||
from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError, WhoamiResponse, JoinResponse, RoomSendResponse
|
||||
from nio.crypto import Olm
|
||||
|
||||
from configparser import ConfigParser
|
||||
from datetime import datetime
|
||||
|
@ -20,6 +22,7 @@ from argparse import ArgumentParser
|
|||
from typing import List, Dict, Union, Optional
|
||||
|
||||
from commands import COMMANDS
|
||||
from classes import DuckDBStore
|
||||
|
||||
|
||||
def logging(message: str, log_level: str = "info"):
|
||||
|
@ -85,6 +88,13 @@ async def fetch_last_n_messages(room_id: str, n: Optional[int] = None,
|
|||
for event in response.chunk:
|
||||
if len(messages) >= n:
|
||||
break
|
||||
if isinstance(event, MegolmEvent):
|
||||
try:
|
||||
event = await client.decrypt_event(event)
|
||||
except (GroupEncryptionError, EncryptionError):
|
||||
logging(
|
||||
f"Could not decrypt message {event.event_id} in room {room_id}", "error")
|
||||
continue
|
||||
if isinstance(event, RoomMessageText):
|
||||
if event.body.startswith("!gptbot ignoreolder"):
|
||||
break
|
||||
|
@ -162,14 +172,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
|||
|
||||
# Convert markdown to HTML
|
||||
|
||||
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
|
||||
formatted_body = markdowner.convert(response)
|
||||
|
||||
message = await client.room_send(
|
||||
room.room_id, "m.room.message",
|
||||
{"msgtype": "m.text", "body": response,
|
||||
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
|
||||
)
|
||||
message = await send_message(room, response)
|
||||
|
||||
if database:
|
||||
logging("Logging tokens used...")
|
||||
|
@ -183,11 +186,8 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
|||
# Send a notice to the room if there was an error
|
||||
|
||||
logging("Error during GPT API call - sending notice to room")
|
||||
|
||||
await client.room_send(
|
||||
room.room_id, "m.room.message", {
|
||||
"msgtype": "m.notice", "body": "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later."}
|
||||
)
|
||||
send_message(
|
||||
room, "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later.", True)
|
||||
print("No response from GPT API")
|
||||
|
||||
await client.room_typing(room.room_id, False)
|
||||
|
@ -199,14 +199,34 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt
|
|||
logging(
|
||||
f"Received command {event.body} from {event.sender} in room {room.room_id}")
|
||||
command = event.body.split()[1] if event.body.split()[1:] else None
|
||||
await COMMANDS.get(command, COMMANDS[None])(room, event, context)
|
||||
|
||||
message = await COMMANDS.get(command, COMMANDS[None])(room, event, context)
|
||||
|
||||
if message:
|
||||
room_id, event, content = message
|
||||
await send_message(context["client"].rooms[room_id], content["body"],
|
||||
True if content["msgtype"] == "m.notice" else False, context["client"])
|
||||
|
||||
|
||||
async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
||||
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
|
||||
context = kwargs.get("context") or CONTEXT
|
||||
|
||||
logging(f"Received message from {event.sender} in room {room.room_id}")
|
||||
|
||||
if isinstance(event, MegolmEvent):
|
||||
try:
|
||||
event = await context["client"].decrypt_event(event)
|
||||
except Exception as e:
|
||||
try:
|
||||
logging("Requesting new encryption keys...")
|
||||
await context["client"].request_room_key(event)
|
||||
except:
|
||||
pass
|
||||
|
||||
logging(f"Error decrypting message: {e}", "error")
|
||||
await send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True, context["client"])
|
||||
return
|
||||
|
||||
if event.sender == context["client"].user_id:
|
||||
logging("Message is from bot itself - ignoring")
|
||||
|
||||
|
@ -221,18 +241,69 @@ async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
|||
|
||||
|
||||
async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs):
|
||||
client = kwargs.get("client") or CONTEXT["client"]
|
||||
client: AsyncClient = kwargs.get("client") or CONTEXT["client"]
|
||||
|
||||
if room.room_id in client.rooms:
|
||||
logging(f"Already in room {room.room_id} - ignoring invite")
|
||||
return
|
||||
|
||||
logging(f"Received invite to room {room.room_id} - joining...")
|
||||
|
||||
await client.join(room.room_id)
|
||||
await client.room_send(
|
||||
room.room_id,
|
||||
"m.room.message",
|
||||
{"msgtype": "m.text",
|
||||
"body": "Hello! I'm a helpful assistant. How can I help you today?"}
|
||||
response = await client.join(room.room_id)
|
||||
if isinstance(response, JoinResponse):
|
||||
await send_message(room, "Hello! I'm a helpful assistant. How can I help you today?", client)
|
||||
else:
|
||||
logging(f"Error joining room {room.room_id}: {response}", "error")
|
||||
|
||||
|
||||
async def send_message(room: MatrixRoom, message: str, notice: bool = False, client: Optional[AsyncClient] = None):
|
||||
client = client or CONTEXT["client"]
|
||||
|
||||
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
|
||||
formatted_body = markdowner.convert(message)
|
||||
|
||||
msgtype = "m.notice" if notice else "m.text"
|
||||
|
||||
msgcontent = {"msgtype": msgtype, "body": message,
|
||||
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
|
||||
|
||||
content = None
|
||||
|
||||
if client.olm and room.encrypted:
|
||||
try:
|
||||
if not room.members_synced:
|
||||
responses = []
|
||||
responses.append(await client.joined_members(room.room_id))
|
||||
|
||||
if client.olm.should_share_group_session(room.room_id):
|
||||
try:
|
||||
event = client.sharing_session[room.room_id]
|
||||
await event.wait()
|
||||
except KeyError:
|
||||
await client.share_group_session(
|
||||
room.room_id,
|
||||
ignore_unverified_devices=True,
|
||||
)
|
||||
|
||||
if msgtype != "m.reaction":
|
||||
response = client.encrypt(room.room_id, "m.room.message", msgcontent)
|
||||
msgtype, content = response
|
||||
|
||||
except Exception as e:
|
||||
logging(
|
||||
f"Error encrypting message: {e} - sending unencrypted", "error")
|
||||
raise
|
||||
|
||||
if not content:
|
||||
msgtype = "m.room.message"
|
||||
content = msgcontent
|
||||
|
||||
method, path, data = Api.room_send(
|
||||
client.access_token, room.room_id, msgtype, content, uuid.uuid4()
|
||||
)
|
||||
|
||||
return await client._send(RoomSendResponse, method, path, data, (room.room_id,))
|
||||
|
||||
|
||||
async def accept_pending_invites(client: Optional[AsyncClient] = None):
|
||||
client = client or CONTEXT["client"]
|
||||
|
@ -242,13 +313,14 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None):
|
|||
for room_id in list(client.invited_rooms.keys()):
|
||||
logging(f"Joining room {room_id}...")
|
||||
|
||||
await client.join(room_id)
|
||||
await client.room_send(
|
||||
room_id,
|
||||
"m.room.message",
|
||||
{"msgtype": "m.text",
|
||||
"body": "Hello! I'm a helpful assistant. How can I help you today?"}
|
||||
)
|
||||
response = await client.join(room_id)
|
||||
|
||||
if isinstance(response, JoinResponse):
|
||||
logging(response, "debug")
|
||||
rooms = await client.joined_rooms()
|
||||
await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
|
||||
else:
|
||||
logging(f"Error joining room {room_id}: {response}", "error")
|
||||
|
||||
|
||||
async def sync_cb(response, write_global: bool = True):
|
||||
|
@ -261,12 +333,95 @@ async def sync_cb(response, write_global: bool = True):
|
|||
CONTEXT["sync_token"] = SYNC_TOKEN
|
||||
|
||||
|
||||
async def main(client: Optional[AsyncClient] = None):
|
||||
client = client or CONTEXT["client"]
|
||||
async def test_callback(room: MatrixRoom, event: Event, **kwargs):
|
||||
logging(
|
||||
f"Received event {event.__class__.__name__} in room {room.room_id}", "debug")
|
||||
|
||||
if not client.user_id:
|
||||
whoami = await client.whoami()
|
||||
client.user_id = whoami.user_id
|
||||
|
||||
async def init(config: ConfigParser):
|
||||
# Set up Matrix client
|
||||
try:
|
||||
assert "Matrix" in config
|
||||
assert "Homeserver" in config["Matrix"]
|
||||
assert "AccessToken" in config["Matrix"]
|
||||
except:
|
||||
logging("Matrix config not found or incomplete", "critical")
|
||||
exit(1)
|
||||
|
||||
homeserver = config["Matrix"]["Homeserver"]
|
||||
access_token = config["Matrix"]["AccessToken"]
|
||||
|
||||
device_id, user_id = await get_device_id(access_token, homeserver)
|
||||
|
||||
device_id = config["Matrix"].get("DeviceID", device_id)
|
||||
user_id = config["Matrix"].get("UserID", user_id)
|
||||
|
||||
# Set up database
|
||||
if "Database" in config and config["Database"].get("Path"):
|
||||
database = CONTEXT["database"] = initialize_database(
|
||||
config["Database"]["Path"])
|
||||
matrix_store = DuckDBStore
|
||||
|
||||
client_config = AsyncClientConfig(
|
||||
store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
|
||||
|
||||
else:
|
||||
client_config = AsyncClientConfig(
|
||||
store_sync_tokens=True, encryption_enabled=False)
|
||||
|
||||
client = AsyncClient(
|
||||
config["Matrix"]["Homeserver"], config=client_config)
|
||||
|
||||
if client.config.encryption_enabled:
|
||||
client.store = client.config.store(
|
||||
user_id,
|
||||
device_id,
|
||||
database
|
||||
)
|
||||
assert client.store
|
||||
|
||||
client.olm = Olm(client.user_id, client.device_id, client.store)
|
||||
client.encrypted_rooms = client.store.load_encrypted_rooms()
|
||||
|
||||
CONTEXT["client"] = client
|
||||
|
||||
CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
|
||||
CONTEXT["client"].user_id = user_id
|
||||
CONTEXT["client"].device_id = device_id
|
||||
|
||||
# Set up GPT API
|
||||
try:
|
||||
assert "OpenAI" in config
|
||||
assert "APIKey" in config["OpenAI"]
|
||||
except:
|
||||
logging("OpenAI config not found or incomplete", "critical")
|
||||
exit(1)
|
||||
|
||||
openai.api_key = config["OpenAI"]["APIKey"]
|
||||
|
||||
if "Model" in config["OpenAI"]:
|
||||
CONTEXT["model"] = config["OpenAI"]["Model"]
|
||||
|
||||
if "MaxTokens" in config["OpenAI"]:
|
||||
CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
|
||||
|
||||
if "MaxMessages" in config["OpenAI"]:
|
||||
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
|
||||
|
||||
# Listen for SIGTERM
|
||||
|
||||
def sigterm_handler(_signo, _stack_frame):
|
||||
logging("Received SIGTERM - exiting...")
|
||||
exit()
|
||||
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
|
||||
|
||||
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
|
||||
if not client and not CONTEXT.get("client"):
|
||||
await init(config)
|
||||
|
||||
client = client or CONTEXT["client"]
|
||||
|
||||
try:
|
||||
assert client.user_id
|
||||
|
@ -285,7 +440,9 @@ async def main(client: Optional[AsyncClient] = None):
|
|||
await client.sync(timeout=30000)
|
||||
|
||||
client.add_event_callback(message_callback, RoomMessageText)
|
||||
client.add_event_callback(message_callback, MegolmEvent)
|
||||
client.add_event_callback(room_invite_callback, InviteEvent)
|
||||
client.add_event_callback(test_callback, Event)
|
||||
|
||||
await accept_pending_invites() # Accept pending invites
|
||||
|
||||
|
@ -351,6 +508,31 @@ def initialize_database(path: os.PathLike):
|
|||
return database
|
||||
|
||||
|
||||
async def get_device_id(access_token, homeserver):
|
||||
client = AsyncClient(homeserver)
|
||||
client.access_token = access_token
|
||||
|
||||
logging(f"Obtaining device ID for access token {access_token}...", "debug")
|
||||
response = await client.whoami()
|
||||
if isinstance(response, WhoamiResponse):
|
||||
logging(
|
||||
f"Authenticated as {response.user_id}.")
|
||||
user_id = response.user_id
|
||||
devices = await client.devices()
|
||||
device_id = devices.devices[0].id
|
||||
|
||||
await client.close()
|
||||
|
||||
return device_id, user_id
|
||||
|
||||
else:
|
||||
logging(f"Failed to obtain device ID: {response}", "error")
|
||||
|
||||
await client.close()
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse command line arguments
|
||||
parser = ArgumentParser()
|
||||
|
@ -362,54 +544,9 @@ if __name__ == "__main__":
|
|||
config = ConfigParser()
|
||||
config.read(args.config)
|
||||
|
||||
# Set up Matrix client
|
||||
try:
|
||||
assert "Matrix" in config
|
||||
assert "Homeserver" in config["Matrix"]
|
||||
assert "AccessToken" in config["Matrix"]
|
||||
except:
|
||||
logging("Matrix config not found or incomplete", "critical")
|
||||
exit(1)
|
||||
|
||||
CONTEXT["client"] = AsyncClient(config["Matrix"]["Homeserver"])
|
||||
|
||||
CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
|
||||
CONTEXT["client"].user_id = config["Matrix"].get("UserID")
|
||||
|
||||
# Set up GPT API
|
||||
try:
|
||||
assert "OpenAI" in config
|
||||
assert "APIKey" in config["OpenAI"]
|
||||
except:
|
||||
logging("OpenAI config not found or incomplete", "critical")
|
||||
exit(1)
|
||||
|
||||
openai.api_key = config["OpenAI"]["APIKey"]
|
||||
|
||||
if "Model" in config["OpenAI"]:
|
||||
CONTEXT["model"] = config["OpenAI"]["Model"]
|
||||
|
||||
if "MaxTokens" in config["OpenAI"]:
|
||||
CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
|
||||
|
||||
if "MaxMessages" in config["OpenAI"]:
|
||||
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
|
||||
|
||||
# Set up database
|
||||
if "Database" in config and config["Database"].get("Path"):
|
||||
CONTEXT["database"] = initialize_database(config["Database"]["Path"])
|
||||
|
||||
# Listen for SIGTERM
|
||||
|
||||
def sigterm_handler(_signo, _stack_frame):
|
||||
logging("Received SIGTERM - exiting...")
|
||||
exit()
|
||||
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
|
||||
# Start bot loop
|
||||
try:
|
||||
asyncio.run(main())
|
||||
asyncio.run(main(config))
|
||||
except KeyboardInterrupt:
|
||||
logging("Received KeyboardInterrupt - exiting...")
|
||||
except SystemExit:
|
||||
|
|
Loading…
Reference in a new issue