Compare commits

...

3 commits

Author SHA1 Message Date
ef3118cbe3
Dall-E model selection 2023-11-07 13:58:25 +01:00
9abea6e3f8
Fix event handling in GPTBot class and add missing callback entries
- Fix handling of events in the `handle_event` method of the `GPTBot` class.
- Add missing callback entries in the `EVENT_CALLBACKS` dictionary.
2023-09-12 09:17:59 +02:00
94b2457a39
A version that does not crash immediately! (I hope) 2023-07-13 16:04:41 +02:00
11 changed files with 153 additions and 232 deletions

View file

@ -67,6 +67,10 @@ LogLevel = info
#
# Model = gpt-3.5-turbo
# The Image Generation model you want to use.
#
# ImageModel = dall-e-2
# Your OpenAI API key
#
# Find this in your OpenAI account:

View file

@ -7,7 +7,7 @@ allow-direct-references = true
[project]
name = "matrix-gptbot"
version = "0.1.1"
version = "0.2.0-dev"
authors = [
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
@ -18,10 +18,6 @@ readme = "README.md"
license = { file="LICENSE" }
requires-python = ">=3.10"
packages = [
"src/gptbot"
]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
@ -29,13 +25,17 @@ classifiers = [
]
dependencies = [
"matrix-nio[e2e]",
"mautrix[all]",
"markdown2[all]",
"tiktoken",
"python-magic",
"pillow",
]
packages = [
{ include = "gptbot", where = "src" },
]
[project.optional-dependencies]
openai = [
"openai",
@ -54,7 +54,6 @@ all = [
]
dev = [
"matrix-gptbot[all]",
"black",
]
@ -63,7 +62,10 @@ dev = [
"Bug Tracker" = "https://kumig.it/kumitterer/matrix-gptbot/issues"
[project.scripts]
gptbot = "gptbot:main"
gptbot = "gptbot.__main___:main"
[tool.hatch.build.targets.wheel]
packages = ["src/gptbot"]
only-include = ["src/gptbot"]
[tool.hatch.build.targets.wheel.sources]
"src" = ""

View file

@ -1,5 +1,5 @@
openai
matrix-nio[e2e]
mautrix
markdown2[all]
tiktoken
duckdb

View file

@ -10,8 +10,7 @@ import asyncio
def sigterm_handler(_signo, _stack_frame):
exit()
if __name__ == "__main__":
def main():
# Parse command line arguments
parser = ArgumentParser()
parser.add_argument(
@ -46,3 +45,7 @@ if __name__ == "__main__":
print("Received KeyboardInterrupt - exiting...")
except SystemExit:
print("Received SIGTERM - exiting...")
if __name__ == "__main__":
main()

View file

@ -12,6 +12,8 @@ from nio import (
Response,
)
from mautrix.types import Event, MessageEvent, StateEvent
from .test import test_callback
from .sync import sync_callback
from .invite import room_invite_callback
@ -27,9 +29,5 @@ RESPONSE_CALLBACKS = {
}
EVENT_CALLBACKS = {
Event: test_callback,
InviteEvent: room_invite_callback,
RoomMessageText: message_callback,
MegolmEvent: message_callback,
RoomMemberEvent: roommember_callback,
MessageEvent: message_callback,
}

View file

@ -0,0 +1,36 @@
from ..classes.bot import GPTBot
from nio import Event
class BaseEventCallback:
EVENTS = [] # List of events that this callback should be called for
def __init__(self, bot: GPTBot):
"""Initialize the callback with the bot instance
Args:
bot (GPTBot): GPTBot instance
"""
self.bot = bot
async def process(self, event: Event, *args, **kwargs):
raise NotImplementedError(
"BaseEventCallback.process() must be implemented by subclasses"
)
class BaseResponseCallback:
RESPONSES = [] # List of responses that this callback should be called for
def __init__(self, bot: GPTBot):
"""Initialize the callback with the bot instance
Args:
bot (GPTBot): GPTBot instance
"""
self.bot = bot
async def process(self, response: Response, *args, **kwargs):
raise NotImplementedError(
"BaseResponseCallback.process() must be implemented by subclasses"
)

View file

@ -1,9 +1,9 @@
from nio import MatrixRoom, RoomMessageText, MegolmEvent, RoomKeyRequestError, RoomKeyRequestResponse
from mautrix.types import MessageEvent
from datetime import datetime
async def message_callback(room: MatrixRoom | str, event: RoomMessageText | MegolmEvent, bot):
bot.logger.log(f"Received message from {event.sender} in room {room.room_id}")
async def message_callback(event: MessageEvent, bot):
bot.logger.log(f"Received message from {event.sender} in room {event.room_id}")
sent = datetime.fromtimestamp(event.server_timestamp / 1000)
received = datetime.now()
@ -34,18 +34,18 @@ async def message_callback(room: MatrixRoom | str, event: RoomMessageText | Mego
bot.logger.log("Message is from bot itself - ignoring")
elif event.body.startswith("!gptbot"):
await bot.process_command(room, event)
await bot.process_command(event)
elif event.body.startswith("!"):
bot.logger.log(f"Received {event.body} - might be a command, but not for this bot - ignoring")
else:
await bot.process_query(room, event)
await bot.process_query(event)
processed = datetime.now()
processing_time = processed - received
bot.logger.log(f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)")
if bot.room_uses_timing(room):
await bot.send_message(room, f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)", True)
if bot.room_uses_timing(event.room_id):
await bot.send_message(event.room_id, f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)", True)

View file

@ -1,11 +1,10 @@
from nio import MatrixRoom, Event
from mautrix.types import Event
async def test_callback(room: MatrixRoom, event: Event, bot):
async def test_callback(event: Event, bot):
"""Test callback for debugging purposes.
Args:
room (MatrixRoom): The room the event was sent in.
event (Event): The event that was sent.
"""
bot.logger.log(f"Test callback called: {room.room_id} {event.event_id} {event.sender} {event.__class__}")
bot.logger.log(f"Test callback called: {event.room_id} {event.event_id} {event.sender} {event.__class__}")

View file

@ -5,31 +5,22 @@ import functools
from PIL import Image
from nio import (
AsyncClient,
AsyncClientConfig,
WhoamiResponse,
DevicesResponse,
Event,
Response,
MatrixRoom,
Api,
RoomMessagesError,
MegolmEvent,
GroupEncryptionError,
EncryptionError,
RoomMessageText,
RoomSendResponse,
SyncResponse,
RoomMessageNotice,
JoinError,
RoomLeaveError,
RoomSendError,
RoomVisibility,
RoomCreateError,
from mautrix.client import Client
from mautrix.types import (
RoomID,
UserID,
EventType,
MessageType,
MessageEvent,
RoomDirectoryVisibility,
)
from mautrix.errors import (
MForbidden,
MNotFound,
MUnknownToken,
MForbidden,
MatrixError,
)
from nio.crypto import Olm
from nio.store import SqliteStore
from typing import Optional, List
from configparser import ConfigParser
@ -65,7 +56,7 @@ class GPTBot:
force_system_message: bool = False
max_tokens: int = 3000 # Maximum number of input tokens
max_messages: int = 30 # Maximum number of messages to consider as input
matrix_client: Optional[AsyncClient] = None
matrix_client: Optional[Client] = None
sync_token: Optional[str] = None
logger: Optional[Logger] = Logger()
chat_api: Optional[OpenAI] = None
@ -138,7 +129,8 @@ class GPTBot:
bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"])
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
config["OpenAI"].get("ImageModel"), bot.logger
)
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
@ -161,13 +153,11 @@ class GPTBot:
assert "Matrix" in config, "Matrix config not found"
homeserver = config["Matrix"]["Homeserver"]
bot.matrix_client = AsyncClient(homeserver)
bot.matrix_client.access_token = config["Matrix"]["AccessToken"]
bot.matrix_client.user_id = config["Matrix"].get("UserID")
bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
bot.homeserver = config["Matrix"]["Homeserver"]
bot.access_token = config["Matrix"]["AccessToken"]
bot.user_id = config["Matrix"].get("UserID")
bot.device_id = config["Matrix"].get("DeviceID")
# Return the new GPTBot instance
return bot
async def _get_user_id(self) -> str:
@ -178,68 +168,12 @@ class GPTBot:
str: The user ID of the bot.
"""
assert self.matrix_client, "Matrix client not set up"
pass
# TODO: Implement
user_id = self.matrix_client.user_id
if not user_id:
assert self.matrix_client.access_token, "Access token not set up"
response = await self.matrix_client.whoami()
if isinstance(response, WhoamiResponse):
user_id = response.user_id
else:
raise Exception(f"Could not get user ID: {response}")
return user_id
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
messages = []
n = n or self.max_messages
room_id = room.room_id if isinstance(room, MatrixRoom) else room
self.logger.log(
f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...",
"debug",
)
response = await self.matrix_client.room_messages(
room_id=room_id,
start=self.sync_token,
limit=2 * n,
)
if isinstance(response, RoomMessagesError):
raise Exception(
f"Error fetching messages: {response.message} (status code {response.status_code})",
"error",
)
for event in response.chunk:
if len(messages) >= n:
break
if isinstance(event, MegolmEvent):
try:
event = await self.matrix_client.decrypt_event(event)
except (GroupEncryptionError, EncryptionError):
self.logger.log(
f"Could not decrypt message {event.event_id} in room {room_id}",
"error",
)
continue
if isinstance(event, (RoomMessageText, RoomMessageNotice)):
if event.body.startswith("!gptbot ignoreolder"):
break
if (not event.body.startswith("!")) or (
event.body.startswith("!gptbot")
):
messages.append(event)
self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
# Reverse the list so that messages are in chronological order
return messages[::-1]
async def _last_n_messages(self, room: str | RoomID, n: Optional[int]):
pass
# TODO: Implement
def _truncate(
self,
@ -298,20 +232,17 @@ class GPTBot:
if not device_id:
assert self.matrix_client.access_token, "Access token not set up"
devices = await self.matrix_client.devices()
if isinstance(devices, DevicesResponse):
device_id = devices.devices[0].id
# TODO: Implement
return device_id
async def process_command(self, room: MatrixRoom, event: RoomMessageText):
async def process_command(self, room: RoomID, event: MessageEvent):
"""Process a command. Called from the event_callback() method.
Delegates to the appropriate command handler.
Args:
room (MatrixRoom): The room the command was sent in.
event (RoomMessageText): The event containing the command.
room (RoomID): The room the command was sent in.
event (MessageEvent): The event containing the command.
"""
self.logger.log(
@ -322,11 +253,11 @@ class GPTBot:
await COMMANDS.get(command, COMMANDS[None])(room, event, self)
def room_uses_classification(self, room: MatrixRoom | str) -> bool:
def room_uses_classification(self, room: RoomID | str) -> bool:
"""Check if a room uses classification.
Args:
room (MatrixRoom | str): The room to check.
room (RoomID | str): The room to check.
Returns:
bool: Whether the room uses classification.
@ -342,7 +273,7 @@ class GPTBot:
return False if not result else bool(int(result[0]))
async def _event_callback(self, room: MatrixRoom, event: Event):
async def _event_callback(self, room: RoomID, event: MessageEvent):
self.logger.log("Received event: " + str(event.event_id), "debug")
try:
for eventtype, callback in EVENT_CALLBACKS.items():
@ -378,12 +309,12 @@ class GPTBot:
else True
)
async def event_callback(self, room: MatrixRoom, event: Event):
async def event_callback(self, room: RoomID, event: MessageEvent):
"""Callback for events.
Args:
room (MatrixRoom): The room the event was sent in.
event (Event): The event.
room (RoomID): The room the event was sent in.
event (MessageEvent): The event.
"""
if event.sender == self.matrix_client.user_id:
@ -403,11 +334,11 @@ class GPTBot:
task = asyncio.create_task(self._event_callback(room, event))
def room_uses_timing(self, room: MatrixRoom):
def room_uses_timing(self, room: RoomID):
"""Check if a room uses timing.
Args:
room (MatrixRoom): The room to check.
room (RoomID): The room to check.
Returns:
bool: Whether the room uses timing.
@ -423,14 +354,6 @@ class GPTBot:
return False if not result else bool(int(result[0]))
async def _response_callback(self, response: Response):
for response_type, callback in RESPONSE_CALLBACKS.items():
if isinstance(response, response_type):
await callback(response, self)
async def response_callback(self, response: Response):
task = asyncio.create_task(self._response_callback(response))
async def accept_pending_invites(self):
"""Accept all pending invites."""
@ -492,12 +415,12 @@ class GPTBot:
return response.content_uri
async def send_image(
self, room: MatrixRoom, image: bytes, message: Optional[str] = None
self, room: RoomID, image: bytes, message: Optional[str] = None
):
"""Send an image to a room.
Args:
room (MatrixRoom): The room to send the image to.
room (RoomID): The room to send the image to.
image (bytes): The image to send.
message (str, optional): The message to send with the image. Defaults to None.
"""
@ -539,13 +462,21 @@ class GPTBot:
self.logger.log("Sent image", "debug")
async def handle_event(self, event):
"""Handle an event."""
for event_type, callback in EVENT_CALLBACKS.items():
if isinstance(event, event_type):
print(event_type, callback)
await callback(event, self)
async def send_message(
self, room: MatrixRoom | str, message: str, notice: bool = False
self, room: RoomID | str, message: str, notice: bool = False
):
"""Send a message to a room.
Args:
room (MatrixRoom): The room to send the message to.
room (RoomID): The room to send the message to.
message (str): The message to send.
notice (bool): Whether to send the message as a notice. Defaults to False.
"""
@ -618,13 +549,13 @@ class GPTBot:
return
def log_api_usage(
self, message: Event | str, room: MatrixRoom | str, api: str, tokens: int
self, message: MessageEvent | str, room: RoomID | str, api: str, tokens: int
):
"""Log API usage to the database.
Args:
message (Event): The event that triggered the API usage.
room (MatrixRoom | str): The room the event was sent in.
message (MessageEvent): The event that triggered the API usage.
room (RoomID | str): The room the event was sent in.
api (str): The API that was used.
tokens (int): The number of tokens used.
"""
@ -648,14 +579,11 @@ class GPTBot:
# Set up the Matrix client
assert self.matrix_client, "Matrix client not set up"
assert self.matrix_client.access_token, "Access token not set up"
self.matrix_client: Client = self.matrix_client or Client(base_url=self.homeserver, token=self.access_token)
if not self.matrix_client.user_id:
self.matrix_client.user_id = await self._get_user_id()
iam = await self.matrix_client.whoami()
if not self.matrix_client.device_id:
self.matrix_client.device_id = await self._get_device_id()
self.logger.log(f"Logged in as {iam.user_id} (device ID: {iam.device_id})", "info")
# Set up database
@ -686,84 +614,33 @@ class GPTBot:
else:
self.logger.log(f"Already at latest version {after}.")
if IN_MEMORY:
client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=False
)
else:
matrix_store = SqliteStore
client_config = AsyncClientConfig(
store_sync_tokens=True, encryption_enabled=True, store=matrix_store
)
self.matrix_client.config = client_config
self.matrix_client.store = matrix_store(
self.matrix_client.user_id,
self.matrix_client.device_id,
'.', #store path
database_name=self.crypto_store_path or "",
)
self.matrix_client.olm = Olm(
self.matrix_client.user_id,
self.matrix_client.device_id,
self.matrix_client.store,
)
self.matrix_client.encrypted_rooms = (
self.matrix_client.store.load_encrypted_rooms()
)
# Set up event handlers
self.matrix_client.add_event_handler(EventType.ALL, self.handle_event)
# Run initial sync (now includes joining rooms)
sync = await self.matrix_client.sync(timeout=30000)
if isinstance(sync, SyncResponse):
await self.response_callback(sync)
else:
self.logger.log(f"Initial sync failed, aborting: {sync}", "critical")
exit(1)
# Set up callbacks
self.matrix_client.add_event_callback(self.event_callback, Event)
self.matrix_client.add_response_callback(self.response_callback, Response)
# Set custom name / logo
if self.display_name:
self.logger.log(f"Setting display name to {self.display_name}", "debug")
asyncio.create_task(self.matrix_client.set_displayname(self.display_name))
if self.logo:
self.logger.log("Setting avatar...")
logo_bio = BytesIO()
self.logo.save(logo_bio, format=self.logo.format)
uri = await self.upload_file(
logo_bio.getvalue(), "logo", Image.MIME[self.logo.format]
)
self.logo_uri = uri
asyncio.create_task(self.matrix_client.set_avatar(uri))
for room in self.matrix_client.rooms.keys():
self.logger.log(f"Setting avatar for {room}...", "debug")
asyncio.create_task(
self.matrix_client.room_put_state(
room, "m.room.avatar", {"url": uri}, ""
)
)
# TODO: Implement
# Start syncing events
self.logger.log("Starting sync loop...", "warning")
try:
await self.matrix_client.sync_forever(timeout=30000)
await self.matrix_client.start(None)
finally:
self.logger.log("Syncing one last time...", "warning")
await self.matrix_client.sync(timeout=30000)
async def create_space(self, name, visibility=RoomVisibility.private) -> str:
async def create_space(
self, name, visibility=RoomDirectoryVisibility.PRIVATE
) -> str:
"""Create a space.
Args:
name (str): The name of the space.
visibility (RoomVisibility, optional): The visibility of the space. Defaults to RoomVisibility.private.
visibility (RoomDirectoryVisibility, optional): The visibility of the space. Defaults to RoomVisibility.private.
Returns:
MatrixRoom: The created space.
@ -780,13 +657,13 @@ class GPTBot:
return response.room_id
async def add_rooms_to_space(
self, space: MatrixRoom | str, rooms: List[MatrixRoom | str]
self, space: RoomID | str, rooms: List[RoomID | str]
):
"""Add rooms to a space.
Args:
space (MatrixRoom | str): The space to add the rooms to.
rooms (List[MatrixRoom | str]): The rooms to add to the space.
space (RoomID | str): The space to add the rooms to.
rooms (List[RoomID | str]): The rooms to add to the space.
"""
if isinstance(space, MatrixRoom):
@ -818,17 +695,17 @@ class GPTBot:
space,
)
def respond_to_room_messages(self, room: MatrixRoom | str) -> bool:
def respond_to_room_messages(self, room: RoomID | str) -> bool:
"""Check whether the bot should respond to all messages sent in a room.
Args:
room (MatrixRoom | str): The room to check.
room (RoomID | str): The room to check.
Returns:
bool: Whether the bot should respond to all messages sent in the room.
"""
if isinstance(room, MatrixRoom):
if isinstance(room, RoomID):
room = room.room_id
with closing(self.database.cursor()) as cursor:
@ -841,26 +718,26 @@ class GPTBot:
return True if not result else bool(int(result[0]))
async def process_query(
self, room: MatrixRoom, event: RoomMessageText, from_chat_command: bool = False
self, room: RoomID, event: MessageEvent, from_chat_command: bool = False
):
"""Process a query message. Generates a response and sends it to the room.
Args:
room (MatrixRoom): The room the message was sent in.
event (RoomMessageText): The event that triggered the query.
room (RoomID): The room the message was sent in.
event (MessageEvent): The event that triggered the query.
from_chat_command (bool, optional): Whether the query was sent via the `!gptbot chat` command. Defaults to False.
"""
if not (
from_chat_command
or self.respond_to_room_messages(room)
or self.matrix_client.user_id in event.body
or self.matrix_client.whoami().user_id in event.body
):
return
await self.matrix_client.room_typing(room.room_id, True)
# TODO: await self.matrix_client.room_typing(room.room_id, True)
await self.matrix_client.room_read_markers(room.room_id, event.event_id)
# TODO: await self.matrix_client.room_read_markers(room.room_id, event.event_id)
if (not from_chat_command) and self.room_uses_classification(room):
try:
@ -949,11 +826,11 @@ class GPTBot:
await self.matrix_client.room_typing(room.room_id, False)
def get_system_message(self, room: MatrixRoom | str) -> str:
def get_system_message(self, room: RoomID | str) -> str:
"""Get the system message for a room.
Args:
room (MatrixRoom | str): The room to get the system message for.
room (RoomID | str): The room to get the system message for.
Returns:
str: The system message.

View file

@ -21,13 +21,14 @@ class OpenAI:
return self.chat_model
classification_api = chat_api
image_api: str = "dalle"
image_model: str = "dall-e-2"
operator: str = "OpenAI ([https://openai.com](https://openai.com))"
def __init__(self, api_key, chat_model=None, logger=None):
def __init__(self, api_key, chat_model=None, image_model=None, logger=None):
self.api_key = api_key
self.chat_model = chat_model or self.chat_model
self.image_model = image_model or self.image_model
self.logger = logger or Logger()
self.base_url = openai.api_base
@ -146,6 +147,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
image_partial = partial(
openai.Image.acreate,
model=self.image_model,
prompt=prompt,
n=1,
api_key=self.api_key,