Kumi
bf23771989
New "imagine" and "calculate" commands Implemented image sending Moved OpenAI specific code to OpenAI class Abstracted away OpenAI API in bot class Minor fixes
526 lines
18 KiB
Python
526 lines
18 KiB
Python
import markdown2
|
|
import duckdb
|
|
import tiktoken
|
|
import magic
|
|
import asyncio
|
|
|
|
from PIL import Image
|
|
|
|
from nio import (
|
|
AsyncClient,
|
|
AsyncClientConfig,
|
|
WhoamiResponse,
|
|
DevicesResponse,
|
|
Event,
|
|
Response,
|
|
MatrixRoom,
|
|
Api,
|
|
RoomMessagesError,
|
|
MegolmEvent,
|
|
GroupEncryptionError,
|
|
EncryptionError,
|
|
RoomMessageText,
|
|
RoomSendResponse,
|
|
SyncResponse
|
|
)
|
|
from nio.crypto import Olm
|
|
|
|
from typing import Optional, List, Dict, Tuple
|
|
from configparser import ConfigParser
|
|
from datetime import datetime
|
|
from io import BytesIO
|
|
|
|
import uuid
|
|
|
|
from .logging import Logger
|
|
from migrations import migrate
|
|
from callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
|
|
from commands import COMMANDS
|
|
from .store import DuckDBStore
|
|
from .openai import OpenAI
|
|
from .wolframalpha import WolframAlpha
|
|
|
|
|
|
class GPTBot:
|
|
# Default values
|
|
database: Optional[duckdb.DuckDBPyConnection] = None
|
|
default_room_name: str = "GPTBot" # Default name of rooms created by the bot
|
|
default_system_message: str = "You are a helpful assistant."
|
|
# Force default system message to be included even if a custom room message is set
|
|
force_system_message: bool = False
|
|
max_tokens: int = 3000 # Maximum number of input tokens
|
|
max_messages: int = 30 # Maximum number of messages to consider as input
|
|
matrix_client: Optional[AsyncClient] = None
|
|
sync_token: Optional[str] = None
|
|
logger: Optional[Logger] = Logger()
|
|
chat_api: Optional[OpenAI] = None
|
|
image_api: Optional[OpenAI] = None
|
|
|
|
@classmethod
|
|
def from_config(cls, config: ConfigParser):
|
|
"""Create a new GPTBot instance from a config file.
|
|
|
|
Args:
|
|
config (ConfigParser): ConfigParser instance with the bot's config.
|
|
|
|
Returns:
|
|
GPTBot: The new GPTBot instance.
|
|
"""
|
|
|
|
# Create a new GPTBot instance
|
|
bot = cls()
|
|
|
|
# Set the database connection
|
|
bot.database = duckdb.connect(
|
|
config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
|
|
|
|
# Override default values
|
|
if "GPTBot" in config:
|
|
bot.default_room_name = config["GPTBot"].get(
|
|
"DefaultRoomName", bot.default_room_name)
|
|
bot.default_system_message = config["GPTBot"].get(
|
|
"SystemMessage", bot.default_system_message)
|
|
bot.force_system_message = config["GPTBot"].getboolean(
|
|
"ForceSystemMessage", bot.force_system_message)
|
|
|
|
bot.chat_api = bot.image_api = OpenAI(config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger)
|
|
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
|
bot.max_messages = config["OpenAI"].getint(
|
|
"MaxMessages", bot.max_messages)
|
|
|
|
# Set up WolframAlpha
|
|
if "WolframAlpha" in config:
|
|
bot.calculation_api = WolframAlpha(
|
|
config["WolframAlpha"]["APIKey"], bot.logger)
|
|
|
|
# Set up the Matrix client
|
|
|
|
assert "Matrix" in config, "Matrix config not found"
|
|
|
|
homeserver = config["Matrix"]["Homeserver"]
|
|
bot.matrix_client = AsyncClient(homeserver)
|
|
bot.matrix_client.access_token = config["Matrix"]["AccessToken"]
|
|
bot.matrix_client.user_id = config["Matrix"].get("UserID")
|
|
bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
|
|
|
|
# Return the new GPTBot instance
|
|
return bot
|
|
|
|
async def _get_user_id(self) -> str:
|
|
"""Get the user ID of the bot from the whoami endpoint.
|
|
Requires an access token to be set up.
|
|
|
|
Returns:
|
|
str: The user ID of the bot.
|
|
"""
|
|
|
|
assert self.matrix_client, "Matrix client not set up"
|
|
|
|
user_id = self.matrix_client.user_id
|
|
|
|
if not user_id:
|
|
assert self.matrix_client.access_token, "Access token not set up"
|
|
|
|
response = await self.matrix_client.whoami()
|
|
|
|
if isinstance(response, WhoamiResponse):
|
|
user_id = response.user_id
|
|
else:
|
|
raise Exception(f"Could not get user ID: {response}")
|
|
|
|
return user_id
|
|
|
|
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
|
|
messages = []
|
|
n = n or bot.max_messages
|
|
room_id = room.room_id if isinstance(room, MatrixRoom) else room
|
|
|
|
self.logger.log(
|
|
f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...")
|
|
|
|
response = await self.matrix_client.room_messages(
|
|
room_id=room_id,
|
|
start=self.sync_token,
|
|
limit=2*n,
|
|
)
|
|
|
|
if isinstance(response, RoomMessagesError):
|
|
raise Exception(
|
|
f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
|
|
|
|
for event in response.chunk:
|
|
if len(messages) >= n:
|
|
break
|
|
if isinstance(event, MegolmEvent):
|
|
try:
|
|
event = await self.matrix_client.decrypt_event(event)
|
|
except (GroupEncryptionError, EncryptionError):
|
|
self.logger.log(
|
|
f"Could not decrypt message {event.event_id} in room {room_id}", "error")
|
|
continue
|
|
if isinstance(event, RoomMessageText):
|
|
if event.body.startswith("!gptbot ignoreolder"):
|
|
break
|
|
if not event.body.startswith("!"):
|
|
messages.append(event)
|
|
|
|
self.logger.log(f"Found {len(messages)} messages (limit: {n})")
|
|
|
|
# Reverse the list so that messages are in chronological order
|
|
return messages[::-1]
|
|
|
|
def _truncate(self, messages: list, max_tokens: Optional[int] = None,
|
|
model: Optional[str] = None, system_message: Optional[str] = None):
|
|
max_tokens = max_tokens or self.max_tokens
|
|
model = model or self.chat_api.chat_model
|
|
system_message = self.default_system_message if system_message is None else system_message
|
|
|
|
encoding = tiktoken.encoding_for_model(model)
|
|
total_tokens = 0
|
|
|
|
system_message_tokens = 0 if not system_message else (
|
|
len(encoding.encode(system_message)) + 1)
|
|
|
|
if system_message_tokens > max_tokens:
|
|
self.logger.log(
|
|
f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
|
|
return []
|
|
|
|
total_tokens += system_message_tokens
|
|
|
|
total_tokens = len(system_message) + 1
|
|
truncated_messages = []
|
|
|
|
for message in [messages[0]] + list(reversed(messages[1:])):
|
|
content = message["content"]
|
|
tokens = len(encoding.encode(content)) + 1
|
|
if total_tokens + tokens > max_tokens:
|
|
break
|
|
total_tokens += tokens
|
|
truncated_messages.append(message)
|
|
|
|
return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
|
|
|
|
async def _get_device_id(self) -> str:
|
|
"""Guess the device ID of the bot.
|
|
Requires an access token to be set up.
|
|
|
|
Returns:
|
|
str: The guessed device ID.
|
|
"""
|
|
|
|
assert self.matrix_client, "Matrix client not set up"
|
|
|
|
device_id = self.matrix_client.device_id
|
|
|
|
if not device_id:
|
|
assert self.matrix_client.access_token, "Access token not set up"
|
|
|
|
devices = await self.matrix_client.devices()
|
|
|
|
if isinstance(devices, DevicesResponse):
|
|
device_id = devices.devices[0].id
|
|
|
|
return device_id
|
|
|
|
async def process_command(self, room: MatrixRoom, event: RoomMessageText):
|
|
self.logger.log(
|
|
f"Received command {event.body} from {event.sender} in room {room.room_id}")
|
|
command = event.body.split()[1] if event.body.split()[1:] else None
|
|
|
|
await COMMANDS.get(command, COMMANDS[None])(room, event, self)
|
|
|
|
async def event_callback(self,room: MatrixRoom, event: Event):
|
|
self.logger.log("Received event: " + str(event.event_id), "debug")
|
|
try:
|
|
for eventtype, callback in EVENT_CALLBACKS.items():
|
|
if isinstance(event, eventtype):
|
|
await callback(room, event, self)
|
|
except Exception as e:
|
|
self.logger.log(f"Error in event callback for {event.__class__}: {e}", "error")
|
|
|
|
async def response_callback(self, response: Response):
|
|
for response_type, callback in RESPONSE_CALLBACKS.items():
|
|
if isinstance(response, response_type):
|
|
await callback(response, self)
|
|
|
|
async def accept_pending_invites(self):
|
|
"""Accept all pending invites."""
|
|
|
|
assert self.matrix_client, "Matrix client not set up"
|
|
|
|
invites = self.matrix_client.invited_rooms
|
|
|
|
for invite in invites.keys():
|
|
await self.matrix_client.join(invite)
|
|
|
|
async def send_image(self, room: MatrixRoom, image: bytes, message: Optional[str] = None):
|
|
self.logger.log(
|
|
f"Sending image of size {len(image)} bytes to room {room.room_id}")
|
|
|
|
bio = BytesIO(image)
|
|
img = Image.open(bio)
|
|
mime = Image.MIME[img.format]
|
|
|
|
(width, height) = img.size
|
|
|
|
self.logger.log(
|
|
f"Uploading - Image size: {width}x{height} pixels, MIME type: {mime}")
|
|
|
|
bio.seek(0)
|
|
|
|
response, _ = await self.matrix_client.upload(
|
|
bio,
|
|
content_type=mime,
|
|
filename="image",
|
|
filesize=len(image)
|
|
)
|
|
|
|
self.logger.log("Uploaded image - sending message...")
|
|
|
|
content = {
|
|
"body": message or "",
|
|
"info": {
|
|
"mimetype": mime,
|
|
"size": len(image),
|
|
"w": width,
|
|
"h": height,
|
|
},
|
|
"msgtype": "m.image",
|
|
"url": response.content_uri
|
|
}
|
|
|
|
status = await self.matrix_client.room_send(
|
|
room.room_id,
|
|
"m.room.message",
|
|
content
|
|
)
|
|
|
|
self.logger.log(str(status), "debug")
|
|
|
|
self.logger.log("Sent image")
|
|
|
|
async def send_message(self, room: MatrixRoom, message: str, notice: bool = False):
|
|
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
|
|
formatted_body = markdowner.convert(message)
|
|
|
|
msgtype = "m.notice" if notice else "m.text"
|
|
|
|
msgcontent = {"msgtype": msgtype, "body": message,
|
|
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
|
|
|
|
content = None
|
|
|
|
if self.matrix_client.olm and room.encrypted:
|
|
try:
|
|
if not room.members_synced:
|
|
responses = []
|
|
responses.append(await self.matrix_client.joined_members(room.room_id))
|
|
|
|
if self.matrix_client.olm.should_share_group_session(room.room_id):
|
|
try:
|
|
event = self.matrix_client.sharing_session[room.room_id]
|
|
await event.wait()
|
|
except KeyError:
|
|
await self.matrix_client.share_group_session(
|
|
room.room_id,
|
|
ignore_unverified_devices=True,
|
|
)
|
|
|
|
if msgtype != "m.reaction":
|
|
response = self.matrix_client.encrypt(
|
|
room.room_id, "m.room.message", msgcontent)
|
|
msgtype, content = response
|
|
|
|
except Exception as e:
|
|
self.logger.log(
|
|
f"Error encrypting message: {e} - sending unencrypted", "error")
|
|
raise
|
|
|
|
if not content:
|
|
msgtype = "m.room.message"
|
|
content = msgcontent
|
|
|
|
method, path, data = Api.room_send(
|
|
self.matrix_client.access_token, room.room_id, msgtype, content, uuid.uuid4()
|
|
)
|
|
|
|
return await self.matrix_client._send(RoomSendResponse, method, path, data, (room.room_id,))
|
|
|
|
async def run(self):
|
|
"""Start the bot."""
|
|
|
|
# Set up the Matrix client
|
|
|
|
assert self.matrix_client, "Matrix client not set up"
|
|
assert self.matrix_client.access_token, "Access token not set up"
|
|
|
|
if not self.matrix_client.user_id:
|
|
self.matrix_client.user_id = await self._get_user_id()
|
|
|
|
if not self.matrix_client.device_id:
|
|
self.matrix_client.device_id = await self._get_device_id()
|
|
|
|
# Set up database
|
|
|
|
IN_MEMORY = False
|
|
if not self.database:
|
|
self.logger.log(
|
|
"No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
|
|
IN_MEMORY = True
|
|
self.database = DuckDBPyConnection(":memory:")
|
|
|
|
self.logger.log("Running migrations...")
|
|
before, after = migrate(self.database)
|
|
if before != after:
|
|
self.logger.log(f"Migrated from version {before} to {after}.")
|
|
else:
|
|
self.logger.log(f"Already at latest version {after}.")
|
|
|
|
if IN_MEMORY:
|
|
client_config = AsyncClientConfig(
|
|
store_sync_tokens=True, encryption_enabled=False)
|
|
else:
|
|
matrix_store = DuckDBStore
|
|
client_config = AsyncClientConfig(
|
|
store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
|
|
self.matrix_client.config = client_config
|
|
self.matrix_client.store = matrix_store(
|
|
self.matrix_client.user_id,
|
|
self.matrix_client.device_id,
|
|
self.database
|
|
)
|
|
|
|
self.matrix_client.olm = Olm(
|
|
self.matrix_client.user_id,
|
|
self.matrix_client.device_id,
|
|
self.matrix_client.store
|
|
)
|
|
|
|
self.matrix_client.encrypted_rooms = self.matrix_client.store.load_encrypted_rooms()
|
|
|
|
# Run initial sync
|
|
sync = await self.matrix_client.sync(timeout=30000)
|
|
if isinstance(sync, SyncResponse):
|
|
await self.response_callback(sync)
|
|
else:
|
|
self.logger.log(f"Initial sync failed, aborting: {sync}", "error")
|
|
return
|
|
|
|
# Set up callbacks
|
|
|
|
self.matrix_client.add_event_callback(self.event_callback, Event)
|
|
self.matrix_client.add_response_callback(self.response_callback, Response)
|
|
|
|
# Accept pending invites
|
|
|
|
self.logger.log("Accepting pending invites...")
|
|
await self.accept_pending_invites()
|
|
|
|
# Start syncing events
|
|
self.logger.log("Starting sync loop...")
|
|
try:
|
|
await self.matrix_client.sync_forever(timeout=30000)
|
|
finally:
|
|
self.logger.log("Syncing one last time...")
|
|
await self.matrix_client.sync(timeout=30000)
|
|
|
|
async def process_query(self, room: MatrixRoom, event: RoomMessageText):
|
|
await self.matrix_client.room_typing(room.room_id, True)
|
|
|
|
await self.matrix_client.room_read_markers(room.room_id, event.event_id)
|
|
|
|
try:
|
|
last_messages = await self._last_n_messages(room.room_id, 20)
|
|
except Exception as e:
|
|
self.logger.log(f"Error getting last messages: {e}", "error")
|
|
await self.send_message(
|
|
room, "Something went wrong. Please try again.", True)
|
|
return
|
|
|
|
system_message = self.get_system_message(room)
|
|
|
|
chat_messages = [{"role": "system", "content": system_message}]
|
|
|
|
for message in last_messages:
|
|
role = "assistant" if message.sender == self.matrix_client.user_id else "user"
|
|
if not message.event_id == event.event_id:
|
|
chat_messages.append({"role": role, "content": message.body})
|
|
|
|
chat_messages.append({"role": "user", "content": event.body})
|
|
|
|
# Truncate messages to fit within the token limit
|
|
truncated_messages = self._truncate(
|
|
chat_messages, self.max_tokens - 1, system_message=system_message)
|
|
|
|
try:
|
|
response, tokens_used = await self.generate_chat_response(truncated_messages)
|
|
except Exception as e:
|
|
self.logger.log(f"Error generating response: {e}", "error")
|
|
await self.send_message(
|
|
room, "Something went wrong. Please try again.", True)
|
|
return
|
|
|
|
if response:
|
|
self.logger.log(f"Sending response to room {room.room_id}...")
|
|
|
|
# Convert markdown to HTML
|
|
|
|
message = await self.send_message(room, response)
|
|
|
|
if self.database:
|
|
self.logger.log("Storing record of tokens used...")
|
|
|
|
with self.database.cursor() as cursor:
|
|
cursor.execute(
|
|
"INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)",
|
|
(message.event_id, room.room_id, tokens_used, datetime.now()))
|
|
self.database.commit()
|
|
else:
|
|
# Send a notice to the room if there was an error
|
|
self.logger.log("Didn't get a response from GPT API", "error")
|
|
send_message(
|
|
room, "Something went wrong. Please try again.", True)
|
|
|
|
await self.matrix_client.room_typing(room.room_id, False)
|
|
|
|
async def generate_chat_response(self, messages: List[Dict[str, str]]) -> Tuple[str, int]:
|
|
"""Generate a response to a chat message.
|
|
|
|
Args:
|
|
messages (List[Dict[str, str]]): A list of messages to use as context.
|
|
|
|
Returns:
|
|
Tuple[str, int]: The response text and the number of tokens used.
|
|
"""
|
|
|
|
return self.chat_api.generate_chat_response(messages)
|
|
|
|
def get_system_message(self, room: MatrixRoom | int) -> str:
|
|
default = self.default_system_message
|
|
|
|
if isinstance(room, int):
|
|
room_id = room
|
|
else:
|
|
room_id = room.room_id
|
|
|
|
with self.database.cursor() as cur:
|
|
cur.execute(
|
|
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
|
|
(room_id,)
|
|
)
|
|
system_message = cur.fetchone()
|
|
|
|
complete = ((default if ((not system_message) or self.force_system_message) else "") + (
|
|
"\n\n" + system_message[0] if system_message else "")).strip()
|
|
|
|
return complete
|
|
|
|
def __del__(self):
|
|
"""Close the bot."""
|
|
|
|
if self.matrix_client:
|
|
asyncio.run(self.matrix_client.close())
|
|
|
|
if self.database:
|
|
self.database.close()
|