matrix-gptbot/classes/store.py
2023-05-02 06:58:49 +00:00

637 lines
19 KiB
Python

import duckdb
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore, Session
from nio.crypto import OlmAccount, OlmDevice
from random import SystemRandom
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from .dict import AttrDict
import json
class DuckDBStore(MatrixStore):
@property
def account_id(self):
id = self._get_account()[0] if self._get_account() else None
if id is None:
id = SystemRandom().randint(0, 2**16)
return id
def __init__(self, user_id, device_id, duckdb_conn):
self.conn = duckdb_conn
self.user_id = user_id
self.device_id = device_id
def _get_account(self):
cursor = self.conn.cursor()
cursor.execute(
"SELECT * FROM accounts WHERE user_id = ? AND device_id = ?",
(self.user_id, self.device_id),
)
account = cursor.fetchone()
cursor.close()
return account
def _get_device(self, device):
acc = self._get_account()
if not acc:
return None
cursor = self.conn.cursor()
cursor.execute(
"SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?",
(device.user_id, device.id, acc[0]),
)
device_entry = cursor.fetchone()
cursor.close()
return device_entry
# Implementing methods with DuckDB equivalents
def verify_device(self, device):
if self.is_device_verified(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.verified),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.verified
return True
def unverify_device(self, device):
if not self.is_device_verified(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.unset),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.unset
return True
def is_device_verified(self, device):
d = self._get_device(device)
if not d:
return False
cursor = self.conn.cursor()
cursor.execute(
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
)
trust_state = cursor.fetchone()
cursor.close()
if not trust_state:
return False
return trust_state[0] == TrustState.verified
def blacklist_device(self, device):
if self.is_device_blacklisted(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.blacklisted),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.blacklisted
return True
def unblacklist_device(self, device):
if not self.is_device_blacklisted(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.unset),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.unset
return True
def is_device_blacklisted(self, device):
d = self._get_device(device)
if not d:
return False
cursor = self.conn.cursor()
cursor.execute(
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
)
trust_state = cursor.fetchone()
cursor.close()
if not trust_state:
return False
return trust_state[0] == TrustState.blacklisted
def ignore_device(self, device):
if self.is_device_ignored(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], int(TrustState.ignored.value)),
)
self.conn.commit()
cursor.close()
return True
def ignore_devices(self, devices):
for device in devices:
self.ignore_device(device)
def unignore_device(self, device):
if not self.is_device_ignored(device):
return False
d = self._get_device(device)
assert d
cursor = self.conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
(d[0], TrustState.unset),
)
self.conn.commit()
cursor.close()
device.trust_state = TrustState.unset
return True
def is_device_ignored(self, device):
d = self._get_device(device)
if not d:
return False
cursor = self.conn.cursor()
cursor.execute(
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
)
trust_state = cursor.fetchone()
cursor.close()
if not trust_state:
return False
return trust_state[0] == TrustState.ignored
def load_device_keys(self):
"""Load all the device keys from the database.
Returns DeviceStore containing the OlmDevices with the device keys.
"""
store = DeviceStore()
account = self.account_id
if not account:
return store
with self.conn.cursor() as cur:
cur.execute(
"SELECT * FROM device_keys WHERE account_id = ?",
(account,)
)
device_keys = cur.fetchall()
for d in device_keys:
cur.execute(
"SELECT * FROM keys WHERE device_id = ?",
(d[0],)
)
keys = cur.fetchall()
key_dict = {k[0]: k[1] for k in keys}
store.add(
OlmDevice(
d[2],
d[0],
key_dict,
display_name=d[3],
deleted=d[4],
)
)
return store
def save_device_keys(self, device_keys):
"""Save the provided device keys to the database."""
account = self.account_id
assert account
rows = []
for user_id, devices_dict in device_keys.items():
for device_id, device in devices_dict.items():
rows.append(
{
"account_id": account,
"user_id": user_id,
"device_id": device_id,
"display_name": device.display_name,
"deleted": device.deleted,
}
)
if not rows:
return
with self.conn.cursor() as cur:
for idx in range(0, len(rows), 100):
data = rows[idx: idx + 100]
cur.executemany(
"INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)",
[(r["account_id"], r["user_id"], r["device_id"],
r["display_name"], r["deleted"]) for r in data]
)
for user_id, devices_dict in device_keys.items():
for device_id, device in devices_dict.items():
cur.execute(
"UPDATE device_keys SET deleted = ? WHERE device_id = ?",
(device.deleted, device_id)
)
for key_type, key in device.keys.items():
cur.execute("""
INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?)
ON CONFLICT (key_type, device_id) DO UPDATE SET key = ?
""",
(key_type, key, device_id, key)
)
self.conn.commit()
def save_group_sessions(self, sessions):
with self.conn.cursor() as cur:
for session in sessions:
cur.execute("""
INSERT OR REPLACE INTO inbound_group_sessions (
session_id, sender_key, signing_key, room_id, pickle, account_id
) VALUES (?, ?, ?, ?, ?, ?)
""", (
session.id,
session.sender_key,
session.signing_key,
session.room_id,
session.pickle,
self.account_id
))
self.conn.commit()
def save_olm_sessions(self, sessions):
with self.conn.cursor() as cur:
for session in sessions:
cur.execute("""
INSERT OR REPLACE INTO olm_sessions (
session_id, sender_key, pickle, account_id
) VALUES (?, ?, ?, ?)
""", (
session.id,
session.sender_key,
session.pickle,
self.account_id
))
self.conn.commit()
def save_outbound_group_sessions(self, sessions):
with self.conn.cursor() as cur:
for session in sessions:
cur.execute("""
INSERT OR REPLACE INTO outbound_group_sessions (
room_id, session_id, pickle, account_id
) VALUES (?, ?, ?, ?)
""", (
session.room_id,
session.id,
session.pickle,
self.account_id
))
self.conn.commit()
def save_account(self, account: OlmAccount):
with self.conn.cursor() as cur:
cur.execute("""
INSERT OR REPLACE INTO accounts (
id, user_id, device_id, shared_account, pickle
) VALUES (?, ?, ?, ?, ?)
""", (
self.account_id,
self.user_id,
self.device_id,
account.shared,
account.pickle(self.pickle_key),
))
self.conn.commit()
def load_sessions(self):
session_store = SessionStore()
with self.conn.cursor() as cur:
cur.execute("""
SELECT
os.sender_key, os.session, os.creation_time
FROM
olm_sessions os
INNER JOIN
accounts a ON os.account_id = a.id
WHERE
a.id = ?
""", (self.account_id,))
for row in cur.fetchall():
sender_key, session_pickle, creation_time = row
session = Session.from_pickle(
session_pickle, creation_time, self.pickle_key)
session_store.add(sender_key, session)
return session_store
def load_inbound_group_sessions(self):
# type: () -> GroupSessionStore
"""Load all Olm sessions from the database.
Returns:
``GroupSessionStore`` object, containing all the loaded sessions.
"""
store = GroupSessionStore()
account = self.account_id
if not account:
return store
with self.conn.cursor() as cursor:
cursor.execute(
"SELECT * FROM inbound_group_sessions WHERE account_id = ?", (
account,)
)
for row in cursor.fetchall():
cursor.execute(
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
(row[1],),
)
chains = cursor.fetchall()
session = InboundGroupSession.from_pickle(
row[2].encode(),
row[3],
row[4],
row[5],
self.pickle_key,
[
chain[0]
for chain in chains
],
)
store.add(session)
return store
def load_outgoing_key_requests(self):
# type: () -> dict
"""Load all outgoing key requests from the database.
Returns:
``OutgoingKeyRequestStore`` object, containing all the loaded key requests.
"""
account = self.account_id
if not account:
return store
with self.conn.cursor() as cur:
cur.execute(
"SELECT * FROM outgoing_key_requests WHERE account_id = ?",
(account,)
)
rows = cur.fetchall()
return {
row[1]: OutgoingKeyRequest.from_response(AttrDict({
"id": row[0],
"account_id": row[1],
"request_id": row[2],
"session_id": row[3],
"room_id": row[4],
"algorithm": row[5],
})) for row in rows
}
def load_encrypted_rooms(self):
"""Load the set of encrypted rooms for this account.
Returns:
``Set`` containing room ids of encrypted rooms.
"""
account = self.account_id
if not account:
return set()
with self.conn.cursor() as cur:
cur.execute(
"SELECT room_id FROM encrypted_rooms WHERE account_id = ?",
(account,)
)
rows = cur.fetchall()
return {row[0] for row in rows}
def save_sync_token(self, token):
"""Save the given token"""
account = self.account_id
assert account
with self.conn.cursor() as cur:
cur.execute(
"INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)",
(account, token)
)
self.conn.commit()
def save_encrypted_rooms(self, rooms):
"""Save the set of room ids for this account."""
account = self.account_id
assert account
data = [(room_id, account) for room_id in rooms]
with self.conn.cursor() as cur:
for idx in range(0, len(data), 400):
rows = data[idx: idx + 400]
cur.executemany(
"INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)",
rows
)
self.conn.commit()
def save_session(self, sender_key, session):
"""Save the provided Olm session to the database.
Args:
sender_key (str): The curve key that owns the Olm session.
session (Session): The Olm session that will be pickled and
saved in the database.
"""
account = self.account_id
assert account
pickled_session = session.pickle(self.pickle_key)
with self.conn.cursor() as cur:
cur.execute(
"INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)",
(account, sender_key, pickled_session, session.id,
session.creation_time, session.use_time)
)
self.conn.commit()
def save_inbound_group_session(self, session):
"""Save the provided Megolm inbound group session to the database.
Args:
session (InboundGroupSession): The session to save.
"""
account = self.account_id
assert account
with self.conn.cursor() as cur:
# Insert a new session or update the existing one
query = """
INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT (account_id, sender_key, fp_key, room_id)
DO UPDATE SET session = excluded.session
"""
cur.execute(query, (account, session.sender_key,
session.ed25519, session.room_id, session.pickle(self.pickle_key)))
# Delete existing forwarded chains for the session
delete_query = """
DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?)
"""
cur.execute(
delete_query, (account, session.sender_key, session.ed25519, session.room_id))
# Insert new forwarded chains for the session
insert_query = """
INSERT INTO forwarded_chains (session_id, sender_key)
VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?)
"""
for chain in session.forwarding_chain:
cur.execute(
insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain))
def add_outgoing_key_request(self, key_request):
"""Add a new outgoing key request to the database.
Args:
key_request (OutgoingKeyRequest): The key request to add.
"""
account_id = self.account_id
with self.conn.cursor() as cursor:
cursor.execute(
"""
SELECT MAX(id) FROM outgoing_key_requests
"""
)
row = cursor.fetchone()
request_id = row[0] + 1 if row[0] else 1
cursor.execute(
"""
INSERT INTO outgoing_key_requests (id, account_id, request_id, session_id, room_id, algorithm)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (account_id, request_id) DO NOTHING
""",
(
request_id,
account_id,
key_request.request_id,
key_request.session_id,
key_request.room_id,
key_request.algorithm,
)
)
def load_account(self):
# type: () -> Optional[OlmAccount]
"""Load the Olm account from the database.
Returns:
``OlmAccount`` object, or ``None`` if it wasn't found for the
current device_id.
"""
cursor = self.conn.cursor()
query = """
SELECT pickle, shared_account
FROM accounts
WHERE device_id = ?;
"""
cursor.execute(query, (self.device_id,))
result = cursor.fetchone()
if not result:
return None
account_pickle, shared = result
return OlmAccount.from_pickle(account_pickle.encode(), self.pickle_key, shared)