diff --git a/config.dist.ini b/config.dist.ini index 5eb3595..4409556 100644 --- a/config.dist.ini +++ b/config.dist.ini @@ -149,6 +149,15 @@ APIKey = sk-yoursecretkey # MaxImageLongSide = 2000 # MaxImageShortSide = 768 +# Whether the used model supports video files as input +# +# If you are using a model that supports video files as input, set this to 1. +# This will make the bot send video files to the model as well as images. +# This may be possible with some self-hosted models, but is not supported by +# the OpenAI API at this time. +# +# ForceVideoInput = 0 + # Advanced settings for the OpenAI API # # These settings are not required for normal operation, but can be used to diff --git a/pyproject.toml b/pyproject.toml index 46ac095..25ad060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ allow-direct-references = true [project] name = "matrix-gptbot" -version = "0.3.14" +version = "0.3.15" authors = [ { name = "Kumi Mitterer", email = "gptbot@kumi.email" }, @@ -39,12 +39,14 @@ dependencies = [ [project.optional-dependencies] openai = ["openai>=1.2", "pydub"] +google = ["google-generativeai"] + wolframalpha = ["wolframalpha"] trackingmore = ["trackingmore-api-tool"] all = [ - "matrix-gptbot[openai,wolframalpha,trackingmore]", + "matrix-gptbot[openai,wolframalpha,trackingmore,google]", "geopy", "beautifulsoup4", ] diff --git a/src/gptbot/classes/ai/base.py b/src/gptbot/classes/ai/base.py index 3b83b9e..03d46b4 100644 --- a/src/gptbot/classes/ai/base.py +++ b/src/gptbot/classes/ai/base.py @@ -4,6 +4,8 @@ import asyncio from functools import partial from typing import Any, AsyncGenerator, Dict, Optional, Mapping +from nio import Event + class AttributeDictionary(dict): def __init__(self, *args, **kwargs): @@ -24,6 +26,29 @@ class BaseAI: def chat_api(self) -> str: return self.chat_model + async def prepare_messages( + self, event: Event, messages: list[Any], system_message: Optional[str] = None + ) -> list[Any]: + """A helper method to prepare messages for the AI. + + This converts a list of Matrix messages into whatever format the AI requires. + + Args: + event (Event): The event that triggered the message generation. Generally a text message from a user. + messages (list[Dict[str, str]]): The messages to prepare. Generally of type RoomMessage*. + system_message (Optional[str], optional): A system message to include. Defaults to None. + + Returns: + list[Any]: The prepared messages in the format the AI requires. + + Raises: + NotImplementedError: If the method is not implemented in the subclass. + """ + + raise NotImplementedError( + "Implementations of BaseAI must implement prepare_messages." + ) + async def _request_with_retries( self, request: partial, attempts: int = 5, retry_interval: int = 2 ) -> AsyncGenerator[Any | list | Dict, None]: diff --git a/src/gptbot/classes/ai/google.py b/src/gptbot/classes/ai/google.py new file mode 100644 index 0000000..93c419a --- /dev/null +++ b/src/gptbot/classes/ai/google.py @@ -0,0 +1,73 @@ +from .base import BaseAI +from ..logging import Logger + +from typing import Optional, Mapping, List, Dict, Tuple + +import google.generativeai as genai + + +class GeminiAI(BaseAI): + api_code: str = "google" + + @property + def chat_api(self) -> str: + return self.chat_model + + google_api: genai.GenerativeModel + + operator: str = "Google (https://ai.google)" + + def __init__( + self, + bot, + config: Mapping, + logger: Optional[Logger] = None, + ): + super().__init__(bot, config, logger) + genai.configure(api_key=self.api_key) + self.gemini_api = genai.GenerativeModel(self.chat_model) + + @property + def api_key(self): + return self._config["APIKey"] + + @property + def chat_model(self): + return self._config.get("Model", fallback="gemini-pro") + + def prepare_messages(event, messages: List[Dict[str, str]], ) -> List[str]: + return [message["content"] for message in messages] + + async def generate_chat_response( + self, + messages: List[Dict[str, str]], + user: Optional[str] = None, + room: Optional[str] = None, + use_tools: bool = True, + model: Optional[str] = None, + ) -> Tuple[str, int]: + """Generate a response to a chat message. + + Args: + messages (List[Dict[str, str]]): A list of messages to use as context. + user (Optional[str], optional): The user to use the assistant for. Defaults to None. + room (Optional[str], optional): The room to use the assistant for. Defaults to None. + use_tools (bool, optional): Whether to use tools. Defaults to True. + model (Optional[str], optional): The model to use. Defaults to None, which uses the default chat model. + + Returns: + Tuple[str, int]: The response text and the number of tokens used. + """ + self.logger.log( + f"Generating response to {len(messages)} messages for user {user} in room {room}..." + ) + + messages = self.prepare_messages(messages) + + return self.gemini_api.generate_content( + messages=messages, + user=user, + room=room, + use_tools=use_tools, + model=model, + ) diff --git a/src/gptbot/classes/bot.py b/src/gptbot/classes/bot.py index 41ea82d..c186c6e 100644 --- a/src/gptbot/classes/bot.py +++ b/src/gptbot/classes/bot.py @@ -24,9 +24,6 @@ from nio import ( RoomVisibility, RoomCreateError, RoomMessageMedia, - RoomMessageImage, - RoomMessageFile, - RoomMessageAudio, DownloadError, RoomGetStateError, DiskDownloadResponse, @@ -43,7 +40,6 @@ from io import BytesIO from pathlib import Path from contextlib import closing -import base64 import uuid import traceback import json @@ -362,61 +358,6 @@ class GPTBot: # Reverse the list so that messages are in chronological order return messages[::-1] - def _truncate( - self, - messages: list, - max_tokens: Optional[int] = None, - model: Optional[str] = None, - system_message: Optional[str] = None, - ): - max_tokens = max_tokens or self.chat_api.max_tokens - model = model or self.chat_api.chat_model - system_message = ( - self.default_system_message if system_message is None else system_message - ) - - try: - encoding = tiktoken.encoding_for_model(model) - except Exception: - # TODO: Handle this more gracefully - encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") - - total_tokens = 0 - - system_message_tokens = ( - 0 if not system_message else (len(encoding.encode(system_message)) + 1) - ) - - if system_message_tokens > max_tokens: - self.logger.log( - f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", - "error", - ) - return [] - - total_tokens += system_message_tokens - - total_tokens = len(system_message) + 1 - truncated_messages = [] - - for message in [messages[0]] + list(reversed(messages[1:])): - content = ( - message["content"] - if isinstance(message["content"], str) - else ( - message["content"][0]["text"] - if isinstance(message["content"][0].get("text"), str) - else "" - ) - ) - tokens = len(encoding.encode(content)) + 1 - if total_tokens + tokens > max_tokens: - break - total_tokens += tokens - truncated_messages.append(message) - - return [truncated_messages[0]] + list(reversed(truncated_messages[1:])) - async def _get_device_id(self) -> str: """Guess the device ID of the bot. Requires an access token to be set up. @@ -1171,172 +1112,8 @@ class GPTBot: system_message = self.get_system_message(room) - chat_messages = [{"role": "system", "content": system_message}] - - last_messages = last_messages + [event] - - for message in last_messages: - if isinstance(message, (RoomMessageNotice, RoomMessageText)): - role = ( - "assistant" - if message.sender == self.matrix_client.user_id - else "user" - ) - if message == event or (not message.event_id == event.event_id): - message_body = ( - message.body - if not self.chat_api.supports_chat_images() - else [{"type": "text", "text": message.body}] - ) - chat_messages.append({"role": role, "content": message_body}) - - elif isinstance(message, RoomMessageAudio) or ( - isinstance(message, RoomMessageFile) and message.body.endswith(".mp3") - ): - role = ( - "assistant" - if message.sender == self.matrix_client.user_id - else "user" - ) - if message == event or (not message.event_id == event.event_id): - if self.room_uses_stt(room): - try: - download = await self.download_file( - message.url, raise_error=True - ) - message_text = await self.stt_api.speech_to_text( - download.body - ) - except Exception as e: - self.logger.log( - f"Error generating text from audio: {e}", "error" - ) - message_text = message.body - else: - message_text = message.body - - message_body = ( - message_text - if not self.chat_api.supports_chat_images() - else [{"type": "text", "text": message_text}] - ) - chat_messages.append({"role": role, "content": message_body}) - - elif isinstance(message, RoomMessageFile): - try: - download = await self.download_file(message.url, raise_error=True) - if download: - try: - text = download.body.decode("utf-8") - except UnicodeDecodeError: - text = None - - if text: - role = ( - "assistant" - if message.sender == self.matrix_client.user_id - else "user" - ) - if message == event or ( - not message.event_id == event.event_id - ): - message_body = ( - text - if not self.chat_api.supports_chat_images() - else [{"type": "text", "text": text}] - ) - chat_messages.append( - {"role": role, "content": message_body} - ) - - except Exception as e: - self.logger.log(f"Error generating text from file: {e}", "error") - message_body = ( - message.body - if not self.chat_api.supports_chat_images() - else [{"type": "text", "text": message.body}] - ) - chat_messages.append({"role": "system", "content": message_body}) - - elif self.chat_api.supports_chat_images() and isinstance( - message, RoomMessageImage - ): - try: - image_url = message.url - download = await self.download_file(image_url, raise_error=True) - - if download: - pil_image = Image.open(BytesIO(download.body)) - - file_format = pil_image.format or "PNG" - - max_long_side = self.chat_api.max_image_long_side - max_short_side = self.chat_api.max_image_short_side - - if max_long_side and max_short_side: - if pil_image.width > pil_image.height: - if pil_image.width > max_long_side: - pil_image.thumbnail((max_long_side, max_short_side)) - - else: - if pil_image.height > max_long_side: - pil_image.thumbnail((max_short_side, max_long_side)) - - bio = BytesIO() - - pil_image.save(bio, format=file_format) - - encoded_url = f"data:{download.content_type};base64,{base64.b64encode(bio.getvalue()).decode('utf-8')}" - parent = ( - chat_messages[-1] - if chat_messages - and chat_messages[-1]["role"] - == ( - "assistant" - if message.sender == self.matrix_client.user_id - else "user" - ) - else None - ) - - if not parent: - chat_messages.append( - { - "role": ( - "assistant" - if message.sender == self.matrix_client.user_id - else "user" - ), - "content": [], - } - ) - parent = chat_messages[-1] - - parent["content"].append( - {"type": "image_url", "image_url": {"url": encoded_url}} - ) - - except Exception as e: - if isinstance(e, DownloadException): - self.send_message( - room, - f"Could not process image due to download error: {e.args[0]}", - True, - ) - - self.logger.log(f"Error generating image from file: {e}", "error") - message_body = ( - message.body - if not self.chat_api.supports_chat_images() - else [{"type": "text", "text": message.body}] - ) - chat_messages.append({"role": "system", "content": message_body}) - - # Truncate messages to fit within the token limit - self._truncate( - chat_messages[1:], - self.chat_api.max_tokens - 1, - system_message=system_message, + chat_messages = await self.chat_api.prepare_messages( + last_messages, system_message ) # Check for a model override @@ -1382,7 +1159,7 @@ class GPTBot: room, "Something went wrong generating audio file.", True ) - message = await self.send_message(room, response) + await self.send_message(room, response) await self.matrix_client.room_typing(room.room_id, False)