A version that does not crash immediately! (I hope)

This commit is contained in:
Kumi 2023-07-13 16:04:41 +02:00
parent d2c6682faa
commit 94b2457a39
Signed by: kumi
GPG key ID: ECBCC9082395383F
8 changed files with 125 additions and 219 deletions

View file

@ -7,7 +7,7 @@ allow-direct-references = true
[project] [project]
name = "matrix-gptbot" name = "matrix-gptbot"
version = "0.1.1" version = "0.2.0-dev"
authors = [ authors = [
{ name="Kumi Mitterer", email="gptbot@kumi.email" }, { name="Kumi Mitterer", email="gptbot@kumi.email" },
@ -54,7 +54,6 @@ all = [
] ]
dev = [ dev = [
"matrix-gptbot[all]",
"black", "black",
] ]

View file

@ -1,5 +1,5 @@
openai openai
matrix-nio[e2e] mautrix
markdown2[all] markdown2[all]
tiktoken tiktoken
duckdb duckdb

View file

@ -12,6 +12,8 @@ from nio import (
Response, Response,
) )
from mautrix.types import Event, MessageEvent, StateEvent
from .test import test_callback from .test import test_callback
from .sync import sync_callback from .sync import sync_callback
from .invite import room_invite_callback from .invite import room_invite_callback
@ -28,8 +30,5 @@ RESPONSE_CALLBACKS = {
EVENT_CALLBACKS = { EVENT_CALLBACKS = {
Event: test_callback, Event: test_callback,
InviteEvent: room_invite_callback, MessageEvent: message_callback,
RoomMessageText: message_callback,
MegolmEvent: message_callback,
RoomMemberEvent: roommember_callback,
} }

View file

@ -0,0 +1,36 @@
from ..classes.bot import GPTBot
from nio import Event
class BaseEventCallback:
EVENTS = [] # List of events that this callback should be called for
def __init__(self, bot: GPTBot):
"""Initialize the callback with the bot instance
Args:
bot (GPTBot): GPTBot instance
"""
self.bot = bot
async def process(self, event: Event, *args, **kwargs):
raise NotImplementedError(
"BaseEventCallback.process() must be implemented by subclasses"
)
class BaseResponseCallback:
RESPONSES = [] # List of responses that this callback should be called for
def __init__(self, bot: GPTBot):
"""Initialize the callback with the bot instance
Args:
bot (GPTBot): GPTBot instance
"""
self.bot = bot
async def process(self, response: Response, *args, **kwargs):
raise NotImplementedError(
"BaseResponseCallback.process() must be implemented by subclasses"
)

View file

@ -1,9 +1,9 @@
from nio import MatrixRoom, RoomMessageText, MegolmEvent, RoomKeyRequestError, RoomKeyRequestResponse from mautrix.types import MessageEvent
from datetime import datetime from datetime import datetime
async def message_callback(room: MatrixRoom | str, event: RoomMessageText | MegolmEvent, bot): async def message_callback(event: MessageEvent, bot):
bot.logger.log(f"Received message from {event.sender} in room {room.room_id}") bot.logger.log(f"Received message from {event.sender} in room {event.room_id}")
sent = datetime.fromtimestamp(event.server_timestamp / 1000) sent = datetime.fromtimestamp(event.server_timestamp / 1000)
received = datetime.now() received = datetime.now()
@ -34,18 +34,18 @@ async def message_callback(room: MatrixRoom | str, event: RoomMessageText | Mego
bot.logger.log("Message is from bot itself - ignoring") bot.logger.log("Message is from bot itself - ignoring")
elif event.body.startswith("!gptbot"): elif event.body.startswith("!gptbot"):
await bot.process_command(room, event) await bot.process_command(event)
elif event.body.startswith("!"): elif event.body.startswith("!"):
bot.logger.log(f"Received {event.body} - might be a command, but not for this bot - ignoring") bot.logger.log(f"Received {event.body} - might be a command, but not for this bot - ignoring")
else: else:
await bot.process_query(room, event) await bot.process_query(event)
processed = datetime.now() processed = datetime.now()
processing_time = processed - received processing_time = processed - received
bot.logger.log(f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)") bot.logger.log(f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)")
if bot.room_uses_timing(room): if bot.room_uses_timing(event.room_id):
await bot.send_message(room, f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)", True) await bot.send_message(event.room_id, f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)", True)

View file

@ -1,11 +1,10 @@
from nio import MatrixRoom, Event from mautrix.types import Event
async def test_callback(room: MatrixRoom, event: Event, bot): async def test_callback(event: Event, bot):
"""Test callback for debugging purposes. """Test callback for debugging purposes.
Args: Args:
room (MatrixRoom): The room the event was sent in.
event (Event): The event that was sent. event (Event): The event that was sent.
""" """
bot.logger.log(f"Test callback called: {room.room_id} {event.event_id} {event.sender} {event.__class__}") bot.logger.log(f"Test callback called: {event.room_id} {event.event_id} {event.sender} {event.__class__}")

View file

@ -5,31 +5,22 @@ import functools
from PIL import Image from PIL import Image
from nio import ( from mautrix.client import Client
AsyncClient, from mautrix.types import (
AsyncClientConfig, RoomID,
WhoamiResponse, UserID,
DevicesResponse, EventType,
Event, MessageType,
Response, MessageEvent,
MatrixRoom, RoomDirectoryVisibility,
Api, )
RoomMessagesError, from mautrix.errors import (
MegolmEvent, MForbidden,
GroupEncryptionError, MNotFound,
EncryptionError, MUnknownToken,
RoomMessageText, MForbidden,
RoomSendResponse, MatrixError,
SyncResponse,
RoomMessageNotice,
JoinError,
RoomLeaveError,
RoomSendError,
RoomVisibility,
RoomCreateError,
) )
from nio.crypto import Olm
from nio.store import SqliteStore
from typing import Optional, List from typing import Optional, List
from configparser import ConfigParser from configparser import ConfigParser
@ -65,7 +56,7 @@ class GPTBot:
force_system_message: bool = False force_system_message: bool = False
max_tokens: int = 3000 # Maximum number of input tokens max_tokens: int = 3000 # Maximum number of input tokens
max_messages: int = 30 # Maximum number of messages to consider as input max_messages: int = 30 # Maximum number of messages to consider as input
matrix_client: Optional[AsyncClient] = None matrix_client: Optional[Client] = None
sync_token: Optional[str] = None sync_token: Optional[str] = None
logger: Optional[Logger] = Logger() logger: Optional[Logger] = Logger()
chat_api: Optional[OpenAI] = None chat_api: Optional[OpenAI] = None
@ -161,13 +152,11 @@ class GPTBot:
assert "Matrix" in config, "Matrix config not found" assert "Matrix" in config, "Matrix config not found"
homeserver = config["Matrix"]["Homeserver"] bot.homeserver = config["Matrix"]["Homeserver"]
bot.matrix_client = AsyncClient(homeserver) bot.access_token = config["Matrix"]["AccessToken"]
bot.matrix_client.access_token = config["Matrix"]["AccessToken"] bot.user_id = config["Matrix"].get("UserID")
bot.matrix_client.user_id = config["Matrix"].get("UserID") bot.device_id = config["Matrix"].get("DeviceID")
bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
# Return the new GPTBot instance
return bot return bot
async def _get_user_id(self) -> str: async def _get_user_id(self) -> str:
@ -178,68 +167,12 @@ class GPTBot:
str: The user ID of the bot. str: The user ID of the bot.
""" """
assert self.matrix_client, "Matrix client not set up" pass
# TODO: Implement
user_id = self.matrix_client.user_id async def _last_n_messages(self, room: str | RoomID, n: Optional[int]):
pass
if not user_id: # TODO: Implement
assert self.matrix_client.access_token, "Access token not set up"
response = await self.matrix_client.whoami()
if isinstance(response, WhoamiResponse):
user_id = response.user_id
else:
raise Exception(f"Could not get user ID: {response}")
return user_id
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
messages = []
n = n or self.max_messages
room_id = room.room_id if isinstance(room, MatrixRoom) else room
self.logger.log(
f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...",
"debug",
)
response = await self.matrix_client.room_messages(
room_id=room_id,
start=self.sync_token,
limit=2 * n,
)
if isinstance(response, RoomMessagesError):
raise Exception(
f"Error fetching messages: {response.message} (status code {response.status_code})",
"error",
)
for event in response.chunk:
if len(messages) >= n:
break
if isinstance(event, MegolmEvent):
try:
event = await self.matrix_client.decrypt_event(event)
except (GroupEncryptionError, EncryptionError):
self.logger.log(
f"Could not decrypt message {event.event_id} in room {room_id}",
"error",
)
continue
if isinstance(event, (RoomMessageText, RoomMessageNotice)):
if event.body.startswith("!gptbot ignoreolder"):
break
if (not event.body.startswith("!")) or (
event.body.startswith("!gptbot")
):
messages.append(event)
self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
# Reverse the list so that messages are in chronological order
return messages[::-1]
def _truncate( def _truncate(
self, self,
@ -298,20 +231,17 @@ class GPTBot:
if not device_id: if not device_id:
assert self.matrix_client.access_token, "Access token not set up" assert self.matrix_client.access_token, "Access token not set up"
devices = await self.matrix_client.devices() # TODO: Implement
if isinstance(devices, DevicesResponse):
device_id = devices.devices[0].id
return device_id return device_id
async def process_command(self, room: MatrixRoom, event: RoomMessageText): async def process_command(self, room: RoomID, event: MessageEvent):
"""Process a command. Called from the event_callback() method. """Process a command. Called from the event_callback() method.
Delegates to the appropriate command handler. Delegates to the appropriate command handler.
Args: Args:
room (MatrixRoom): The room the command was sent in. room (RoomID): The room the command was sent in.
event (RoomMessageText): The event containing the command. event (MessageEvent): The event containing the command.
""" """
self.logger.log( self.logger.log(
@ -322,11 +252,11 @@ class GPTBot:
await COMMANDS.get(command, COMMANDS[None])(room, event, self) await COMMANDS.get(command, COMMANDS[None])(room, event, self)
def room_uses_classification(self, room: MatrixRoom | str) -> bool: def room_uses_classification(self, room: RoomID | str) -> bool:
"""Check if a room uses classification. """Check if a room uses classification.
Args: Args:
room (MatrixRoom | str): The room to check. room (RoomID | str): The room to check.
Returns: Returns:
bool: Whether the room uses classification. bool: Whether the room uses classification.
@ -342,7 +272,7 @@ class GPTBot:
return False if not result else bool(int(result[0])) return False if not result else bool(int(result[0]))
async def _event_callback(self, room: MatrixRoom, event: Event): async def _event_callback(self, room: RoomID, event: MessageEvent):
self.logger.log("Received event: " + str(event.event_id), "debug") self.logger.log("Received event: " + str(event.event_id), "debug")
try: try:
for eventtype, callback in EVENT_CALLBACKS.items(): for eventtype, callback in EVENT_CALLBACKS.items():
@ -378,12 +308,12 @@ class GPTBot:
else True else True
) )
async def event_callback(self, room: MatrixRoom, event: Event): async def event_callback(self, room: RoomID, event: MessageEvent):
"""Callback for events. """Callback for events.
Args: Args:
room (MatrixRoom): The room the event was sent in. room (RoomID): The room the event was sent in.
event (Event): The event. event (MessageEvent): The event.
""" """
if event.sender == self.matrix_client.user_id: if event.sender == self.matrix_client.user_id:
@ -403,11 +333,11 @@ class GPTBot:
task = asyncio.create_task(self._event_callback(room, event)) task = asyncio.create_task(self._event_callback(room, event))
def room_uses_timing(self, room: MatrixRoom): def room_uses_timing(self, room: RoomID):
"""Check if a room uses timing. """Check if a room uses timing.
Args: Args:
room (MatrixRoom): The room to check. room (RoomID): The room to check.
Returns: Returns:
bool: Whether the room uses timing. bool: Whether the room uses timing.
@ -423,14 +353,6 @@ class GPTBot:
return False if not result else bool(int(result[0])) return False if not result else bool(int(result[0]))
async def _response_callback(self, response: Response):
for response_type, callback in RESPONSE_CALLBACKS.items():
if isinstance(response, response_type):
await callback(response, self)
async def response_callback(self, response: Response):
task = asyncio.create_task(self._response_callback(response))
async def accept_pending_invites(self): async def accept_pending_invites(self):
"""Accept all pending invites.""" """Accept all pending invites."""
@ -492,12 +414,12 @@ class GPTBot:
return response.content_uri return response.content_uri
async def send_image( async def send_image(
self, room: MatrixRoom, image: bytes, message: Optional[str] = None self, room: RoomID, image: bytes, message: Optional[str] = None
): ):
"""Send an image to a room. """Send an image to a room.
Args: Args:
room (MatrixRoom): The room to send the image to. room (RoomID): The room to send the image to.
image (bytes): The image to send. image (bytes): The image to send.
message (str, optional): The message to send with the image. Defaults to None. message (str, optional): The message to send with the image. Defaults to None.
""" """
@ -539,13 +461,18 @@ class GPTBot:
self.logger.log("Sent image", "debug") self.logger.log("Sent image", "debug")
async def handle_event(self, *args, **kwargs):
"""Handle an event."""
self.logger.log(f"Handling event: {args} {kwargs}", "debug")
async def send_message( async def send_message(
self, room: MatrixRoom | str, message: str, notice: bool = False self, room: RoomID | str, message: str, notice: bool = False
): ):
"""Send a message to a room. """Send a message to a room.
Args: Args:
room (MatrixRoom): The room to send the message to. room (RoomID): The room to send the message to.
message (str): The message to send. message (str): The message to send.
notice (bool): Whether to send the message as a notice. Defaults to False. notice (bool): Whether to send the message as a notice. Defaults to False.
""" """
@ -618,13 +545,13 @@ class GPTBot:
return return
def log_api_usage( def log_api_usage(
self, message: Event | str, room: MatrixRoom | str, api: str, tokens: int self, message: MessageEvent | str, room: RoomID | str, api: str, tokens: int
): ):
"""Log API usage to the database. """Log API usage to the database.
Args: Args:
message (Event): The event that triggered the API usage. message (MessageEvent): The event that triggered the API usage.
room (MatrixRoom | str): The room the event was sent in. room (RoomID | str): The room the event was sent in.
api (str): The API that was used. api (str): The API that was used.
tokens (int): The number of tokens used. tokens (int): The number of tokens used.
""" """
@ -648,14 +575,11 @@ class GPTBot:
# Set up the Matrix client # Set up the Matrix client
assert self.matrix_client, "Matrix client not set up" self.matrix_client: Client = self.matrix_client or Client(base_url=self.homeserver, token=self.access_token)
assert self.matrix_client.access_token, "Access token not set up"
if not self.matrix_client.user_id: iam = await self.matrix_client.whoami()
self.matrix_client.user_id = await self._get_user_id()
if not self.matrix_client.device_id: self.logger.log(f"Logged in as {iam.user_id} (device ID: {iam.device_id})", "info")
self.matrix_client.device_id = await self._get_device_id()
# Set up database # Set up database
@ -686,84 +610,33 @@ class GPTBot:
else: else:
self.logger.log(f"Already at latest version {after}.") self.logger.log(f"Already at latest version {after}.")
if IN_MEMORY: # Set up event handlers
client_config = AsyncClientConfig( self.matrix_client.add_event_handler(EventType.ALL, self.handle_event)
store_sync_tokens=True, encryption_enabled=False
)
else:
matrix_store = SqliteStore
client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=True, store=matrix_store
)
self.matrix_client.config = client_config
self.matrix_client.store = matrix_store(
self.matrix_client.user_id,
self.matrix_client.device_id,
'.', #store path
database_name=self.crypto_store_path or "",
)
self.matrix_client.olm = Olm(
self.matrix_client.user_id,
self.matrix_client.device_id,
self.matrix_client.store,
)
self.matrix_client.encrypted_rooms = (
self.matrix_client.store.load_encrypted_rooms()
)
# Run initial sync (now includes joining rooms) # Run initial sync (now includes joining rooms)
sync = await self.matrix_client.sync(timeout=30000) sync = await self.matrix_client.sync(timeout=30000)
if isinstance(sync, SyncResponse):
await self.response_callback(sync)
else:
self.logger.log(f"Initial sync failed, aborting: {sync}", "critical")
exit(1)
# Set up callbacks
self.matrix_client.add_event_callback(self.event_callback, Event)
self.matrix_client.add_response_callback(self.response_callback, Response)
# Set custom name / logo # Set custom name / logo
if self.display_name: # TODO: Implement
self.logger.log(f"Setting display name to {self.display_name}", "debug")
asyncio.create_task(self.matrix_client.set_displayname(self.display_name))
if self.logo:
self.logger.log("Setting avatar...")
logo_bio = BytesIO()
self.logo.save(logo_bio, format=self.logo.format)
uri = await self.upload_file(
logo_bio.getvalue(), "logo", Image.MIME[self.logo.format]
)
self.logo_uri = uri
asyncio.create_task(self.matrix_client.set_avatar(uri))
for room in self.matrix_client.rooms.keys():
self.logger.log(f"Setting avatar for {room}...", "debug")
asyncio.create_task(
self.matrix_client.room_put_state(
room, "m.room.avatar", {"url": uri}, ""
)
)
# Start syncing events # Start syncing events
self.logger.log("Starting sync loop...", "warning") self.logger.log("Starting sync loop...", "warning")
try: try:
await self.matrix_client.sync_forever(timeout=30000) await self.matrix_client.start(None)
finally: finally:
self.logger.log("Syncing one last time...", "warning") self.logger.log("Syncing one last time...", "warning")
await self.matrix_client.sync(timeout=30000) await self.matrix_client.sync(timeout=30000)
async def create_space(self, name, visibility=RoomVisibility.private) -> str: async def create_space(
self, name, visibility=RoomDirectoryVisibility.PRIVATE
) -> str:
"""Create a space. """Create a space.
Args: Args:
name (str): The name of the space. name (str): The name of the space.
visibility (RoomVisibility, optional): The visibility of the space. Defaults to RoomVisibility.private. visibility (RoomDirectoryVisibility, optional): The visibility of the space. Defaults to RoomVisibility.private.
Returns: Returns:
MatrixRoom: The created space. MatrixRoom: The created space.
@ -780,13 +653,13 @@ class GPTBot:
return response.room_id return response.room_id
async def add_rooms_to_space( async def add_rooms_to_space(
self, space: MatrixRoom | str, rooms: List[MatrixRoom | str] self, space: RoomID | str, rooms: List[RoomID | str]
): ):
"""Add rooms to a space. """Add rooms to a space.
Args: Args:
space (MatrixRoom | str): The space to add the rooms to. space (RoomID | str): The space to add the rooms to.
rooms (List[MatrixRoom | str]): The rooms to add to the space. rooms (List[RoomID | str]): The rooms to add to the space.
""" """
if isinstance(space, MatrixRoom): if isinstance(space, MatrixRoom):
@ -818,17 +691,17 @@ class GPTBot:
space, space,
) )
def respond_to_room_messages(self, room: MatrixRoom | str) -> bool: def respond_to_room_messages(self, room: RoomID | str) -> bool:
"""Check whether the bot should respond to all messages sent in a room. """Check whether the bot should respond to all messages sent in a room.
Args: Args:
room (MatrixRoom | str): The room to check. room (RoomID | str): The room to check.
Returns: Returns:
bool: Whether the bot should respond to all messages sent in the room. bool: Whether the bot should respond to all messages sent in the room.
""" """
if isinstance(room, MatrixRoom): if isinstance(room, RoomID):
room = room.room_id room = room.room_id
with closing(self.database.cursor()) as cursor: with closing(self.database.cursor()) as cursor:
@ -841,26 +714,26 @@ class GPTBot:
return True if not result else bool(int(result[0])) return True if not result else bool(int(result[0]))
async def process_query( async def process_query(
self, room: MatrixRoom, event: RoomMessageText, from_chat_command: bool = False self, room: RoomID, event: MessageEvent, from_chat_command: bool = False
): ):
"""Process a query message. Generates a response and sends it to the room. """Process a query message. Generates a response and sends it to the room.
Args: Args:
room (MatrixRoom): The room the message was sent in. room (RoomID): The room the message was sent in.
event (RoomMessageText): The event that triggered the query. event (MessageEvent): The event that triggered the query.
from_chat_command (bool, optional): Whether the query was sent via the `!gptbot chat` command. Defaults to False. from_chat_command (bool, optional): Whether the query was sent via the `!gptbot chat` command. Defaults to False.
""" """
if not ( if not (
from_chat_command from_chat_command
or self.respond_to_room_messages(room) or self.respond_to_room_messages(room)
or self.matrix_client.user_id in event.body or self.matrix_client.whoami().user_id in event.body
): ):
return return
await self.matrix_client.room_typing(room.room_id, True) # TODO: await self.matrix_client.room_typing(room.room_id, True)
await self.matrix_client.room_read_markers(room.room_id, event.event_id) # TODO: await self.matrix_client.room_read_markers(room.room_id, event.event_id)
if (not from_chat_command) and self.room_uses_classification(room): if (not from_chat_command) and self.room_uses_classification(room):
try: try:
@ -949,11 +822,11 @@ class GPTBot:
await self.matrix_client.room_typing(room.room_id, False) await self.matrix_client.room_typing(room.room_id, False)
def get_system_message(self, room: MatrixRoom | str) -> str: def get_system_message(self, room: RoomID | str) -> str:
"""Get the system message for a room. """Get the system message for a room.
Args: Args:
room (MatrixRoom | str): The room to get the system message for. room (RoomID | str): The room to get the system message for.
Returns: Returns:
str: The system message. str: The system message.