Make it an importable module
Abandon DuckDB in favor of sqlite3
This commit is contained in:
parent
3a1d1ea86a
commit
55809a9a39
56 changed files with 234 additions and 828 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -4,3 +4,4 @@ config.ini
|
||||||
venv/
|
venv/
|
||||||
*.pyc
|
*.pyc
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
*.bak
|
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"python.formatting.provider": "black"
|
||||||
|
}
|
57
README.md
57
README.md
|
@ -26,12 +26,61 @@ probably add more in the future, so the name is a bit misleading.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
Simply clone this repository and install the requirements.
|
To run the bot, you will need Python 3.10 or newer.
|
||||||
|
|
||||||
### Requirements
|
The bot has been tested with Python 3.11 on Arch, but should work with any
|
||||||
|
current version, and should not require any special dependencies or operating
|
||||||
|
system features.
|
||||||
|
|
||||||
- Python 3.10 or later
|
### Production
|
||||||
- Requirements from `requirements.txt` (install with `pip install -r requirements.txt` in a venv)
|
|
||||||
|
The easiest way to install the bot is to use pip to install it directly from
|
||||||
|
[its Git repository](https://kumig.it/kumitterer/matrix-gptbot/):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# If desired, activate a venv first
|
||||||
|
|
||||||
|
python -m venv venv
|
||||||
|
. venv/bin/activate
|
||||||
|
|
||||||
|
# Install the bot
|
||||||
|
|
||||||
|
pip install git+https://kumig.it/kumitterer/matrix-gptbot.git
|
||||||
|
```
|
||||||
|
|
||||||
|
This will install the bot from the main branch and all required dependencies.
|
||||||
|
A release to PyPI is planned, but not yet available.
|
||||||
|
|
||||||
|
### Development
|
||||||
|
|
||||||
|
Clone the repository and install the requirements to a virtual environment.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Clone the repository
|
||||||
|
|
||||||
|
git clone https://kumig.it/kumitterer/matrix-gptbot.git
|
||||||
|
cd matrix-gptbot
|
||||||
|
|
||||||
|
# If desired, activate a venv first
|
||||||
|
|
||||||
|
python -m venv venv
|
||||||
|
. venv/bin/activate
|
||||||
|
|
||||||
|
# Install the requirements
|
||||||
|
|
||||||
|
pip install -Ur requirements.txt
|
||||||
|
|
||||||
|
# Install the bot in editable mode
|
||||||
|
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
|
# Go to the bot directory and start working
|
||||||
|
|
||||||
|
cd src/gptbot
|
||||||
|
```
|
||||||
|
|
||||||
|
Of course, you can also fork the repository on [GitHub](https://github.com/kumitterer/matrix-gptbot/)
|
||||||
|
and work on your own copy.
|
||||||
|
|
||||||
### Configuration
|
### Configuration
|
||||||
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
from .store import DuckDBStore
|
|
637
classes/store.py
637
classes/store.py
|
@ -1,637 +0,0 @@
|
||||||
import duckdb
|
|
||||||
|
|
||||||
from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore, Session
|
|
||||||
from nio.crypto import OlmAccount, OlmDevice
|
|
||||||
|
|
||||||
from random import SystemRandom
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from .dict import AttrDict
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
class DuckDBStore(MatrixStore):
|
|
||||||
@property
|
|
||||||
def account_id(self):
|
|
||||||
id = self._get_account()[0] if self._get_account() else None
|
|
||||||
|
|
||||||
if id is None:
|
|
||||||
id = SystemRandom().randint(0, 2**16)
|
|
||||||
|
|
||||||
return id
|
|
||||||
|
|
||||||
def __init__(self, user_id, device_id, duckdb_conn):
|
|
||||||
self.conn = duckdb_conn
|
|
||||||
self.user_id = user_id
|
|
||||||
self.device_id = device_id
|
|
||||||
|
|
||||||
def _get_account(self):
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT * FROM accounts WHERE user_id = ? AND device_id = ?",
|
|
||||||
(self.user_id, self.device_id),
|
|
||||||
)
|
|
||||||
account = cursor.fetchone()
|
|
||||||
cursor.close()
|
|
||||||
return account
|
|
||||||
|
|
||||||
def _get_device(self, device):
|
|
||||||
acc = self._get_account()
|
|
||||||
|
|
||||||
if not acc:
|
|
||||||
return None
|
|
||||||
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?",
|
|
||||||
(device.user_id, device.id, acc[0]),
|
|
||||||
)
|
|
||||||
device_entry = cursor.fetchone()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
return device_entry
|
|
||||||
|
|
||||||
# Implementing methods with DuckDB equivalents
|
|
||||||
def verify_device(self, device):
|
|
||||||
if self.is_device_verified(device):
|
|
||||||
return False
|
|
||||||
|
|
||||||
d = self._get_device(device)
|
|
||||||
assert d
|
|
||||||
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
|
||||||
(d[0], TrustState.verified),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
device.trust_state = TrustState.verified
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def unverify_device(self, device):
|
|
||||||
if not self.is_device_verified(device):
|
|
||||||
return False
|
|
||||||
|
|
||||||
d = self._get_device(device)
|
|
||||||
assert d
|
|
||||||
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
|
||||||
(d[0], TrustState.unset),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
device.trust_state = TrustState.unset
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def is_device_verified(self, device):
|
|
||||||
d = self._get_device(device)
|
|
||||||
|
|
||||||
if not d:
|
|
||||||
return False
|
|
||||||
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
|
||||||
)
|
|
||||||
trust_state = cursor.fetchone()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
if not trust_state:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return trust_state[0] == TrustState.verified
|
|
||||||
|
|
||||||
def blacklist_device(self, device):
|
|
||||||
if self.is_device_blacklisted(device):
|
|
||||||
return False
|
|
||||||
|
|
||||||
d = self._get_device(device)
|
|
||||||
assert d
|
|
||||||
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
|
||||||
(d[0], TrustState.blacklisted),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
device.trust_state = TrustState.blacklisted
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def unblacklist_device(self, device):
|
|
||||||
if not self.is_device_blacklisted(device):
|
|
||||||
return False
|
|
||||||
|
|
||||||
d = self._get_device(device)
|
|
||||||
assert d
|
|
||||||
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
|
||||||
(d[0], TrustState.unset),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
device.trust_state = TrustState.unset
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def is_device_blacklisted(self, device):
|
|
||||||
d = self._get_device(device)
|
|
||||||
|
|
||||||
if not d:
|
|
||||||
return False
|
|
||||||
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
|
||||||
)
|
|
||||||
trust_state = cursor.fetchone()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
if not trust_state:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return trust_state[0] == TrustState.blacklisted
|
|
||||||
|
|
||||||
def ignore_device(self, device):
|
|
||||||
if self.is_device_ignored(device):
|
|
||||||
return False
|
|
||||||
|
|
||||||
d = self._get_device(device)
|
|
||||||
assert d
|
|
||||||
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
|
||||||
(d[0], int(TrustState.ignored.value)),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def ignore_devices(self, devices):
|
|
||||||
for device in devices:
|
|
||||||
self.ignore_device(device)
|
|
||||||
|
|
||||||
def unignore_device(self, device):
|
|
||||||
if not self.is_device_ignored(device):
|
|
||||||
return False
|
|
||||||
|
|
||||||
d = self._get_device(device)
|
|
||||||
assert d
|
|
||||||
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
|
|
||||||
(d[0], TrustState.unset),
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
device.trust_state = TrustState.unset
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def is_device_ignored(self, device):
|
|
||||||
d = self._get_device(device)
|
|
||||||
|
|
||||||
if not d:
|
|
||||||
return False
|
|
||||||
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
|
|
||||||
)
|
|
||||||
trust_state = cursor.fetchone()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
if not trust_state:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return trust_state[0] == TrustState.ignored
|
|
||||||
|
|
||||||
def load_device_keys(self):
|
|
||||||
"""Load all the device keys from the database.
|
|
||||||
|
|
||||||
Returns DeviceStore containing the OlmDevices with the device keys.
|
|
||||||
"""
|
|
||||||
store = DeviceStore()
|
|
||||||
account = self.account_id
|
|
||||||
|
|
||||||
if not account:
|
|
||||||
return store
|
|
||||||
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
cur.execute(
|
|
||||||
"SELECT * FROM device_keys WHERE account_id = ?",
|
|
||||||
(account,)
|
|
||||||
)
|
|
||||||
device_keys = cur.fetchall()
|
|
||||||
|
|
||||||
for d in device_keys:
|
|
||||||
cur.execute(
|
|
||||||
"SELECT * FROM keys WHERE device_id = ?",
|
|
||||||
(d[0],)
|
|
||||||
)
|
|
||||||
keys = cur.fetchall()
|
|
||||||
key_dict = {k[0]: k[1] for k in keys}
|
|
||||||
|
|
||||||
store.add(
|
|
||||||
OlmDevice(
|
|
||||||
d[2],
|
|
||||||
d[0],
|
|
||||||
key_dict,
|
|
||||||
display_name=d[3],
|
|
||||||
deleted=d[4],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return store
|
|
||||||
|
|
||||||
def save_device_keys(self, device_keys):
|
|
||||||
"""Save the provided device keys to the database."""
|
|
||||||
account = self.account_id
|
|
||||||
assert account
|
|
||||||
rows = []
|
|
||||||
|
|
||||||
for user_id, devices_dict in device_keys.items():
|
|
||||||
for device_id, device in devices_dict.items():
|
|
||||||
rows.append(
|
|
||||||
{
|
|
||||||
"account_id": account,
|
|
||||||
"user_id": user_id,
|
|
||||||
"device_id": device_id,
|
|
||||||
"display_name": device.display_name,
|
|
||||||
"deleted": device.deleted,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not rows:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
for idx in range(0, len(rows), 100):
|
|
||||||
data = rows[idx: idx + 100]
|
|
||||||
cur.executemany(
|
|
||||||
"INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
[(r["account_id"], r["user_id"], r["device_id"],
|
|
||||||
r["display_name"], r["deleted"]) for r in data]
|
|
||||||
)
|
|
||||||
|
|
||||||
for user_id, devices_dict in device_keys.items():
|
|
||||||
for device_id, device in devices_dict.items():
|
|
||||||
cur.execute(
|
|
||||||
"UPDATE device_keys SET deleted = ? WHERE device_id = ?",
|
|
||||||
(device.deleted, device_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
for key_type, key in device.keys.items():
|
|
||||||
cur.execute("""
|
|
||||||
INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?)
|
|
||||||
ON CONFLICT (key_type, device_id) DO UPDATE SET key = ?
|
|
||||||
""",
|
|
||||||
(key_type, key, device_id, key)
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def save_group_sessions(self, sessions):
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
for session in sessions:
|
|
||||||
cur.execute("""
|
|
||||||
INSERT OR REPLACE INTO inbound_group_sessions (
|
|
||||||
session_id, sender_key, signing_key, room_id, pickle, account_id
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?)
|
|
||||||
""", (
|
|
||||||
session.id,
|
|
||||||
session.sender_key,
|
|
||||||
session.signing_key,
|
|
||||||
session.room_id,
|
|
||||||
session.pickle,
|
|
||||||
self.account_id
|
|
||||||
))
|
|
||||||
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def save_olm_sessions(self, sessions):
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
for session in sessions:
|
|
||||||
cur.execute("""
|
|
||||||
INSERT OR REPLACE INTO olm_sessions (
|
|
||||||
session_id, sender_key, pickle, account_id
|
|
||||||
) VALUES (?, ?, ?, ?)
|
|
||||||
""", (
|
|
||||||
session.id,
|
|
||||||
session.sender_key,
|
|
||||||
session.pickle,
|
|
||||||
self.account_id
|
|
||||||
))
|
|
||||||
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def save_outbound_group_sessions(self, sessions):
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
for session in sessions:
|
|
||||||
cur.execute("""
|
|
||||||
INSERT OR REPLACE INTO outbound_group_sessions (
|
|
||||||
room_id, session_id, pickle, account_id
|
|
||||||
) VALUES (?, ?, ?, ?)
|
|
||||||
""", (
|
|
||||||
session.room_id,
|
|
||||||
session.id,
|
|
||||||
session.pickle,
|
|
||||||
self.account_id
|
|
||||||
))
|
|
||||||
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def save_account(self, account: OlmAccount):
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
cur.execute("""
|
|
||||||
INSERT OR REPLACE INTO accounts (
|
|
||||||
id, user_id, device_id, shared_account, pickle
|
|
||||||
) VALUES (?, ?, ?, ?, ?)
|
|
||||||
""", (
|
|
||||||
self.account_id,
|
|
||||||
self.user_id,
|
|
||||||
self.device_id,
|
|
||||||
account.shared,
|
|
||||||
account.pickle(self.pickle_key),
|
|
||||||
))
|
|
||||||
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def load_sessions(self):
|
|
||||||
session_store = SessionStore()
|
|
||||||
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
cur.execute("""
|
|
||||||
SELECT
|
|
||||||
os.sender_key, os.session, os.creation_time
|
|
||||||
FROM
|
|
||||||
olm_sessions os
|
|
||||||
INNER JOIN
|
|
||||||
accounts a ON os.account_id = a.id
|
|
||||||
WHERE
|
|
||||||
a.id = ?
|
|
||||||
""", (self.account_id,))
|
|
||||||
|
|
||||||
for row in cur.fetchall():
|
|
||||||
sender_key, session_pickle, creation_time = row
|
|
||||||
session = Session.from_pickle(
|
|
||||||
session_pickle, creation_time, self.pickle_key)
|
|
||||||
session_store.add(sender_key, session)
|
|
||||||
|
|
||||||
return session_store
|
|
||||||
|
|
||||||
def load_inbound_group_sessions(self):
|
|
||||||
# type: () -> GroupSessionStore
|
|
||||||
"""Load all Olm sessions from the database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``GroupSessionStore`` object, containing all the loaded sessions.
|
|
||||||
|
|
||||||
"""
|
|
||||||
store = GroupSessionStore()
|
|
||||||
|
|
||||||
account = self.account_id
|
|
||||||
|
|
||||||
if not account:
|
|
||||||
return store
|
|
||||||
|
|
||||||
with self.conn.cursor() as cursor:
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT * FROM inbound_group_sessions WHERE account_id = ?", (
|
|
||||||
account,)
|
|
||||||
)
|
|
||||||
|
|
||||||
for row in cursor.fetchall():
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
|
|
||||||
(row[1],),
|
|
||||||
)
|
|
||||||
chains = cursor.fetchall()
|
|
||||||
|
|
||||||
session = InboundGroupSession.from_pickle(
|
|
||||||
row[2].encode(),
|
|
||||||
row[3],
|
|
||||||
row[4],
|
|
||||||
row[5],
|
|
||||||
self.pickle_key,
|
|
||||||
[
|
|
||||||
chain[0]
|
|
||||||
for chain in chains
|
|
||||||
],
|
|
||||||
)
|
|
||||||
store.add(session)
|
|
||||||
|
|
||||||
return store
|
|
||||||
|
|
||||||
def load_outgoing_key_requests(self):
|
|
||||||
# type: () -> dict
|
|
||||||
"""Load all outgoing key requests from the database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``OutgoingKeyRequestStore`` object, containing all the loaded key requests.
|
|
||||||
"""
|
|
||||||
account = self.account_id
|
|
||||||
|
|
||||||
if not account:
|
|
||||||
return store
|
|
||||||
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
cur.execute(
|
|
||||||
"SELECT * FROM outgoing_key_requests WHERE account_id = ?",
|
|
||||||
(account,)
|
|
||||||
)
|
|
||||||
rows = cur.fetchall()
|
|
||||||
|
|
||||||
return {
|
|
||||||
row[1]: OutgoingKeyRequest.from_response(AttrDict({
|
|
||||||
"id": row[0],
|
|
||||||
"account_id": row[1],
|
|
||||||
"request_id": row[2],
|
|
||||||
"session_id": row[3],
|
|
||||||
"room_id": row[4],
|
|
||||||
"algorithm": row[5],
|
|
||||||
})) for row in rows
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_encrypted_rooms(self):
|
|
||||||
"""Load the set of encrypted rooms for this account.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``Set`` containing room ids of encrypted rooms.
|
|
||||||
"""
|
|
||||||
account = self.account_id
|
|
||||||
|
|
||||||
if not account:
|
|
||||||
return set()
|
|
||||||
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
cur.execute(
|
|
||||||
"SELECT room_id FROM encrypted_rooms WHERE account_id = ?",
|
|
||||||
(account,)
|
|
||||||
)
|
|
||||||
rows = cur.fetchall()
|
|
||||||
|
|
||||||
return {row[0] for row in rows}
|
|
||||||
|
|
||||||
def save_sync_token(self, token):
|
|
||||||
"""Save the given token"""
|
|
||||||
account = self.account_id
|
|
||||||
assert account
|
|
||||||
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
cur.execute(
|
|
||||||
"INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)",
|
|
||||||
(account, token)
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def save_encrypted_rooms(self, rooms):
|
|
||||||
"""Save the set of room ids for this account."""
|
|
||||||
account = self.account_id
|
|
||||||
assert account
|
|
||||||
|
|
||||||
data = [(room_id, account) for room_id in rooms]
|
|
||||||
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
for idx in range(0, len(data), 400):
|
|
||||||
rows = data[idx: idx + 400]
|
|
||||||
cur.executemany(
|
|
||||||
"INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)",
|
|
||||||
rows
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def save_session(self, sender_key, session):
|
|
||||||
"""Save the provided Olm session to the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sender_key (str): The curve key that owns the Olm session.
|
|
||||||
session (Session): The Olm session that will be pickled and
|
|
||||||
saved in the database.
|
|
||||||
"""
|
|
||||||
account = self.account_id
|
|
||||||
assert account
|
|
||||||
|
|
||||||
pickled_session = session.pickle(self.pickle_key)
|
|
||||||
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
|
|
||||||
cur.execute(
|
|
||||||
"INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)",
|
|
||||||
(account, sender_key, pickled_session, session.id,
|
|
||||||
session.creation_time, session.use_time)
|
|
||||||
)
|
|
||||||
self.conn.commit()
|
|
||||||
|
|
||||||
def save_inbound_group_session(self, session):
|
|
||||||
"""Save the provided Megolm inbound group session to the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session (InboundGroupSession): The session to save.
|
|
||||||
"""
|
|
||||||
account = self.account_id
|
|
||||||
assert account
|
|
||||||
|
|
||||||
with self.conn.cursor() as cur:
|
|
||||||
|
|
||||||
# Insert a new session or update the existing one
|
|
||||||
query = """
|
|
||||||
INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
ON CONFLICT (account_id, sender_key, fp_key, room_id)
|
|
||||||
DO UPDATE SET session = excluded.session
|
|
||||||
"""
|
|
||||||
cur.execute(query, (account, session.sender_key,
|
|
||||||
session.ed25519, session.room_id, session.pickle(self.pickle_key)))
|
|
||||||
|
|
||||||
# Delete existing forwarded chains for the session
|
|
||||||
delete_query = """
|
|
||||||
DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?)
|
|
||||||
"""
|
|
||||||
cur.execute(
|
|
||||||
delete_query, (account, session.sender_key, session.ed25519, session.room_id))
|
|
||||||
|
|
||||||
# Insert new forwarded chains for the session
|
|
||||||
insert_query = """
|
|
||||||
INSERT INTO forwarded_chains (session_id, sender_key)
|
|
||||||
VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?)
|
|
||||||
"""
|
|
||||||
|
|
||||||
for chain in session.forwarding_chain:
|
|
||||||
cur.execute(
|
|
||||||
insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain))
|
|
||||||
|
|
||||||
def add_outgoing_key_request(self, key_request):
|
|
||||||
"""Add a new outgoing key request to the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key_request (OutgoingKeyRequest): The key request to add.
|
|
||||||
"""
|
|
||||||
|
|
||||||
account_id = self.account_id
|
|
||||||
with self.conn.cursor() as cursor:
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
SELECT MAX(id) FROM outgoing_key_requests
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
|
||||||
request_id = row[0] + 1 if row[0] else 1
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO outgoing_key_requests (id, account_id, request_id, session_id, room_id, algorithm)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
|
||||||
ON CONFLICT (account_id, request_id) DO NOTHING
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
request_id,
|
|
||||||
account_id,
|
|
||||||
key_request.request_id,
|
|
||||||
key_request.session_id,
|
|
||||||
key_request.room_id,
|
|
||||||
key_request.algorithm,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_account(self):
|
|
||||||
# type: () -> Optional[OlmAccount]
|
|
||||||
"""Load the Olm account from the database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``OlmAccount`` object, or ``None`` if it wasn't found for the
|
|
||||||
current device_id.
|
|
||||||
|
|
||||||
"""
|
|
||||||
cursor = self.conn.cursor()
|
|
||||||
query = """
|
|
||||||
SELECT pickle, shared_account
|
|
||||||
FROM accounts
|
|
||||||
WHERE device_id = ?;
|
|
||||||
"""
|
|
||||||
cursor.execute(query, (self.device_id,))
|
|
||||||
|
|
||||||
result = cursor.fetchone()
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return None
|
|
||||||
|
|
||||||
account_pickle, shared = result
|
|
||||||
return OlmAccount.from_pickle(account_pickle.encode(), self.pickle_key, shared)
|
|
|
@ -106,15 +106,35 @@ Operator = Contact details not set
|
||||||
|
|
||||||
[Database]
|
[Database]
|
||||||
|
|
||||||
# Settings for the DuckDB database.
|
# Path of the main database
|
||||||
# If not defined, the bot will not be able to remember anything, and will not support encryption
|
# Used to "remember" settings, etc.
|
||||||
# N.B.: Encryption doesn't work as it is supposed to anyway.
|
#
|
||||||
|
|
||||||
Path = database.db
|
Path = database.db
|
||||||
|
|
||||||
|
# Path of the Crypto Store - required to support encrypted rooms
|
||||||
|
# (not tested/supported yet)
|
||||||
|
#
|
||||||
|
CryptoStore = store.db
|
||||||
|
|
||||||
[TrackingMore]
|
[TrackingMore]
|
||||||
|
|
||||||
# API key for TrackingMore
|
# API key for TrackingMore
|
||||||
# If not defined, the bot will not be able to provide parcel tracking
|
# If not defined, the bot will not be able to provide parcel tracking
|
||||||
#
|
#
|
||||||
# APIKey = abcde-fghij-klmnop
|
# APIKey = abcde-fghij-klmnop
|
||||||
|
|
||||||
|
[Replicate]
|
||||||
|
|
||||||
|
# API key for replicate.com
|
||||||
|
# Can be used to run lots of different AI models
|
||||||
|
# If not defined, the features that depend on it are not available
|
||||||
|
#
|
||||||
|
# APIKey = r8_alotoflettersandnumbershere
|
||||||
|
|
||||||
|
[HuggingFace]
|
||||||
|
|
||||||
|
# API key for Hugging Face
|
||||||
|
# Can be used to run lots of different AI models
|
||||||
|
# If not defined, the features that depend on it are not available
|
||||||
|
#
|
||||||
|
# APIKey = __________________________
|
BIN
database.db-journal
Normal file
BIN
database.db-journal
Normal file
Binary file not shown.
|
@ -1,5 +1,5 @@
|
||||||
[Unit]
|
[Unit]
|
||||||
Description=GPTbot - A GPT bot for Matrix
|
Description=GPTbot - A multifunctional Chatbot for Matrix
|
||||||
Requires=network.target
|
Requires=network.target
|
||||||
|
|
||||||
[Service]
|
[Service]
|
||||||
|
|
|
@ -1,138 +0,0 @@
|
||||||
# Migration for Matrix Store
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
def migration(conn):
|
|
||||||
with conn.cursor() as cursor:
|
|
||||||
# Create accounts table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS accounts (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
user_id VARCHAR NOT NULL,
|
|
||||||
device_id VARCHAR NOT NULL,
|
|
||||||
shared_account INTEGER NOT NULL,
|
|
||||||
pickle VARCHAR NOT NULL
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create device_keys table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS device_keys (
|
|
||||||
device_id TEXT PRIMARY KEY,
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
user_id TEXT NOT NULL,
|
|
||||||
display_name TEXT,
|
|
||||||
deleted BOOLEAN NOT NULL DEFAULT 0,
|
|
||||||
UNIQUE (account_id, user_id, device_id),
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS keys (
|
|
||||||
key_type TEXT NOT NULL,
|
|
||||||
key TEXT NOT NULL,
|
|
||||||
device_id VARCHAR NOT NULL,
|
|
||||||
UNIQUE (key_type, device_id),
|
|
||||||
FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create device_trust_state table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS device_trust_state (
|
|
||||||
device_id VARCHAR PRIMARY KEY,
|
|
||||||
state INTEGER NOT NULL,
|
|
||||||
FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create olm_sessions table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS olm_sessions (
|
|
||||||
id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
sender_key TEXT NOT NULL,
|
|
||||||
session BLOB NOT NULL,
|
|
||||||
session_id VARCHAR NOT NULL,
|
|
||||||
creation_time TIMESTAMP NOT NULL,
|
|
||||||
last_usage_date TIMESTAMP NOT NULL,
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create inbound_group_sessions table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS inbound_group_sessions (
|
|
||||||
id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
session TEXT NOT NULL,
|
|
||||||
fp_key TEXT NOT NULL,
|
|
||||||
sender_key TEXT NOT NULL,
|
|
||||||
room_id TEXT NOT NULL,
|
|
||||||
UNIQUE (account_id, sender_key, fp_key, room_id),
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS forwarded_chains (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
session_id INTEGER NOT NULL,
|
|
||||||
sender_key TEXT NOT NULL,
|
|
||||||
FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create outbound_group_sessions table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS outbound_group_sessions (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
room_id VARCHAR NOT NULL,
|
|
||||||
session_id VARCHAR NOT NULL UNIQUE,
|
|
||||||
session BLOB NOT NULL,
|
|
||||||
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create outgoing_key_requests table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS outgoing_key_requests (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
request_id TEXT NOT NULL,
|
|
||||||
session_id TEXT NOT NULL,
|
|
||||||
room_id TEXT NOT NULL,
|
|
||||||
algorithm TEXT NOT NULL,
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
|
||||||
UNIQUE (account_id, request_id)
|
|
||||||
);
|
|
||||||
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create encrypted_rooms table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS encrypted_rooms (
|
|
||||||
room_id TEXT NOT NULL,
|
|
||||||
account_id INTEGER NOT NULL,
|
|
||||||
PRIMARY KEY (room_id, account_id),
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create sync_tokens table
|
|
||||||
cursor.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS sync_tokens (
|
|
||||||
account_id INTEGER PRIMARY KEY,
|
|
||||||
token TEXT NOT NULL,
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
""")
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"INSERT INTO migrations (id, timestamp) VALUES (2, ?)",
|
|
||||||
(datetime.now(),)
|
|
||||||
)
|
|
||||||
|
|
||||||
conn.commit()
|
|
69
pyproject.toml
Normal file
69
pyproject.toml
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.metadata]
|
||||||
|
allow-direct-references = true
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "matrix-gptbot"
|
||||||
|
version = "0.1.0-alpha1"
|
||||||
|
|
||||||
|
authors = [
|
||||||
|
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
|
||||||
|
]
|
||||||
|
|
||||||
|
description = "Multifunctional Chatbot for Matrix"
|
||||||
|
readme = "README.md"
|
||||||
|
license = { file="LICENSE" }
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
|
packages = [
|
||||||
|
"src/gptbot"
|
||||||
|
]
|
||||||
|
|
||||||
|
classifiers = [
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
]
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"matrix-nio[e2e]",
|
||||||
|
"markdown2[all]",
|
||||||
|
"tiktoken",
|
||||||
|
"python-magic",
|
||||||
|
"pillow",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
openai = [
|
||||||
|
"openai",
|
||||||
|
]
|
||||||
|
|
||||||
|
wolframalpha = [
|
||||||
|
"wolframalpha",
|
||||||
|
]
|
||||||
|
|
||||||
|
trackingmore = [
|
||||||
|
"trackingmore @ git+https://kumig.it/kumitterer/trackingmore-api-tool.git",
|
||||||
|
]
|
||||||
|
|
||||||
|
all = [
|
||||||
|
"matrix-gptbot[openai,wolframalpha,trackingmore]",
|
||||||
|
]
|
||||||
|
|
||||||
|
dev = [
|
||||||
|
"matrix-gptbot[all]",
|
||||||
|
"black",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
"Homepage" = "https://kumig.it/kumitterer/matrix-gptbot"
|
||||||
|
"Bug Tracker" = "https://kumig.it/kumitterer/matrix-gptbot/issues"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
gptbot = "gptbot:main"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/gptbot"]
|
|
@ -1,4 +1,4 @@
|
||||||
from classes.bot import GPTBot
|
from .classes.bot import GPTBot
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
|
@ -15,7 +15,10 @@ if __name__ == "__main__":
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config", help="Path to config file (default: config.ini in working directory)", default="config.ini")
|
"--config",
|
||||||
|
help="Path to config file (default: config.ini in working directory)",
|
||||||
|
default="config.ini",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Read config file
|
# Read config file
|
Before Width: | Height: | Size: 186 KiB After Width: | Height: | Size: 186 KiB |
|
@ -1,10 +1,12 @@
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
async def join_callback(response, bot):
|
async def join_callback(response, bot):
|
||||||
bot.logger.log(
|
bot.logger.log(
|
||||||
f"Join response received for room {response.room_id}", "debug")
|
f"Join response received for room {response.room_id}", "debug")
|
||||||
|
|
||||||
bot.matrix_client.joined_rooms()
|
bot.matrix_client.joined_rooms()
|
||||||
|
|
||||||
with bot.database.cursor() as cursor:
|
with closing(bot.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
|
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
|
||||||
space = cursor.fetchone()
|
space = cursor.fetchone()
|
0
src/gptbot/classes/__init__.py
Normal file
0
src/gptbot/classes/__init__.py
Normal file
|
@ -1,5 +1,4 @@
|
||||||
import markdown2
|
import markdown2
|
||||||
import duckdb
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
@ -30,30 +29,34 @@ from nio import (
|
||||||
RoomCreateError,
|
RoomCreateError,
|
||||||
)
|
)
|
||||||
from nio.crypto import Olm
|
from nio.crypto import Olm
|
||||||
|
from nio.store import SqliteStore
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import traceback
|
import traceback
|
||||||
import json
|
import json
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
from .logging import Logger
|
from .logging import Logger
|
||||||
from migrations import migrate
|
from ..migrations import migrate
|
||||||
from callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
|
from ..callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
|
||||||
from commands import COMMANDS
|
from ..commands import COMMANDS
|
||||||
from .store import DuckDBStore
|
|
||||||
from .openai import OpenAI
|
from .openai import OpenAI
|
||||||
from .wolframalpha import WolframAlpha
|
from .wolframalpha import WolframAlpha
|
||||||
from .trackingmore import TrackingMore
|
from .trackingmore import TrackingMore
|
||||||
|
|
||||||
|
|
||||||
class GPTBot:
|
class GPTBot:
|
||||||
# Default values
|
# Default values
|
||||||
database: Optional[duckdb.DuckDBPyConnection] = None
|
database: Optional[sqlite3.Connection] = None
|
||||||
|
crypto_store_path: Optional[str|Path] = None
|
||||||
# Default name of rooms created by the bot
|
# Default name of rooms created by the bot
|
||||||
display_name = default_room_name = "GPTBot"
|
display_name = default_room_name = "GPTBot"
|
||||||
default_system_message: str = "You are a helpful assistant."
|
default_system_message: str = "You are a helpful assistant."
|
||||||
|
@ -90,9 +93,11 @@ class GPTBot:
|
||||||
bot = cls()
|
bot = cls()
|
||||||
|
|
||||||
# Set the database connection
|
# Set the database connection
|
||||||
bot.database = duckdb.connect(
|
bot.database = sqlite3.connect(
|
||||||
config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
|
config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
|
||||||
|
|
||||||
|
bot.crypto_store_path = config["Database"]["CryptoStore"] if "Database" in config and "CryptoStore" in config["Database"] else None
|
||||||
|
|
||||||
# Override default values
|
# Override default values
|
||||||
if "GPTBot" in config:
|
if "GPTBot" in config:
|
||||||
bot.operator = config["GPTBot"].get("Operator", bot.operator)
|
bot.operator = config["GPTBot"].get("Operator", bot.operator)
|
||||||
|
@ -290,7 +295,7 @@ class GPTBot:
|
||||||
"""
|
"""
|
||||||
room_id = room.room_id if isinstance(room, MatrixRoom) else room
|
room_id = room.room_id if isinstance(room, MatrixRoom) else room
|
||||||
|
|
||||||
with self.database.cursor() as cursor:
|
with closing(self.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_classification"))
|
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_classification"))
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
|
@ -362,7 +367,7 @@ class GPTBot:
|
||||||
"""
|
"""
|
||||||
room_id = room.room_id
|
room_id = room.room_id
|
||||||
|
|
||||||
with self.database.cursor() as cursor:
|
with closing(self.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_timing"))
|
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_timing"))
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
|
@ -584,10 +589,17 @@ class GPTBot:
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
"No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
|
"No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
|
||||||
IN_MEMORY = True
|
IN_MEMORY = True
|
||||||
self.database = duckdb.DuckDBPyConnection(":memory:")
|
self.database = sqlite3.connect(":memory:")
|
||||||
|
|
||||||
self.logger.log("Running migrations...")
|
self.logger.log("Running migrations...")
|
||||||
|
|
||||||
|
try:
|
||||||
before, after = migrate(self.database)
|
before, after = migrate(self.database)
|
||||||
|
except sqlite3.DatabaseError as e:
|
||||||
|
self.logger.log(f"Error migrating database: {e}", "fatal")
|
||||||
|
self.logger.log("If you have just updated the bot, the previous version of the database may be incompatible with this version. Please delete the database file and try again.", "fatal")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
if before != after:
|
if before != after:
|
||||||
self.logger.log(f"Migrated from version {before} to {after}.")
|
self.logger.log(f"Migrated from version {before} to {after}.")
|
||||||
else:
|
else:
|
||||||
|
@ -597,14 +609,14 @@ class GPTBot:
|
||||||
client_config = AsyncClientConfig(
|
client_config = AsyncClientConfig(
|
||||||
store_sync_tokens=True, encryption_enabled=False)
|
store_sync_tokens=True, encryption_enabled=False)
|
||||||
else:
|
else:
|
||||||
matrix_store = DuckDBStore
|
matrix_store = SqliteStore
|
||||||
client_config = AsyncClientConfig(
|
client_config = AsyncClientConfig(
|
||||||
store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
|
store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
|
||||||
self.matrix_client.config = client_config
|
self.matrix_client.config = client_config
|
||||||
self.matrix_client.store = matrix_store(
|
self.matrix_client.store = matrix_store(
|
||||||
self.matrix_client.user_id,
|
self.matrix_client.user_id,
|
||||||
self.matrix_client.device_id,
|
self.matrix_client.device_id,
|
||||||
self.database
|
self.crypto_store_path or ""
|
||||||
)
|
)
|
||||||
|
|
||||||
self.matrix_client.olm = Olm(
|
self.matrix_client.olm = Olm(
|
||||||
|
@ -722,7 +734,7 @@ class GPTBot:
|
||||||
if isinstance(room, MatrixRoom):
|
if isinstance(room, MatrixRoom):
|
||||||
room = room.room_id
|
room = room.room_id
|
||||||
|
|
||||||
with self.database.cursor() as cursor:
|
with closing(self.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room, "always_reply"))
|
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room, "always_reply"))
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
|
@ -830,7 +842,7 @@ class GPTBot:
|
||||||
else:
|
else:
|
||||||
room_id = room.room_id
|
room_id = room.room_id
|
||||||
|
|
||||||
with self.database.cursor() as cur:
|
with closing(self.database.cursor()) as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
|
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
|
||||||
(room_id, "system_message")
|
(room_id, "system_message")
|
|
@ -24,7 +24,7 @@ for command in [
|
||||||
"space",
|
"space",
|
||||||
]:
|
]:
|
||||||
function = getattr(import_module(
|
function = getattr(import_module(
|
||||||
"commands." + command), "command_" + command)
|
"." + command, "gptbot.commands"), "command_" + command)
|
||||||
COMMANDS[command] = function
|
COMMANDS[command] = function
|
||||||
|
|
||||||
COMMANDS[None] = command_unknown
|
COMMANDS[None] = command_unknown
|
|
@ -2,6 +2,7 @@ from nio.events.room_events import RoomMessageText
|
||||||
from nio import RoomCreateError, RoomInviteError
|
from nio import RoomCreateError, RoomInviteError
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
async def command_newroom(room: MatrixRoom, event: RoomMessageText, bot):
|
async def command_newroom(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
room_name = " ".join(event.body.split()[
|
room_name = " ".join(event.body.split()[
|
||||||
|
@ -23,7 +24,7 @@ async def command_newroom(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
await bot.send_message(room, f"Sorry, I was unable to invite you to the new room. Please try again later, or create a room manually.", True)
|
await bot.send_message(room, f"Sorry, I was unable to invite you to the new room. Please try again later, or create a room manually.", True)
|
||||||
return
|
return
|
||||||
|
|
||||||
with bot.database.cursor() as cursor:
|
with closing(bot.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
|
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
|
||||||
space = cursor.fetchone()
|
space = cursor.fetchone()
|
|
@ -1,6 +1,8 @@
|
||||||
from nio.events.room_events import RoomMessageText
|
from nio.events.room_events import RoomMessageText
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
|
|
||||||
async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
|
async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
setting = event.body.split()[2] if len(event.body.split()) > 2 else None
|
setting = event.body.split()[2] if len(event.body.split()) > 2 else None
|
||||||
|
@ -16,7 +18,7 @@ async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
if value:
|
if value:
|
||||||
bot.logger.log("Adding system message...")
|
bot.logger.log("Adding system message...")
|
||||||
|
|
||||||
with bot.database.cursor() as cur:
|
with closing(bot.database.cursor()) as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)
|
"""INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)
|
||||||
ON CONFLICT (room_id, setting) DO UPDATE SET value = ?;""",
|
ON CONFLICT (room_id, setting) DO UPDATE SET value = ?;""",
|
||||||
|
@ -40,7 +42,7 @@ async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
|
|
||||||
bot.logger.log(f"Setting {setting} status for {room.room_id} to {value}...")
|
bot.logger.log(f"Setting {setting} status for {room.room_id} to {value}...")
|
||||||
|
|
||||||
with bot.database.cursor() as cur:
|
with closing(bot.database.cursor()) as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)
|
"""INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)
|
||||||
ON CONFLICT (room_id, setting) DO UPDATE SET value = ?;""",
|
ON CONFLICT (room_id, setting) DO UPDATE SET value = ?;""",
|
||||||
|
@ -55,7 +57,7 @@ async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
|
|
||||||
bot.logger.log(f"Retrieving {setting} status for {room.room_id}...")
|
bot.logger.log(f"Retrieving {setting} status for {room.room_id}...")
|
||||||
|
|
||||||
with bot.database.cursor() as cur:
|
with closing(bot.database.cursor()) as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""SELECT value FROM room_settings WHERE room_id = ? AND setting = ?;""",
|
"""SELECT value FROM room_settings WHERE room_id = ? AND setting = ?;""",
|
||||||
(room.room_id, setting)
|
(room.room_id, setting)
|
|
@ -2,6 +2,8 @@ from nio.events.room_events import RoomMessageText
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
from nio.responses import RoomInviteError
|
from nio.responses import RoomInviteError
|
||||||
|
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
|
|
||||||
async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
|
async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
if len(event.body.split()) == 3:
|
if len(event.body.split()) == 3:
|
||||||
|
@ -10,7 +12,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
if request.lower() == "enable":
|
if request.lower() == "enable":
|
||||||
bot.logger.log("Enabling space...")
|
bot.logger.log("Enabling space...")
|
||||||
|
|
||||||
with bot.database.cursor() as cursor:
|
with closing(bot.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
|
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
|
||||||
space = cursor.fetchone()
|
space = cursor.fetchone()
|
||||||
|
@ -25,7 +27,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
"url": bot.logo_uri
|
"url": bot.logo_uri
|
||||||
}, "")
|
}, "")
|
||||||
|
|
||||||
with bot.database.cursor() as cursor:
|
with closing(bot.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"INSERT INTO user_spaces (space_id, user_id) VALUES (?, ?)", (space, event.sender))
|
"INSERT INTO user_spaces (space_id, user_id) VALUES (?, ?)", (space, event.sender))
|
||||||
|
|
||||||
|
@ -48,7 +50,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
elif request.lower() == "disable":
|
elif request.lower() == "disable":
|
||||||
bot.logger.log("Disabling space...")
|
bot.logger.log("Disabling space...")
|
||||||
|
|
||||||
with bot.database.cursor() as cursor:
|
with closing(bot.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
|
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
|
||||||
space = cursor.fetchone()[0]
|
space = cursor.fetchone()[0]
|
||||||
|
@ -58,7 +60,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
await bot.send_message(room, "You don't have a space enabled.", True)
|
await bot.send_message(room, "You don't have a space enabled.", True)
|
||||||
return
|
return
|
||||||
|
|
||||||
with bot.database.cursor() as cursor:
|
with closing(bot.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"UPDATE user_spaces SET active = FALSE WHERE user_id = ?", (event.sender,))
|
"UPDATE user_spaces SET active = FALSE WHERE user_id = ?", (event.sender,))
|
||||||
|
|
||||||
|
@ -69,7 +71,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
if request.lower() == "update":
|
if request.lower() == "update":
|
||||||
bot.logger.log("Updating space...")
|
bot.logger.log("Updating space...")
|
||||||
|
|
||||||
with bot.database.cursor() as cursor:
|
with closing(bot.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
|
"SELECT space_id FROM user_spaces WHERE user_id = ? AND active = TRUE", (event.sender,))
|
||||||
space = cursor.fetchone()[0]
|
space = cursor.fetchone()[0]
|
||||||
|
@ -103,7 +105,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
if request.lower() == "invite":
|
if request.lower() == "invite":
|
||||||
bot.logger.log("Inviting user to space...")
|
bot.logger.log("Inviting user to space...")
|
||||||
|
|
||||||
with bot.database.cursor() as cursor:
|
with closing(bot.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT space_id FROM user_spaces WHERE user_id = ?", (event.sender,))
|
"SELECT space_id FROM user_spaces WHERE user_id = ?", (event.sender,))
|
||||||
space = cursor.fetchone()[0]
|
space = cursor.fetchone()[0]
|
||||||
|
@ -126,7 +128,7 @@ async def command_space(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
await bot.send_message(room, "Invited you to the space.", True)
|
await bot.send_message(room, "Invited you to the space.", True)
|
||||||
return
|
return
|
||||||
|
|
||||||
with bot.database.cursor() as cursor:
|
with closing(bot.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT active FROM user_spaces WHERE user_id = ?", (event.sender,))
|
"SELECT active FROM user_spaces WHERE user_id = ?", (event.sender,))
|
||||||
status = cursor.fetchone()
|
status = cursor.fetchone()
|
|
@ -1,6 +1,8 @@
|
||||||
from nio.events.room_events import RoomMessageText
|
from nio.events.room_events import RoomMessageText
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
|
|
||||||
async def command_stats(room: MatrixRoom, event: RoomMessageText, bot):
|
async def command_stats(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
bot.logger.log("Showing stats...")
|
bot.logger.log("Showing stats...")
|
||||||
|
@ -10,7 +12,7 @@ async def command_stats(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
bot.send_message(room, "Sorry, I'm not connected to a database, so I don't have any statistics on your usage.", True)
|
bot.send_message(room, "Sorry, I'm not connected to a database, so I don't have any statistics on your usage.", True)
|
||||||
return
|
return
|
||||||
|
|
||||||
with bot.database.cursor() as cursor:
|
with closing(bot.database.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,))
|
"SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,))
|
||||||
total_tokens = cursor.fetchone()[0] or 0
|
total_tokens = cursor.fetchone()[0] or 0
|
|
@ -1,6 +1,8 @@
|
||||||
from nio.events.room_events import RoomMessageText
|
from nio.events.room_events import RoomMessageText
|
||||||
from nio.rooms import MatrixRoom
|
from nio.rooms import MatrixRoom
|
||||||
|
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
|
|
||||||
async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, bot):
|
async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
system_message = " ".join(event.body.split()[2:])
|
system_message = " ".join(event.body.split()[2:])
|
||||||
|
@ -8,7 +10,7 @@ async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
if system_message:
|
if system_message:
|
||||||
bot.logger.log("Adding system message...")
|
bot.logger.log("Adding system message...")
|
||||||
|
|
||||||
with bot.database.cursor() as cur:
|
with closing(bot.database.cursor()) as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)
|
INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)
|
|
@ -1,9 +1,10 @@
|
||||||
# Initial migration, token usage logging
|
# Initial migration, token usage logging
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
def migration(conn):
|
def migration(conn):
|
||||||
with conn.cursor() as cursor:
|
with closing(conn.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS token_usage (
|
CREATE TABLE IF NOT EXISTS token_usage (
|
7
src/gptbot/migrations/migration_2.py
Normal file
7
src/gptbot/migrations/migration_2.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Migration for Matrix Store - No longer used
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
|
def migration(conn):
|
||||||
|
pass
|
|
@ -1,9 +1,10 @@
|
||||||
# Migration for custom system messages
|
# Migration for custom system messages
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
def migration(conn):
|
def migration(conn):
|
||||||
with conn.cursor() as cursor:
|
with closing(conn.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS system_messages (
|
CREATE TABLE IF NOT EXISTS system_messages (
|
||||||
|
@ -11,7 +12,7 @@ def migration(conn):
|
||||||
message_id TEXT NOT NULL,
|
message_id TEXT NOT NULL,
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
body TEXT NOT NULL,
|
body TEXT NOT NULL,
|
||||||
timestamp BIGINT NOT NULL,
|
timestamp BIGINT NOT NULL
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
|
@ -1,9 +1,10 @@
|
||||||
# Migration to add API column to token usage table
|
# Migration to add API column to token usage table
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
def migration(conn):
|
def migration(conn):
|
||||||
with conn.cursor() as cursor:
|
with closing(conn.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
ALTER TABLE token_usage ADD COLUMN api TEXT DEFAULT 'openai'
|
ALTER TABLE token_usage ADD COLUMN api TEXT DEFAULT 'openai'
|
|
@ -1,9 +1,10 @@
|
||||||
# Migration to add room settings table
|
# Migration to add room settings table
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
def migration(conn):
|
def migration(conn):
|
||||||
with conn.cursor() as cursor:
|
with closing(conn.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS room_settings (
|
CREATE TABLE IF NOT EXISTS room_settings (
|
|
@ -1,9 +1,10 @@
|
||||||
# Migration to drop primary key constraint from token_usage table
|
# Migration to drop primary key constraint from token_usage table
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
def migration(conn):
|
def migration(conn):
|
||||||
with conn.cursor() as cursor:
|
with closing(conn.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE token_usage_temp (
|
CREATE TABLE token_usage_temp (
|
|
@ -1,9 +1,10 @@
|
||||||
# Migration to add user_spaces table
|
# Migration to add user_spaces table
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
def migration(conn):
|
def migration(conn):
|
||||||
with conn.cursor() as cursor:
|
with closing(conn.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE user_spaces (
|
CREATE TABLE user_spaces (
|
|
@ -1,9 +1,10 @@
|
||||||
# Migration to add settings table
|
# Migration to add settings table
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from contextlib import closing
|
||||||
|
|
||||||
def migration(conn):
|
def migration(conn):
|
||||||
with conn.cursor() as cursor:
|
with closing(conn.cursor()) as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS settings (
|
CREATE TABLE IF NOT EXISTS settings (
|
Loading…
Reference in a new issue