diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/callbacks/__init__.py b/callbacks/__init__.py new file mode 100644 index 0000000..e2693bb --- /dev/null +++ b/callbacks/__init__.py @@ -0,0 +1,29 @@ +from nio import ( + RoomMessageText, + MegolmEvent, + InviteEvent, + Event, + SyncResponse, + JoinResponse, + InviteEvent, + OlmEvent, + MegolmEvent +) + +from .test import test_callback +from .sync import sync_callback +from .invite import room_invite_callback +from .join import join_callback +from .message import message_callback + +RESPONSE_CALLBACKS = { + SyncResponse: sync_callback, + JoinResponse: join_callback, +} + +EVENT_CALLBACKS = { + Event: test_callback, + InviteEvent: room_invite_callback, + RoomMessageText: message_callback, + MegolmEvent: message_callback, +} \ No newline at end of file diff --git a/callbacks/invite.py b/callbacks/invite.py new file mode 100644 index 0000000..67995bc --- /dev/null +++ b/callbacks/invite.py @@ -0,0 +1,10 @@ +from nio import InviteEvent, MatrixRoom + +async def room_invite_callback(room: MatrixRoom, event: InviteEvent, bot): + if room.room_id in bot.matrix_client.rooms: + logging(f"Already in room {room.room_id} - ignoring invite") + return + + bot.logger.log(f"Received invite to room {room.room_id} - joining...") + + response = await bot.matrix_client.join(room.room_id) \ No newline at end of file diff --git a/callbacks/join.py b/callbacks/join.py new file mode 100644 index 0000000..b424c13 --- /dev/null +++ b/callbacks/join.py @@ -0,0 +1,7 @@ +async def join_callback(response, bot): + bot.logger.log( + f"Join response received for room {response.room_id}", "debug") + + bot.matrix_client.joined_rooms() + + await bot.send_message(bot.matrix_client.rooms[response.room_id], "Hello! Thanks for inviting me! How can I help you today?") \ No newline at end of file diff --git a/callbacks/message.py b/callbacks/message.py new file mode 100644 index 0000000..68e0a3e --- /dev/null +++ b/callbacks/message.py @@ -0,0 +1,30 @@ +from nio import MatrixRoom, RoomMessageText, MegolmEvent + +async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, bot): + bot.logger.log(f"Received message from {event.sender} in room {room.room_id}") + + if isinstance(event, MegolmEvent): + try: + event = await bot.matrix_client.decrypt_event(event) + except Exception as e: + try: + bot.logger.log("Requesting new encryption keys...") + await bot.matrix_client.request_room_key(event) + except: + pass + + bot.logger.log(f"Error decrypting message: {e}", "error") + await bot.send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True) + return + + if event.sender == bot.matrix_client.user_id: + bot.logger.log("Message is from bot itself - ignoring") + + elif event.body.startswith("!gptbot"): + await bot.process_command(room, event) + + elif event.body.startswith("!"): + bot.logger.log(f"Received {event.body} - might be a command, but not for this bot - ignoring") + + else: + await bot.process_query(room, event) \ No newline at end of file diff --git a/callbacks/sync.py b/callbacks/sync.py new file mode 100644 index 0000000..fc4a9e7 --- /dev/null +++ b/callbacks/sync.py @@ -0,0 +1,6 @@ +async def sync_callback(response, bot): + bot.logger.log( + f"Sync response received (next batch: {response.next_batch})", "debug") + SYNC_TOKEN = response.next_batch + + bot.sync_token = SYNC_TOKEN \ No newline at end of file diff --git a/callbacks/test.py b/callbacks/test.py new file mode 100644 index 0000000..137b61d --- /dev/null +++ b/callbacks/test.py @@ -0,0 +1,11 @@ +from nio import MatrixRoom, Event + +async def test_callback(room: MatrixRoom, event: Event, bot): + """Test callback for debugging purposes. + + Args: + room (MatrixRoom): The room the event was sent in. + event (Event): The event that was sent. + """ + + bot.logger.log(f"Test callback called: {room.room_id} {event.event_id} {event.sender} {event.__class__}") \ No newline at end of file diff --git a/classes/bot.py b/classes/bot.py new file mode 100644 index 0000000..a68c8aa --- /dev/null +++ b/classes/bot.py @@ -0,0 +1,481 @@ +import openai +import markdown2 +import duckdb +import tiktoken + +import asyncio + +from nio import ( + AsyncClient, + AsyncClientConfig, + WhoamiResponse, + DevicesResponse, + Event, + Response, + MatrixRoom, + Api, + RoomMessagesError, + MegolmEvent, + GroupEncryptionError, + EncryptionError, + RoomMessageText, + RoomSendResponse, + SyncResponse +) +from nio.crypto import Olm + +from typing import Optional, List, Dict +from configparser import ConfigParser +from datetime import datetime + +import uuid + +from .logging import Logger +from migrations import migrate +from callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS +from commands import COMMANDS +from .store import DuckDBStore + + +class GPTBot: + # Default values + database: Optional[duckdb.DuckDBPyConnection] = None + default_room_name: str = "GPTBot" # Default name of rooms created by the bot + default_system_message: str = "You are a helpful assistant." + # Force default system message to be included even if a custom room message is set + force_system_message: bool = False + max_tokens: int = 3000 # Maximum number of input tokens + max_messages: int = 30 # Maximum number of messages to consider as input + model: str = "gpt-3.5-turbo" # OpenAI chat model to use + matrix_client: Optional[AsyncClient] = None + sync_token: Optional[str] = None + logger: Optional[Logger] = Logger() + openai_api_key: Optional[str] = None + + @classmethod + def from_config(cls, config: ConfigParser): + """Create a new GPTBot instance from a config file. + + Args: + config (ConfigParser): ConfigParser instance with the bot's config. + + Returns: + GPTBot: The new GPTBot instance. + """ + + # Create a new GPTBot instance + bot = cls() + + # Set the database connection + bot.database = duckdb.connect( + config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None + + # Override default values + if "GPTBot" in config: + bot.default_room_name = config["GPTBot"].get( + "DefaultRoomName", bot.default_room_name) + bot.default_system_message = config["GPTBot"].get( + "SystemMessage", bot.default_system_message) + bot.force_system_message = config["GPTBot"].getboolean( + "ForceSystemMessage", bot.force_system_message) + + bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens) + bot.max_messages = config["OpenAI"].getint( + "MaxMessages", bot.max_messages) + bot.model = config["OpenAI"].get("Model", bot.model) + + bot.openai_api_key = config["OpenAI"]["APIKey"] + + # Set up the Matrix client + + assert "Matrix" in config, "Matrix config not found" + + homeserver = config["Matrix"]["Homeserver"] + bot.matrix_client = AsyncClient(homeserver) + bot.matrix_client.access_token = config["Matrix"]["AccessToken"] + bot.matrix_client.user_id = config["Matrix"].get("UserID") + bot.matrix_client.device_id = config["Matrix"].get("DeviceID") + + # Return the new GPTBot instance + return bot + + async def _get_user_id(self) -> str: + """Get the user ID of the bot from the whoami endpoint. + Requires an access token to be set up. + + Returns: + str: The user ID of the bot. + """ + + assert self.matrix_client, "Matrix client not set up" + + user_id = self.matrix_client.user_id + + if not user_id: + assert self.matrix_client.access_token, "Access token not set up" + + response = await self.matrix_client.whoami() + + if isinstance(response, WhoamiResponse): + user_id = response.user_id + else: + raise Exception(f"Could not get user ID: {response}") + + return user_id + + async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]): + messages = [] + n = n or bot.max_messages + room_id = room.room_id if isinstance(room, MatrixRoom) else room + + self.logger.log( + f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...") + + response = await self.matrix_client.room_messages( + room_id=room_id, + start=self.sync_token, + limit=2*n, + ) + + if isinstance(response, RoomMessagesError): + raise Exception( + f"Error fetching messages: {response.message} (status code {response.status_code})", "error") + + for event in response.chunk: + if len(messages) >= n: + break + if isinstance(event, MegolmEvent): + try: + event = await self.matrix_client.decrypt_event(event) + except (GroupEncryptionError, EncryptionError): + self.logger.log( + f"Could not decrypt message {event.event_id} in room {room_id}", "error") + continue + if isinstance(event, RoomMessageText): + if event.body.startswith("!gptbot ignoreolder"): + break + if not event.body.startswith("!"): + messages.append(event) + + self.logger.log(f"Found {len(messages)} messages (limit: {n})") + + # Reverse the list so that messages are in chronological order + return messages[::-1] + + def _truncate(self, messages: list, max_tokens: Optional[int] = None, + model: Optional[str] = None, system_message: Optional[str] = None): + max_tokens = max_tokens or self.max_tokens + model = model or self.model + system_message = self.default_system_message if system_message is None else system_message + + encoding = tiktoken.encoding_for_model(model) + total_tokens = 0 + + system_message_tokens = 0 if not system_message else ( + len(encoding.encode(system_message)) + 1) + + if system_message_tokens > max_tokens: + self.logger.log( + f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error") + return [] + + total_tokens += system_message_tokens + + total_tokens = len(system_message) + 1 + truncated_messages = [] + + for message in [messages[0]] + list(reversed(messages[1:])): + content = message["content"] + tokens = len(encoding.encode(content)) + 1 + if total_tokens + tokens > max_tokens: + break + total_tokens += tokens + truncated_messages.append(message) + + return [truncated_messages[0]] + list(reversed(truncated_messages[1:])) + + async def _get_device_id(self) -> str: + """Guess the device ID of the bot. + Requires an access token to be set up. + + Returns: + str: The guessed device ID. + """ + + assert self.matrix_client, "Matrix client not set up" + + device_id = self.matrix_client.device_id + + if not device_id: + assert self.matrix_client.access_token, "Access token not set up" + + devices = await self.matrix_client.devices() + + if isinstance(devices, DevicesResponse): + device_id = devices.devices[0].id + + return device_id + + async def process_command(self, room: MatrixRoom, event: RoomMessageText): + self.logger.log( + f"Received command {event.body} from {event.sender} in room {room.room_id}") + command = event.body.split()[1] if event.body.split()[1:] else None + + await COMMANDS.get(command, COMMANDS[None])(room, event, self) + + async def event_callback(self,room: MatrixRoom, event: Event): + self.logger.log("Received event: " + str(event), "debug") + for eventtype, callback in EVENT_CALLBACKS.items(): + if isinstance(event, eventtype): + await callback(room, event, self) + + async def response_callback(self, response: Response): + for response_type, callback in RESPONSE_CALLBACKS.items(): + if isinstance(response, response_type): + await callback(response, self) + + async def accept_pending_invites(self): + """Accept all pending invites.""" + + assert self.matrix_client, "Matrix client not set up" + + invites = self.matrix_client.invited_rooms + + for invite in invites.keys(): + await self.matrix_client.join(invite) + + async def send_message(self, room: MatrixRoom, message: str, notice: bool = False): + markdowner = markdown2.Markdown(extras=["fenced-code-blocks"]) + formatted_body = markdowner.convert(message) + + msgtype = "m.notice" if notice else "m.text" + + msgcontent = {"msgtype": msgtype, "body": message, + "format": "org.matrix.custom.html", "formatted_body": formatted_body} + + content = None + + if self.matrix_client.olm and room.encrypted: + try: + if not room.members_synced: + responses = [] + responses.append(await self.matrix_client.joined_members(room.room_id)) + + if self.matrix_client.olm.should_share_group_session(room.room_id): + try: + event = self.matrix_client.sharing_session[room.room_id] + await event.wait() + except KeyError: + await self.matrix_client.share_group_session( + room.room_id, + ignore_unverified_devices=True, + ) + + if msgtype != "m.reaction": + response = self.matrix_client.encrypt( + room.room_id, "m.room.message", msgcontent) + msgtype, content = response + + except Exception as e: + self.logger.log( + f"Error encrypting message: {e} - sending unencrypted", "error") + raise + + if not content: + msgtype = "m.room.message" + content = msgcontent + + method, path, data = Api.room_send( + self.matrix_client.access_token, room.room_id, msgtype, content, uuid.uuid4() + ) + + return await self.matrix_client._send(RoomSendResponse, method, path, data, (room.room_id,)) + + async def run(self): + """Start the bot.""" + + # Set up the Matrix client + + assert self.matrix_client, "Matrix client not set up" + assert self.matrix_client.access_token, "Access token not set up" + + if not self.matrix_client.user_id: + self.matrix_client.user_id = await self._get_user_id() + + if not self.matrix_client.device_id: + self.matrix_client.device_id = await self._get_device_id() + + # Set up database + + IN_MEMORY = False + if not self.database: + self.logger.log( + "No database connection set up, using in-memory database. Data will be lost on bot shutdown.") + IN_MEMORY = True + self.database = DuckDBPyConnection(":memory:") + + self.logger.log("Running migrations...") + before, after = migrate(self.database) + if before != after: + self.logger.log(f"Migrated from version {before} to {after}.") + else: + self.logger.log(f"Already at latest version {after}.") + + if IN_MEMORY: + client_config = AsyncClientConfig( + store_sync_tokens=True, encryption_enabled=False) + else: + matrix_store = DuckDBStore + client_config = AsyncClientConfig( + store_sync_tokens=True, encryption_enabled=True, store=matrix_store) + self.matrix_client.config = client_config + self.matrix_client.store = matrix_store( + self.matrix_client.user_id, + self.matrix_client.device_id, + self.database + ) + + self.matrix_client.olm = Olm( + self.matrix_client.user_id, + self.matrix_client.device_id, + self.matrix_client.store + ) + + self.matrix_client.encrypted_rooms = self.matrix_client.store.load_encrypted_rooms() + + # Run initial sync + sync = await self.matrix_client.sync(timeout=30000) + if isinstance(sync, SyncResponse): + await self.response_callback(sync) + else: + self.logger.log(f"Initial sync failed, aborting: {sync}", "error") + return + + # Set up callbacks + + self.matrix_client.add_event_callback(self.event_callback, Event) + self.matrix_client.add_response_callback(self.response_callback, Response) + + # Accept pending invites + + self.logger.log("Accepting pending invites...") + await self.accept_pending_invites() + + # Start syncing events + self.logger.log("Starting sync loop...") + try: + await self.matrix_client.sync_forever(timeout=30000) + finally: + self.logger.log("Syncing one last time...") + await self.matrix_client.sync(timeout=30000) + + async def process_query(self, room: MatrixRoom, event: RoomMessageText): + await self.matrix_client.room_typing(room.room_id, True) + + await self.matrix_client.room_read_markers(room.room_id, event.event_id) + + try: + last_messages = await self._last_n_messages(room.room_id, 20) + except Exception as e: + self.logger.log(f"Error getting last messages: {e}", "error") + await self.send_message( + room, "Something went wrong. Please try again.", True) + return + + system_message = self.get_system_message(room) + + chat_messages = [{"role": "system", "content": system_message}] + + for message in last_messages: + role = "assistant" if message.sender == self.matrix_client.user_id else "user" + if not message.event_id == event.event_id: + chat_messages.append({"role": role, "content": message.body}) + + chat_messages.append({"role": "user", "content": event.body}) + + # Truncate messages to fit within the token limit + truncated_messages = self._truncate( + chat_messages, self.max_tokens - 1, system_message=system_message) + + try: + response, tokens_used = await self.generate_chat_response(truncated_messages) + except Exception as e: + self.logger.log(f"Error generating response: {e}", "error") + await self.send_message( + room, "Something went wrong. Please try again.", True) + return + + if response: + self.logger.log(f"Sending response to room {room.room_id}...") + + # Convert markdown to HTML + + message = await self.send_message(room, response) + + if self.database: + self.logger.log("Storing record of tokens used...") + + with self.database.cursor() as cursor: + cursor.execute( + "INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)", + (message.event_id, room.room_id, tokens_used, datetime.now())) + self.database.commit() + else: + # Send a notice to the room if there was an error + self.logger.log("Didn't get a response from GPT API", "error") + send_message( + room, "Something went wrong. Please try again.", True) + + await self.matrix_client.room_typing(room.room_id, False) + + async def generate_chat_response(self, messages: List[Dict[str, str]]) -> str: + """Generate a response to a chat message. + + Args: + messages (List[Dict[str, str]]): A list of messages to use as context. + + Returns: + str: The response to the chat. + """ + + self.logger.log(f"Generating response to {len(messages)} messages...") + + response = openai.ChatCompletion.create( + model=self.model, + messages=messages, + api_key=self.openai_api_key + ) + + result_text = response.choices[0].message['content'] + tokens_used = response.usage["total_tokens"] + self.logger.log(f"Generated response with {tokens_used} tokens.") + return result_text, tokens_used + + def get_system_message(self, room: MatrixRoom | int) -> str: + default = self.default_system_message + + if isinstance(room, int): + room_id = room + else: + room_id = room.room_id + + with self.database.cursor() as cur: + cur.execute( + "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1", + (room_id,) + ) + system_message = cur.fetchone() + + complete = ((default if ((not system_message) or self.force_system_message) else "") + ( + "\n\n" + system_message[0] if system_message else "")).strip() + + return complete + + def __del__(self): + """Close the bot.""" + + if self.matrix_client: + asyncio.run(self.matrix_client.close()) + + if self.database: + self.database.close() diff --git a/classes/logging.py b/classes/logging.py new file mode 100644 index 0000000..4b6022a --- /dev/null +++ b/classes/logging.py @@ -0,0 +1,10 @@ +import inspect + +from datetime import datetime + + +class Logger: + def log(self, message: str, log_level: str = "info"): + caller = inspect.currentframe().f_back.f_code.co_name + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f") + print(f"[{timestamp}] - {caller} - [{log_level.upper()}] {message}") diff --git a/commands/botinfo.py b/commands/botinfo.py index 430912a..2f505ae 100644 --- a/commands/botinfo.py +++ b/commands/botinfo.py @@ -2,21 +2,22 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom -async def command_botinfo(room: MatrixRoom, event: RoomMessageText, context: dict): +async def command_botinfo(room: MatrixRoom, event: RoomMessageText, bot): logging("Showing bot info...") - return room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": f"""GPT Info: + body = f"""GPT Info: -Model: {context["model"]} -Maximum context tokens: {context["max_tokens"]} -Maximum context messages: {context["max_messages"]} -System message: {context["system_message"]} +Model: {bot.model} +Maximum context tokens: {bot.max_tokens} +Maximum context messages: {bot.max_messages} Room info: -Bot user ID: {context["client"].user_id} +Bot user ID: {bot.matrix_client.user_id} Current room ID: {room.room_id} +System message: {bot.get_system_message(room)} For usage statistics, run !gptbot stats -"""} +""" + + await bot.send_message(room, body, True) diff --git a/commands/coin.py b/commands/coin.py index 82a3741..da45236 100644 --- a/commands/coin.py +++ b/commands/coin.py @@ -4,10 +4,10 @@ from nio.rooms import MatrixRoom from random import SystemRandom -async def command_coin(room: MatrixRoom, event: RoomMessageText, context: dict): - context["logger"]("Flipping a coin...") +async def command_coin(room: MatrixRoom, event: RoomMessageText, bot): + bot.logger.log("Flipping a coin...") heads = SystemRandom().choice([True, False]) + body = "Flipping a coin... It's " + ("heads!" if heads else "tails!") - return room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": "Heads!" if heads else "Tails!"} + await bot.send_message(room, body, True) \ No newline at end of file diff --git a/commands/help.py b/commands/help.py index d5fb235..de4e58d 100644 --- a/commands/help.py +++ b/commands/help.py @@ -2,9 +2,8 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom -async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict): - return room.guest_accessroom_id, "m.room.message", {"msgtype": "m.notice", - "body": """Available commands: +async def command_help(room: MatrixRoom, event: RoomMessageText, bot): + body = """Available commands: !gptbot help - Show this message !gptbot newroom - Create a new room and invite yourself to it @@ -12,4 +11,7 @@ async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict): !gptbot botinfo - Show information about the bot !gptbot coin - Flip a coin (heads or tails) !gptbot ignoreolder - Ignore messages before this point as context -"""} +!gptbot systemmessage - Get or set the system message for this room +""" + + await bot.send_message(room, body, True) \ No newline at end of file diff --git a/commands/ignoreolder.py b/commands/ignoreolder.py index 348abe8..1dfce72 100644 --- a/commands/ignoreolder.py +++ b/commands/ignoreolder.py @@ -1,8 +1,9 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom -async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, context: dict): - return room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": """Alright, messages before this point will not be processed as context anymore. +async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, bot): + body = """Alright, messages before this point will not be processed as context anymore. -If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""} \ No newline at end of file +If you ever reconsider, you can simply delete your message and I will start processing messages before it again.""" + + await bot.send_message(room, body, True) \ No newline at end of file diff --git a/commands/newroom.py b/commands/newroom.py index 823fb7a..26f54b9 100644 --- a/commands/newroom.py +++ b/commands/newroom.py @@ -1,17 +1,31 @@ from nio.events.room_events import RoomMessageText +from nio import RoomCreateError, RoomInviteError from nio.rooms import MatrixRoom -async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dict): +async def command_newroom(room: MatrixRoom, event: RoomMessageText, bot): room_name = " ".join(event.body.split()[ 2:]) or context["default_room_name"] - context["logger"]("Creating new room...") - new_room = await context["client"].room_create(name=room_name) + bot.logger.log("Creating new room...") + new_room = await bot.matrix_client.room_create(name=room_name) + + if isinstance(new_room, RoomCreateError): + bot.logger.log(f"Failed to create room: {new_room.message}") + await bot.send_message(room, f"Sorry, I was unable to create a new room. Please try again later, or create a room manually.", True) + return + + bot.logger.log(f"Inviting {event.sender} to new room...") + invite = await bot.matrix_client.room_invite(new_room.room_id, event.sender) + + if isinstance(invite, RoomInviteError): + bot.logger.log(f"Failed to invite user: {invite.message}") + await bot.send_message(room, f"Sorry, I was unable to invite you to the new room. Please try again later, or create a room manually.", True) + return - context["logger"](f"Inviting {event.sender} to new room...") - await context["client"].room_invite(new_room.room_id, event.sender) await context["client"].room_put_state( new_room.room_id, "m.room.power_levels", {"users": {event.sender: 100}}) - return new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"} + await bot.matrix_client.joined_rooms() + await bot.send_message(room, f"Alright, I've created a new room called '{room_name}' and invited you to it. You can find it at {new_room.room_id}", True) + await bot.send_message(new_room.room_id, f"Welcome to the new room! What can I do for you?") \ No newline at end of file diff --git a/commands/stats.py b/commands/stats.py index cffb3e4..ed543a8 100644 --- a/commands/stats.py +++ b/commands/stats.py @@ -2,18 +2,17 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom -async def command_stats(room: MatrixRoom, event: RoomMessageText, context: dict): - context["logger"]("Showing stats...") +async def command_stats(room: MatrixRoom, event: RoomMessageText, bot): + bot.logger.log("Showing stats...") - if not (database := context.get("database")): - context["logger"]("No database connection - cannot show stats") - return room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."} + if not bot.database: + bot.logger.log("No database connection - cannot show stats") + bot.send_message(room, "Sorry, I'm not connected to a database, so I don't have any statistics on your usage.", True) + return - with database.cursor() as cursor: + with bot.database.cursor() as cursor: cursor.execute( "SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,)) total_tokens = cursor.fetchone()[0] or 0 - return room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": f"Total tokens used: {total_tokens}"} + bot.send_message(room, f"Total tokens used: {total_tokens}", True) diff --git a/commands/systemmessage.py b/commands/systemmessage.py index 571a2ef..cd3d4f4 100644 --- a/commands/systemmessage.py +++ b/commands/systemmessage.py @@ -2,33 +2,24 @@ from nio.events.room_events import RoomMessageText from nio.rooms import MatrixRoom -async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, context: dict): +async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, bot): system_message = " ".join(event.body.split()[2:]) if system_message: - context["logger"]("Adding system message...") + bot.logger.log("Adding system message...") - with context["database"].cursor() as cur: + with bot.database.cursor() as cur: cur.execute( "INSERT INTO system_messages (room_id, message_id, user_id, body, timestamp) VALUES (?, ?, ?, ?, ?)", (room.room_id, event.event_id, event.sender, system_message, event.server_timestamp) ) - return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message stored: {system_message}"} + bot.send_message(room, f"Alright, I've stored the system message: '{system_message}'.", True) + return - context["logger"]("Retrieving system message...") + bot.logger.log("Retrieving system message...") - with context["database"].cursor() as cur: - cur.execute( - "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1", - (room.room_id,) - ) - system_message = cur.fetchone() + system_message = bot.get_system_message(room) - if system_message is None: - system_message = context.get("system_message", "No system message set") - elif context.get("force_system_message") and context.get("system_message"): - system_message = system_message + "\n\n" + context["system_message"] - - return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message: {system_message}"} + bot.send_message(room, f"The current system message is: '{system_message}'.", True) diff --git a/commands/unknown.py b/commands/unknown.py index ce01eb9..5a5bdc7 100644 --- a/commands/unknown.py +++ b/commands/unknown.py @@ -3,7 +3,6 @@ from nio.rooms import MatrixRoom async def command_unknown(room: MatrixRoom, event: RoomMessageText, context: dict): - context["logger"]("Unknown command") + bot.logger.log("Unknown command") - return room.room_id, "m.room.message", {"msgtype": "m.notice", - "body": "Unknown command - try !gptbot help"} + bot.send_message(room, "Unknown command - try !gptbot help", True) \ No newline at end of file diff --git a/gptbot.py b/gptbot.py index 18e51d7..8487f3a 100644 --- a/gptbot.py +++ b/gptbot.py @@ -1,554 +1,14 @@ -import os -import inspect -import logging -import signal -import random -import uuid +from classes.bot import GPTBot -import openai -import asyncio -import markdown2 -import tiktoken -import duckdb - -from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent, AsyncClientConfig, MegolmEvent, GroupEncryptionError, EncryptionError, HttpClient, Api -from nio.api import MessageDirection -from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError, WhoamiResponse, JoinResponse, RoomSendResponse -from nio.crypto import Olm - -from configparser import ConfigParser -from datetime import datetime from argparse import ArgumentParser -from typing import List, Dict, Union, Optional +from configparser import ConfigParser -from commands import COMMANDS -from classes import DuckDBStore -from migrations import MIGRATIONS +import signal +import asyncio -def logging(message: str, log_level: str = "info"): - caller = inspect.currentframe().f_back.f_code.co_name - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f") - print(f"[{timestamp}] - {caller} - [{log_level.upper()}] {message}") - - -CONTEXT = { - "database": False, - "default_room_name": "GPTBot", - "system_message": "You are a helpful assistant.", - "force_system_message": False, - "max_tokens": 3000, - "max_messages": 20, - "model": "gpt-3.5-turbo", - "client": None, - "sync_token": None, - "logger": logging -} - - -async def gpt_query(messages: list, model: Optional[str] = None): - model = model or CONTEXT["model"] - - logging(f"Querying GPT with {len(messages)} messages") - logging(messages, "debug") - - try: - response = openai.ChatCompletion.create( - model=model, - messages=messages - ) - result_text = response.choices[0].message['content'] - tokens_used = response.usage["total_tokens"] - logging(f"Used {tokens_used} tokens") - return result_text, tokens_used - - except Exception as e: - logging(f"Error during GPT API call: {e}", "error") - return None, 0 - - -async def fetch_last_n_messages(room_id: str, n: Optional[int] = None, - client: Optional[AsyncClient] = None, sync_token: Optional[str] = None): - messages = [] - - n = n or CONTEXT["max_messages"] - client = client or CONTEXT["client"] - sync_token = sync_token or CONTEXT["sync_token"] - - logging( - f"Fetching last {2*n} messages from room {room_id} (starting at {sync_token})...") - - response = await client.room_messages( - room_id=room_id, - start=sync_token, - limit=2*n, - ) - - if isinstance(response, RoomMessagesError): - logging( - f"Error fetching messages: {response.message} (status code {response.status_code})", "error") - return [] - - for event in response.chunk: - if len(messages) >= n: - break - if isinstance(event, MegolmEvent): - try: - event = await client.decrypt_event(event) - except (GroupEncryptionError, EncryptionError): - logging( - f"Could not decrypt message {event.event_id} in room {room_id}", "error") - continue - if isinstance(event, RoomMessageText): - if event.body.startswith("!gptbot ignoreolder"): - break - if not event.body.startswith("!"): - messages.append(event) - - logging(f"Found {len(messages)} messages (limit: {n})") - - # Reverse the list so that messages are in chronological order - return messages[::-1] - - -def truncate_messages_to_fit_tokens(messages: list, max_tokens: Optional[int] = None, - model: Optional[str] = None, system_message: Optional[str] = None): - max_tokens = max_tokens or CONTEXT["max_tokens"] - model = model or CONTEXT["model"] - system_message = system_message or CONTEXT["system_message"] - - encoding = tiktoken.encoding_for_model(model) - total_tokens = 0 - - system_message_tokens = len(encoding.encode(system_message)) + 1 - - if system_message_tokens > max_tokens: - logging( - f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error") - return [] - - total_tokens += system_message_tokens - - total_tokens = len(system_message) + 1 - truncated_messages = [] - - for message in [messages[0]] + list(reversed(messages[1:])): - content = message["content"] - tokens = len(encoding.encode(content)) + 1 - if total_tokens + tokens > max_tokens: - break - total_tokens += tokens - truncated_messages.append(message) - - return [truncated_messages[0]] + list(reversed(truncated_messages[1:])) - - -async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs): - - client = kwargs.get("client") or CONTEXT["client"] - database = kwargs.get("database") or CONTEXT["database"] - max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"] - system_message = kwargs.get("system_message") or CONTEXT["system_message"] - force_system_message = kwargs.get("force_system_message") or CONTEXT["force_system_message"] - - await client.room_typing(room.room_id, True) - - await client.room_read_markers(room.room_id, event.event_id) - - last_messages = await fetch_last_n_messages(room.room_id, 20) - - system_message = get_system_message(room, { - "database": database, - "system_message": system_message, - "force_system_message": force_system_message, - }) - - chat_messages = [{"role": "system", "content": system_message}] - - for message in last_messages: - role = "assistant" if message.sender == client.user_id else "user" - if not message.event_id == event.event_id: - chat_messages.append({"role": role, "content": message.body}) - - chat_messages.append({"role": "user", "content": event.body}) - - # Truncate messages to fit within the token limit - truncated_messages = truncate_messages_to_fit_tokens( - chat_messages, max_tokens - 1, system_message=system_message) - response, tokens_used = await gpt_query(truncated_messages) - - if response: - logging(f"Sending response to room {room.room_id}...") - - # Convert markdown to HTML - - message = await send_message(room, response) - - if database: - logging("Logging tokens used...") - - with database.cursor() as cursor: - cursor.execute( - "INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)", - (message.event_id, room.room_id, tokens_used, datetime.now())) - database.commit() - else: - # Send a notice to the room if there was an error - - logging("Error during GPT API call - sending notice to room") - send_message( - room, "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later.", True) - print("No response from GPT API") - - await client.room_typing(room.room_id, False) - - -async def process_command(room: MatrixRoom, event: RoomMessageText, context: Optional[dict] = None): - context = context or CONTEXT - - logging( - f"Received command {event.body} from {event.sender} in room {room.room_id}") - command = event.body.split()[1] if event.body.split()[1:] else None - - message = await COMMANDS.get(command, COMMANDS[None])(room, event, context) - - if message: - room_id, event, content = message - rooms = await context["client"].joined_rooms() - await send_message(context["client"].rooms[room_id], content["body"], - True if content["msgtype"] == "m.notice" else False, context["client"]) - - -def get_system_message(room: MatrixRoom, context: Optional[dict]) -> str: - context = context or CONTEXT - - default = context.get("system_message") - - with context["database"].cursor() as cur: - cur.execute( - "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1", - (room.room_id,) - ) - system_message = cur.fetchone() - - complete = ((default if ((not system_message) or context["force_system_message"]) else "") + ( - "\n\n" + system_message[0] if system_message else "")).strip() - - return complete - - -async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs): - context = kwargs.get("context") or CONTEXT - - logging(f"Received message from {event.sender} in room {room.room_id}") - - if isinstance(event, MegolmEvent): - try: - event = await context["client"].decrypt_event(event) - except Exception as e: - try: - logging("Requesting new encryption keys...") - await context["client"].request_room_key(event) - except: - pass - - logging(f"Error decrypting message: {e}", "error") - await send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True, context["client"]) - return - - if event.sender == context["client"].user_id: - logging("Message is from bot itself - ignoring") - - elif event.body.startswith("!gptbot"): - await process_command(room, event) - - elif event.body.startswith("!"): - logging("Might be a command, but not for this bot - ignoring") - - else: - await process_query(room, event, context=context) - - -async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs): - client: AsyncClient = kwargs.get("client") or CONTEXT["client"] - - if room.room_id in client.rooms: - logging(f"Already in room {room.room_id} - ignoring invite") - return - - logging(f"Received invite to room {room.room_id} - joining...") - - response = await client.join(room.room_id) - if isinstance(response, JoinResponse): - await send_message(room, "Hello! I'm a helpful assistant. How can I help you today?", client) - else: - logging(f"Error joining room {room.room_id}: {response}", "error") - - -async def send_message(room: MatrixRoom, message: str, notice: bool = False, client: Optional[AsyncClient] = None): - client = client or CONTEXT["client"] - - markdowner = markdown2.Markdown(extras=["fenced-code-blocks"]) - formatted_body = markdowner.convert(message) - - msgtype = "m.notice" if notice else "m.text" - - msgcontent = {"msgtype": msgtype, "body": message, - "format": "org.matrix.custom.html", "formatted_body": formatted_body} - - content = None - - if client.olm and room.encrypted: - try: - if not room.members_synced: - responses = [] - responses.append(await client.joined_members(room.room_id)) - - if client.olm.should_share_group_session(room.room_id): - try: - event = client.sharing_session[room.room_id] - await event.wait() - except KeyError: - await client.share_group_session( - room.room_id, - ignore_unverified_devices=True, - ) - - if msgtype != "m.reaction": - response = client.encrypt( - room.room_id, "m.room.message", msgcontent) - msgtype, content = response - - except Exception as e: - logging( - f"Error encrypting message: {e} - sending unencrypted", "error") - raise - - if not content: - msgtype = "m.room.message" - content = msgcontent - - method, path, data = Api.room_send( - client.access_token, room.room_id, msgtype, content, uuid.uuid4() - ) - - return await client._send(RoomSendResponse, method, path, data, (room.room_id,)) - - -async def accept_pending_invites(client: Optional[AsyncClient] = None): - client = client or CONTEXT["client"] - - logging("Accepting pending invites...") - - for room_id in list(client.invited_rooms.keys()): - logging(f"Joining room {room_id}...") - - response = await client.join(room_id) - - if isinstance(response, JoinResponse): - logging(response, "debug") - rooms = await client.joined_rooms() - await send_message(client.rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client) - else: - logging(f"Error joining room {room_id}: {response}", "error") - - -async def sync_cb(response, write_global: bool = True): - logging( - f"Sync response received (next batch: {response.next_batch})", "debug") - SYNC_TOKEN = response.next_batch - - if write_global: - global CONTEXT - CONTEXT["sync_token"] = SYNC_TOKEN - - -async def test_callback(room: MatrixRoom, event: Event, **kwargs): - logging( - f"Received event {event.__class__.__name__} in room {room.room_id}", "debug") - - -async def init(config: ConfigParser): - # Set up Matrix client - try: - assert "Matrix" in config - assert "Homeserver" in config["Matrix"] - assert "AccessToken" in config["Matrix"] - except: - logging("Matrix config not found or incomplete", "critical") - exit(1) - - homeserver = config["Matrix"]["Homeserver"] - access_token = config["Matrix"]["AccessToken"] - - device_id, user_id = await get_device_id(access_token, homeserver) - - device_id = config["Matrix"].get("DeviceID", device_id) - user_id = config["Matrix"].get("UserID", user_id) - - # Set up database - if "Database" in config and config["Database"].get("Path"): - database = CONTEXT["database"] = initialize_database( - config["Database"]["Path"]) - matrix_store = DuckDBStore - - client_config = AsyncClientConfig( - store_sync_tokens=True, encryption_enabled=True, store=matrix_store) - - else: - client_config = AsyncClientConfig( - store_sync_tokens=True, encryption_enabled=False) - - client = AsyncClient( - config["Matrix"]["Homeserver"], config=client_config) - - if client.config.encryption_enabled: - client.store = client.config.store( - user_id, - device_id, - database - ) - assert client.store - - client.olm = Olm(client.user_id, client.device_id, client.store) - client.encrypted_rooms = client.store.load_encrypted_rooms() - - CONTEXT["client"] = client - - CONTEXT["client"].access_token = config["Matrix"]["AccessToken"] - CONTEXT["client"].user_id = user_id - CONTEXT["client"].device_id = device_id - - # Set up GPT API - try: - assert "OpenAI" in config - assert "APIKey" in config["OpenAI"] - except: - logging("OpenAI config not found or incomplete", "critical") - exit(1) - - openai.api_key = config["OpenAI"]["APIKey"] - - if "Model" in config["OpenAI"]: - CONTEXT["model"] = config["OpenAI"]["Model"] - - if "MaxTokens" in config["OpenAI"]: - CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"]) - - if "MaxMessages" in config["OpenAI"]: - CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"]) - - # Override defaults with config - - if "GPTBot" in config: - if "SystemMessage" in config["GPTBot"]: - CONTEXT["system_message"] = config["GPTBot"]["SystemMessage"] - if "DefaultRoomName" in config["GPTBot"]: - CONTEXT["default_room_name"] = config["GPTBot"]["DefaultRoomName"] - if "ForceSystemMessage" in config["GPTBot"]: - CONTEXT["force_system_message"] = config["GPTBot"].getboolean( - "ForceSystemMessage") - - -async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None): - if not client and not CONTEXT.get("client"): - await init(config) - - client = client or CONTEXT["client"] - - try: - assert client.user_id - except AssertionError: - logging( - "Failed to get user ID - check your access token or try setting it manually", "critical") - await client.close() - return - - # Listen for SIGTERM - - def sigterm_handler(_signo, _stack_frame): - logging("Received SIGTERM - exiting...") - exit() - - signal.signal(signal.SIGTERM, sigterm_handler) - - logging("Starting bot...") - - client.add_response_callback(sync_cb, SyncResponse) - - logging("Syncing...") - - await client.sync(timeout=30000) - - client.add_event_callback(message_callback, RoomMessageText) - client.add_event_callback(message_callback, MegolmEvent) - client.add_event_callback(room_invite_callback, InviteEvent) - client.add_event_callback(test_callback, Event) - - await accept_pending_invites() # Accept pending invites - - logging("Bot started") - - try: - # Continue syncing events - await client.sync_forever(timeout=30000) - finally: - logging("Syncing one last time...") - await client.sync(timeout=30000) - await client.close() # Properly close the aiohttp client session - logging("Bot stopped") - - -def initialize_database(path: os.PathLike): - logging("Initializing database...") - conn = duckdb.connect(path) - - with conn.cursor() as cursor: - # Get the latest migration ID if the migrations table exists - try: - cursor.execute( - """ - SELECT MAX(id) FROM migrations - """ - ) - - latest_migration = int(cursor.fetchone()[0]) - - except: - latest_migration = 0 - - for migration, function in MIGRATIONS.items(): - if latest_migration < migration: - logging(f"Running migration {migration}...") - function(conn) - latest_migration = migration - - return conn - - -async def get_device_id(access_token, homeserver): - client = AsyncClient(homeserver) - client.access_token = access_token - - logging(f"Obtaining device ID for access token {access_token}...", "debug") - response = await client.whoami() - if isinstance(response, WhoamiResponse): - logging( - f"Authenticated as {response.user_id}.") - user_id = response.user_id - devices = await client.devices() - device_id = devices.devices[0].id - - await client.close() - - return device_id, user_id - - else: - logging(f"Failed to obtain device ID: {response}", "error") - - await client.close() - - return None, None +def sigterm_handler(_signo, _stack_frame): + exit() if __name__ == "__main__": @@ -562,13 +22,16 @@ if __name__ == "__main__": config = ConfigParser() config.read(args.config) - # Start bot loop + # Create bot + bot = GPTBot.from_config(config) + + # Listen for SIGTERM + signal.signal(signal.SIGTERM, sigterm_handler) + + # Start bot try: - asyncio.run(main(config)) + asyncio.run(bot.run()) except KeyboardInterrupt: - logging("Received KeyboardInterrupt - exiting...") + print("Received KeyboardInterrupt - exiting...") except SystemExit: - logging("Received SIGTERM - exiting...") - finally: - if CONTEXT["database"]: - CONTEXT["database"].close() + print("Received SIGTERM - exiting...") diff --git a/migrations/__init__.py b/migrations/__init__.py index 199d578..097e95b 100644 --- a/migrations/__init__.py +++ b/migrations/__init__.py @@ -1,4 +1,7 @@ from collections import OrderedDict +from typing import Optional + +from duckdb import DuckDBPyConnection from .migration_1 import migration as migration_1 from .migration_2 import migration as migration_2 @@ -8,4 +11,43 @@ MIGRATIONS = OrderedDict() MIGRATIONS[1] = migration_1 MIGRATIONS[2] = migration_2 -MIGRATIONS[3] = migration_3 \ No newline at end of file +MIGRATIONS[3] = migration_3 + +def get_version(db: DuckDBPyConnection) -> int: + """Get the current database version. + + Args: + db (DuckDBPyConnection): Database connection. + + Returns: + int: Current database version. + """ + + try: + return int(db.execute("SELECT MAX(id) FROM migrations").fetchone()[0]) + except: + return 0 + +def migrate(db: DuckDBPyConnection, from_version: Optional[int] = None, to_version: Optional[int] = None) -> None: + """Migrate the database to a specific version. + + Args: + db (DuckDBPyConnection): Database connection. + from_version (Optional[int]): Version to migrate from. If None, the current version is used. + to_version (Optional[int]): Version to migrate to. If None, the latest version is used. + """ + + if from_version is None: + from_version = get_version(db) + + if to_version is None: + to_version = max(MIGRATIONS.keys()) + + if from_version > to_version: + raise ValueError("Cannot migrate from a higher version to a lower version.") + + for version in range(from_version, to_version): + if version in MIGRATIONS: + MIGRATIONS[version + 1](db) + + return from_version, to_version \ No newline at end of file