Moving migrations to subdirectory
Add option for custom system messages per room Fixing some methods in store
This commit is contained in:
parent
f20b762558
commit
2bbc6a33ca
10 changed files with 358 additions and 207 deletions
|
@ -6,6 +6,9 @@ to generate responses to messages in a Matrix room.
|
|||
It will also save a log of the spent tokens to a DuckDB database
|
||||
(database.db in the working directory, by default).
|
||||
|
||||
Note that this bot does not yet support encryption - this is still work in
|
||||
progress.
|
||||
|
||||
## Installation
|
||||
|
||||
Simply clone this repository and install the requirements.
|
||||
|
|
208
classes/store.py
208
classes/store.py
|
@ -1,6 +1,6 @@
|
|||
import duckdb
|
||||
|
||||
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore
|
||||
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore, Session
|
||||
from nio.crypto import OlmAccount, OlmDevice
|
||||
|
||||
from random import SystemRandom
|
||||
|
@ -24,150 +24,6 @@ class DuckDBStore(MatrixStore):
|
|||
self.conn = duckdb_conn
|
||||
self.user_id = user_id
|
||||
self.device_id = device_id
|
||||
self._create_tables()
|
||||
|
||||
def _create_tables(self):
|
||||
with self.conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
DROP TABLE IF EXISTS sync_tokens CASCADE;
|
||||
DROP TABLE IF EXISTS encrypted_rooms CASCADE;
|
||||
DROP TABLE IF EXISTS outgoing_key_requests CASCADE;
|
||||
DROP TABLE IF EXISTS forwarded_chains CASCADE;
|
||||
DROP TABLE IF EXISTS outbound_group_sessions CASCADE;
|
||||
DROP TABLE IF EXISTS inbound_group_sessions CASCADE;
|
||||
DROP TABLE IF EXISTS olm_sessions CASCADE;
|
||||
DROP TABLE IF EXISTS device_trust_state CASCADE;
|
||||
DROP TABLE IF EXISTS keys CASCADE;
|
||||
DROP TABLE IF EXISTS device_keys_key CASCADE;
|
||||
DROP TABLE IF EXISTS device_keys CASCADE;
|
||||
DROP TABLE IF EXISTS accounts CASCADE;
|
||||
""")
|
||||
|
||||
# Create accounts table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id VARCHAR NOT NULL,
|
||||
device_id VARCHAR NOT NULL,
|
||||
shared_account INTEGER NOT NULL,
|
||||
pickle VARCHAR NOT NULL
|
||||
);
|
||||
""")
|
||||
|
||||
# Create device_keys table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS device_keys (
|
||||
device_id TEXT PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
deleted BOOLEAN NOT NULL DEFAULT 0,
|
||||
UNIQUE (account_id, user_id, device_id),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS keys (
|
||||
key_type TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
device_id VARCHAR NOT NULL,
|
||||
UNIQUE (key_type, device_id),
|
||||
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create device_trust_state table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS device_trust_state (
|
||||
device_id VARCHAR PRIMARY KEY,
|
||||
state INTEGER NOT NULL,
|
||||
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create olm_sessions table
|
||||
cursor.execute("""
|
||||
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS olm_sessions (
|
||||
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
|
||||
account_id INTEGER NOT NULL,
|
||||
sender_key TEXT NOT NULL,
|
||||
session BLOB NOT NULL,
|
||||
session_id VARCHAR NOT NULL,
|
||||
creation_time TIMESTAMP NOT NULL,
|
||||
last_usage_date TIMESTAMP NOT NULL,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create inbound_group_sessions table
|
||||
cursor.execute("""
|
||||
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
|
||||
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
|
||||
account_id INTEGER NOT NULL,
|
||||
session TEXT NOT NULL,
|
||||
fp_key TEXT NOT NULL,
|
||||
sender_key TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
UNIQUE (account_id, sender_key, fp_key, room_id),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS forwarded_chains (
|
||||
id INTEGER PRIMARY KEY,
|
||||
session_id INTEGER NOT NULL,
|
||||
sender_key TEXT NOT NULL,
|
||||
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create outbound_group_sessions table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
|
||||
id INTEGER PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL,
|
||||
room_id VARCHAR NOT NULL,
|
||||
session_id VARCHAR NOT NULL UNIQUE,
|
||||
session BLOB NOT NULL,
|
||||
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create outgoing_key_requests table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
|
||||
id INTEGER PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL,
|
||||
request_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
UNIQUE (account_id, request_id)
|
||||
);
|
||||
|
||||
""")
|
||||
|
||||
# Create encrypted_rooms table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS encrypted_rooms (
|
||||
room_id TEXT NOT NULL,
|
||||
account_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (room_id, account_id),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create sync_tokens table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sync_tokens (
|
||||
account_id INTEGER PRIMARY KEY,
|
||||
token TEXT NOT NULL,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
def _get_account(self):
|
||||
cursor = self.conn.cursor()
|
||||
|
@ -387,18 +243,18 @@ class DuckDBStore(MatrixStore):
|
|||
for d in device_keys:
|
||||
cur.execute(
|
||||
"SELECT * FROM keys WHERE device_id = ?",
|
||||
(d["id"],)
|
||||
(d[0],)
|
||||
)
|
||||
keys = cur.fetchall()
|
||||
key_dict = {k["key_type"]: k["key"] for k in keys}
|
||||
key_dict = {k[0]: k[1] for k in keys}
|
||||
|
||||
store.add(
|
||||
OlmDevice(
|
||||
d["user_id"],
|
||||
d["device_id"],
|
||||
d[2],
|
||||
d[0],
|
||||
key_dict,
|
||||
display_name=d["display_name"],
|
||||
deleted=d["deleted"],
|
||||
display_name=d[3],
|
||||
deleted=d[4],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -561,18 +417,21 @@ class DuckDBStore(MatrixStore):
|
|||
)
|
||||
|
||||
for row in cursor.fetchall():
|
||||
cursor.execute(
|
||||
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
|
||||
(row[1],),
|
||||
)
|
||||
chains = cursor.fetchall()
|
||||
|
||||
session = InboundGroupSession.from_pickle(
|
||||
row["session"],
|
||||
row["fp_key"],
|
||||
row["sender_key"],
|
||||
row["room_id"],
|
||||
row[2].encode(),
|
||||
row[3],
|
||||
row[4],
|
||||
row[5],
|
||||
self.pickle_key,
|
||||
[
|
||||
chain["sender_key"]
|
||||
for chain in cursor.execute(
|
||||
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
|
||||
(row["id"],),
|
||||
)
|
||||
chain[0]
|
||||
for chain in chains
|
||||
],
|
||||
)
|
||||
store.add(session)
|
||||
|
@ -621,7 +480,7 @@ class DuckDBStore(MatrixStore):
|
|||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
return {row["room_id"] for row in rows}
|
||||
return {row[0] for row in rows}
|
||||
|
||||
def save_sync_token(self, token):
|
||||
"""Save the given token"""
|
||||
|
@ -727,4 +586,29 @@ class DuckDBStore(MatrixStore):
|
|||
key_request.room_id,
|
||||
key_request.algorithm,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def load_account(self):
|
||||
# type: () -> Optional[OlmAccount]
|
||||
"""Load the Olm account from the database.
|
||||
|
||||
Returns:
|
||||
``OlmAccount`` object, or ``None`` if it wasn't found for the
|
||||
current device_id.
|
||||
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
query = """
|
||||
SELECT pickle, shared_account
|
||||
FROM accounts
|
||||
WHERE device_id = ?;
|
||||
"""
|
||||
cursor.execute(query, (self.device_id,))
|
||||
|
||||
result = cursor.fetchone()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
account_pickle, shared = result
|
||||
return OlmAccount.from_pickle(account_pickle.encode(), self.pickle_key, shared)
|
||||
|
|
|
@ -5,6 +5,7 @@ from .botinfo import command_botinfo
|
|||
from .unknown import command_unknown
|
||||
from .coin import command_coin
|
||||
from .ignoreolder import command_ignoreolder
|
||||
from .systemmessage import command_systemmessage
|
||||
|
||||
COMMANDS = {
|
||||
"help": command_help,
|
||||
|
@ -13,5 +14,6 @@ COMMANDS = {
|
|||
"botinfo": command_botinfo,
|
||||
"coin": command_coin,
|
||||
"ignoreolder": command_ignoreolder,
|
||||
"systemmessage": command_systemmessage,
|
||||
None: command_unknown,
|
||||
}
|
||||
|
|
34
commands/systemmessage.py
Normal file
34
commands/systemmessage.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
from nio.events.room_events import RoomMessageText
|
||||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
system_message = " ".join(event.body.split()[2:])
|
||||
|
||||
if system_message:
|
||||
context["logger"]("Adding system message...")
|
||||
|
||||
with context["database"].cursor() as cur:
|
||||
cur.execute(
|
||||
"INSERT INTO system_messages (room_id, message_id, user_id, body, timestamp) VALUES (?, ?, ?, ?, ?)",
|
||||
(room.room_id, event.event_id, event.sender,
|
||||
system_message, event.server_timestamp)
|
||||
)
|
||||
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message stored: {system_message}"}
|
||||
|
||||
context["logger"]("Retrieving system message...")
|
||||
|
||||
with context["database"].cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
|
||||
(room.room_id,)
|
||||
)
|
||||
system_message = cur.fetchone()
|
||||
|
||||
if system_message is None:
|
||||
system_message = context.get("system_message", "No system message set")
|
||||
elif context.get("force_system_message") and context.get("system_message"):
|
||||
system_message = system_message + "\n\n" + context["system_message"]
|
||||
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message: {system_message}"}
|
|
@ -69,10 +69,15 @@ AccessToken = syt_yoursynapsetoken
|
|||
#
|
||||
# SystemMessage = You are a helpful bot.
|
||||
|
||||
# Force inclusion of the SystemMessage defined above if one is defined on per-room level
|
||||
# If no custom message is defined for the room, SystemMessage is always included
|
||||
#
|
||||
# ForceSystemMessage = 0
|
||||
|
||||
[Database]
|
||||
|
||||
# Settings for the DuckDB database.
|
||||
# Currently only used to store details on spent tokens per room.
|
||||
# If not defined, the bot will not store this data.
|
||||
# If not defined, the bot will not be able to remember anything, and will not support encryption
|
||||
# N.B.: Encryption doesn't work as it is supposed to anyway.
|
||||
|
||||
Path = database.db
|
104
gptbot.py
104
gptbot.py
|
@ -23,6 +23,7 @@ from typing import List, Dict, Union, Optional
|
|||
|
||||
from commands import COMMANDS
|
||||
from classes import DuckDBStore
|
||||
from migrations import MIGRATIONS
|
||||
|
||||
|
||||
def logging(message: str, log_level: str = "info"):
|
||||
|
@ -35,6 +36,7 @@ CONTEXT = {
|
|||
"database": False,
|
||||
"default_room_name": "GPTBot",
|
||||
"system_message": "You are a helpful assistant.",
|
||||
"force_system_message": False,
|
||||
"max_tokens": 3000,
|
||||
"max_messages": 20,
|
||||
"model": "gpt-3.5-turbo",
|
||||
|
@ -48,6 +50,8 @@ async def gpt_query(messages: list, model: Optional[str] = None):
|
|||
model = model or CONTEXT["model"]
|
||||
|
||||
logging(f"Querying GPT with {len(messages)} messages")
|
||||
logging(messages, "debug")
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
|
@ -143,8 +147,9 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
|||
|
||||
client = kwargs.get("client") or CONTEXT["client"]
|
||||
database = kwargs.get("database") or CONTEXT["database"]
|
||||
system_message = kwargs.get("system_message") or CONTEXT["system_message"]
|
||||
max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
|
||||
system_message = kwargs.get("system_message") or CONTEXT["system_message"]
|
||||
force_system_message = kwargs.get("force_system_message") or CONTEXT["force_system_message"]
|
||||
|
||||
await client.room_typing(room.room_id, True)
|
||||
|
||||
|
@ -152,6 +157,12 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
|||
|
||||
last_messages = await fetch_last_n_messages(room.room_id, 20)
|
||||
|
||||
system_message = get_system_message(room, {
|
||||
"database": database,
|
||||
"system_message": system_message,
|
||||
"force_system_message": force_system_message,
|
||||
})
|
||||
|
||||
chat_messages = [{"role": "system", "content": system_message}]
|
||||
|
||||
for message in last_messages:
|
||||
|
@ -163,8 +174,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
|||
|
||||
# Truncate messages to fit within the token limit
|
||||
truncated_messages = truncate_messages_to_fit_tokens(
|
||||
chat_messages, max_tokens - 1)
|
||||
|
||||
chat_messages, max_tokens - 1, system_message=system_message)
|
||||
response, tokens_used = await gpt_query(truncated_messages)
|
||||
|
||||
if response:
|
||||
|
@ -204,10 +214,29 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt
|
|||
|
||||
if message:
|
||||
room_id, event, content = message
|
||||
rooms = await context["client"].joined_rooms()
|
||||
await send_message(context["client"].rooms[room_id], content["body"],
|
||||
True if content["msgtype"] == "m.notice" else False, context["client"])
|
||||
|
||||
|
||||
def get_system_message(room: MatrixRoom, context: Optional[dict]) -> str:
|
||||
context = context or CONTEXT
|
||||
|
||||
default = context.get("system_message")
|
||||
|
||||
with context["database"].cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
|
||||
(room.room_id,)
|
||||
)
|
||||
system_message = cur.fetchone()
|
||||
|
||||
complete = ((default if ((not system_message) or context["force_system_message"]) else "") + (
|
||||
"\n\n" + system_message[0] if system_message else "")).strip()
|
||||
|
||||
return complete
|
||||
|
||||
|
||||
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
|
||||
context = kwargs.get("context") or CONTEXT
|
||||
|
||||
|
@ -286,7 +315,8 @@ async def send_message(room: MatrixRoom, message: str, notice: bool = False, cli
|
|||
)
|
||||
|
||||
if msgtype != "m.reaction":
|
||||
response = client.encrypt(room.room_id, "m.room.message", msgcontent)
|
||||
response = client.encrypt(
|
||||
room.room_id, "m.room.message", msgcontent)
|
||||
msgtype, content = response
|
||||
|
||||
except Exception as e:
|
||||
|
@ -318,7 +348,7 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None):
|
|||
if isinstance(response, JoinResponse):
|
||||
logging(response, "debug")
|
||||
rooms = await client.joined_rooms()
|
||||
await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
|
||||
await send_message(client.rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
|
||||
else:
|
||||
logging(f"Error joining room {room_id}: {response}", "error")
|
||||
|
||||
|
@ -408,13 +438,16 @@ async def init(config: ConfigParser):
|
|||
if "MaxMessages" in config["OpenAI"]:
|
||||
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
|
||||
|
||||
# Listen for SIGTERM
|
||||
# Override defaults with config
|
||||
|
||||
def sigterm_handler(_signo, _stack_frame):
|
||||
logging("Received SIGTERM - exiting...")
|
||||
exit()
|
||||
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
if "GPTBot" in config:
|
||||
if "SystemMessage" in config["GPTBot"]:
|
||||
CONTEXT["system_message"] = config["GPTBot"]["SystemMessage"]
|
||||
if "DefaultRoomName" in config["GPTBot"]:
|
||||
CONTEXT["default_room_name"] = config["GPTBot"]["DefaultRoomName"]
|
||||
if "ForceSystemMessage" in config["GPTBot"]:
|
||||
CONTEXT["force_system_message"] = config["GPTBot"].getboolean(
|
||||
"ForceSystemMessage")
|
||||
|
||||
|
||||
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
|
||||
|
@ -431,6 +464,14 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
|
|||
await client.close()
|
||||
return
|
||||
|
||||
# Listen for SIGTERM
|
||||
|
||||
def sigterm_handler(_signo, _stack_frame):
|
||||
logging("Received SIGTERM - exiting...")
|
||||
exit()
|
||||
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
|
||||
logging("Starting bot...")
|
||||
|
||||
client.add_response_callback(sync_cb, SyncResponse)
|
||||
|
@ -460,9 +501,9 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
|
|||
|
||||
def initialize_database(path: os.PathLike):
|
||||
logging("Initializing database...")
|
||||
database = duckdb.connect(path)
|
||||
conn = duckdb.connect(path)
|
||||
|
||||
with database.cursor() as cursor:
|
||||
with conn.cursor() as cursor:
|
||||
# Get the latest migration ID if the migrations table exists
|
||||
try:
|
||||
cursor.execute(
|
||||
|
@ -472,40 +513,17 @@ def initialize_database(path: os.PathLike):
|
|||
)
|
||||
|
||||
latest_migration = int(cursor.fetchone()[0])
|
||||
|
||||
except:
|
||||
latest_migration = 0
|
||||
|
||||
# Version 1
|
||||
for migration, function in MIGRATIONS.items():
|
||||
if latest_migration < migration:
|
||||
logging(f"Running migration {migration}...")
|
||||
function(conn)
|
||||
latest_migration = migration
|
||||
|
||||
if latest_migration < 1:
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS token_usage (
|
||||
message_id TEXT PRIMARY KEY,
|
||||
room_id TEXT NOT NULL,
|
||||
tokens INTEGER NOT NULL,
|
||||
timestamp TIMESTAMP NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
id INTEGER NOT NULL,
|
||||
timestamp TIMESTAMP NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
|
||||
(datetime.now(),)
|
||||
)
|
||||
|
||||
database.commit()
|
||||
|
||||
return database
|
||||
return conn
|
||||
|
||||
|
||||
async def get_device_id(access_token, homeserver):
|
||||
|
|
11
migrations/__init__.py
Normal file
11
migrations/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
from .migration_1 import migration as migration_1
|
||||
from .migration_2 import migration as migration_2
|
||||
from .migration_3 import migration as migration_3
|
||||
|
||||
MIGRATIONS = OrderedDict()
|
||||
|
||||
MIGRATIONS[1] = migration_1
|
||||
MIGRATIONS[2] = migration_2
|
||||
MIGRATIONS[3] = migration_3
|
32
migrations/migration_1.py
Normal file
32
migrations/migration_1.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Initial migration, token usage logging
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
def migration(conn):
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS token_usage (
|
||||
message_id TEXT PRIMARY KEY,
|
||||
room_id TEXT NOT NULL,
|
||||
tokens INTEGER NOT NULL,
|
||||
timestamp TIMESTAMP NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
id INTEGER NOT NULL,
|
||||
timestamp TIMESTAMP NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
|
||||
(datetime.now(),)
|
||||
)
|
||||
|
||||
conn.commit()
|
138
migrations/migration_2.py
Normal file
138
migrations/migration_2.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
# Migration for Matrix Store
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
def migration(conn):
|
||||
with conn.cursor() as cursor:
|
||||
# Create accounts table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id VARCHAR NOT NULL,
|
||||
device_id VARCHAR NOT NULL,
|
||||
shared_account INTEGER NOT NULL,
|
||||
pickle VARCHAR NOT NULL
|
||||
);
|
||||
""")
|
||||
|
||||
# Create device_keys table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS device_keys (
|
||||
device_id TEXT PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
deleted BOOLEAN NOT NULL DEFAULT 0,
|
||||
UNIQUE (account_id, user_id, device_id),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS keys (
|
||||
key_type TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
device_id VARCHAR NOT NULL,
|
||||
UNIQUE (key_type, device_id),
|
||||
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create device_trust_state table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS device_trust_state (
|
||||
device_id VARCHAR PRIMARY KEY,
|
||||
state INTEGER NOT NULL,
|
||||
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create olm_sessions table
|
||||
cursor.execute("""
|
||||
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS olm_sessions (
|
||||
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
|
||||
account_id INTEGER NOT NULL,
|
||||
sender_key TEXT NOT NULL,
|
||||
session BLOB NOT NULL,
|
||||
session_id VARCHAR NOT NULL,
|
||||
creation_time TIMESTAMP NOT NULL,
|
||||
last_usage_date TIMESTAMP NOT NULL,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create inbound_group_sessions table
|
||||
cursor.execute("""
|
||||
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
|
||||
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
|
||||
account_id INTEGER NOT NULL,
|
||||
session TEXT NOT NULL,
|
||||
fp_key TEXT NOT NULL,
|
||||
sender_key TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
UNIQUE (account_id, sender_key, fp_key, room_id),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS forwarded_chains (
|
||||
id INTEGER PRIMARY KEY,
|
||||
session_id INTEGER NOT NULL,
|
||||
sender_key TEXT NOT NULL,
|
||||
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create outbound_group_sessions table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
|
||||
id INTEGER PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL,
|
||||
room_id VARCHAR NOT NULL,
|
||||
session_id VARCHAR NOT NULL UNIQUE,
|
||||
session BLOB NOT NULL,
|
||||
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create outgoing_key_requests table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
|
||||
id INTEGER PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL,
|
||||
request_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
UNIQUE (account_id, request_id)
|
||||
);
|
||||
|
||||
""")
|
||||
|
||||
# Create encrypted_rooms table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS encrypted_rooms (
|
||||
room_id TEXT NOT NULL,
|
||||
account_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (room_id, account_id),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
# Create sync_tokens table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sync_tokens (
|
||||
account_id INTEGER PRIMARY KEY,
|
||||
token TEXT NOT NULL,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
""")
|
||||
|
||||
cursor.execute(
|
||||
"INSERT INTO migrations (id, timestamp) VALUES (2, ?)",
|
||||
(datetime.now(),)
|
||||
)
|
||||
|
||||
conn.commit()
|
24
migrations/migration_3.py
Normal file
24
migrations/migration_3.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
# Migration for custom system messages
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
def migration(conn):
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS system_messages (
|
||||
room_id TEXT NOT NULL,
|
||||
message_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
body TEXT NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"INSERT INTO migrations (id, timestamp) VALUES (3, ?)",
|
||||
(datetime.now(),)
|
||||
)
|
||||
|
||||
conn.commit()
|
Loading…
Reference in a new issue