diff --git a/src/gptbot/classes/ai/openai.py b/src/gptbot/classes/ai/openai.py index 1ec931c..3239879 100644 --- a/src/gptbot/classes/ai/openai.py +++ b/src/gptbot/classes/ai/openai.py @@ -1,17 +1,29 @@ import openai import requests +import tiktoken +import base64 import json import inspect from functools import partial -from typing import Dict, List, Tuple, Generator, Optional, Mapping +from typing import Dict, List, Tuple, Generator, Optional, Mapping, Any from io import BytesIO from pydub import AudioSegment +from PIL import Image +from nio import ( + RoomMessageNotice, + RoomMessageText, + RoomMessageAudio, + RoomMessageFile, + RoomMessageImage, + RoomMessageVideo, +) from ..logging import Logger from ...tools import TOOLS, Handover, StopProcessing +from ..exceptions import DownloadException from .base import BaseAI, AttributeDictionary ASSISTANT_CODE_INTERPRETER = [ @@ -93,6 +105,10 @@ class OpenAI(BaseAI): def force_vision(self): return self._config.getboolean("ForceVision", fallback=False) + @property + def force_video_input(self): + return self._config.getboolean("ForceVideoInput", fallback=False) + @property def force_tools(self): return self._config.getboolean("ForceTools", fallback=False) @@ -135,6 +151,9 @@ class OpenAI(BaseAI): def supports_chat_images(self): return self._is_vision_model(self.chat_model) or self.force_vision + def supports_chat_videos(self): + return self.force_video_input + def json_decode(self, data): if data.startswith("```json\n"): data = data[8:] @@ -149,6 +168,306 @@ class OpenAI(BaseAI): except Exception: return False + async def prepare_messages( + self, event, messages: List[Dict[str, str]], system_message=None + ) -> List[Any]: + chat_messages = [] + + for message in messages: + if isinstance(message, (RoomMessageNotice, RoomMessageText)): + role = ( + "assistant" + if message.sender == self.bot.matrix_client.user_id + else "user" + ) + if message == event or (not message.event_id == event.event_id): + message_body = ( + message.body + if not self.supports_chat_images() + else [{"type": "text", "text": message.body}] + ) + chat_messages.append({"role": role, "content": message_body}) + + elif isinstance(message, RoomMessageAudio) or ( + isinstance(message, RoomMessageFile) and message.body.endswith(".mp3") + ): + role = ( + "assistant" + if message.sender == self.bot.matrix_client.user_id + else "user" + ) + if message == event or (not message.event_id == event.event_id): + if self.room_uses_stt(event.room_id): + try: + download = await self.bot.download_file( + message.url, raise_error=True + ) + message_text = await self.bot.stt_api.speech_to_text( + download.body + ) + except Exception as e: + self.logger.log( + f"Error generating text from audio: {e}", "error" + ) + message_text = message.body + else: + message_text = message.body + + message_body = ( + message_text + if not self.supports_chat_images() + else [{"type": "text", "text": message_text}] + ) + chat_messages.append({"role": role, "content": message_body}) + + elif isinstance(message, RoomMessageFile): + try: + download = await self.bot.download_file( + message.url, raise_error=True + ) + if download: + try: + text = download.body.decode("utf-8") + except UnicodeDecodeError: + text = None + + if text: + role = ( + "assistant" + if message.sender == self.bot.matrix_client.user_id + else "user" + ) + if message == event or ( + not message.event_id == event.event_id + ): + message_body = ( + text + if not self.supports_chat_images() + else [{"type": "text", "text": text}] + ) + chat_messages.append( + {"role": role, "content": message_body} + ) + + except Exception as e: + self.logger.log(f"Error generating text from file: {e}", "error") + message_body = ( + message.body + if not self.supports_chat_images() + else [{"type": "text", "text": message.body}] + ) + chat_messages.append({"role": "system", "content": message_body}) + + elif self.supports_chat_images() and isinstance(message, RoomMessageImage): + try: + image_url = message.url + download = await self.bot.download_file(image_url, raise_error=True) + + if download: + pil_image = Image.open(BytesIO(download.body)) + + file_format = pil_image.format or "PNG" + + max_long_side = self.max_image_long_side + max_short_side = self.max_image_short_side + + if max_long_side and max_short_side: + if pil_image.width > pil_image.height: + if pil_image.width > max_long_side: + pil_image.thumbnail((max_long_side, max_short_side)) + + else: + if pil_image.height > max_long_side: + pil_image.thumbnail((max_short_side, max_long_side)) + + bio = BytesIO() + + pil_image.save(bio, format=file_format) + + encoded_url = f"data:{download.content_type};base64,{base64.b64encode(bio.getvalue()).decode('utf-8')}" + parent = ( + chat_messages[-1] + if chat_messages + and chat_messages[-1]["role"] + == ( + "assistant" + if message.sender == self.bot.matrix_client.user_id + else "user" + ) + else None + ) + + if not parent: + chat_messages.append( + { + "role": ( + "assistant" + if message.sender == self.matrix_client.user_id + else "user" + ), + "content": [], + } + ) + parent = chat_messages[-1] + + parent["content"].append( + {"type": "image_url", "image_url": {"url": encoded_url}} + ) + + except Exception as e: + if isinstance(e, DownloadException): + self.bot.send_message( + event.room_id, + f"Could not process image due to download error: {e.args[0]}", + True, + ) + + self.logger.log(f"Error generating image from file: {e}", "error") + message_body = ( + message.body + if not self.supports_chat_images() + else [{"type": "text", "text": message.body}] + ) + chat_messages.append({"role": "system", "content": message_body}) + + elif self.supports_chat_videos() and ( + isinstance(message, RoomMessageVideo) + or ( + isinstance(message, RoomMessageFile) + and message.body.endswith(".mp4") + ) + ): + try: + video_url = message.url + download = await self.bot.download_file(video_url, raise_error=True) + + if download: + video = BytesIO(download.body) + video_url = f"data:{download.content_type};base64,{base64.b64encode(video.getvalue()).decode('utf-8')}" + + parent = ( + chat_messages[-1] + if chat_messages + and chat_messages[-1]["role"] + == ( + "assistant" + if message.sender == self.bot.matrix_client.user_id + else "user" + ) + else None + ) + + if not parent: + chat_messages.append( + { + "role": ( + "assistant" + if message.sender == self.matrix_client.user_id + else "user" + ), + "content": [], + } + ) + parent = chat_messages[-1] + + parent["content"].append( + {"type": "image_url", "image_url": {"url": video_url}} + ) + + except Exception as e: + if isinstance(e, DownloadException): + self.bot.send_message( + event.room_id, + f"Could not process video due to download error: {e.args[0]}", + True, + ) + + self.logger.log(f"Error generating video from file: {e}", "error") + message_body = ( + message.body + if not self.supports_chat_images() + else [{"type": "text", "text": message.body}] + ) + chat_messages.append({"role": "system", "content": message_body}) + + # Truncate messages to fit within the token limit + self._truncate( + messages=chat_messages, + max_tokens=self.chat_api.max_tokens - 1, + system_message=system_message, + ) + + return chat_messages + + def _truncate( + self, + messages: List[Any], + max_tokens: Optional[int] = None, + model: Optional[str] = None, + system_message: Optional[str] = None, + ) -> List[Any]: + """Truncate messages to fit within the token limit. + + Args: + messages (List[Any]): The messages to truncate. + max_tokens (Optional[int], optional): The maximum number of tokens to use. Defaults to None, which uses the default token limit. + model (Optional[str], optional): The model to use. Defaults to None, which uses the default chat model. + system_message (Optional[str], optional): The system message to use. Defaults to None, which uses the default system message. + + Returns: + List[Any]: The truncated messages. + """ + + max_tokens = max_tokens or self.max_tokens + model = model or self.chat_model + system_message = ( + self.bot.default_system_message + if system_message is None + else system_message + ) + + try: + encoding = tiktoken.encoding_for_model(model) + except Exception: + # TODO: Handle this more gracefully + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + + total_tokens = 0 + + system_message_tokens = ( + 0 if not system_message else (len(encoding.encode(system_message)) + 1) + ) + + if system_message_tokens > max_tokens: + self.logger.log( + f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", + "error", + ) + return [] + + total_tokens += system_message_tokens + + total_tokens = len(system_message) + 1 + truncated_messages = [] + + for message in [messages[0]] + list(reversed(messages[1:])): + content = ( + message["content"] + if isinstance(message["content"], str) + else ( + message["content"][0]["text"] + if isinstance(message["content"][0].get("text"), str) + else "" + ) + ) + tokens = len(encoding.encode(content)) + 1 + if total_tokens + tokens > max_tokens: + break + total_tokens += tokens + truncated_messages.append(message) + + return [truncated_messages[0]] + list(reversed(truncated_messages[1:])) + async def generate_chat_response( self, messages: List[Dict[str, str]],