Moving migrations to subdirectory
Add option for custom system messages per room Fixing some methods in store
This commit is contained in:
parent
f20b762558
commit
2bbc6a33ca
10 changed files with 358 additions and 207 deletions
|
@ -6,6 +6,9 @@ to generate responses to messages in a Matrix room.
|
||||||
It will also save a log of the spent tokens to a DuckDB database
|
It will also save a log of the spent tokens to a DuckDB database
|
||||||
(database.db in the working directory, by default).
|
(database.db in the working directory, by default).
|
||||||
|
|
||||||
|
Note that this bot does not yet support encryption - this is still work in
|
||||||
|
progress.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
Simply clone this repository and install the requirements.
|
Simply clone this repository and install the requirements.
|
||||||
|
|
206
classes/store.py
206
classes/store.py
|
@ -1,6 +1,6 @@
|
||||||
import duckdb
|
import duckdb
|
||||||
|
|
||||||
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore
|
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore, Session
|
||||||
from nio.crypto import OlmAccount, OlmDevice
|
from nio.crypto import OlmAccount, OlmDevice
|
||||||
|
|
||||||
from random import SystemRandom
|
from random import SystemRandom
|
||||||
|
@ -24,150 +24,6 @@ class DuckDBStore(MatrixStore):
|
||||||
self.conn = duckdb_conn
|
self.conn = duckdb_conn
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.device_id = device_id
|
self.device_id = device_id
|
||||||
self._create_tables()
|
|
||||||
|
|
||||||
def _create_tables(self):
|
|
||||||
with self.conn.cursor() as cursor:
|
|
||||||
cursor.execute("""
|
|
||||||
DROP TABLE IF EXISTS sync_tokens CASCADE;
|
|
||||||
DROP TABLE IF EXISTS encrypted_rooms CASCADE;
|
|
||||||
DROP TABLE IF EXISTS outgoing_key_requests CASCADE;
|
|
||||||
DROP TABLE IF EXISTS forwarded_chains CASCADE;
|
|
||||||
DROP TABLE IF EXISTS outbound_group_sessions CASCADE;
|
|
||||||
DROP TABLE IF EXISTS inbound_group_sessions CASCADE;
|
|
||||||
DROP TABLE IF EXISTS olm_sessions CASCADE;
|
|
||||||
DROP TABLE IF EXISTS device_trust_state CASCADE;
|
|
||||||
DROP TABLE IF EXISTS keys CASCADE;
|
|
||||||
DROP TABLE IF EXISTS device_keys_key CASCADE;
|
|
||||||
DROP TABLE IF EXISTS device_keys CASCADE;
|
|
||||||
DROP TABLE IF EXISTS accounts CASCADE;
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create accounts table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS accounts (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
user_id VARCHAR NOT NULL,
|
|
||||||
device_id VARCHAR NOT NULL,
|
|
||||||
shared_account INTEGER NOT NULL,
|
|
||||||
pickle VARCHAR NOT NULL
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create device_keys table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS device_keys (
|
|
||||||
device_id TEXT PRIMARY KEY,
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
user_id TEXT NOT NULL,
|
|
||||||
display_name TEXT,
|
|
||||||
deleted BOOLEAN NOT NULL DEFAULT 0,
|
|
||||||
UNIQUE (account_id, user_id, device_id),
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS keys (
|
|
||||||
key_type TEXT NOT NULL,
|
|
||||||
key TEXT NOT NULL,
|
|
||||||
device_id VARCHAR NOT NULL,
|
|
||||||
UNIQUE (key_type, device_id),
|
|
||||||
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create device_trust_state table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS device_trust_state (
|
|
||||||
device_id VARCHAR PRIMARY KEY,
|
|
||||||
state INTEGER NOT NULL,
|
|
||||||
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create olm_sessions table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS olm_sessions (
|
|
||||||
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
sender_key TEXT NOT NULL,
|
|
||||||
session BLOB NOT NULL,
|
|
||||||
session_id VARCHAR NOT NULL,
|
|
||||||
creation_time TIMESTAMP NOT NULL,
|
|
||||||
last_usage_date TIMESTAMP NOT NULL,
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create inbound_group_sessions table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
|
|
||||||
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
session TEXT NOT NULL,
|
|
||||||
fp_key TEXT NOT NULL,
|
|
||||||
sender_key TEXT NOT NULL,
|
|
||||||
room_id TEXT NOT NULL,
|
|
||||||
UNIQUE (account_id, sender_key, fp_key, room_id),
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS forwarded_chains (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
session_id INTEGER NOT NULL,
|
|
||||||
sender_key TEXT NOT NULL,
|
|
||||||
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create outbound_group_sessions table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
room_id VARCHAR NOT NULL,
|
|
||||||
session_id VARCHAR NOT NULL UNIQUE,
|
|
||||||
session BLOB NOT NULL,
|
|
||||||
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create outgoing_key_requests table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
request_id TEXT NOT NULL,
|
|
||||||
session_id TEXT NOT NULL,
|
|
||||||
room_id TEXT NOT NULL,
|
|
||||||
algorithm TEXT NOT NULL,
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
|
||||||
UNIQUE (account_id, request_id)
|
|
||||||
);
|
|
||||||
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create encrypted_rooms table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS encrypted_rooms (
|
|
||||||
room_id TEXT NOT NULL,
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
PRIMARY KEY (room_id, account_id),
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create sync_tokens table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS sync_tokens (
|
|
||||||
account_id INTEGER PRIMARY KEY,
|
|
||||||
token TEXT NOT NULL,
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
def _get_account(self):
|
def _get_account(self):
|
||||||
cursor = self.conn.cursor()
|
cursor = self.conn.cursor()
|
||||||
|
@ -387,18 +243,18 @@ class DuckDBStore(MatrixStore):
|
||||||
for d in device_keys:
|
for d in device_keys:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT * FROM keys WHERE device_id = ?",
|
"SELECT * FROM keys WHERE device_id = ?",
|
||||||
(d["id"],)
|
(d[0],)
|
||||||
)
|
)
|
||||||
keys = cur.fetchall()
|
keys = cur.fetchall()
|
||||||
key_dict = {k["key_type"]: k["key"] for k in keys}
|
key_dict = {k[0]: k[1] for k in keys}
|
||||||
|
|
||||||
store.add(
|
store.add(
|
||||||
OlmDevice(
|
OlmDevice(
|
||||||
d["user_id"],
|
d[2],
|
||||||
d["device_id"],
|
d[0],
|
||||||
key_dict,
|
key_dict,
|
||||||
display_name=d["display_name"],
|
display_name=d[3],
|
||||||
deleted=d["deleted"],
|
deleted=d[4],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -561,18 +417,21 @@ class DuckDBStore(MatrixStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in cursor.fetchall():
|
for row in cursor.fetchall():
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
|
||||||
|
(row[1],),
|
||||||
|
)
|
||||||
|
chains = cursor.fetchall()
|
||||||
|
|
||||||
session = InboundGroupSession.from_pickle(
|
session = InboundGroupSession.from_pickle(
|
||||||
row["session"],
|
row[2].encode(),
|
||||||
row["fp_key"],
|
row[3],
|
||||||
row["sender_key"],
|
row[4],
|
||||||
row["room_id"],
|
row[5],
|
||||||
self.pickle_key,
|
self.pickle_key,
|
||||||
[
|
[
|
||||||
chain["sender_key"]
|
chain[0]
|
||||||
for chain in cursor.execute(
|
for chain in chains
|
||||||
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
|
|
||||||
(row["id"],),
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
store.add(session)
|
store.add(session)
|
||||||
|
@ -621,7 +480,7 @@ class DuckDBStore(MatrixStore):
|
||||||
)
|
)
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
|
|
||||||
return {row["room_id"] for row in rows}
|
return {row[0] for row in rows}
|
||||||
|
|
||||||
def save_sync_token(self, token):
|
def save_sync_token(self, token):
|
||||||
"""Save the given token"""
|
"""Save the given token"""
|
||||||
|
@ -728,3 +587,28 @@ class DuckDBStore(MatrixStore):
|
||||||
key_request.algorithm,
|
key_request.algorithm,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_account(self):
|
||||||
|
# type: () -> Optional[OlmAccount]
|
||||||
|
"""Load the Olm account from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``OlmAccount`` object, or ``None`` if it wasn't found for the
|
||||||
|
current device_id.
|
||||||
|
|
||||||
|
"""
|
||||||
|
cursor = self.conn.cursor()
|
||||||
|
query = """
|
||||||
|
SELECT pickle, shared_account
|
||||||
|
FROM accounts
|
||||||
|
WHERE device_id = ?;
|
||||||
|
"""
|
||||||
|
cursor.execute(query, (self.device_id,))
|
||||||
|
|
||||||
|
result = cursor.fetchone()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
account_pickle, shared = result
|
||||||
|
return OlmAccount.from_pickle(account_pickle.encode(), self.pickle_key, shared)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from .botinfo import command_botinfo
|
||||||
from .unknown import command_unknown
|
from .unknown import command_unknown
|
||||||
from .coin import command_coin
|
from .coin import command_coin
|
||||||
from .ignoreolder import command_ignoreolder
|
from .ignoreolder import command_ignoreolder
|
||||||
|
from .systemmessage import command_systemmessage
|
||||||
|
|
||||||
COMMANDS = {
|
COMMANDS = {
|
||||||
"help": command_help,
|
"help": command_help,
|
||||||
|
@ -13,5 +14,6 @@ COMMANDS = {
|
||||||
"botinfo": command_botinfo,
|
"botinfo": command_botinfo,
|
||||||
"coin": command_coin,
|
"coin": command_coin,
|
||||||
"ignoreolder": command_ignoreolder,
|
"ignoreolder": command_ignoreolder,
|
||||||
|
"systemmessage": command_systemmessage,
|
||||||
None: command_unknown,
|
None: command_unknown,
|
||||||
}
|
}
|
||||||
|
|
34
commands/systemmessage.py
Normal file
34
commands/systemmessage.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
from nio.events.room_events import RoomMessageText
|
||||||
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
|
|
||||||
|
async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||||
|
system_message = " ".join(event.body.split()[2:])
|
||||||
|
|
||||||
|
if system_message:
|
||||||
|
context["logger"]("Adding system message...")
|
||||||
|
|
||||||
|
with context["database"].cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO system_messages (room_id, message_id, user_id, body, timestamp) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(room.room_id, event.event_id, event.sender,
|
||||||
|
system_message, event.server_timestamp)
|
||||||
|
)
|
||||||
|
|
||||||
|
return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message stored: {system_message}"}
|
||||||
|
|
||||||
|
context["logger"]("Retrieving system message...")
|
||||||
|
|
||||||
|
with context["database"].cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
|
||||||
|
(room.room_id,)
|
||||||
|
)
|
||||||
|
system_message = cur.fetchone()
|
||||||
|
|
||||||
|
if system_message is None:
|
||||||
|
system_message = context.get("system_message", "No system message set")
|
||||||
|
elif context.get("force_system_message") and context.get("system_message"):
|
||||||
|
system_message = system_message + "\n\n" + context["system_message"]
|
||||||
|
|
||||||
|
return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message: {system_message}"}
|
|
@ -69,10 +69,15 @@ AccessToken = syt_yoursynapsetoken
|
||||||
#
|
#
|
||||||
# SystemMessage = You are a helpful bot.
|
# SystemMessage = You are a helpful bot.
|
||||||
|
|
||||||
|
# Force inclusion of the SystemMessage defined above if one is defined on per-room level
|
||||||
|
# If no custom message is defined for the room, SystemMessage is always included
|
||||||
|
#
|
||||||
|
# ForceSystemMessage = 0
|
||||||
|
|
||||||
[Database]
|
[Database]
|
||||||
|
|
||||||
# Settings for the DuckDB database.
|
# Settings for the DuckDB database.
|
||||||
# Currently only used to store details on spent tokens per room.
|
# If not defined, the bot will not be able to remember anything, and will not support encryption
|
||||||
# If not defined, the bot will not store this data.
|
# N.B.: Encryption doesn't work as it is supposed to anyway.
|
||||||
|
|
||||||
Path = database.db
|
Path = database.db
|
104
gptbot.py
104
gptbot.py
|
@ -23,6 +23,7 @@ from typing import List, Dict, Union, Optional
|
||||||
|
|
||||||
from commands import COMMANDS
|
from commands import COMMANDS
|
||||||
from classes import DuckDBStore
|
from classes import DuckDBStore
|
||||||
|
from migrations import MIGRATIONS
|
||||||
|
|
||||||
|
|
||||||
def logging(message: str, log_level: str = "info"):
|
def logging(message: str, log_level: str = "info"):
|
||||||
|
@ -35,6 +36,7 @@ CONTEXT = {
|
||||||
"database": False,
|
"database": False,
|
||||||
"default_room_name": "GPTBot",
|
"default_room_name": "GPTBot",
|
||||||
"system_message": "You are a helpful assistant.",
|
"system_message": "You are a helpful assistant.",
|
||||||
|
"force_system_message": False,
|
||||||
"max_tokens": 3000,
|
"max_tokens": 3000,
|
||||||
"max_messages": 20,
|
"max_messages": 20,
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
|
@ -48,6 +50,8 @@ async def gpt_query(messages: list, model: Optional[str] = None):
|
||||||
model = model or CONTEXT["model"]
|
model = model or CONTEXT["model"]
|
||||||
|
|
||||||
logging(f"Querying GPT with {len(messages)} messages")
|
logging(f"Querying GPT with {len(messages)} messages")
|
||||||
|
logging(messages, "debug")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = openai.ChatCompletion.create(
|
response = openai.ChatCompletion.create(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -143,8 +147,9 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
||||||
|
|
||||||
client = kwargs.get("client") or CONTEXT["client"]
|
client = kwargs.get("client") or CONTEXT["client"]
|
||||||
database = kwargs.get("database") or CONTEXT["database"]
|
database = kwargs.get("database") or CONTEXT["database"]
|
||||||
system_message = kwargs.get("system_message") or CONTEXT["system_message"]
|
|
||||||
max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
|
max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
|
||||||
|
system_message = kwargs.get("system_message") or CONTEXT["system_message"]
|
||||||
|
force_system_message = kwargs.get("force_system_message") or CONTEXT["force_system_message"]
|
||||||
|
|
||||||
await client.room_typing(room.room_id, True)
|
await client.room_typing(room.room_id, True)
|
||||||
|
|
||||||
|
@ -152,6 +157,12 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
||||||
|
|
||||||
last_messages = await fetch_last_n_messages(room.room_id, 20)
|
last_messages = await fetch_last_n_messages(room.room_id, 20)
|
||||||
|
|
||||||
|
system_message = get_system_message(room, {
|
||||||
|
"database": database,
|
||||||
|
"system_message": system_message,
|
||||||
|
"force_system_message": force_system_message,
|
||||||
|
})
|
||||||
|
|
||||||
chat_messages = [{"role": "system", "content": system_message}]
|
chat_messages = [{"role": "system", "content": system_message}]
|
||||||
|
|
||||||
for message in last_messages:
|
for message in last_messages:
|
||||||
|
@ -163,8 +174,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
||||||
|
|
||||||
# Truncate messages to fit within the token limit
|
# Truncate messages to fit within the token limit
|
||||||
truncated_messages = truncate_messages_to_fit_tokens(
|
truncated_messages = truncate_messages_to_fit_tokens(
|
||||||
chat_messages, max_tokens - 1)
|
chat_messages, max_tokens - 1, system_message=system_message)
|
||||||
|
|
||||||
response, tokens_used = await gpt_query(truncated_messages)
|
response, tokens_used = await gpt_query(truncated_messages)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
|
@ -204,10 +214,29 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
room_id, event, content = message
|
room_id, event, content = message
|
||||||
|
rooms = await context["client"].joined_rooms()
|
||||||
await send_message(context["client"].rooms[room_id], content["body"],
|
await send_message(context["client"].rooms[room_id], content["body"],
|
||||||
True if content["msgtype"] == "m.notice" else False, context["client"])
|
True if content["msgtype"] == "m.notice" else False, context["client"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_message(room: MatrixRoom, context: Optional[dict]) -> str:
|
||||||
|
context = context or CONTEXT
|
||||||
|
|
||||||
|
default = context.get("system_message")
|
||||||
|
|
||||||
|
with context["database"].cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
|
||||||
|
(room.room_id,)
|
||||||
|
)
|
||||||
|
system_message = cur.fetchone()
|
||||||
|
|
||||||
|
complete = ((default if ((not system_message) or context["force_system_message"]) else "") + (
|
||||||
|
"\n\n" + system_message[0] if system_message else "")).strip()
|
||||||
|
|
||||||
|
return complete
|
||||||
|
|
||||||
|
|
||||||
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
|
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
|
||||||
context = kwargs.get("context") or CONTEXT
|
context = kwargs.get("context") or CONTEXT
|
||||||
|
|
||||||
|
@ -286,7 +315,8 @@ async def send_message(room: MatrixRoom, message: str, notice: bool = False, cli
|
||||||
)
|
)
|
||||||
|
|
||||||
if msgtype != "m.reaction":
|
if msgtype != "m.reaction":
|
||||||
response = client.encrypt(room.room_id, "m.room.message", msgcontent)
|
response = client.encrypt(
|
||||||
|
room.room_id, "m.room.message", msgcontent)
|
||||||
msgtype, content = response
|
msgtype, content = response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -318,7 +348,7 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None):
|
||||||
if isinstance(response, JoinResponse):
|
if isinstance(response, JoinResponse):
|
||||||
logging(response, "debug")
|
logging(response, "debug")
|
||||||
rooms = await client.joined_rooms()
|
rooms = await client.joined_rooms()
|
||||||
await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
|
await send_message(client.rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
|
||||||
else:
|
else:
|
||||||
logging(f"Error joining room {room_id}: {response}", "error")
|
logging(f"Error joining room {room_id}: {response}", "error")
|
||||||
|
|
||||||
|
@ -408,13 +438,16 @@ async def init(config: ConfigParser):
|
||||||
if "MaxMessages" in config["OpenAI"]:
|
if "MaxMessages" in config["OpenAI"]:
|
||||||
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
|
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
|
||||||
|
|
||||||
# Listen for SIGTERM
|
# Override defaults with config
|
||||||
|
|
||||||
def sigterm_handler(_signo, _stack_frame):
|
if "GPTBot" in config:
|
||||||
logging("Received SIGTERM - exiting...")
|
if "SystemMessage" in config["GPTBot"]:
|
||||||
exit()
|
CONTEXT["system_message"] = config["GPTBot"]["SystemMessage"]
|
||||||
|
if "DefaultRoomName" in config["GPTBot"]:
|
||||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
CONTEXT["default_room_name"] = config["GPTBot"]["DefaultRoomName"]
|
||||||
|
if "ForceSystemMessage" in config["GPTBot"]:
|
||||||
|
CONTEXT["force_system_message"] = config["GPTBot"].getboolean(
|
||||||
|
"ForceSystemMessage")
|
||||||
|
|
||||||
|
|
||||||
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
|
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
|
||||||
|
@ -431,6 +464,14 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
|
||||||
await client.close()
|
await client.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Listen for SIGTERM
|
||||||
|
|
||||||
|
def sigterm_handler(_signo, _stack_frame):
|
||||||
|
logging("Received SIGTERM - exiting...")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||||
|
|
||||||
logging("Starting bot...")
|
logging("Starting bot...")
|
||||||
|
|
||||||
client.add_response_callback(sync_cb, SyncResponse)
|
client.add_response_callback(sync_cb, SyncResponse)
|
||||||
|
@ -460,9 +501,9 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
|
||||||
|
|
||||||
def initialize_database(path: os.PathLike):
|
def initialize_database(path: os.PathLike):
|
||||||
logging("Initializing database...")
|
logging("Initializing database...")
|
||||||
database = duckdb.connect(path)
|
conn = duckdb.connect(path)
|
||||||
|
|
||||||
with database.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
# Get the latest migration ID if the migrations table exists
|
# Get the latest migration ID if the migrations table exists
|
||||||
try:
|
try:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
|
@ -472,40 +513,17 @@ def initialize_database(path: os.PathLike):
|
||||||
)
|
)
|
||||||
|
|
||||||
latest_migration = int(cursor.fetchone()[0])
|
latest_migration = int(cursor.fetchone()[0])
|
||||||
|
|
||||||
except:
|
except:
|
||||||
latest_migration = 0
|
latest_migration = 0
|
||||||
|
|
||||||
# Version 1
|
for migration, function in MIGRATIONS.items():
|
||||||
|
if latest_migration < migration:
|
||||||
|
logging(f"Running migration {migration}...")
|
||||||
|
function(conn)
|
||||||
|
latest_migration = migration
|
||||||
|
|
||||||
if latest_migration < 1:
|
return conn
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS token_usage (
|
|
||||||
message_id TEXT PRIMARY KEY,
|
|
||||||
room_id TEXT NOT NULL,
|
|
||||||
tokens INTEGER NOT NULL,
|
|
||||||
timestamp TIMESTAMP NOT NULL
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS migrations (
|
|
||||||
id INTEGER NOT NULL,
|
|
||||||
timestamp TIMESTAMP NOT NULL
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
|
|
||||||
(datetime.now(),)
|
|
||||||
)
|
|
||||||
|
|
||||||
database.commit()
|
|
||||||
|
|
||||||
return database
|
|
||||||
|
|
||||||
|
|
||||||
async def get_device_id(access_token, homeserver):
|
async def get_device_id(access_token, homeserver):
|
||||||
|
|
11
migrations/__init__.py
Normal file
11
migrations/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from .migration_1 import migration as migration_1
|
||||||
|
from .migration_2 import migration as migration_2
|
||||||
|
from .migration_3 import migration as migration_3
|
||||||
|
|
||||||
|
MIGRATIONS = OrderedDict()
|
||||||
|
|
||||||
|
MIGRATIONS[1] = migration_1
|
||||||
|
MIGRATIONS[2] = migration_2
|
||||||
|
MIGRATIONS[3] = migration_3
|
32
migrations/migration_1.py
Normal file
32
migrations/migration_1.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Initial migration, token usage logging
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
def migration(conn):
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS token_usage (
|
||||||
|
message_id TEXT PRIMARY KEY,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
tokens INTEGER NOT NULL,
|
||||||
|
timestamp TIMESTAMP NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS migrations (
|
||||||
|
id INTEGER NOT NULL,
|
||||||
|
timestamp TIMESTAMP NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
|
||||||
|
(datetime.now(),)
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
138
migrations/migration_2.py
Normal file
138
migrations/migration_2.py
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
# Migration for Matrix Store
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
def migration(conn):
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
# Create accounts table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS accounts (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
user_id VARCHAR NOT NULL,
|
||||||
|
device_id VARCHAR NOT NULL,
|
||||||
|
shared_account INTEGER NOT NULL,
|
||||||
|
pickle VARCHAR NOT NULL
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create device_keys table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS device_keys (
|
||||||
|
device_id TEXT PRIMARY KEY,
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
|
deleted BOOLEAN NOT NULL DEFAULT 0,
|
||||||
|
UNIQUE (account_id, user_id, device_id),
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS keys (
|
||||||
|
key_type TEXT NOT NULL,
|
||||||
|
key TEXT NOT NULL,
|
||||||
|
device_id VARCHAR NOT NULL,
|
||||||
|
UNIQUE (key_type, device_id),
|
||||||
|
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create device_trust_state table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS device_trust_state (
|
||||||
|
device_id VARCHAR PRIMARY KEY,
|
||||||
|
state INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create olm_sessions table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS olm_sessions (
|
||||||
|
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
sender_key TEXT NOT NULL,
|
||||||
|
session BLOB NOT NULL,
|
||||||
|
session_id VARCHAR NOT NULL,
|
||||||
|
creation_time TIMESTAMP NOT NULL,
|
||||||
|
last_usage_date TIMESTAMP NOT NULL,
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create inbound_group_sessions table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
|
||||||
|
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
session TEXT NOT NULL,
|
||||||
|
fp_key TEXT NOT NULL,
|
||||||
|
sender_key TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
UNIQUE (account_id, sender_key, fp_key, room_id),
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS forwarded_chains (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
session_id INTEGER NOT NULL,
|
||||||
|
sender_key TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create outbound_group_sessions table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
room_id VARCHAR NOT NULL,
|
||||||
|
session_id VARCHAR NOT NULL UNIQUE,
|
||||||
|
session BLOB NOT NULL,
|
||||||
|
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create outgoing_key_requests table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
request_id TEXT NOT NULL,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
algorithm TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
||||||
|
UNIQUE (account_id, request_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create encrypted_rooms table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS encrypted_rooms (
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
account_id INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (room_id, account_id),
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create sync_tokens table
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS sync_tokens (
|
||||||
|
account_id INTEGER PRIMARY KEY,
|
||||||
|
token TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO migrations (id, timestamp) VALUES (2, ?)",
|
||||||
|
(datetime.now(),)
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
24
migrations/migration_3.py
Normal file
24
migrations/migration_3.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# Migration for custom system messages
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
def migration(conn):
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS system_messages (
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
message_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
body TEXT NOT NULL,
|
||||||
|
timestamp BIGINT NOT NULL,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO migrations (id, timestamp) VALUES (3, ?)",
|
||||||
|
(datetime.now(),)
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
Loading…
Reference in a new issue