Create a bot class

Make everything compatible with that
This commit is contained in:
Kumi 2023-04-25 11:25:53 +00:00
parent 2bbc6a33ca
commit 1dc0378853
Signed by: kumi
GPG key ID: ECBCC9082395383F
19 changed files with 706 additions and 610 deletions

0
__init__.py Normal file
View file

29
callbacks/__init__.py Normal file
View file

@ -0,0 +1,29 @@
from nio import (
RoomMessageText,
MegolmEvent,
InviteEvent,
Event,
SyncResponse,
JoinResponse,
InviteEvent,
OlmEvent,
MegolmEvent
)
from .test import test_callback
from .sync import sync_callback
from .invite import room_invite_callback
from .join import join_callback
from .message import message_callback
RESPONSE_CALLBACKS = {
SyncResponse: sync_callback,
JoinResponse: join_callback,
}
EVENT_CALLBACKS = {
Event: test_callback,
InviteEvent: room_invite_callback,
RoomMessageText: message_callback,
MegolmEvent: message_callback,
}

10
callbacks/invite.py Normal file
View file

@ -0,0 +1,10 @@
from nio import InviteEvent, MatrixRoom
async def room_invite_callback(room: MatrixRoom, event: InviteEvent, bot):
if room.room_id in bot.matrix_client.rooms:
logging(f"Already in room {room.room_id} - ignoring invite")
return
bot.logger.log(f"Received invite to room {room.room_id} - joining...")
response = await bot.matrix_client.join(room.room_id)

7
callbacks/join.py Normal file
View file

@ -0,0 +1,7 @@
async def join_callback(response, bot):
bot.logger.log(
f"Join response received for room {response.room_id}", "debug")
bot.matrix_client.joined_rooms()
await bot.send_message(bot.matrix_client.rooms[response.room_id], "Hello! Thanks for inviting me! How can I help you today?")

30
callbacks/message.py Normal file
View file

@ -0,0 +1,30 @@
from nio import MatrixRoom, RoomMessageText, MegolmEvent
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, bot):
bot.logger.log(f"Received message from {event.sender} in room {room.room_id}")
if isinstance(event, MegolmEvent):
try:
event = await bot.matrix_client.decrypt_event(event)
except Exception as e:
try:
bot.logger.log("Requesting new encryption keys...")
await bot.matrix_client.request_room_key(event)
except:
pass
bot.logger.log(f"Error decrypting message: {e}", "error")
await bot.send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True)
return
if event.sender == bot.matrix_client.user_id:
bot.logger.log("Message is from bot itself - ignoring")
elif event.body.startswith("!gptbot"):
await bot.process_command(room, event)
elif event.body.startswith("!"):
bot.logger.log(f"Received {event.body} - might be a command, but not for this bot - ignoring")
else:
await bot.process_query(room, event)

6
callbacks/sync.py Normal file
View file

@ -0,0 +1,6 @@
async def sync_callback(response, bot):
bot.logger.log(
f"Sync response received (next batch: {response.next_batch})", "debug")
SYNC_TOKEN = response.next_batch
bot.sync_token = SYNC_TOKEN

11
callbacks/test.py Normal file
View file

@ -0,0 +1,11 @@
from nio import MatrixRoom, Event
async def test_callback(room: MatrixRoom, event: Event, bot):
"""Test callback for debugging purposes.
Args:
room (MatrixRoom): The room the event was sent in.
event (Event): The event that was sent.
"""
bot.logger.log(f"Test callback called: {room.room_id} {event.event_id} {event.sender} {event.__class__}")

481
classes/bot.py Normal file
View file

@ -0,0 +1,481 @@
import openai
import markdown2
import duckdb
import tiktoken
import asyncio
from nio import (
AsyncClient,
AsyncClientConfig,
WhoamiResponse,
DevicesResponse,
Event,
Response,
MatrixRoom,
Api,
RoomMessagesError,
MegolmEvent,
GroupEncryptionError,
EncryptionError,
RoomMessageText,
RoomSendResponse,
SyncResponse
)
from nio.crypto import Olm
from typing import Optional, List, Dict
from configparser import ConfigParser
from datetime import datetime
import uuid
from .logging import Logger
from migrations import migrate
from callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
from commands import COMMANDS
from .store import DuckDBStore
class GPTBot:
# Default values
database: Optional[duckdb.DuckDBPyConnection] = None
default_room_name: str = "GPTBot" # Default name of rooms created by the bot
default_system_message: str = "You are a helpful assistant."
# Force default system message to be included even if a custom room message is set
force_system_message: bool = False
max_tokens: int = 3000 # Maximum number of input tokens
max_messages: int = 30 # Maximum number of messages to consider as input
model: str = "gpt-3.5-turbo" # OpenAI chat model to use
matrix_client: Optional[AsyncClient] = None
sync_token: Optional[str] = None
logger: Optional[Logger] = Logger()
openai_api_key: Optional[str] = None
@classmethod
def from_config(cls, config: ConfigParser):
"""Create a new GPTBot instance from a config file.
Args:
config (ConfigParser): ConfigParser instance with the bot's config.
Returns:
GPTBot: The new GPTBot instance.
"""
# Create a new GPTBot instance
bot = cls()
# Set the database connection
bot.database = duckdb.connect(
config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
# Override default values
if "GPTBot" in config:
bot.default_room_name = config["GPTBot"].get(
"DefaultRoomName", bot.default_room_name)
bot.default_system_message = config["GPTBot"].get(
"SystemMessage", bot.default_system_message)
bot.force_system_message = config["GPTBot"].getboolean(
"ForceSystemMessage", bot.force_system_message)
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
bot.max_messages = config["OpenAI"].getint(
"MaxMessages", bot.max_messages)
bot.model = config["OpenAI"].get("Model", bot.model)
bot.openai_api_key = config["OpenAI"]["APIKey"]
# Set up the Matrix client
assert "Matrix" in config, "Matrix config not found"
homeserver = config["Matrix"]["Homeserver"]
bot.matrix_client = AsyncClient(homeserver)
bot.matrix_client.access_token = config["Matrix"]["AccessToken"]
bot.matrix_client.user_id = config["Matrix"].get("UserID")
bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
# Return the new GPTBot instance
return bot
async def _get_user_id(self) -> str:
"""Get the user ID of the bot from the whoami endpoint.
Requires an access token to be set up.
Returns:
str: The user ID of the bot.
"""
assert self.matrix_client, "Matrix client not set up"
user_id = self.matrix_client.user_id
if not user_id:
assert self.matrix_client.access_token, "Access token not set up"
response = await self.matrix_client.whoami()
if isinstance(response, WhoamiResponse):
user_id = response.user_id
else:
raise Exception(f"Could not get user ID: {response}")
return user_id
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
messages = []
n = n or bot.max_messages
room_id = room.room_id if isinstance(room, MatrixRoom) else room
self.logger.log(
f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...")
response = await self.matrix_client.room_messages(
room_id=room_id,
start=self.sync_token,
limit=2*n,
)
if isinstance(response, RoomMessagesError):
raise Exception(
f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
for event in response.chunk:
if len(messages) >= n:
break
if isinstance(event, MegolmEvent):
try:
event = await self.matrix_client.decrypt_event(event)
except (GroupEncryptionError, EncryptionError):
self.logger.log(
f"Could not decrypt message {event.event_id} in room {room_id}", "error")
continue
if isinstance(event, RoomMessageText):
if event.body.startswith("!gptbot ignoreolder"):
break
if not event.body.startswith("!"):
messages.append(event)
self.logger.log(f"Found {len(messages)} messages (limit: {n})")
# Reverse the list so that messages are in chronological order
return messages[::-1]
def _truncate(self, messages: list, max_tokens: Optional[int] = None,
model: Optional[str] = None, system_message: Optional[str] = None):
max_tokens = max_tokens or self.max_tokens
model = model or self.model
system_message = self.default_system_message if system_message is None else system_message
encoding = tiktoken.encoding_for_model(model)
total_tokens = 0
system_message_tokens = 0 if not system_message else (
len(encoding.encode(system_message)) + 1)
if system_message_tokens > max_tokens:
self.logger.log(
f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
return []
total_tokens += system_message_tokens
total_tokens = len(system_message) + 1
truncated_messages = []
for message in [messages[0]] + list(reversed(messages[1:])):
content = message["content"]
tokens = len(encoding.encode(content)) + 1
if total_tokens + tokens > max_tokens:
break
total_tokens += tokens
truncated_messages.append(message)
return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
async def _get_device_id(self) -> str:
"""Guess the device ID of the bot.
Requires an access token to be set up.
Returns:
str: The guessed device ID.
"""
assert self.matrix_client, "Matrix client not set up"
device_id = self.matrix_client.device_id
if not device_id:
assert self.matrix_client.access_token, "Access token not set up"
devices = await self.matrix_client.devices()
if isinstance(devices, DevicesResponse):
device_id = devices.devices[0].id
return device_id
async def process_command(self, room: MatrixRoom, event: RoomMessageText):
self.logger.log(
f"Received command {event.body} from {event.sender} in room {room.room_id}")
command = event.body.split()[1] if event.body.split()[1:] else None
await COMMANDS.get(command, COMMANDS[None])(room, event, self)
async def event_callback(self,room: MatrixRoom, event: Event):
self.logger.log("Received event: " + str(event), "debug")
for eventtype, callback in EVENT_CALLBACKS.items():
if isinstance(event, eventtype):
await callback(room, event, self)
async def response_callback(self, response: Response):
for response_type, callback in RESPONSE_CALLBACKS.items():
if isinstance(response, response_type):
await callback(response, self)
async def accept_pending_invites(self):
"""Accept all pending invites."""
assert self.matrix_client, "Matrix client not set up"
invites = self.matrix_client.invited_rooms
for invite in invites.keys():
await self.matrix_client.join(invite)
async def send_message(self, room: MatrixRoom, message: str, notice: bool = False):
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
formatted_body = markdowner.convert(message)
msgtype = "m.notice" if notice else "m.text"
msgcontent = {"msgtype": msgtype, "body": message,
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
content = None
if self.matrix_client.olm and room.encrypted:
try:
if not room.members_synced:
responses = []
responses.append(await self.matrix_client.joined_members(room.room_id))
if self.matrix_client.olm.should_share_group_session(room.room_id):
try:
event = self.matrix_client.sharing_session[room.room_id]
await event.wait()
except KeyError:
await self.matrix_client.share_group_session(
room.room_id,
ignore_unverified_devices=True,
)
if msgtype != "m.reaction":
response = self.matrix_client.encrypt(
room.room_id, "m.room.message", msgcontent)
msgtype, content = response
except Exception as e:
self.logger.log(
f"Error encrypting message: {e} - sending unencrypted", "error")
raise
if not content:
msgtype = "m.room.message"
content = msgcontent
method, path, data = Api.room_send(
self.matrix_client.access_token, room.room_id, msgtype, content, uuid.uuid4()
)
return await self.matrix_client._send(RoomSendResponse, method, path, data, (room.room_id,))
async def run(self):
"""Start the bot."""
# Set up the Matrix client
assert self.matrix_client, "Matrix client not set up"
assert self.matrix_client.access_token, "Access token not set up"
if not self.matrix_client.user_id:
self.matrix_client.user_id = await self._get_user_id()
if not self.matrix_client.device_id:
self.matrix_client.device_id = await self._get_device_id()
# Set up database
IN_MEMORY = False
if not self.database:
self.logger.log(
"No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
IN_MEMORY = True
self.database = DuckDBPyConnection(":memory:")
self.logger.log("Running migrations...")
before, after = migrate(self.database)
if before != after:
self.logger.log(f"Migrated from version {before} to {after}.")
else:
self.logger.log(f"Already at latest version {after}.")
if IN_MEMORY:
client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=False)
else:
matrix_store = DuckDBStore
client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
self.matrix_client.config = client_config
self.matrix_client.store = matrix_store(
self.matrix_client.user_id,
self.matrix_client.device_id,
self.database
)
self.matrix_client.olm = Olm(
self.matrix_client.user_id,
self.matrix_client.device_id,
self.matrix_client.store
)
self.matrix_client.encrypted_rooms = self.matrix_client.store.load_encrypted_rooms()
# Run initial sync
sync = await self.matrix_client.sync(timeout=30000)
if isinstance(sync, SyncResponse):
await self.response_callback(sync)
else:
self.logger.log(f"Initial sync failed, aborting: {sync}", "error")
return
# Set up callbacks
self.matrix_client.add_event_callback(self.event_callback, Event)
self.matrix_client.add_response_callback(self.response_callback, Response)
# Accept pending invites
self.logger.log("Accepting pending invites...")
await self.accept_pending_invites()
# Start syncing events
self.logger.log("Starting sync loop...")
try:
await self.matrix_client.sync_forever(timeout=30000)
finally:
self.logger.log("Syncing one last time...")
await self.matrix_client.sync(timeout=30000)
async def process_query(self, room: MatrixRoom, event: RoomMessageText):
await self.matrix_client.room_typing(room.room_id, True)
await self.matrix_client.room_read_markers(room.room_id, event.event_id)
try:
last_messages = await self._last_n_messages(room.room_id, 20)
except Exception as e:
self.logger.log(f"Error getting last messages: {e}", "error")
await self.send_message(
room, "Something went wrong. Please try again.", True)
return
system_message = self.get_system_message(room)
chat_messages = [{"role": "system", "content": system_message}]
for message in last_messages:
role = "assistant" if message.sender == self.matrix_client.user_id else "user"
if not message.event_id == event.event_id:
chat_messages.append({"role": role, "content": message.body})
chat_messages.append({"role": "user", "content": event.body})
# Truncate messages to fit within the token limit
truncated_messages = self._truncate(
chat_messages, self.max_tokens - 1, system_message=system_message)
try:
response, tokens_used = await self.generate_chat_response(truncated_messages)
except Exception as e:
self.logger.log(f"Error generating response: {e}", "error")
await self.send_message(
room, "Something went wrong. Please try again.", True)
return
if response:
self.logger.log(f"Sending response to room {room.room_id}...")
# Convert markdown to HTML
message = await self.send_message(room, response)
if self.database:
self.logger.log("Storing record of tokens used...")
with self.database.cursor() as cursor:
cursor.execute(
"INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)",
(message.event_id, room.room_id, tokens_used, datetime.now()))
self.database.commit()
else:
# Send a notice to the room if there was an error
self.logger.log("Didn't get a response from GPT API", "error")
send_message(
room, "Something went wrong. Please try again.", True)
await self.matrix_client.room_typing(room.room_id, False)
async def generate_chat_response(self, messages: List[Dict[str, str]]) -> str:
"""Generate a response to a chat message.
Args:
messages (List[Dict[str, str]]): A list of messages to use as context.
Returns:
str: The response to the chat.
"""
self.logger.log(f"Generating response to {len(messages)} messages...")
response = openai.ChatCompletion.create(
model=self.model,
messages=messages,
api_key=self.openai_api_key
)
result_text = response.choices[0].message['content']
tokens_used = response.usage["total_tokens"]
self.logger.log(f"Generated response with {tokens_used} tokens.")
return result_text, tokens_used
def get_system_message(self, room: MatrixRoom | int) -> str:
default = self.default_system_message
if isinstance(room, int):
room_id = room
else:
room_id = room.room_id
with self.database.cursor() as cur:
cur.execute(
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
(room_id,)
)
system_message = cur.fetchone()
complete = ((default if ((not system_message) or self.force_system_message) else "") + (
"\n\n" + system_message[0] if system_message else "")).strip()
return complete
def __del__(self):
"""Close the bot."""
if self.matrix_client:
asyncio.run(self.matrix_client.close())
if self.database:
self.database.close()

10
classes/logging.py Normal file
View file

@ -0,0 +1,10 @@
import inspect
from datetime import datetime
class Logger:
def log(self, message: str, log_level: str = "info"):
caller = inspect.currentframe().f_back.f_code.co_name
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")
print(f"[{timestamp}] - {caller} - [{log_level.upper()}] {message}")

View file

@ -2,21 +2,22 @@ from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom
async def command_botinfo(room: MatrixRoom, event: RoomMessageText, context: dict):
async def command_botinfo(room: MatrixRoom, event: RoomMessageText, bot):
logging("Showing bot info...")
return room.room_id, "m.room.message", {"msgtype": "m.notice",
"body": f"""GPT Info:
body = f"""GPT Info:
Model: {context["model"]}
Maximum context tokens: {context["max_tokens"]}
Maximum context messages: {context["max_messages"]}
System message: {context["system_message"]}
Model: {bot.model}
Maximum context tokens: {bot.max_tokens}
Maximum context messages: {bot.max_messages}
Room info:
Bot user ID: {context["client"].user_id}
Bot user ID: {bot.matrix_client.user_id}
Current room ID: {room.room_id}
System message: {bot.get_system_message(room)}
For usage statistics, run !gptbot stats
"""}
"""
await bot.send_message(room, body, True)

View file

@ -4,10 +4,10 @@ from nio.rooms import MatrixRoom
from random import SystemRandom
async def command_coin(room: MatrixRoom, event: RoomMessageText, context: dict):
context["logger"]("Flipping a coin...")
async def command_coin(room: MatrixRoom, event: RoomMessageText, bot):
bot.logger.log("Flipping a coin...")
heads = SystemRandom().choice([True, False])
body = "Flipping a coin... It's " + ("heads!" if heads else "tails!")
return room.room_id, "m.room.message", {"msgtype": "m.notice",
"body": "Heads!" if heads else "Tails!"}
await bot.send_message(room, body, True)

View file

@ -2,9 +2,8 @@ from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom
async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
return room.guest_accessroom_id, "m.room.message", {"msgtype": "m.notice",
"body": """Available commands:
async def command_help(room: MatrixRoom, event: RoomMessageText, bot):
body = """Available commands:
!gptbot help - Show this message
!gptbot newroom <room name> - Create a new room and invite yourself to it
@ -12,4 +11,7 @@ async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
!gptbot botinfo - Show information about the bot
!gptbot coin - Flip a coin (heads or tails)
!gptbot ignoreolder - Ignore messages before this point as context
"""}
!gptbot systemmessage - Get or set the system message for this room
"""
await bot.send_message(room, body, True)

View file

@ -1,8 +1,9 @@
from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom
async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, context: dict):
return room.room_id, "m.room.message", {"msgtype": "m.notice",
"body": """Alright, messages before this point will not be processed as context anymore.
async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, bot):
body = """Alright, messages before this point will not be processed as context anymore.
If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""}
If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""
await bot.send_message(room, body, True)

View file

@ -1,17 +1,31 @@
from nio.events.room_events import RoomMessageText
from nio import RoomCreateError, RoomInviteError
from nio.rooms import MatrixRoom
async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dict):
async def command_newroom(room: MatrixRoom, event: RoomMessageText, bot):
room_name = " ".join(event.body.split()[
2:]) or context["default_room_name"]
context["logger"]("Creating new room...")
new_room = await context["client"].room_create(name=room_name)
bot.logger.log("Creating new room...")
new_room = await bot.matrix_client.room_create(name=room_name)
if isinstance(new_room, RoomCreateError):
bot.logger.log(f"Failed to create room: {new_room.message}")
await bot.send_message(room, f"Sorry, I was unable to create a new room. Please try again later, or create a room manually.", True)
return
bot.logger.log(f"Inviting {event.sender} to new room...")
invite = await bot.matrix_client.room_invite(new_room.room_id, event.sender)
if isinstance(invite, RoomInviteError):
bot.logger.log(f"Failed to invite user: {invite.message}")
await bot.send_message(room, f"Sorry, I was unable to invite you to the new room. Please try again later, or create a room manually.", True)
return
context["logger"](f"Inviting {event.sender} to new room...")
await context["client"].room_invite(new_room.room_id, event.sender)
await context["client"].room_put_state(
new_room.room_id, "m.room.power_levels", {"users": {event.sender: 100}})
return new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"}
await bot.matrix_client.joined_rooms()
await bot.send_message(room, f"Alright, I've created a new room called '{room_name}' and invited you to it. You can find it at {new_room.room_id}", True)
await bot.send_message(new_room.room_id, f"Welcome to the new room! What can I do for you?")

View file

@ -2,18 +2,17 @@ from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom
async def command_stats(room: MatrixRoom, event: RoomMessageText, context: dict):
context["logger"]("Showing stats...")
async def command_stats(room: MatrixRoom, event: RoomMessageText, bot):
bot.logger.log("Showing stats...")
if not (database := context.get("database")):
context["logger"]("No database connection - cannot show stats")
return room.room_id, "m.room.message", {"msgtype": "m.notice",
"body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."}
if not bot.database:
bot.logger.log("No database connection - cannot show stats")
bot.send_message(room, "Sorry, I'm not connected to a database, so I don't have any statistics on your usage.", True)
return
with database.cursor() as cursor:
with bot.database.cursor() as cursor:
cursor.execute(
"SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,))
total_tokens = cursor.fetchone()[0] or 0
return room.room_id, "m.room.message", {"msgtype": "m.notice",
"body": f"Total tokens used: {total_tokens}"}
bot.send_message(room, f"Total tokens used: {total_tokens}", True)

View file

@ -2,33 +2,24 @@ from nio.events.room_events import RoomMessageText
from nio.rooms import MatrixRoom
async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, context: dict):
async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, bot):
system_message = " ".join(event.body.split()[2:])
if system_message:
context["logger"]("Adding system message...")
bot.logger.log("Adding system message...")
with context["database"].cursor() as cur:
with bot.database.cursor() as cur:
cur.execute(
"INSERT INTO system_messages (room_id, message_id, user_id, body, timestamp) VALUES (?, ?, ?, ?, ?)",
(room.room_id, event.event_id, event.sender,
system_message, event.server_timestamp)
)
return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message stored: {system_message}"}
bot.send_message(room, f"Alright, I've stored the system message: '{system_message}'.", True)
return
context["logger"]("Retrieving system message...")
bot.logger.log("Retrieving system message...")
with context["database"].cursor() as cur:
cur.execute(
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
(room.room_id,)
)
system_message = cur.fetchone()
system_message = bot.get_system_message(room)
if system_message is None:
system_message = context.get("system_message", "No system message set")
elif context.get("force_system_message") and context.get("system_message"):
system_message = system_message + "\n\n" + context["system_message"]
return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message: {system_message}"}
bot.send_message(room, f"The current system message is: '{system_message}'.", True)

View file

@ -3,7 +3,6 @@ from nio.rooms import MatrixRoom
async def command_unknown(room: MatrixRoom, event: RoomMessageText, context: dict):
context["logger"]("Unknown command")
bot.logger.log("Unknown command")
return room.room_id, "m.room.message", {"msgtype": "m.notice",
"body": "Unknown command - try !gptbot help"}
bot.send_message(room, "Unknown command - try !gptbot help", True)

569
gptbot.py
View file

@ -1,554 +1,14 @@
import os
import inspect
import logging
import signal
import random
import uuid
from classes.bot import GPTBot
import openai
import asyncio
import markdown2
import tiktoken
import duckdb
from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent, AsyncClientConfig, MegolmEvent, GroupEncryptionError, EncryptionError, HttpClient, Api
from nio.api import MessageDirection
from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError, WhoamiResponse, JoinResponse, RoomSendResponse
from nio.crypto import Olm
from configparser import ConfigParser
from datetime import datetime
from argparse import ArgumentParser
from typing import List, Dict, Union, Optional
from configparser import ConfigParser
from commands import COMMANDS
from classes import DuckDBStore
from migrations import MIGRATIONS
import signal
import asyncio
def logging(message: str, log_level: str = "info"):
caller = inspect.currentframe().f_back.f_code.co_name
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")
print(f"[{timestamp}] - {caller} - [{log_level.upper()}] {message}")
CONTEXT = {
"database": False,
"default_room_name": "GPTBot",
"system_message": "You are a helpful assistant.",
"force_system_message": False,
"max_tokens": 3000,
"max_messages": 20,
"model": "gpt-3.5-turbo",
"client": None,
"sync_token": None,
"logger": logging
}
async def gpt_query(messages: list, model: Optional[str] = None):
model = model or CONTEXT["model"]
logging(f"Querying GPT with {len(messages)} messages")
logging(messages, "debug")
try:
response = openai.ChatCompletion.create(
model=model,
messages=messages
)
result_text = response.choices[0].message['content']
tokens_used = response.usage["total_tokens"]
logging(f"Used {tokens_used} tokens")
return result_text, tokens_used
except Exception as e:
logging(f"Error during GPT API call: {e}", "error")
return None, 0
async def fetch_last_n_messages(room_id: str, n: Optional[int] = None,
client: Optional[AsyncClient] = None, sync_token: Optional[str] = None):
messages = []
n = n or CONTEXT["max_messages"]
client = client or CONTEXT["client"]
sync_token = sync_token or CONTEXT["sync_token"]
logging(
f"Fetching last {2*n} messages from room {room_id} (starting at {sync_token})...")
response = await client.room_messages(
room_id=room_id,
start=sync_token,
limit=2*n,
)
if isinstance(response, RoomMessagesError):
logging(
f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
return []
for event in response.chunk:
if len(messages) >= n:
break
if isinstance(event, MegolmEvent):
try:
event = await client.decrypt_event(event)
except (GroupEncryptionError, EncryptionError):
logging(
f"Could not decrypt message {event.event_id} in room {room_id}", "error")
continue
if isinstance(event, RoomMessageText):
if event.body.startswith("!gptbot ignoreolder"):
break
if not event.body.startswith("!"):
messages.append(event)
logging(f"Found {len(messages)} messages (limit: {n})")
# Reverse the list so that messages are in chronological order
return messages[::-1]
def truncate_messages_to_fit_tokens(messages: list, max_tokens: Optional[int] = None,
model: Optional[str] = None, system_message: Optional[str] = None):
max_tokens = max_tokens or CONTEXT["max_tokens"]
model = model or CONTEXT["model"]
system_message = system_message or CONTEXT["system_message"]
encoding = tiktoken.encoding_for_model(model)
total_tokens = 0
system_message_tokens = len(encoding.encode(system_message)) + 1
if system_message_tokens > max_tokens:
logging(
f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
return []
total_tokens += system_message_tokens
total_tokens = len(system_message) + 1
truncated_messages = []
for message in [messages[0]] + list(reversed(messages[1:])):
content = message["content"]
tokens = len(encoding.encode(content)) + 1
if total_tokens + tokens > max_tokens:
break
total_tokens += tokens
truncated_messages.append(message)
return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
client = kwargs.get("client") or CONTEXT["client"]
database = kwargs.get("database") or CONTEXT["database"]
max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
system_message = kwargs.get("system_message") or CONTEXT["system_message"]
force_system_message = kwargs.get("force_system_message") or CONTEXT["force_system_message"]
await client.room_typing(room.room_id, True)
await client.room_read_markers(room.room_id, event.event_id)
last_messages = await fetch_last_n_messages(room.room_id, 20)
system_message = get_system_message(room, {
"database": database,
"system_message": system_message,
"force_system_message": force_system_message,
})
chat_messages = [{"role": "system", "content": system_message}]
for message in last_messages:
role = "assistant" if message.sender == client.user_id else "user"
if not message.event_id == event.event_id:
chat_messages.append({"role": role, "content": message.body})
chat_messages.append({"role": "user", "content": event.body})
# Truncate messages to fit within the token limit
truncated_messages = truncate_messages_to_fit_tokens(
chat_messages, max_tokens - 1, system_message=system_message)
response, tokens_used = await gpt_query(truncated_messages)
if response:
logging(f"Sending response to room {room.room_id}...")
# Convert markdown to HTML
message = await send_message(room, response)
if database:
logging("Logging tokens used...")
with database.cursor() as cursor:
cursor.execute(
"INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)",
(message.event_id, room.room_id, tokens_used, datetime.now()))
database.commit()
else:
# Send a notice to the room if there was an error
logging("Error during GPT API call - sending notice to room")
send_message(
room, "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later.", True)
print("No response from GPT API")
await client.room_typing(room.room_id, False)
async def process_command(room: MatrixRoom, event: RoomMessageText, context: Optional[dict] = None):
context = context or CONTEXT
logging(
f"Received command {event.body} from {event.sender} in room {room.room_id}")
command = event.body.split()[1] if event.body.split()[1:] else None
message = await COMMANDS.get(command, COMMANDS[None])(room, event, context)
if message:
room_id, event, content = message
rooms = await context["client"].joined_rooms()
await send_message(context["client"].rooms[room_id], content["body"],
True if content["msgtype"] == "m.notice" else False, context["client"])
def get_system_message(room: MatrixRoom, context: Optional[dict]) -> str:
context = context or CONTEXT
default = context.get("system_message")
with context["database"].cursor() as cur:
cur.execute(
"SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
(room.room_id,)
)
system_message = cur.fetchone()
complete = ((default if ((not system_message) or context["force_system_message"]) else "") + (
"\n\n" + system_message[0] if system_message else "")).strip()
return complete
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
context = kwargs.get("context") or CONTEXT
logging(f"Received message from {event.sender} in room {room.room_id}")
if isinstance(event, MegolmEvent):
try:
event = await context["client"].decrypt_event(event)
except Exception as e:
try:
logging("Requesting new encryption keys...")
await context["client"].request_room_key(event)
except:
pass
logging(f"Error decrypting message: {e}", "error")
await send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True, context["client"])
return
if event.sender == context["client"].user_id:
logging("Message is from bot itself - ignoring")
elif event.body.startswith("!gptbot"):
await process_command(room, event)
elif event.body.startswith("!"):
logging("Might be a command, but not for this bot - ignoring")
else:
await process_query(room, event, context=context)
async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs):
client: AsyncClient = kwargs.get("client") or CONTEXT["client"]
if room.room_id in client.rooms:
logging(f"Already in room {room.room_id} - ignoring invite")
return
logging(f"Received invite to room {room.room_id} - joining...")
response = await client.join(room.room_id)
if isinstance(response, JoinResponse):
await send_message(room, "Hello! I'm a helpful assistant. How can I help you today?", client)
else:
logging(f"Error joining room {room.room_id}: {response}", "error")
async def send_message(room: MatrixRoom, message: str, notice: bool = False, client: Optional[AsyncClient] = None):
client = client or CONTEXT["client"]
markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
formatted_body = markdowner.convert(message)
msgtype = "m.notice" if notice else "m.text"
msgcontent = {"msgtype": msgtype, "body": message,
"format": "org.matrix.custom.html", "formatted_body": formatted_body}
content = None
if client.olm and room.encrypted:
try:
if not room.members_synced:
responses = []
responses.append(await client.joined_members(room.room_id))
if client.olm.should_share_group_session(room.room_id):
try:
event = client.sharing_session[room.room_id]
await event.wait()
except KeyError:
await client.share_group_session(
room.room_id,
ignore_unverified_devices=True,
)
if msgtype != "m.reaction":
response = client.encrypt(
room.room_id, "m.room.message", msgcontent)
msgtype, content = response
except Exception as e:
logging(
f"Error encrypting message: {e} - sending unencrypted", "error")
raise
if not content:
msgtype = "m.room.message"
content = msgcontent
method, path, data = Api.room_send(
client.access_token, room.room_id, msgtype, content, uuid.uuid4()
)
return await client._send(RoomSendResponse, method, path, data, (room.room_id,))
async def accept_pending_invites(client: Optional[AsyncClient] = None):
client = client or CONTEXT["client"]
logging("Accepting pending invites...")
for room_id in list(client.invited_rooms.keys()):
logging(f"Joining room {room_id}...")
response = await client.join(room_id)
if isinstance(response, JoinResponse):
logging(response, "debug")
rooms = await client.joined_rooms()
await send_message(client.rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
else:
logging(f"Error joining room {room_id}: {response}", "error")
async def sync_cb(response, write_global: bool = True):
logging(
f"Sync response received (next batch: {response.next_batch})", "debug")
SYNC_TOKEN = response.next_batch
if write_global:
global CONTEXT
CONTEXT["sync_token"] = SYNC_TOKEN
async def test_callback(room: MatrixRoom, event: Event, **kwargs):
logging(
f"Received event {event.__class__.__name__} in room {room.room_id}", "debug")
async def init(config: ConfigParser):
# Set up Matrix client
try:
assert "Matrix" in config
assert "Homeserver" in config["Matrix"]
assert "AccessToken" in config["Matrix"]
except:
logging("Matrix config not found or incomplete", "critical")
exit(1)
homeserver = config["Matrix"]["Homeserver"]
access_token = config["Matrix"]["AccessToken"]
device_id, user_id = await get_device_id(access_token, homeserver)
device_id = config["Matrix"].get("DeviceID", device_id)
user_id = config["Matrix"].get("UserID", user_id)
# Set up database
if "Database" in config and config["Database"].get("Path"):
database = CONTEXT["database"] = initialize_database(
config["Database"]["Path"])
matrix_store = DuckDBStore
client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
else:
client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=False)
client = AsyncClient(
config["Matrix"]["Homeserver"], config=client_config)
if client.config.encryption_enabled:
client.store = client.config.store(
user_id,
device_id,
database
)
assert client.store
client.olm = Olm(client.user_id, client.device_id, client.store)
client.encrypted_rooms = client.store.load_encrypted_rooms()
CONTEXT["client"] = client
CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
CONTEXT["client"].user_id = user_id
CONTEXT["client"].device_id = device_id
# Set up GPT API
try:
assert "OpenAI" in config
assert "APIKey" in config["OpenAI"]
except:
logging("OpenAI config not found or incomplete", "critical")
exit(1)
openai.api_key = config["OpenAI"]["APIKey"]
if "Model" in config["OpenAI"]:
CONTEXT["model"] = config["OpenAI"]["Model"]
if "MaxTokens" in config["OpenAI"]:
CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
if "MaxMessages" in config["OpenAI"]:
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
# Override defaults with config
if "GPTBot" in config:
if "SystemMessage" in config["GPTBot"]:
CONTEXT["system_message"] = config["GPTBot"]["SystemMessage"]
if "DefaultRoomName" in config["GPTBot"]:
CONTEXT["default_room_name"] = config["GPTBot"]["DefaultRoomName"]
if "ForceSystemMessage" in config["GPTBot"]:
CONTEXT["force_system_message"] = config["GPTBot"].getboolean(
"ForceSystemMessage")
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
if not client and not CONTEXT.get("client"):
await init(config)
client = client or CONTEXT["client"]
try:
assert client.user_id
except AssertionError:
logging(
"Failed to get user ID - check your access token or try setting it manually", "critical")
await client.close()
return
# Listen for SIGTERM
def sigterm_handler(_signo, _stack_frame):
logging("Received SIGTERM - exiting...")
exit()
signal.signal(signal.SIGTERM, sigterm_handler)
logging("Starting bot...")
client.add_response_callback(sync_cb, SyncResponse)
logging("Syncing...")
await client.sync(timeout=30000)
client.add_event_callback(message_callback, RoomMessageText)
client.add_event_callback(message_callback, MegolmEvent)
client.add_event_callback(room_invite_callback, InviteEvent)
client.add_event_callback(test_callback, Event)
await accept_pending_invites() # Accept pending invites
logging("Bot started")
try:
# Continue syncing events
await client.sync_forever(timeout=30000)
finally:
logging("Syncing one last time...")
await client.sync(timeout=30000)
await client.close() # Properly close the aiohttp client session
logging("Bot stopped")
def initialize_database(path: os.PathLike):
logging("Initializing database...")
conn = duckdb.connect(path)
with conn.cursor() as cursor:
# Get the latest migration ID if the migrations table exists
try:
cursor.execute(
"""
SELECT MAX(id) FROM migrations
"""
)
latest_migration = int(cursor.fetchone()[0])
except:
latest_migration = 0
for migration, function in MIGRATIONS.items():
if latest_migration < migration:
logging(f"Running migration {migration}...")
function(conn)
latest_migration = migration
return conn
async def get_device_id(access_token, homeserver):
client = AsyncClient(homeserver)
client.access_token = access_token
logging(f"Obtaining device ID for access token {access_token}...", "debug")
response = await client.whoami()
if isinstance(response, WhoamiResponse):
logging(
f"Authenticated as {response.user_id}.")
user_id = response.user_id
devices = await client.devices()
device_id = devices.devices[0].id
await client.close()
return device_id, user_id
else:
logging(f"Failed to obtain device ID: {response}", "error")
await client.close()
return None, None
def sigterm_handler(_signo, _stack_frame):
exit()
if __name__ == "__main__":
@ -562,13 +22,16 @@ if __name__ == "__main__":
config = ConfigParser()
config.read(args.config)
# Start bot loop
# Create bot
bot = GPTBot.from_config(config)
# Listen for SIGTERM
signal.signal(signal.SIGTERM, sigterm_handler)
# Start bot
try:
asyncio.run(main(config))
asyncio.run(bot.run())
except KeyboardInterrupt:
logging("Received KeyboardInterrupt - exiting...")
print("Received KeyboardInterrupt - exiting...")
except SystemExit:
logging("Received SIGTERM - exiting...")
finally:
if CONTEXT["database"]:
CONTEXT["database"].close()
print("Received SIGTERM - exiting...")

View file

@ -1,4 +1,7 @@
from collections import OrderedDict
from typing import Optional
from duckdb import DuckDBPyConnection
from .migration_1 import migration as migration_1
from .migration_2 import migration as migration_2
@ -8,4 +11,43 @@ MIGRATIONS = OrderedDict()
MIGRATIONS[1] = migration_1
MIGRATIONS[2] = migration_2
MIGRATIONS[3] = migration_3
MIGRATIONS[3] = migration_3
def get_version(db: DuckDBPyConnection) -> int:
"""Get the current database version.
Args:
db (DuckDBPyConnection): Database connection.
Returns:
int: Current database version.
"""
try:
return int(db.execute("SELECT MAX(id) FROM migrations").fetchone()[0])
except:
return 0
def migrate(db: DuckDBPyConnection, from_version: Optional[int] = None, to_version: Optional[int] = None) -> None:
"""Migrate the database to a specific version.
Args:
db (DuckDBPyConnection): Database connection.
from_version (Optional[int]): Version to migrate from. If None, the current version is used.
to_version (Optional[int]): Version to migrate to. If None, the latest version is used.
"""
if from_version is None:
from_version = get_version(db)
if to_version is None:
to_version = max(MIGRATIONS.keys())
if from_version > to_version:
raise ValueError("Cannot migrate from a higher version to a lower version.")
for version in range(from_version, to_version):
if version in MIGRATIONS:
MIGRATIONS[version + 1](db)
return from_version, to_version