730 lines
24 KiB
Python
730 lines
24 KiB
Python
|
import duckdb
|
||
|
|
||
|
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore
|
||
|
from nio.crypto import OlmAccount, OlmDevice
|
||
|
|
||
|
from random import SystemRandom
|
||
|
from collections import defaultdict
|
||
|
from typing import Dict, List, Optional, Tuple
|
||
|
|
||
|
import json
|
||
|
|
||
|
|
||
|
class DuckDBStore(MatrixStore):
|
||
|
@property
|
||
|
def account_id(self):
|
||
|
id = self._get_account()[0] if self._get_account() else None
|
||
|
|
||
|
if id is None:
|
||
|
id = SystemRandom().randint(0, 2**16)
|
||
|
|
||
|
return id
|
||
|
|
||
|
def __init__(self, user_id, device_id, duckdb_conn):
|
||
|
self.conn = duckdb_conn
|
||
|
self.user_id = user_id
|
||
|
self.device_id = device_id
|
||
|
self._create_tables()
|
||
|
|
||
|
def _create_tables(self):
|
||
|
with self.conn.cursor() as cursor:
|
||
|
cursor.execute("""
|
||
|
DROP TABLE IF EXISTS sync_tokens CASCADE;
|
||
|
DROP TABLE IF EXISTS encrypted_rooms CASCADE;
|
||
|
DROP TABLE IF EXISTS outgoing_key_requests CASCADE;
|
||
|
DROP TABLE IF EXISTS forwarded_chains CASCADE;
|
||
|
DROP TABLE IF EXISTS outbound_group_sessions CASCADE;
|
||
|
DROP TABLE IF EXISTS inbound_group_sessions CASCADE;
|
||
|
DROP TABLE IF EXISTS olm_sessions CASCADE;
|
||
|
DROP TABLE IF EXISTS device_trust_state CASCADE;
|
||
|
DROP TABLE IF EXISTS keys CASCADE;
|
||
|
DROP TABLE IF EXISTS device_keys_key CASCADE;
|
||
|
DROP TABLE IF EXISTS device_keys CASCADE;
|
||
|
DROP TABLE IF EXISTS accounts CASCADE;
|
||
|
""")
|
||
|
|
||
|
# Create accounts table
|
||
|
cursor.execute("""
|
||
|
CREATE TABLE IF NOT EXISTS accounts (
|
||
|
id INTEGER PRIMARY KEY,
|
||
|
user_id VARCHAR NOT NULL,
|
||
|
device_id VARCHAR NOT NULL,
|
||
|
shared_account INTEGER NOT NULL,
|
||
|
pickle VARCHAR NOT NULL
|
||
|
);
|
||
|
""")
|
||
|
|
||
|
# Create device_keys table
|
||
|
cursor.execute("""
|
||
|
CREATE TABLE IF NOT EXISTS device_keys (
|
||
|
device_id TEXT PRIMARY KEY,
|
||
|
account_id INTEGER NOT NULL,
|
||
|
user_id TEXT NOT NULL,
|
||
|
display_name TEXT,
|
||
|
deleted BOOLEAN NOT NULL DEFAULT 0,
|
||
|
UNIQUE (account_id, user_id, device_id),
|
||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||
|
);
|
||
|
|
||
|
CREATE TABLE IF NOT EXISTS keys (
|
||
|
key_type TEXT NOT NULL,
|
||
|
key TEXT NOT NULL,
|
||
|
device_id VARCHAR NOT NULL,
|
||
|
UNIQUE (key_type, device_id),
|
||
|
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||
|
);
|
||
|
""")
|
||
|
|
||
|
# Create device_trust_state table
|
||
|
cursor.execute("""
|
||
|
CREATE TABLE IF NOT EXISTS device_trust_state (
|
||
|
device_id VARCHAR PRIMARY KEY,
|
||
|
state INTEGER NOT NULL,
|
||
|
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||
|
);
|
||
|
""")
|
||
|
|
||
|
# Create olm_sessions table
|
||
|
cursor.execute("""
|
||
|
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
|
||
|
|
||
|
CREATE TABLE IF NOT EXISTS olm_sessions (
|
||
|
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
|
||
|
account_id INTEGER NOT NULL,
|
||
|
sender_key TEXT NOT NULL,
|
||
|
session BLOB NOT NULL,
|
||
|
session_id VARCHAR NOT NULL,
|
||
|
creation_time TIMESTAMP NOT NULL,
|
||
|
last_usage_date TIMESTAMP NOT NULL,
|
||
|
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
|
||
|
);
|
||
|
""")
|
||
|
|
||
|
# Create inbound_group_sessions table
|
||
|
cursor.execute("""
|
||
|
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
|
||
|
|
||
|
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
|
||
|
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
|
||
|
account_id INTEGER NOT NULL,
|
||
|
session TEXT NOT NULL,
|
||
|
fp_key TEXT NOT NULL,
|
||
|
sender_key TEXT NOT NULL,
|
||
|
room_id TEXT NOT NULL,
|
||
|
UNIQUE (account_id, sender_key, fp_key, room_id),
|
||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||
|
);
|
||
|
|
||
|
CREATE TABLE IF NOT EXISTS forwarded_chains (
|
||
|
id INTEGER PRIMARY KEY,
|
||
|
session_id INTEGER NOT NULL,
|
||
|
sender_key TEXT NOT NULL,
|
||
|
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
|
||
|
);
|
||
|
""")
|
||
|
|
||
|
# Create outbound_group_sessions table
|
||
|
cursor.execute("""
|
||
|
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
|
||
|
id INTEGER PRIMARY KEY,
|
||
|
account_id INTEGER NOT NULL,
|
||
|
room_id VARCHAR NOT NULL,
|
||
|
session_id VARCHAR NOT NULL UNIQUE,
|
||
|
session BLOB NOT NULL,
|
||
|
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||
|
);
|
||
|
""")
|
||
|
|
||
|
# Create outgoing_key_requests table
|
||
|
cursor.execute("""
|
||
|
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
|
||
|
id INTEGER PRIMARY KEY,
|
||
|
account_id INTEGER NOT NULL,
|
||
|
request_id TEXT NOT NULL,
|
||
|
session_id TEXT NOT NULL,
|
||
|
room_id TEXT NOT NULL,
|
||
|
algorithm TEXT NOT NULL,
|
||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
||
|
UNIQUE (account_id, request_id)
|
||
|
);
|
||
|
|
||
|
""")
|
||
|
|
||
|
# Create encrypted_rooms table
|
||
|
cursor.execute("""
|
||
|
CREATE TABLE IF NOT EXISTS encrypted_rooms (
|
||
|
room_id TEXT NOT NULL,
|
||
|
account_id INTEGER NOT NULL,
|
||
|
PRIMARY KEY (room_id, account_id),
|
||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||
|
);
|
||
|
""")
|
||
|
|
||
|
# Create sync_tokens table
|
||
|
cursor.execute("""
|
||
|
CREATE TABLE IF NOT EXISTS sync_tokens (
|
||
|
account_id INTEGER PRIMARY KEY,
|
||
|
token TEXT NOT NULL,
|
||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||
|
);
|
||
|
""")
|
||
|
|
||
|
def _get_account(self):
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"SELECT * FROM accounts WHERE user_id = ? AND device_id = ?",
|
||
|
(self.user_id, self.device_id),
|
||
|
)
|
||
|
account = cursor.fetchone()
|
||
|
cursor.close()
|
||
|
return account
|
||
|
|
||
|
def _get_device(self, device):
|
||
|
acc = self._get_account()
|
||
|
|
||
|
if not acc:
|
||
|
return None
|
||
|
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?",
|
||
|
(device.user_id, device.id, acc[0]),
|
||
|
)
|
||
|
device_entry = cursor.fetchone()
|
||
|
cursor.close()
|
||
|
|
||
|
return device_entry
|
||
|
|
||
|
# Implementing methods with DuckDB equivalents
|
||
|
def verify_device(self, device):
|
||
|
if self.is_device_verified(device):
|
||
|
return False
|
||
|
|
||
|
d = self._get_device(device)
|
||
|
assert d
|
||
|
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||
|
(d[0], TrustState.verified),
|
||
|
)
|
||
|
self.conn.commit()
|
||
|
cursor.close()
|
||
|
|
||
|
device.trust_state = TrustState.verified
|
||
|
|
||
|
return True
|
||
|
|
||
|
def unverify_device(self, device):
|
||
|
if not self.is_device_verified(device):
|
||
|
return False
|
||
|
|
||
|
d = self._get_device(device)
|
||
|
assert d
|
||
|
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||
|
(d[0], TrustState.unset),
|
||
|
)
|
||
|
self.conn.commit()
|
||
|
cursor.close()
|
||
|
|
||
|
device.trust_state = TrustState.unset
|
||
|
|
||
|
return True
|
||
|
|
||
|
def is_device_verified(self, device):
|
||
|
d = self._get_device(device)
|
||
|
|
||
|
if not d:
|
||
|
return False
|
||
|
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
||
|
)
|
||
|
trust_state = cursor.fetchone()
|
||
|
cursor.close()
|
||
|
|
||
|
if not trust_state:
|
||
|
return False
|
||
|
|
||
|
return trust_state[0] == TrustState.verified
|
||
|
|
||
|
def blacklist_device(self, device):
|
||
|
if self.is_device_blacklisted(device):
|
||
|
return False
|
||
|
|
||
|
d = self._get_device(device)
|
||
|
assert d
|
||
|
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||
|
(d[0], TrustState.blacklisted),
|
||
|
)
|
||
|
self.conn.commit()
|
||
|
cursor.close()
|
||
|
|
||
|
device.trust_state = TrustState.blacklisted
|
||
|
|
||
|
return True
|
||
|
|
||
|
def unblacklist_device(self, device):
|
||
|
if not self.is_device_blacklisted(device):
|
||
|
return False
|
||
|
|
||
|
d = self._get_device(device)
|
||
|
assert d
|
||
|
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||
|
(d[0], TrustState.unset),
|
||
|
)
|
||
|
self.conn.commit()
|
||
|
cursor.close()
|
||
|
|
||
|
device.trust_state = TrustState.unset
|
||
|
|
||
|
return True
|
||
|
|
||
|
def is_device_blacklisted(self, device):
|
||
|
d = self._get_device(device)
|
||
|
|
||
|
if not d:
|
||
|
return False
|
||
|
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
||
|
)
|
||
|
trust_state = cursor.fetchone()
|
||
|
cursor.close()
|
||
|
|
||
|
if not trust_state:
|
||
|
return False
|
||
|
|
||
|
return trust_state[0] == TrustState.blacklisted
|
||
|
|
||
|
def ignore_device(self, device):
|
||
|
if self.is_device_ignored(device):
|
||
|
return False
|
||
|
|
||
|
d = self._get_device(device)
|
||
|
assert d
|
||
|
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||
|
(d[0], int(TrustState.ignored.value)),
|
||
|
)
|
||
|
self.conn.commit()
|
||
|
cursor.close()
|
||
|
|
||
|
return True
|
||
|
|
||
|
def ignore_devices(self, devices):
|
||
|
for device in devices:
|
||
|
self.ignore_device(device)
|
||
|
|
||
|
def unignore_device(self, device):
|
||
|
if not self.is_device_ignored(device):
|
||
|
return False
|
||
|
|
||
|
d = self._get_device(device)
|
||
|
assert d
|
||
|
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
||
|
(d[0], TrustState.unset),
|
||
|
)
|
||
|
self.conn.commit()
|
||
|
cursor.close()
|
||
|
|
||
|
device.trust_state = TrustState.unset
|
||
|
|
||
|
return True
|
||
|
|
||
|
def is_device_ignored(self, device):
|
||
|
d = self._get_device(device)
|
||
|
|
||
|
if not d:
|
||
|
return False
|
||
|
|
||
|
cursor = self.conn.cursor()
|
||
|
cursor.execute(
|
||
|
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
||
|
)
|
||
|
trust_state = cursor.fetchone()
|
||
|
cursor.close()
|
||
|
|
||
|
if not trust_state:
|
||
|
return False
|
||
|
|
||
|
return trust_state[0] == TrustState.ignored
|
||
|
|
||
|
def load_device_keys(self):
|
||
|
"""Load all the device keys from the database.
|
||
|
|
||
|
Returns DeviceStore containing the OlmDevices with the device keys.
|
||
|
"""
|
||
|
store = DeviceStore()
|
||
|
account = self.account_id
|
||
|
|
||
|
if not account:
|
||
|
return store
|
||
|
|
||
|
with self.conn.cursor() as cur:
|
||
|
cur.execute(
|
||
|
"SELECT * FROM device_keys WHERE account_id = ?",
|
||
|
(account,)
|
||
|
)
|
||
|
device_keys = cur.fetchall()
|
||
|
|
||
|
for d in device_keys:
|
||
|
cur.execute(
|
||
|
"SELECT * FROM keys WHERE device_id = ?",
|
||
|
(d["id"],)
|
||
|
)
|
||
|
keys = cur.fetchall()
|
||
|
key_dict = {k["key_type"]: k["key"] for k in keys}
|
||
|
|
||
|
store.add(
|
||
|
OlmDevice(
|
||
|
d["user_id"],
|
||
|
d["device_id"],
|
||
|
key_dict,
|
||
|
display_name=d["display_name"],
|
||
|
deleted=d["deleted"],
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return store
|
||
|
|
||
|
def save_device_keys(self, device_keys):
|
||
|
"""Save the provided device keys to the database."""
|
||
|
account = self.account_id
|
||
|
assert account
|
||
|
rows = []
|
||
|
|
||
|
for user_id, devices_dict in device_keys.items():
|
||
|
for device_id, device in devices_dict.items():
|
||
|
rows.append(
|
||
|
{
|
||
|
"account_id": account,
|
||
|
"user_id": user_id,
|
||
|
"device_id": device_id,
|
||
|
"display_name": device.display_name,
|
||
|
"deleted": device.deleted,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
if not rows:
|
||
|
return
|
||
|
|
||
|
with self.conn.cursor() as cur:
|
||
|
for idx in range(0, len(rows), 100):
|
||
|
data = rows[idx: idx + 100]
|
||
|
cur.executemany(
|
||
|
"INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)",
|
||
|
[(r["account_id"], r["user_id"], r["device_id"],
|
||
|
r["display_name"], r["deleted"]) for r in data]
|
||
|
)
|
||
|
|
||
|
for user_id, devices_dict in device_keys.items():
|
||
|
for device_id, device in devices_dict.items():
|
||
|
cur.execute(
|
||
|
"UPDATE device_keys SET deleted = ? WHERE device_id = ?",
|
||
|
(device.deleted, device_id)
|
||
|
)
|
||
|
|
||
|
for key_type, key in device.keys.items():
|
||
|
cur.execute("""
|
||
|
INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?)
|
||
|
ON CONFLICT (key_type, device_id) DO UPDATE SET key = ?
|
||
|
""",
|
||
|
(key_type, key, device_id, key)
|
||
|
)
|
||
|
self.conn.commit()
|
||
|
|
||
|
def save_group_sessions(self, sessions):
|
||
|
with self.conn.cursor() as cur:
|
||
|
for session in sessions:
|
||
|
cur.execute("""
|
||
|
INSERT OR REPLACE INTO inbound_group_sessions (
|
||
|
session_id, sender_key, signing_key, room_id, pickle, account_id
|
||
|
) VALUES (?, ?, ?, ?, ?, ?)
|
||
|
""", (
|
||
|
session.id,
|
||
|
session.sender_key,
|
||
|
session.signing_key,
|
||
|
session.room_id,
|
||
|
session.pickle,
|
||
|
self.account_id
|
||
|
))
|
||
|
|
||
|
self.conn.commit()
|
||
|
|
||
|
def save_olm_sessions(self, sessions):
|
||
|
with self.conn.cursor() as cur:
|
||
|
for session in sessions:
|
||
|
cur.execute("""
|
||
|
INSERT OR REPLACE INTO olm_sessions (
|
||
|
session_id, sender_key, pickle, account_id
|
||
|
) VALUES (?, ?, ?, ?)
|
||
|
""", (
|
||
|
session.id,
|
||
|
session.sender_key,
|
||
|
session.pickle,
|
||
|
self.account_id
|
||
|
))
|
||
|
|
||
|
self.conn.commit()
|
||
|
|
||
|
def save_outbound_group_sessions(self, sessions):
|
||
|
with self.conn.cursor() as cur:
|
||
|
for session in sessions:
|
||
|
cur.execute("""
|
||
|
INSERT OR REPLACE INTO outbound_group_sessions (
|
||
|
room_id, session_id, pickle, account_id
|
||
|
) VALUES (?, ?, ?, ?)
|
||
|
""", (
|
||
|
session.room_id,
|
||
|
session.id,
|
||
|
session.pickle,
|
||
|
self.account_id
|
||
|
))
|
||
|
|
||
|
self.conn.commit()
|
||
|
|
||
|
def save_account(self, account: OlmAccount):
|
||
|
with self.conn.cursor() as cur:
|
||
|
cur.execute("""
|
||
|
INSERT OR REPLACE INTO accounts (
|
||
|
id, user_id, device_id, shared_account, pickle
|
||
|
) VALUES (?, ?, ?, ?, ?)
|
||
|
""", (
|
||
|
self.account_id,
|
||
|
self.user_id,
|
||
|
self.device_id,
|
||
|
account.shared,
|
||
|
account.pickle(self.pickle_key),
|
||
|
))
|
||
|
|
||
|
self.conn.commit()
|
||
|
|
||
|
def load_sessions(self):
|
||
|
session_store = SessionStore()
|
||
|
|
||
|
with self.conn.cursor() as cur:
|
||
|
cur.execute("""
|
||
|
SELECT
|
||
|
os.sender_key, os.session, os.creation_time
|
||
|
FROM
|
||
|
olm_sessions os
|
||
|
INNER JOIN
|
||
|
accounts a ON os.account_id = a.id
|
||
|
WHERE
|
||
|
a.id = ?
|
||
|
""", (self.account_id,))
|
||
|
|
||
|
for row in cur.fetchall():
|
||
|
sender_key, session_pickle, creation_time = row
|
||
|
session = Session.from_pickle(
|
||
|
session_pickle, creation_time, self.pickle_key)
|
||
|
session_store.add(sender_key, session)
|
||
|
|
||
|
return session_store
|
||
|
|
||
|
def load_inbound_group_sessions(self):
|
||
|
# type: () -> GroupSessionStore
|
||
|
"""Load all Olm sessions from the database.
|
||
|
|
||
|
Returns:
|
||
|
``GroupSessionStore`` object, containing all the loaded sessions.
|
||
|
|
||
|
"""
|
||
|
store = GroupSessionStore()
|
||
|
|
||
|
account = self.account_id
|
||
|
|
||
|
if not account:
|
||
|
return store
|
||
|
|
||
|
with self.conn.cursor() as cursor:
|
||
|
cursor.execute(
|
||
|
"SELECT * FROM inbound_group_sessions WHERE account_id = ?", (
|
||
|
account,)
|
||
|
)
|
||
|
|
||
|
for row in cursor.fetchall():
|
||
|
session = InboundGroupSession.from_pickle(
|
||
|
row["session"],
|
||
|
row["fp_key"],
|
||
|
row["sender_key"],
|
||
|
row["room_id"],
|
||
|
self.pickle_key,
|
||
|
[
|
||
|
chain["sender_key"]
|
||
|
for chain in cursor.execute(
|
||
|
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
|
||
|
(row["id"],),
|
||
|
)
|
||
|
],
|
||
|
)
|
||
|
store.add(session)
|
||
|
|
||
|
return store
|
||
|
|
||
|
def load_outgoing_key_requests(self):
|
||
|
# type: () -> dict
|
||
|
"""Load all outgoing key requests from the database.
|
||
|
|
||
|
Returns:
|
||
|
``OutgoingKeyRequestStore`` object, containing all the loaded key requests.
|
||
|
"""
|
||
|
account = self.account_id
|
||
|
|
||
|
if not account:
|
||
|
return store
|
||
|
|
||
|
with self.conn.cursor() as cur:
|
||
|
cur.execute(
|
||
|
"SELECT * FROM outgoing_key_requests WHERE account_id = ?",
|
||
|
(account,)
|
||
|
)
|
||
|
rows = cur.fetchall()
|
||
|
|
||
|
return {
|
||
|
request.request_id: OutgoingKeyRequest.from_database(request)
|
||
|
for request in rows
|
||
|
}
|
||
|
|
||
|
def load_encrypted_rooms(self):
|
||
|
"""Load the set of encrypted rooms for this account.
|
||
|
|
||
|
Returns:
|
||
|
``Set`` containing room ids of encrypted rooms.
|
||
|
"""
|
||
|
account = self.account_id
|
||
|
|
||
|
if not account:
|
||
|
return set()
|
||
|
|
||
|
with self.conn.cursor() as cur:
|
||
|
cur.execute(
|
||
|
"SELECT room_id FROM encrypted_rooms WHERE account_id = ?",
|
||
|
(account,)
|
||
|
)
|
||
|
rows = cur.fetchall()
|
||
|
|
||
|
return {row["room_id"] for row in rows}
|
||
|
|
||
|
def save_sync_token(self, token):
|
||
|
"""Save the given token"""
|
||
|
account = self.account_id
|
||
|
assert account
|
||
|
|
||
|
with self.conn.cursor() as cur:
|
||
|
cur.execute(
|
||
|
"INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)",
|
||
|
(account, token)
|
||
|
)
|
||
|
self.conn.commit()
|
||
|
|
||
|
def save_encrypted_rooms(self, rooms):
|
||
|
"""Save the set of room ids for this account."""
|
||
|
account = self.account_id
|
||
|
assert account
|
||
|
|
||
|
data = [(room_id, account) for room_id in rooms]
|
||
|
|
||
|
with self.conn.cursor() as cur:
|
||
|
for idx in range(0, len(data), 400):
|
||
|
rows = data[idx: idx + 400]
|
||
|
cur.executemany(
|
||
|
"INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)",
|
||
|
rows
|
||
|
)
|
||
|
self.conn.commit()
|
||
|
|
||
|
def save_session(self, sender_key, session):
|
||
|
"""Save the provided Olm session to the database.
|
||
|
|
||
|
Args:
|
||
|
sender_key (str): The curve key that owns the Olm session.
|
||
|
session (Session): The Olm session that will be pickled and
|
||
|
saved in the database.
|
||
|
"""
|
||
|
account = self.account_id
|
||
|
assert account
|
||
|
|
||
|
pickled_session = session.pickle(self.pickle_key)
|
||
|
|
||
|
with self.conn.cursor() as cur:
|
||
|
|
||
|
cur.execute(
|
||
|
"INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)",
|
||
|
(account, sender_key, pickled_session, session.id,
|
||
|
session.creation_time, session.use_time)
|
||
|
)
|
||
|
self.conn.commit()
|
||
|
|
||
|
def save_inbound_group_session(self, session):
|
||
|
"""Save the provided Megolm inbound group session to the database.
|
||
|
|
||
|
Args:
|
||
|
session (InboundGroupSession): The session to save.
|
||
|
"""
|
||
|
account = self.account_id
|
||
|
assert account
|
||
|
|
||
|
with self.conn.cursor() as cur:
|
||
|
|
||
|
# Insert a new session or update the existing one
|
||
|
query = """
|
||
|
INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session)
|
||
|
VALUES (?, ?, ?, ?, ?)
|
||
|
ON CONFLICT (account_id, sender_key, fp_key, room_id)
|
||
|
DO UPDATE SET session = excluded.session
|
||
|
"""
|
||
|
cur.execute(query, (account, session.sender_key,
|
||
|
session.ed25519, session.room_id, session.pickle(self.pickle_key)))
|
||
|
|
||
|
# Delete existing forwarded chains for the session
|
||
|
delete_query = """
|
||
|
DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?)
|
||
|
"""
|
||
|
cur.execute(
|
||
|
delete_query, (account, session.sender_key, session.ed25519, session.room_id))
|
||
|
|
||
|
# Insert new forwarded chains for the session
|
||
|
insert_query = """
|
||
|
INSERT INTO forwarded_chains (session_id, sender_key)
|
||
|
VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?)
|
||
|
"""
|
||
|
|
||
|
for chain in session.forwarding_chain:
|
||
|
cur.execute(
|
||
|
insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain))
|
||
|
|
||
|
def add_outgoing_key_request(self, key_request):
|
||
|
account_id = self.account_id
|
||
|
with self.conn.cursor() as cursor:
|
||
|
cursor.execute(
|
||
|
"""
|
||
|
INSERT INTO outgoing_key_requests (account_id, request_id, session_id, room_id, algorithm)
|
||
|
VALUES (?, ?, ?, ?, ?)
|
||
|
ON CONFLICT (account_id, request_id) DO NOTHING
|
||
|
""",
|
||
|
(
|
||
|
account_id,
|
||
|
key_request.request_id,
|
||
|
key_request.session_id,
|
||
|
key_request.room_id,
|
||
|
key_request.algorithm,
|
||
|
)
|
||
|
)
|