From 05ba26d540610677f6e9f9a08a209fad5f030ebd Mon Sep 17 00:00:00 2001 From: Kumi Date: Sat, 25 May 2024 17:35:05 +0200 Subject: [PATCH] feat(openai.py): expand message handling capabilities Enhanced the OpenAI class to better support diverse message types in chat interactions, including image and video processing. This update introduces several key improvements: - Added handling for image and video messages, converting them to a format compatible with the OpenAI API. - Implemented a new method to prepare messages for OpenAI, allowing for richer interaction by including media content directly within the chat. - Incorporated message truncation to adhere to token limits, ensuring efficient usage of OpenAI's API without sacrificing message content. - Extended support for additional message types, such as audio and file messages, with specialized processing for each category. This change aims to enhance user experience by allowing more dynamic and multimedia-rich interactions, aligning with modern chat functionalities. It also addresses potential issues with token limit surpassing and ensures smoother integration of different message formats into the chat flow. --- src/gptbot/classes/ai/openai.py | 321 +++++++++++++++++++++++++++++++- 1 file changed, 320 insertions(+), 1 deletion(-) diff --git a/src/gptbot/classes/ai/openai.py b/src/gptbot/classes/ai/openai.py index 1ec931c..3239879 100644 --- a/src/gptbot/classes/ai/openai.py +++ b/src/gptbot/classes/ai/openai.py @@ -1,17 +1,29 @@ import openai import requests +import tiktoken +import base64 import json import inspect from functools import partial -from typing import Dict, List, Tuple, Generator, Optional, Mapping +from typing import Dict, List, Tuple, Generator, Optional, Mapping, Any from io import BytesIO from pydub import AudioSegment +from PIL import Image +from nio import ( + RoomMessageNotice, + RoomMessageText, + RoomMessageAudio, + RoomMessageFile, + RoomMessageImage, + RoomMessageVideo, +) from ..logging import Logger from ...tools import TOOLS, Handover, StopProcessing +from ..exceptions import DownloadException from .base import BaseAI, AttributeDictionary ASSISTANT_CODE_INTERPRETER = [ @@ -93,6 +105,10 @@ class OpenAI(BaseAI): def force_vision(self): return self._config.getboolean("ForceVision", fallback=False) + @property + def force_video_input(self): + return self._config.getboolean("ForceVideoInput", fallback=False) + @property def force_tools(self): return self._config.getboolean("ForceTools", fallback=False) @@ -135,6 +151,9 @@ class OpenAI(BaseAI): def supports_chat_images(self): return self._is_vision_model(self.chat_model) or self.force_vision + def supports_chat_videos(self): + return self.force_video_input + def json_decode(self, data): if data.startswith("```json\n"): data = data[8:] @@ -149,6 +168,306 @@ class OpenAI(BaseAI): except Exception: return False + async def prepare_messages( + self, event, messages: List[Dict[str, str]], system_message=None + ) -> List[Any]: + chat_messages = [] + + for message in messages: + if isinstance(message, (RoomMessageNotice, RoomMessageText)): + role = ( + "assistant" + if message.sender == self.bot.matrix_client.user_id + else "user" + ) + if message == event or (not message.event_id == event.event_id): + message_body = ( + message.body + if not self.supports_chat_images() + else [{"type": "text", "text": message.body}] + ) + chat_messages.append({"role": role, "content": message_body}) + + elif isinstance(message, RoomMessageAudio) or ( + isinstance(message, RoomMessageFile) and message.body.endswith(".mp3") + ): + role = ( + "assistant" + if message.sender == self.bot.matrix_client.user_id + else "user" + ) + if message == event or (not message.event_id == event.event_id): + if self.room_uses_stt(event.room_id): + try: + download = await self.bot.download_file( + message.url, raise_error=True + ) + message_text = await self.bot.stt_api.speech_to_text( + download.body + ) + except Exception as e: + self.logger.log( + f"Error generating text from audio: {e}", "error" + ) + message_text = message.body + else: + message_text = message.body + + message_body = ( + message_text + if not self.supports_chat_images() + else [{"type": "text", "text": message_text}] + ) + chat_messages.append({"role": role, "content": message_body}) + + elif isinstance(message, RoomMessageFile): + try: + download = await self.bot.download_file( + message.url, raise_error=True + ) + if download: + try: + text = download.body.decode("utf-8") + except UnicodeDecodeError: + text = None + + if text: + role = ( + "assistant" + if message.sender == self.bot.matrix_client.user_id + else "user" + ) + if message == event or ( + not message.event_id == event.event_id + ): + message_body = ( + text + if not self.supports_chat_images() + else [{"type": "text", "text": text}] + ) + chat_messages.append( + {"role": role, "content": message_body} + ) + + except Exception as e: + self.logger.log(f"Error generating text from file: {e}", "error") + message_body = ( + message.body + if not self.supports_chat_images() + else [{"type": "text", "text": message.body}] + ) + chat_messages.append({"role": "system", "content": message_body}) + + elif self.supports_chat_images() and isinstance(message, RoomMessageImage): + try: + image_url = message.url + download = await self.bot.download_file(image_url, raise_error=True) + + if download: + pil_image = Image.open(BytesIO(download.body)) + + file_format = pil_image.format or "PNG" + + max_long_side = self.max_image_long_side + max_short_side = self.max_image_short_side + + if max_long_side and max_short_side: + if pil_image.width > pil_image.height: + if pil_image.width > max_long_side: + pil_image.thumbnail((max_long_side, max_short_side)) + + else: + if pil_image.height > max_long_side: + pil_image.thumbnail((max_short_side, max_long_side)) + + bio = BytesIO() + + pil_image.save(bio, format=file_format) + + encoded_url = f"data:{download.content_type};base64,{base64.b64encode(bio.getvalue()).decode('utf-8')}" + parent = ( + chat_messages[-1] + if chat_messages + and chat_messages[-1]["role"] + == ( + "assistant" + if message.sender == self.bot.matrix_client.user_id + else "user" + ) + else None + ) + + if not parent: + chat_messages.append( + { + "role": ( + "assistant" + if message.sender == self.matrix_client.user_id + else "user" + ), + "content": [], + } + ) + parent = chat_messages[-1] + + parent["content"].append( + {"type": "image_url", "image_url": {"url": encoded_url}} + ) + + except Exception as e: + if isinstance(e, DownloadException): + self.bot.send_message( + event.room_id, + f"Could not process image due to download error: {e.args[0]}", + True, + ) + + self.logger.log(f"Error generating image from file: {e}", "error") + message_body = ( + message.body + if not self.supports_chat_images() + else [{"type": "text", "text": message.body}] + ) + chat_messages.append({"role": "system", "content": message_body}) + + elif self.supports_chat_videos() and ( + isinstance(message, RoomMessageVideo) + or ( + isinstance(message, RoomMessageFile) + and message.body.endswith(".mp4") + ) + ): + try: + video_url = message.url + download = await self.bot.download_file(video_url, raise_error=True) + + if download: + video = BytesIO(download.body) + video_url = f"data:{download.content_type};base64,{base64.b64encode(video.getvalue()).decode('utf-8')}" + + parent = ( + chat_messages[-1] + if chat_messages + and chat_messages[-1]["role"] + == ( + "assistant" + if message.sender == self.bot.matrix_client.user_id + else "user" + ) + else None + ) + + if not parent: + chat_messages.append( + { + "role": ( + "assistant" + if message.sender == self.matrix_client.user_id + else "user" + ), + "content": [], + } + ) + parent = chat_messages[-1] + + parent["content"].append( + {"type": "image_url", "image_url": {"url": video_url}} + ) + + except Exception as e: + if isinstance(e, DownloadException): + self.bot.send_message( + event.room_id, + f"Could not process video due to download error: {e.args[0]}", + True, + ) + + self.logger.log(f"Error generating video from file: {e}", "error") + message_body = ( + message.body + if not self.supports_chat_images() + else [{"type": "text", "text": message.body}] + ) + chat_messages.append({"role": "system", "content": message_body}) + + # Truncate messages to fit within the token limit + self._truncate( + messages=chat_messages, + max_tokens=self.chat_api.max_tokens - 1, + system_message=system_message, + ) + + return chat_messages + + def _truncate( + self, + messages: List[Any], + max_tokens: Optional[int] = None, + model: Optional[str] = None, + system_message: Optional[str] = None, + ) -> List[Any]: + """Truncate messages to fit within the token limit. + + Args: + messages (List[Any]): The messages to truncate. + max_tokens (Optional[int], optional): The maximum number of tokens to use. Defaults to None, which uses the default token limit. + model (Optional[str], optional): The model to use. Defaults to None, which uses the default chat model. + system_message (Optional[str], optional): The system message to use. Defaults to None, which uses the default system message. + + Returns: + List[Any]: The truncated messages. + """ + + max_tokens = max_tokens or self.max_tokens + model = model or self.chat_model + system_message = ( + self.bot.default_system_message + if system_message is None + else system_message + ) + + try: + encoding = tiktoken.encoding_for_model(model) + except Exception: + # TODO: Handle this more gracefully + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + + total_tokens = 0 + + system_message_tokens = ( + 0 if not system_message else (len(encoding.encode(system_message)) + 1) + ) + + if system_message_tokens > max_tokens: + self.logger.log( + f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", + "error", + ) + return [] + + total_tokens += system_message_tokens + + total_tokens = len(system_message) + 1 + truncated_messages = [] + + for message in [messages[0]] + list(reversed(messages[1:])): + content = ( + message["content"] + if isinstance(message["content"], str) + else ( + message["content"][0]["text"] + if isinstance(message["content"][0].get("text"), str) + else "" + ) + ) + tokens = len(encoding.encode(content)) + 1 + if total_tokens + tokens > max_tokens: + break + total_tokens += tokens + truncated_messages.append(message) + + return [truncated_messages[0]] + list(reversed(truncated_messages[1:])) + async def generate_chat_response( self, messages: List[Dict[str, str]],