Create a bot class
Make everything compatible with that
This commit is contained in:
parent
2bbc6a33ca
commit
1dc0378853
19 changed files with 706 additions and 610 deletions
0
__init__.py
Normal file
0
__init__.py
Normal file
29
callbacks/__init__.py
Normal file
29
callbacks/__init__.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from nio import (
|
||||
RoomMessageText,
|
||||
MegolmEvent,
|
||||
InviteEvent,
|
||||
Event,
|
||||
SyncResponse,
|
||||
JoinResponse,
|
||||
InviteEvent,
|
||||
OlmEvent,
|
||||
MegolmEvent
|
||||
)
|
||||
|
||||
from .test import test_callback
|
||||
from .sync import sync_callback
|
||||
from .invite import room_invite_callback
|
||||
from .join import join_callback
|
||||
from .message import message_callback
|
||||
|
||||
RESPONSE_CALLBACKS = {
|
||||
SyncResponse: sync_callback,
|
||||
JoinResponse: join_callback,
|
||||
}
|
||||
|
||||
EVENT_CALLBACKS = {
|
||||
Event: test_callback,
|
||||
InviteEvent: room_invite_callback,
|
||||
RoomMessageText: message_callback,
|
||||
MegolmEvent: message_callback,
|
||||
}
|
10
callbacks/invite.py
Normal file
10
callbacks/invite.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from nio import InviteEvent, MatrixRoom
|
||||
|
||||
async def room_invite_callback(room: MatrixRoom, event: InviteEvent, bot):
|
||||
if room.room_id in bot.matrix_client.rooms:
|
||||
logging(f"Already in room {room.room_id} - ignoring invite")
|
||||
return
|
||||
|
||||
bot.logger.log(f"Received invite to room {room.room_id} - joining...")
|
||||
|
||||
response = await bot.matrix_client.join(room.room_id)
|
7
callbacks/join.py
Normal file
7
callbacks/join.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
async def join_callback(response, bot):
|
||||
bot.logger.log(
|
||||
f"Join response received for room {response.room_id}", "debug")
|
||||
|
||||
bot.matrix_client.joined_rooms()
|
||||
|
||||
await bot.send_message(bot.matrix_client.rooms[response.room_id], "Hello! Thanks for inviting me! How can I help you today?")
|
30
callbacks/message.py
Normal file
30
callbacks/message.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from nio import MatrixRoom, RoomMessageText, MegolmEvent
|
||||
|
||||
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, bot):
|
||||
bot.logger.log(f"Received message from {event.sender} in room {room.room_id}")
|
||||
|
||||
if isinstance(event, MegolmEvent):
|
||||
try:
|
||||
event = await bot.matrix_client.decrypt_event(event)
|
||||
except Exception as e:
|
||||
try:
|
||||
bot.logger.log("Requesting new encryption keys...")
|
||||
await bot.matrix_client.request_room_key(event)
|
||||
except:
|
||||
pass
|
||||
|
||||
bot.logger.log(f"Error decrypting message: {e}", "error")
|
||||
await bot.send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True)
|
||||
return
|
||||
|
||||
if event.sender == bot.matrix_client.user_id:
|
||||
bot.logger.log("Message is from bot itself - ignoring")
|
||||
|
||||
elif event.body.startswith("!gptbot"):
|
||||
await bot.process_command(room, event)
|
||||
|
||||
elif event.body.startswith("!"):
|
||||
bot.logger.log(f"Received {event.body} - might be a command, but not for this bot - ignoring")
|
||||
|
||||
else:
|
||||
await bot.process_query(room, event)
|
6
callbacks/sync.py
Normal file
6
callbacks/sync.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
async def sync_callback(response, bot):
|
||||
bot.logger.log(
|
||||
f"Sync response received (next batch: {response.next_batch})", "debug")
|
||||
SYNC_TOKEN = response.next_batch
|
||||
|
||||
bot.sync_token = SYNC_TOKEN
|
11
callbacks/test.py
Normal file
11
callbacks/test.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from nio import MatrixRoom, Event
|
||||
|
||||
async def test_callback(room: MatrixRoom, event: Event, bot):
|
||||
"""Test callback for debugging purposes.
|
||||
|
||||
Args:
|
||||
room (MatrixRoom): The room the event was sent in.
|
||||
event (Event): The event that was sent.
|
||||
"""
|
||||
|
||||
bot.logger.log(f"Test callback called: {room.room_id} {event.event_id} {event.sender} {event.__class__}")
|
481
classes/bot.py
Normal file
481
classes/bot.py
Normal file
|
@ -0,0 +1,481 @@
|
|||
import openai
|
||||
import markdown2
|
||||
import duckdb
|
||||
import tiktoken
|
||||
|
||||
import asyncio
|
||||
|
||||
from nio import (
|
||||
AsyncClient,
|
||||
AsyncClientConfig,
|
||||
WhoamiResponse,
|
||||
DevicesResponse,
|
||||
Event,
|
||||
Response,
|
||||
MatrixRoom,
|
||||
Api,
|
||||
RoomMessagesError,
|
||||
MegolmEvent,
|
||||
GroupEncryptionError,
|
||||
EncryptionError,
|
||||
RoomMessageText,
|
||||
RoomSendResponse,
|
||||
SyncResponse
|
||||
)
|
||||
from nio.crypto import Olm
|
||||
|
||||
from typing import Optional, List, Dict
|
||||
from configparser import ConfigParser
|
||||
from datetime import datetime
|
||||
|
||||
import uuid
|
||||
|
||||
from .logging import Logger
|
||||
from migrations import migrate
|
||||
from callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
|
||||
from commands import COMMANDS
|
||||
from .store import DuckDBStore
|
||||
|
||||
|
||||
class GPTBot:
|
||||
# Default values
|
||||
database: Optional[duckdb.DuckDBPyConnection] = None
|
||||
default_room_name: str = "GPTBot" # Default name of rooms created by the bot
|
||||
default_system_message: str = "You are a helpful assistant."
|
||||
# Force default system message to be included even if a custom room message is set
|
||||
force_system_message: bool = False
|
||||
max_tokens: int = 3000 # Maximum number of input tokens
|
||||
max_messages: int = 30 # Maximum number of messages to consider as input
|
||||
model: str = "gpt-3.5-turbo" # OpenAI chat model to use
|
||||
matrix_client: Optional[AsyncClient] = None
|
||||
sync_token: Optional[str] = None
|
||||
logger: Optional[Logger] = Logger()
|
||||
openai_api_key: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigParser):
|
||||
"""Create a new GPTBot instance from a config file.
|
||||
|
||||
Args:
|
||||
config (ConfigParser): ConfigParser instance with the bot's config.
|
||||
|
||||
Returns:
|
||||
GPTBot: The new GPTBot instance.
|
||||
"""
|
||||
|
||||
# Create a new GPTBot instance
|
||||
bot = cls()
|
||||
|
||||
# Set the database connection
|
||||
bot.database = duckdb.connect(
|
||||
config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
|
||||
|
||||
# Override default values
|
||||
if "GPTBot" in config:
|
||||
bot.default_room_name = config["GPTBot"].get(
|
||||
"DefaultRoomName", bot.default_room_name)
|
||||
bot.default_system_message = config["GPTBot"].get(
|
||||
"SystemMessage", bot.default_system_message)
|
||||
bot.force_system_message = config["GPTBot"].getboolean(
|
||||
"ForceSystemMessage", bot.force_system_message)
|
||||
|
||||
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
||||
bot.max_messages = config["OpenAI"].getint(
|
||||
"MaxMessages", bot.max_messages)
|
||||
bot.model = config["OpenAI"].get("Model", bot.model)
|
||||
|
||||
bot.openai_api_key = config["OpenAI"]["APIKey"]
|
||||
|
||||
# Set up the Matrix client
|
||||
|
||||
assert "Matrix" in config, "Matrix config not found"
|
||||
|
||||
homeserver = config["Matrix"]["Homeserver"]
|
||||
bot.matrix_client = AsyncClient(homeserver)
|
||||
bot.matrix_client.access_token = config["Matrix"]["AccessToken"]
|
||||
bot.matrix_client.user_id = config["Matrix"].get("UserID")
|
||||
bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
|
||||
|
||||
# Return the new GPTBot instance
|
||||
return bot
|
||||
|
||||
async def _get_user_id(self) -> str:
|
||||
"""Get the user ID of the bot from the whoami endpoint.
|
||||
Requires an access token to be set up.
|
||||
|
||||
Returns:
|
||||
str: The user ID of the bot.
|
||||
"""
|
||||
|
||||
assert self.matrix_client, "Matrix client not set up"
|
||||
|
||||
user_id = self.matrix_client.user_id
|
||||
|
||||
if not user_id:
|
||||
assert self.matrix_client.access_token, "Access token not set up"
|
||||
|
||||
response = await self.matrix_client.whoami()
|
||||
|
||||
if isinstance(response, WhoamiResponse):
|
||||
user_id = response.user_id
|
||||
else:
|
||||
raise Exception(f"Could not get user ID: {response}")
|
||||
|
||||
return user_id
|
||||
|
||||
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
|
||||
messages = []
|
||||
n = n or bot.max_messages
|
||||
room_id = room.room_id if isinstance(room, MatrixRoom) else room
|
||||
|
||||
self.logger.log(
|
||||
f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...")
|
||||
|
||||
response = await self.matrix_client.room_messages(
|
||||
room_id=room_id,
|
||||
start=self.sync_token,
|
||||
limit=2*n,
|
||||
)
|
||||
|
||||
if isinstance(response, RoomMessagesError):
|
||||
raise Exception(
|
||||
f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
|
||||
|
||||
for event in response.chunk:
|
||||
if len(messages) >= n:
|
||||
break
|
||||
if isinstance(event, MegolmEvent):
|
||||
try:
|
||||
event = await self.matrix_client.decrypt_event(event)
|
||||
except (GroupEncryptionError, EncryptionError):
|
||||
self.logger.log(
|
||||
f"Could not decrypt message {event.event_id} in room {room_id}", "error")
|
||||
continue
|
||||
if isinstance(event, RoomMessageText):
|
||||
if event.body.startswith("!gptbot ignoreolder"):
|
||||
break
|
||||
if not event.body.startswith("!"):
|
||||
messages.append(event)
|
||||
|
||||
self.logger.log(f"Found {len(messages)} messages (limit: {n})")
|
||||
|
||||
# Reverse the list so that messages are in chronological order
|
||||
return messages[::-1]
|
||||
|
||||
def _truncate(self, messages: list, max_tokens: Optional[int] = None,
|
||||
model: Optional[str] = None, system_message: Optional[str] = None):
|
||||
max_tokens = max_tokens or self.max_tokens
|
||||
model = model or self.model
|
||||
system_message = self.default_system_message if system_message is None else system_message
|
||||
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
total_tokens = 0
|
||||
|
||||
system_message_tokens = 0 if not system_message else (
|
||||
len(encoding.encode(system_message)) + 1)
|
||||
|
||||
if system_message_tokens > max_tokens:
|
||||
self.logger.log(
|
||||
f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
|
||||
return []
|
||||
|
||||
total_tokens += system_message_tokens
|
||||
|
||||
total_tokens = len(system_message) + 1
|
||||
truncated_messages = []
|
||||
|
||||
for message in [messages[0]] + list(reversed(messages[1:])):
|
||||
content = message["content"]
|
||||
tokens = len(encoding.encode(content)) + 1
|
||||
if total_tokens + tokens > max_tokens:
|
||||
break
|
||||
total_tokens += tokens
|
||||
truncated_messages.append(message)
|
||||
|
||||
return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
|
||||
|
||||
async def _get_device_id(self) -> str:
|
||||
"""Guess the device ID of the bot.
|
||||
Requires an access token to be set up.
|
||||
|
||||
Returns:
|
||||
str: The guessed device ID.
|
||||
"""
|
||||
|
||||
assert self.matrix_client, "Matrix client not set up"
|
||||
|
||||
device_id = self.matrix_client.device_id
|
||||
|
||||
if not device_id:
|
||||
assert self.matrix_client.access_token, "Access token not set up"
|
||||
|
||||
devices = await self.matrix_client.devices()
|
||||
|
||||
if isinstance(devices, DevicesResponse):
|
||||
device_id = devices.devices[0].id
|
||||
|
||||
return device_id
|
||||
|
||||
async def process_command(self, room: MatrixRoom, event: RoomMessageText):
|
||||
self.logger.log(
|
||||
f"Received command {event.body} from {event.sender} in room {room.room_id}")
|
||||
command = event.body.split()[1] if event.body.split()[1:] else None
|
||||
|
||||
await COMMANDS.get(command, COMMANDS[None])(room, event, self)
|
||||
|
||||
async def event_callback(self,room: MatrixRoom, event: Event):
|
||||
self.logger.log("Received event: " + str(event), "debug")
|
||||
for eventtype, callback in EVENT_CALLBACKS.items():
|
||||
if isinstance(event, eventtype):
|
||||
await callback(room, event, self)
|
||||
|
||||
async def response_callback(self, response: Response):
|
||||
for response_type, callback in RESPONSE_CALLBACKS.items():
|
||||
if isinstance(response, response_type):
|
||||
await callback(response, self)
|
||||
|
||||
async def accept_pending_invites(self):
|
||||
"""Accept all pending invites."""
|
||||
|
||||
assert self.matrix_client, "Matrix client not set up"
|
||||
|
||||
invites = self.matrix_client.invited_rooms
|
||||
|
||||
for invite in invites.keys():
|
||||
await self.matrix_client.join(invite)
|
||||
|
||||
async def send_message(self, room: MatrixRoom, message: str, notice: bool = False):
|
||||
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
|
||||
formatted_body = markdowner.convert(message)
|
||||
|
||||
msgtype = "m.notice" if notice else "m.text"
|
||||
|
||||
msgcontent = {"msgtype": msgtype, "body": message,
|
||||
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
|
||||
|
||||
content = None
|
||||
|
||||
if self.matrix_client.olm and room.encrypted:
|
||||
try:
|
||||
if not room.members_synced:
|
||||
responses = []
|
||||
responses.append(await self.matrix_client.joined_members(room.room_id))
|
||||
|
||||
if self.matrix_client.olm.should_share_group_session(room.room_id):
|
||||
try:
|
||||
event = self.matrix_client.sharing_session[room.room_id]
|
||||
await event.wait()
|
||||
except KeyError:
|
||||
await self.matrix_client.share_group_session(
|
||||
room.room_id,
|
||||
ignore_unverified_devices=True,
|
||||
)
|
||||
|
||||
if msgtype != "m.reaction":
|
||||
response = self.matrix_client.encrypt(
|
||||
room.room_id, "m.room.message", msgcontent)
|
||||
msgtype, content = response
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(
|
||||
f"Error encrypting message: {e} - sending unencrypted", "error")
|
||||
raise
|
||||
|
||||
if not content:
|
||||
msgtype = "m.room.message"
|
||||
content = msgcontent
|
||||
|
||||
method, path, data = Api.room_send(
|
||||
self.matrix_client.access_token, room.room_id, msgtype, content, uuid.uuid4()
|
||||
)
|
||||
|
||||
return await self.matrix_client._send(RoomSendResponse, method, path, data, (room.room_id,))
|
||||
|
||||
async def run(self):
|
||||
"""Start the bot."""
|
||||
|
||||
# Set up the Matrix client
|
||||
|
||||
assert self.matrix_client, "Matrix client not set up"
|
||||
assert self.matrix_client.access_token, "Access token not set up"
|
||||
|
||||
if not self.matrix_client.user_id:
|
||||
self.matrix_client.user_id = await self._get_user_id()
|
||||
|
||||
if not self.matrix_client.device_id:
|
||||
self.matrix_client.device_id = await self._get_device_id()
|
||||
|
||||
# Set up database
|
||||
|
||||
IN_MEMORY = False
|
||||
if not self.database:
|
||||
self.logger.log(
|
||||
"No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
|
||||
IN_MEMORY = True
|
||||
self.database = DuckDBPyConnection(":memory:")
|
||||
|
||||
self.logger.log("Running migrations...")
|
||||
before, after = migrate(self.database)
|
||||
if before != after:
|
||||
self.logger.log(f"Migrated from version {before} to {after}.")
|
||||
else:
|
||||
self.logger.log(f"Already at latest version {after}.")
|
||||
|
||||
if IN_MEMORY:
|
||||
client_config = AsyncClientConfig(
|
||||
store_sync_tokens=True, encryption_enabled=False)
|
||||
else:
|
||||
matrix_store = DuckDBStore
|
||||
client_config = AsyncClientConfig(
|
||||
store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
|
||||
self.matrix_client.config = client_config
|
||||
self.matrix_client.store = matrix_store(
|
||||
self.matrix_client.user_id,
|
||||
self.matrix_client.device_id,
|
||||
self.database
|
||||
)
|
||||
|
||||
self.matrix_client.olm = Olm(
|
||||
self.matrix_client.user_id,
|
||||
self.matrix_client.device_id,
|
||||
self.matrix_client.store
|
||||
)
|
||||
|
||||
self.matrix_client.encrypted_rooms = self.matrix_client.store.load_encrypted_rooms()
|
||||
|
||||
# Run initial sync
|
||||
sync = await self.matrix_client.sync(timeout=30000)
|
||||
if isinstance(sync, SyncResponse):
|
||||
await self.response_callback(sync)
|
||||
else:
|
||||
self.logger.log(f"Initial sync failed, aborting: {sync}", "error")
|
||||
return
|
||||
|
||||
# Set up callbacks
|
||||
|
||||
self.matrix_client.add_event_callback(self.event_callback, Event)
|
||||
self.matrix_client.add_response_callback(self.response_callback, Response)
|
||||
|
||||
# Accept pending invites
|
||||
|
||||
self.logger.log("Accepting pending invites...")
|
||||
await self.accept_pending_invites()
|
||||
|
||||
# Start syncing events
|
||||
self.logger.log("Starting sync loop...")
|
||||
try:
|
||||
await self.matrix_client.sync_forever(timeout=30000)
|
||||
finally:
|
||||
self.logger.log("Syncing one last time...")
|
||||
await self.matrix_client.sync(timeout=30000)
|
||||
|
||||
async def process_query(self, room: MatrixRoom, event: RoomMessageText):
|
||||
await self.matrix_client.room_typing(room.room_id, True)
|
||||
|
||||
await self.matrix_client.room_read_markers(room.room_id, event.event_id)
|
||||
|
||||
try:
|
||||
last_messages = await self._last_n_messages(room.room_id, 20)
|
||||
except Exception as e:
|
||||
self.logger.log(f"Error getting last messages: {e}", "error")
|
||||
await self.send_message(
|
||||
room, "Something went wrong. Please try again.", True)
|
||||
return
|
||||
|
||||
system_message = self.get_system_message(room)
|
||||
|
||||
chat_messages = [{"role": "system", "content": system_message}]
|
||||
|
||||
for message in last_messages:
|
||||
role = "assistant" if message.sender == self.matrix_client.user_id else "user"
|
||||
if not message.event_id == event.event_id:
|
||||
chat_messages.append({"role": role, "content": message.body})
|
||||
|
||||
chat_messages.append({"role": "user", "content": event.body})
|
||||
|
||||
# Truncate messages to fit within the token limit
|
||||
truncated_messages = self._truncate(
|
||||
chat_messages, self.max_tokens - 1, system_message=system_message)
|
||||
|
||||
try:
|
||||
response, tokens_used = await self.generate_chat_response(truncated_messages)
|
||||
except Exception as e:
|
||||
self.logger.log(f"Error generating response: {e}", "error")
|
||||
await self.send_message(
|
||||
room, "Something went wrong. Please try again.", True)
|
||||
return
|
||||
|
||||
if response:
|
||||
self.logger.log(f"Sending response to room {room.room_id}...")
|
||||
|
||||
# Convert markdown to HTML
|
||||
|
||||
message = await self.send_message(room, response)
|
||||
|
||||
if self.database:
|
||||
self.logger.log("Storing record of tokens used...")
|
||||
|
||||
with self.database.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)",
|
||||
(message.event_id, room.room_id, tokens_used, datetime.now()))
|
||||
self.database.commit()
|
||||
else:
|
||||
# Send a notice to the room if there was an error
|
||||
self.logger.log("Didn't get a response from GPT API", "error")
|
||||
send_message(
|
||||
room, "Something went wrong. Please try again.", True)
|
||||
|
||||
await self.matrix_client.room_typing(room.room_id, False)
|
||||
|
||||
async def generate_chat_response(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Generate a response to a chat message.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, str]]): A list of messages to use as context.
|
||||
|
||||
Returns:
|
||||
str: The response to the chat.
|
||||
"""
|
||||
|
||||
self.logger.log(f"Generating response to {len(messages)} messages...")
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
api_key=self.openai_api_key
|
||||
)
|
||||
|
||||
result_text = response.choices[0].message['content']
|
||||
tokens_used = response.usage["total_tokens"]
|
||||
self.logger.log(f"Generated response with {tokens_used} tokens.")
|
||||
return result_text, tokens_used
|
||||
|
||||
def get_system_message(self, room: MatrixRoom | int) -> str:
|
||||
default = self.default_system_message
|
||||
|
||||
if isinstance(room, int):
|
||||
room_id = room
|
||||
else:
|
||||
room_id = room.room_id
|
||||
|
||||
with self.database.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
|
||||
(room_id,)
|
||||
)
|
||||
system_message = cur.fetchone()
|
||||
|
||||
complete = ((default if ((not system_message) or self.force_system_message) else "") + (
|
||||
"\n\n" + system_message[0] if system_message else "")).strip()
|
||||
|
||||
return complete
|
||||
|
||||
def __del__(self):
|
||||
"""Close the bot."""
|
||||
|
||||
if self.matrix_client:
|
||||
asyncio.run(self.matrix_client.close())
|
||||
|
||||
if self.database:
|
||||
self.database.close()
|
10
classes/logging.py
Normal file
10
classes/logging.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
import inspect
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Logger:
|
||||
def log(self, message: str, log_level: str = "info"):
|
||||
caller = inspect.currentframe().f_back.f_code.co_name
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")
|
||||
print(f"[{timestamp}] - {caller} - [{log_level.upper()}] {message}")
|
|
@ -2,21 +2,22 @@ from nio.events.room_events import RoomMessageText
|
|||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_botinfo(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
async def command_botinfo(room: MatrixRoom, event: RoomMessageText, bot):
|
||||
logging("Showing bot info...")
|
||||
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": f"""GPT Info:
|
||||
body = f"""GPT Info:
|
||||
|
||||
Model: {context["model"]}
|
||||
Maximum context tokens: {context["max_tokens"]}
|
||||
Maximum context messages: {context["max_messages"]}
|
||||
System message: {context["system_message"]}
|
||||
Model: {bot.model}
|
||||
Maximum context tokens: {bot.max_tokens}
|
||||
Maximum context messages: {bot.max_messages}
|
||||
|
||||
Room info:
|
||||
|
||||
Bot user ID: {context["client"].user_id}
|
||||
Bot user ID: {bot.matrix_client.user_id}
|
||||
Current room ID: {room.room_id}
|
||||
System message: {bot.get_system_message(room)}
|
||||
|
||||
For usage statistics, run !gptbot stats
|
||||
"""}
|
||||
"""
|
||||
|
||||
await bot.send_message(room, body, True)
|
||||
|
|
|
@ -4,10 +4,10 @@ from nio.rooms import MatrixRoom
|
|||
from random import SystemRandom
|
||||
|
||||
|
||||
async def command_coin(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
context["logger"]("Flipping a coin...")
|
||||
async def command_coin(room: MatrixRoom, event: RoomMessageText, bot):
|
||||
bot.logger.log("Flipping a coin...")
|
||||
|
||||
heads = SystemRandom().choice([True, False])
|
||||
body = "Flipping a coin... It's " + ("heads!" if heads else "tails!")
|
||||
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": "Heads!" if heads else "Tails!"}
|
||||
await bot.send_message(room, body, True)
|
|
@ -2,9 +2,8 @@ from nio.events.room_events import RoomMessageText
|
|||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
return room.guest_accessroom_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": """Available commands:
|
||||
async def command_help(room: MatrixRoom, event: RoomMessageText, bot):
|
||||
body = """Available commands:
|
||||
|
||||
!gptbot help - Show this message
|
||||
!gptbot newroom <room name> - Create a new room and invite yourself to it
|
||||
|
@ -12,4 +11,7 @@ async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
|
|||
!gptbot botinfo - Show information about the bot
|
||||
!gptbot coin - Flip a coin (heads or tails)
|
||||
!gptbot ignoreolder - Ignore messages before this point as context
|
||||
"""}
|
||||
!gptbot systemmessage - Get or set the system message for this room
|
||||
"""
|
||||
|
||||
await bot.send_message(room, body, True)
|
|
@ -1,8 +1,9 @@
|
|||
from nio.events.room_events import RoomMessageText
|
||||
from nio.rooms import MatrixRoom
|
||||
|
||||
async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": """Alright, messages before this point will not be processed as context anymore.
|
||||
async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, bot):
|
||||
body = """Alright, messages before this point will not be processed as context anymore.
|
||||
|
||||
If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""}
|
||||
If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""
|
||||
|
||||
await bot.send_message(room, body, True)
|
|
@ -1,17 +1,31 @@
|
|||
from nio.events.room_events import RoomMessageText
|
||||
from nio import RoomCreateError, RoomInviteError
|
||||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
async def command_newroom(room: MatrixRoom, event: RoomMessageText, bot):
|
||||
room_name = " ".join(event.body.split()[
|
||||
2:]) or context["default_room_name"]
|
||||
|
||||
context["logger"]("Creating new room...")
|
||||
new_room = await context["client"].room_create(name=room_name)
|
||||
bot.logger.log("Creating new room...")
|
||||
new_room = await bot.matrix_client.room_create(name=room_name)
|
||||
|
||||
if isinstance(new_room, RoomCreateError):
|
||||
bot.logger.log(f"Failed to create room: {new_room.message}")
|
||||
await bot.send_message(room, f"Sorry, I was unable to create a new room. Please try again later, or create a room manually.", True)
|
||||
return
|
||||
|
||||
bot.logger.log(f"Inviting {event.sender} to new room...")
|
||||
invite = await bot.matrix_client.room_invite(new_room.room_id, event.sender)
|
||||
|
||||
if isinstance(invite, RoomInviteError):
|
||||
bot.logger.log(f"Failed to invite user: {invite.message}")
|
||||
await bot.send_message(room, f"Sorry, I was unable to invite you to the new room. Please try again later, or create a room manually.", True)
|
||||
return
|
||||
|
||||
context["logger"](f"Inviting {event.sender} to new room...")
|
||||
await context["client"].room_invite(new_room.room_id, event.sender)
|
||||
await context["client"].room_put_state(
|
||||
new_room.room_id, "m.room.power_levels", {"users": {event.sender: 100}})
|
||||
|
||||
return new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"}
|
||||
await bot.matrix_client.joined_rooms()
|
||||
await bot.send_message(room, f"Alright, I've created a new room called '{room_name}' and invited you to it. You can find it at {new_room.room_id}", True)
|
||||
await bot.send_message(new_room.room_id, f"Welcome to the new room! What can I do for you?")
|
|
@ -2,18 +2,17 @@ from nio.events.room_events import RoomMessageText
|
|||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_stats(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
context["logger"]("Showing stats...")
|
||||
async def command_stats(room: MatrixRoom, event: RoomMessageText, bot):
|
||||
bot.logger.log("Showing stats...")
|
||||
|
||||
if not (database := context.get("database")):
|
||||
context["logger"]("No database connection - cannot show stats")
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."}
|
||||
if not bot.database:
|
||||
bot.logger.log("No database connection - cannot show stats")
|
||||
bot.send_message(room, "Sorry, I'm not connected to a database, so I don't have any statistics on your usage.", True)
|
||||
return
|
||||
|
||||
with database.cursor() as cursor:
|
||||
with bot.database.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,))
|
||||
total_tokens = cursor.fetchone()[0] or 0
|
||||
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": f"Total tokens used: {total_tokens}"}
|
||||
bot.send_message(room, f"Total tokens used: {total_tokens}", True)
|
||||
|
|
|
@ -2,33 +2,24 @@ from nio.events.room_events import RoomMessageText
|
|||
from nio.rooms import MatrixRoom
|
||||
|
||||
|
||||
async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, bot):
|
||||
system_message = " ".join(event.body.split()[2:])
|
||||
|
||||
if system_message:
|
||||
context["logger"]("Adding system message...")
|
||||
bot.logger.log("Adding system message...")
|
||||
|
||||
with context["database"].cursor() as cur:
|
||||
with bot.database.cursor() as cur:
|
||||
cur.execute(
|
||||
"INSERT INTO system_messages (room_id, message_id, user_id, body, timestamp) VALUES (?, ?, ?, ?, ?)",
|
||||
(room.room_id, event.event_id, event.sender,
|
||||
system_message, event.server_timestamp)
|
||||
)
|
||||
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message stored: {system_message}"}
|
||||
bot.send_message(room, f"Alright, I've stored the system message: '{system_message}'.", True)
|
||||
return
|
||||
|
||||
context["logger"]("Retrieving system message...")
|
||||
bot.logger.log("Retrieving system message...")
|
||||
|
||||
with context["database"].cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
|
||||
(room.room_id,)
|
||||
)
|
||||
system_message = cur.fetchone()
|
||||
system_message = bot.get_system_message(room)
|
||||
|
||||
if system_message is None:
|
||||
system_message = context.get("system_message", "No system message set")
|
||||
elif context.get("force_system_message") and context.get("system_message"):
|
||||
system_message = system_message + "\n\n" + context["system_message"]
|
||||
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message: {system_message}"}
|
||||
bot.send_message(room, f"The current system message is: '{system_message}'.", True)
|
||||
|
|
|
@ -3,7 +3,6 @@ from nio.rooms import MatrixRoom
|
|||
|
||||
|
||||
async def command_unknown(room: MatrixRoom, event: RoomMessageText, context: dict):
|
||||
context["logger"]("Unknown command")
|
||||
bot.logger.log("Unknown command")
|
||||
|
||||
return room.room_id, "m.room.message", {"msgtype": "m.notice",
|
||||
"body": "Unknown command - try !gptbot help"}
|
||||
bot.send_message(room, "Unknown command - try !gptbot help", True)
|
569
gptbot.py
569
gptbot.py
|
@ -1,554 +1,14 @@
|
|||
import os
|
||||
import inspect
|
||||
import logging
|
||||
import signal
|
||||
import random
|
||||
import uuid
|
||||
from classes.bot import GPTBot
|
||||
|
||||
import openai
|
||||
import asyncio
|
||||
import markdown2
|
||||
import tiktoken
|
||||
import duckdb
|
||||
|
||||
from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent, AsyncClientConfig, MegolmEvent, GroupEncryptionError, EncryptionError, HttpClient, Api
|
||||
from nio.api import MessageDirection
|
||||
from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError, WhoamiResponse, JoinResponse, RoomSendResponse
|
||||
from nio.crypto import Olm
|
||||
|
||||
from configparser import ConfigParser
|
||||
from datetime import datetime
|
||||
from argparse import ArgumentParser
|
||||
from typing import List, Dict, Union, Optional
|
||||
from configparser import ConfigParser
|
||||
|
||||
from commands import COMMANDS
|
||||
from classes import DuckDBStore
|
||||
from migrations import MIGRATIONS
|
||||
import signal
|
||||
import asyncio
|
||||
|
||||
|
||||
def logging(message: str, log_level: str = "info"):
|
||||
caller = inspect.currentframe().f_back.f_code.co_name
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")
|
||||
print(f"[{timestamp}] - {caller} - [{log_level.upper()}] {message}")
|
||||
|
||||
|
||||
CONTEXT = {
|
||||
"database": False,
|
||||
"default_room_name": "GPTBot",
|
||||
"system_message": "You are a helpful assistant.",
|
||||
"force_system_message": False,
|
||||
"max_tokens": 3000,
|
||||
"max_messages": 20,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"client": None,
|
||||
"sync_token": None,
|
||||
"logger": logging
|
||||
}
|
||||
|
||||
|
||||
async def gpt_query(messages: list, model: Optional[str] = None):
|
||||
model = model or CONTEXT["model"]
|
||||
|
||||
logging(f"Querying GPT with {len(messages)} messages")
|
||||
logging(messages, "debug")
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=messages
|
||||
)
|
||||
result_text = response.choices[0].message['content']
|
||||
tokens_used = response.usage["total_tokens"]
|
||||
logging(f"Used {tokens_used} tokens")
|
||||
return result_text, tokens_used
|
||||
|
||||
except Exception as e:
|
||||
logging(f"Error during GPT API call: {e}", "error")
|
||||
return None, 0
|
||||
|
||||
|
||||
async def fetch_last_n_messages(room_id: str, n: Optional[int] = None,
|
||||
client: Optional[AsyncClient] = None, sync_token: Optional[str] = None):
|
||||
messages = []
|
||||
|
||||
n = n or CONTEXT["max_messages"]
|
||||
client = client or CONTEXT["client"]
|
||||
sync_token = sync_token or CONTEXT["sync_token"]
|
||||
|
||||
logging(
|
||||
f"Fetching last {2*n} messages from room {room_id} (starting at {sync_token})...")
|
||||
|
||||
response = await client.room_messages(
|
||||
room_id=room_id,
|
||||
start=sync_token,
|
||||
limit=2*n,
|
||||
)
|
||||
|
||||
if isinstance(response, RoomMessagesError):
|
||||
logging(
|
||||
f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
|
||||
return []
|
||||
|
||||
for event in response.chunk:
|
||||
if len(messages) >= n:
|
||||
break
|
||||
if isinstance(event, MegolmEvent):
|
||||
try:
|
||||
event = await client.decrypt_event(event)
|
||||
except (GroupEncryptionError, EncryptionError):
|
||||
logging(
|
||||
f"Could not decrypt message {event.event_id} in room {room_id}", "error")
|
||||
continue
|
||||
if isinstance(event, RoomMessageText):
|
||||
if event.body.startswith("!gptbot ignoreolder"):
|
||||
break
|
||||
if not event.body.startswith("!"):
|
||||
messages.append(event)
|
||||
|
||||
logging(f"Found {len(messages)} messages (limit: {n})")
|
||||
|
||||
# Reverse the list so that messages are in chronological order
|
||||
return messages[::-1]
|
||||
|
||||
|
||||
def truncate_messages_to_fit_tokens(messages: list, max_tokens: Optional[int] = None,
|
||||
model: Optional[str] = None, system_message: Optional[str] = None):
|
||||
max_tokens = max_tokens or CONTEXT["max_tokens"]
|
||||
model = model or CONTEXT["model"]
|
||||
system_message = system_message or CONTEXT["system_message"]
|
||||
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
total_tokens = 0
|
||||
|
||||
system_message_tokens = len(encoding.encode(system_message)) + 1
|
||||
|
||||
if system_message_tokens > max_tokens:
|
||||
logging(
|
||||
f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
|
||||
return []
|
||||
|
||||
total_tokens += system_message_tokens
|
||||
|
||||
total_tokens = len(system_message) + 1
|
||||
truncated_messages = []
|
||||
|
||||
for message in [messages[0]] + list(reversed(messages[1:])):
|
||||
content = message["content"]
|
||||
tokens = len(encoding.encode(content)) + 1
|
||||
if total_tokens + tokens > max_tokens:
|
||||
break
|
||||
total_tokens += tokens
|
||||
truncated_messages.append(message)
|
||||
|
||||
return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
|
||||
|
||||
|
||||
async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
||||
|
||||
client = kwargs.get("client") or CONTEXT["client"]
|
||||
database = kwargs.get("database") or CONTEXT["database"]
|
||||
max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
|
||||
system_message = kwargs.get("system_message") or CONTEXT["system_message"]
|
||||
force_system_message = kwargs.get("force_system_message") or CONTEXT["force_system_message"]
|
||||
|
||||
await client.room_typing(room.room_id, True)
|
||||
|
||||
await client.room_read_markers(room.room_id, event.event_id)
|
||||
|
||||
last_messages = await fetch_last_n_messages(room.room_id, 20)
|
||||
|
||||
system_message = get_system_message(room, {
|
||||
"database": database,
|
||||
"system_message": system_message,
|
||||
"force_system_message": force_system_message,
|
||||
})
|
||||
|
||||
chat_messages = [{"role": "system", "content": system_message}]
|
||||
|
||||
for message in last_messages:
|
||||
role = "assistant" if message.sender == client.user_id else "user"
|
||||
if not message.event_id == event.event_id:
|
||||
chat_messages.append({"role": role, "content": message.body})
|
||||
|
||||
chat_messages.append({"role": "user", "content": event.body})
|
||||
|
||||
# Truncate messages to fit within the token limit
|
||||
truncated_messages = truncate_messages_to_fit_tokens(
|
||||
chat_messages, max_tokens - 1, system_message=system_message)
|
||||
response, tokens_used = await gpt_query(truncated_messages)
|
||||
|
||||
if response:
|
||||
logging(f"Sending response to room {room.room_id}...")
|
||||
|
||||
# Convert markdown to HTML
|
||||
|
||||
message = await send_message(room, response)
|
||||
|
||||
if database:
|
||||
logging("Logging tokens used...")
|
||||
|
||||
with database.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)",
|
||||
(message.event_id, room.room_id, tokens_used, datetime.now()))
|
||||
database.commit()
|
||||
else:
|
||||
# Send a notice to the room if there was an error
|
||||
|
||||
logging("Error during GPT API call - sending notice to room")
|
||||
send_message(
|
||||
room, "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later.", True)
|
||||
print("No response from GPT API")
|
||||
|
||||
await client.room_typing(room.room_id, False)
|
||||
|
||||
|
||||
async def process_command(room: MatrixRoom, event: RoomMessageText, context: Optional[dict] = None):
|
||||
context = context or CONTEXT
|
||||
|
||||
logging(
|
||||
f"Received command {event.body} from {event.sender} in room {room.room_id}")
|
||||
command = event.body.split()[1] if event.body.split()[1:] else None
|
||||
|
||||
message = await COMMANDS.get(command, COMMANDS[None])(room, event, context)
|
||||
|
||||
if message:
|
||||
room_id, event, content = message
|
||||
rooms = await context["client"].joined_rooms()
|
||||
await send_message(context["client"].rooms[room_id], content["body"],
|
||||
True if content["msgtype"] == "m.notice" else False, context["client"])
|
||||
|
||||
|
||||
def get_system_message(room: MatrixRoom, context: Optional[dict]) -> str:
|
||||
context = context or CONTEXT
|
||||
|
||||
default = context.get("system_message")
|
||||
|
||||
with context["database"].cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
|
||||
(room.room_id,)
|
||||
)
|
||||
system_message = cur.fetchone()
|
||||
|
||||
complete = ((default if ((not system_message) or context["force_system_message"]) else "") + (
|
||||
"\n\n" + system_message[0] if system_message else "")).strip()
|
||||
|
||||
return complete
|
||||
|
||||
|
||||
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
|
||||
context = kwargs.get("context") or CONTEXT
|
||||
|
||||
logging(f"Received message from {event.sender} in room {room.room_id}")
|
||||
|
||||
if isinstance(event, MegolmEvent):
|
||||
try:
|
||||
event = await context["client"].decrypt_event(event)
|
||||
except Exception as e:
|
||||
try:
|
||||
logging("Requesting new encryption keys...")
|
||||
await context["client"].request_room_key(event)
|
||||
except:
|
||||
pass
|
||||
|
||||
logging(f"Error decrypting message: {e}", "error")
|
||||
await send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True, context["client"])
|
||||
return
|
||||
|
||||
if event.sender == context["client"].user_id:
|
||||
logging("Message is from bot itself - ignoring")
|
||||
|
||||
elif event.body.startswith("!gptbot"):
|
||||
await process_command(room, event)
|
||||
|
||||
elif event.body.startswith("!"):
|
||||
logging("Might be a command, but not for this bot - ignoring")
|
||||
|
||||
else:
|
||||
await process_query(room, event, context=context)
|
||||
|
||||
|
||||
async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs):
|
||||
client: AsyncClient = kwargs.get("client") or CONTEXT["client"]
|
||||
|
||||
if room.room_id in client.rooms:
|
||||
logging(f"Already in room {room.room_id} - ignoring invite")
|
||||
return
|
||||
|
||||
logging(f"Received invite to room {room.room_id} - joining...")
|
||||
|
||||
response = await client.join(room.room_id)
|
||||
if isinstance(response, JoinResponse):
|
||||
await send_message(room, "Hello! I'm a helpful assistant. How can I help you today?", client)
|
||||
else:
|
||||
logging(f"Error joining room {room.room_id}: {response}", "error")
|
||||
|
||||
|
||||
async def send_message(room: MatrixRoom, message: str, notice: bool = False, client: Optional[AsyncClient] = None):
|
||||
client = client or CONTEXT["client"]
|
||||
|
||||
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
|
||||
formatted_body = markdowner.convert(message)
|
||||
|
||||
msgtype = "m.notice" if notice else "m.text"
|
||||
|
||||
msgcontent = {"msgtype": msgtype, "body": message,
|
||||
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
|
||||
|
||||
content = None
|
||||
|
||||
if client.olm and room.encrypted:
|
||||
try:
|
||||
if not room.members_synced:
|
||||
responses = []
|
||||
responses.append(await client.joined_members(room.room_id))
|
||||
|
||||
if client.olm.should_share_group_session(room.room_id):
|
||||
try:
|
||||
event = client.sharing_session[room.room_id]
|
||||
await event.wait()
|
||||
except KeyError:
|
||||
await client.share_group_session(
|
||||
room.room_id,
|
||||
ignore_unverified_devices=True,
|
||||
)
|
||||
|
||||
if msgtype != "m.reaction":
|
||||
response = client.encrypt(
|
||||
room.room_id, "m.room.message", msgcontent)
|
||||
msgtype, content = response
|
||||
|
||||
except Exception as e:
|
||||
logging(
|
||||
f"Error encrypting message: {e} - sending unencrypted", "error")
|
||||
raise
|
||||
|
||||
if not content:
|
||||
msgtype = "m.room.message"
|
||||
content = msgcontent
|
||||
|
||||
method, path, data = Api.room_send(
|
||||
client.access_token, room.room_id, msgtype, content, uuid.uuid4()
|
||||
)
|
||||
|
||||
return await client._send(RoomSendResponse, method, path, data, (room.room_id,))
|
||||
|
||||
|
||||
async def accept_pending_invites(client: Optional[AsyncClient] = None):
|
||||
client = client or CONTEXT["client"]
|
||||
|
||||
logging("Accepting pending invites...")
|
||||
|
||||
for room_id in list(client.invited_rooms.keys()):
|
||||
logging(f"Joining room {room_id}...")
|
||||
|
||||
response = await client.join(room_id)
|
||||
|
||||
if isinstance(response, JoinResponse):
|
||||
logging(response, "debug")
|
||||
rooms = await client.joined_rooms()
|
||||
await send_message(client.rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
|
||||
else:
|
||||
logging(f"Error joining room {room_id}: {response}", "error")
|
||||
|
||||
|
||||
async def sync_cb(response, write_global: bool = True):
|
||||
logging(
|
||||
f"Sync response received (next batch: {response.next_batch})", "debug")
|
||||
SYNC_TOKEN = response.next_batch
|
||||
|
||||
if write_global:
|
||||
global CONTEXT
|
||||
CONTEXT["sync_token"] = SYNC_TOKEN
|
||||
|
||||
|
||||
async def test_callback(room: MatrixRoom, event: Event, **kwargs):
|
||||
logging(
|
||||
f"Received event {event.__class__.__name__} in room {room.room_id}", "debug")
|
||||
|
||||
|
||||
async def init(config: ConfigParser):
|
||||
# Set up Matrix client
|
||||
try:
|
||||
assert "Matrix" in config
|
||||
assert "Homeserver" in config["Matrix"]
|
||||
assert "AccessToken" in config["Matrix"]
|
||||
except:
|
||||
logging("Matrix config not found or incomplete", "critical")
|
||||
exit(1)
|
||||
|
||||
homeserver = config["Matrix"]["Homeserver"]
|
||||
access_token = config["Matrix"]["AccessToken"]
|
||||
|
||||
device_id, user_id = await get_device_id(access_token, homeserver)
|
||||
|
||||
device_id = config["Matrix"].get("DeviceID", device_id)
|
||||
user_id = config["Matrix"].get("UserID", user_id)
|
||||
|
||||
# Set up database
|
||||
if "Database" in config and config["Database"].get("Path"):
|
||||
database = CONTEXT["database"] = initialize_database(
|
||||
config["Database"]["Path"])
|
||||
matrix_store = DuckDBStore
|
||||
|
||||
client_config = AsyncClientConfig(
|
||||
store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
|
||||
|
||||
else:
|
||||
client_config = AsyncClientConfig(
|
||||
store_sync_tokens=True, encryption_enabled=False)
|
||||
|
||||
client = AsyncClient(
|
||||
config["Matrix"]["Homeserver"], config=client_config)
|
||||
|
||||
if client.config.encryption_enabled:
|
||||
client.store = client.config.store(
|
||||
user_id,
|
||||
device_id,
|
||||
database
|
||||
)
|
||||
assert client.store
|
||||
|
||||
client.olm = Olm(client.user_id, client.device_id, client.store)
|
||||
client.encrypted_rooms = client.store.load_encrypted_rooms()
|
||||
|
||||
CONTEXT["client"] = client
|
||||
|
||||
CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
|
||||
CONTEXT["client"].user_id = user_id
|
||||
CONTEXT["client"].device_id = device_id
|
||||
|
||||
# Set up GPT API
|
||||
try:
|
||||
assert "OpenAI" in config
|
||||
assert "APIKey" in config["OpenAI"]
|
||||
except:
|
||||
logging("OpenAI config not found or incomplete", "critical")
|
||||
exit(1)
|
||||
|
||||
openai.api_key = config["OpenAI"]["APIKey"]
|
||||
|
||||
if "Model" in config["OpenAI"]:
|
||||
CONTEXT["model"] = config["OpenAI"]["Model"]
|
||||
|
||||
if "MaxTokens" in config["OpenAI"]:
|
||||
CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
|
||||
|
||||
if "MaxMessages" in config["OpenAI"]:
|
||||
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
|
||||
|
||||
# Override defaults with config
|
||||
|
||||
if "GPTBot" in config:
|
||||
if "SystemMessage" in config["GPTBot"]:
|
||||
CONTEXT["system_message"] = config["GPTBot"]["SystemMessage"]
|
||||
if "DefaultRoomName" in config["GPTBot"]:
|
||||
CONTEXT["default_room_name"] = config["GPTBot"]["DefaultRoomName"]
|
||||
if "ForceSystemMessage" in config["GPTBot"]:
|
||||
CONTEXT["force_system_message"] = config["GPTBot"].getboolean(
|
||||
"ForceSystemMessage")
|
||||
|
||||
|
||||
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
|
||||
if not client and not CONTEXT.get("client"):
|
||||
await init(config)
|
||||
|
||||
client = client or CONTEXT["client"]
|
||||
|
||||
try:
|
||||
assert client.user_id
|
||||
except AssertionError:
|
||||
logging(
|
||||
"Failed to get user ID - check your access token or try setting it manually", "critical")
|
||||
await client.close()
|
||||
return
|
||||
|
||||
# Listen for SIGTERM
|
||||
|
||||
def sigterm_handler(_signo, _stack_frame):
|
||||
logging("Received SIGTERM - exiting...")
|
||||
exit()
|
||||
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
|
||||
logging("Starting bot...")
|
||||
|
||||
client.add_response_callback(sync_cb, SyncResponse)
|
||||
|
||||
logging("Syncing...")
|
||||
|
||||
await client.sync(timeout=30000)
|
||||
|
||||
client.add_event_callback(message_callback, RoomMessageText)
|
||||
client.add_event_callback(message_callback, MegolmEvent)
|
||||
client.add_event_callback(room_invite_callback, InviteEvent)
|
||||
client.add_event_callback(test_callback, Event)
|
||||
|
||||
await accept_pending_invites() # Accept pending invites
|
||||
|
||||
logging("Bot started")
|
||||
|
||||
try:
|
||||
# Continue syncing events
|
||||
await client.sync_forever(timeout=30000)
|
||||
finally:
|
||||
logging("Syncing one last time...")
|
||||
await client.sync(timeout=30000)
|
||||
await client.close() # Properly close the aiohttp client session
|
||||
logging("Bot stopped")
|
||||
|
||||
|
||||
def initialize_database(path: os.PathLike):
|
||||
logging("Initializing database...")
|
||||
conn = duckdb.connect(path)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Get the latest migration ID if the migrations table exists
|
||||
try:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT MAX(id) FROM migrations
|
||||
"""
|
||||
)
|
||||
|
||||
latest_migration = int(cursor.fetchone()[0])
|
||||
|
||||
except:
|
||||
latest_migration = 0
|
||||
|
||||
for migration, function in MIGRATIONS.items():
|
||||
if latest_migration < migration:
|
||||
logging(f"Running migration {migration}...")
|
||||
function(conn)
|
||||
latest_migration = migration
|
||||
|
||||
return conn
|
||||
|
||||
|
||||
async def get_device_id(access_token, homeserver):
|
||||
client = AsyncClient(homeserver)
|
||||
client.access_token = access_token
|
||||
|
||||
logging(f"Obtaining device ID for access token {access_token}...", "debug")
|
||||
response = await client.whoami()
|
||||
if isinstance(response, WhoamiResponse):
|
||||
logging(
|
||||
f"Authenticated as {response.user_id}.")
|
||||
user_id = response.user_id
|
||||
devices = await client.devices()
|
||||
device_id = devices.devices[0].id
|
||||
|
||||
await client.close()
|
||||
|
||||
return device_id, user_id
|
||||
|
||||
else:
|
||||
logging(f"Failed to obtain device ID: {response}", "error")
|
||||
|
||||
await client.close()
|
||||
|
||||
return None, None
|
||||
def sigterm_handler(_signo, _stack_frame):
|
||||
exit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -562,13 +22,16 @@ if __name__ == "__main__":
|
|||
config = ConfigParser()
|
||||
config.read(args.config)
|
||||
|
||||
# Start bot loop
|
||||
# Create bot
|
||||
bot = GPTBot.from_config(config)
|
||||
|
||||
# Listen for SIGTERM
|
||||
signal.signal(signal.SIGTERM, sigterm_handler)
|
||||
|
||||
# Start bot
|
||||
try:
|
||||
asyncio.run(main(config))
|
||||
asyncio.run(bot.run())
|
||||
except KeyboardInterrupt:
|
||||
logging("Received KeyboardInterrupt - exiting...")
|
||||
print("Received KeyboardInterrupt - exiting...")
|
||||
except SystemExit:
|
||||
logging("Received SIGTERM - exiting...")
|
||||
finally:
|
||||
if CONTEXT["database"]:
|
||||
CONTEXT["database"].close()
|
||||
print("Received SIGTERM - exiting...")
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
from duckdb import DuckDBPyConnection
|
||||
|
||||
from .migration_1 import migration as migration_1
|
||||
from .migration_2 import migration as migration_2
|
||||
|
@ -9,3 +12,42 @@ MIGRATIONS = OrderedDict()
|
|||
MIGRATIONS[1] = migration_1
|
||||
MIGRATIONS[2] = migration_2
|
||||
MIGRATIONS[3] = migration_3
|
||||
|
||||
def get_version(db: DuckDBPyConnection) -> int:
|
||||
"""Get the current database version.
|
||||
|
||||
Args:
|
||||
db (DuckDBPyConnection): Database connection.
|
||||
|
||||
Returns:
|
||||
int: Current database version.
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(db.execute("SELECT MAX(id) FROM migrations").fetchone()[0])
|
||||
except:
|
||||
return 0
|
||||
|
||||
def migrate(db: DuckDBPyConnection, from_version: Optional[int] = None, to_version: Optional[int] = None) -> None:
|
||||
"""Migrate the database to a specific version.
|
||||
|
||||
Args:
|
||||
db (DuckDBPyConnection): Database connection.
|
||||
from_version (Optional[int]): Version to migrate from. If None, the current version is used.
|
||||
to_version (Optional[int]): Version to migrate to. If None, the latest version is used.
|
||||
"""
|
||||
|
||||
if from_version is None:
|
||||
from_version = get_version(db)
|
||||
|
||||
if to_version is None:
|
||||
to_version = max(MIGRATIONS.keys())
|
||||
|
||||
if from_version > to_version:
|
||||
raise ValueError("Cannot migrate from a higher version to a lower version.")
|
||||
|
||||
for version in range(from_version, to_version):
|
||||
if version in MIGRATIONS:
|
||||
MIGRATIONS[version + 1](db)
|
||||
|
||||
return from_version, to_version
|
Loading…
Reference in a new issue