feat(openai.py): expand message handling capabilities

Enhanced the OpenAI class to better support diverse message types in chat interactions, including image and video processing. This update introduces several key improvements:
- Added handling for image and video messages, converting them to a format compatible with the OpenAI API.
- Implemented a new method to prepare messages for OpenAI, allowing for richer interaction by including media content directly within the chat.
- Incorporated message truncation to adhere to token limits, ensuring efficient usage of OpenAI's API without sacrificing message content.
- Extended support for additional message types, such as audio and file messages, with specialized processing for each category.

This change aims to enhance user experience by allowing more dynamic and multimedia-rich interactions, aligning with modern chat functionalities. It also addresses potential issues with token limit surpassing and ensures smoother integration of different message formats into the chat flow.
This commit is contained in:
Kumi 2024-05-25 17:35:05 +02:00
parent 75e637546a
commit 05ba26d540
Signed by: kumi
GPG key ID: ECBCC9082395383F

View file

@ -1,17 +1,29 @@
import openai import openai
import requests import requests
import tiktoken
import base64
import json import json
import inspect import inspect
from functools import partial from functools import partial
from typing import Dict, List, Tuple, Generator, Optional, Mapping from typing import Dict, List, Tuple, Generator, Optional, Mapping, Any
from io import BytesIO from io import BytesIO
from pydub import AudioSegment from pydub import AudioSegment
from PIL import Image
from nio import (
RoomMessageNotice,
RoomMessageText,
RoomMessageAudio,
RoomMessageFile,
RoomMessageImage,
RoomMessageVideo,
)
from ..logging import Logger from ..logging import Logger
from ...tools import TOOLS, Handover, StopProcessing from ...tools import TOOLS, Handover, StopProcessing
from ..exceptions import DownloadException
from .base import BaseAI, AttributeDictionary from .base import BaseAI, AttributeDictionary
ASSISTANT_CODE_INTERPRETER = [ ASSISTANT_CODE_INTERPRETER = [
@ -93,6 +105,10 @@ class OpenAI(BaseAI):
def force_vision(self): def force_vision(self):
return self._config.getboolean("ForceVision", fallback=False) return self._config.getboolean("ForceVision", fallback=False)
@property
def force_video_input(self):
return self._config.getboolean("ForceVideoInput", fallback=False)
@property @property
def force_tools(self): def force_tools(self):
return self._config.getboolean("ForceTools", fallback=False) return self._config.getboolean("ForceTools", fallback=False)
@ -135,6 +151,9 @@ class OpenAI(BaseAI):
def supports_chat_images(self): def supports_chat_images(self):
return self._is_vision_model(self.chat_model) or self.force_vision return self._is_vision_model(self.chat_model) or self.force_vision
def supports_chat_videos(self):
return self.force_video_input
def json_decode(self, data): def json_decode(self, data):
if data.startswith("```json\n"): if data.startswith("```json\n"):
data = data[8:] data = data[8:]
@ -149,6 +168,306 @@ class OpenAI(BaseAI):
except Exception: except Exception:
return False return False
async def prepare_messages(
self, event, messages: List[Dict[str, str]], system_message=None
) -> List[Any]:
chat_messages = []
for message in messages:
if isinstance(message, (RoomMessageNotice, RoomMessageText)):
role = (
"assistant"
if message.sender == self.bot.matrix_client.user_id
else "user"
)
if message == event or (not message.event_id == event.event_id):
message_body = (
message.body
if not self.supports_chat_images()
else [{"type": "text", "text": message.body}]
)
chat_messages.append({"role": role, "content": message_body})
elif isinstance(message, RoomMessageAudio) or (
isinstance(message, RoomMessageFile) and message.body.endswith(".mp3")
):
role = (
"assistant"
if message.sender == self.bot.matrix_client.user_id
else "user"
)
if message == event or (not message.event_id == event.event_id):
if self.room_uses_stt(event.room_id):
try:
download = await self.bot.download_file(
message.url, raise_error=True
)
message_text = await self.bot.stt_api.speech_to_text(
download.body
)
except Exception as e:
self.logger.log(
f"Error generating text from audio: {e}", "error"
)
message_text = message.body
else:
message_text = message.body
message_body = (
message_text
if not self.supports_chat_images()
else [{"type": "text", "text": message_text}]
)
chat_messages.append({"role": role, "content": message_body})
elif isinstance(message, RoomMessageFile):
try:
download = await self.bot.download_file(
message.url, raise_error=True
)
if download:
try:
text = download.body.decode("utf-8")
except UnicodeDecodeError:
text = None
if text:
role = (
"assistant"
if message.sender == self.bot.matrix_client.user_id
else "user"
)
if message == event or (
not message.event_id == event.event_id
):
message_body = (
text
if not self.supports_chat_images()
else [{"type": "text", "text": text}]
)
chat_messages.append(
{"role": role, "content": message_body}
)
except Exception as e:
self.logger.log(f"Error generating text from file: {e}", "error")
message_body = (
message.body
if not self.supports_chat_images()
else [{"type": "text", "text": message.body}]
)
chat_messages.append({"role": "system", "content": message_body})
elif self.supports_chat_images() and isinstance(message, RoomMessageImage):
try:
image_url = message.url
download = await self.bot.download_file(image_url, raise_error=True)
if download:
pil_image = Image.open(BytesIO(download.body))
file_format = pil_image.format or "PNG"
max_long_side = self.max_image_long_side
max_short_side = self.max_image_short_side
if max_long_side and max_short_side:
if pil_image.width > pil_image.height:
if pil_image.width > max_long_side:
pil_image.thumbnail((max_long_side, max_short_side))
else:
if pil_image.height > max_long_side:
pil_image.thumbnail((max_short_side, max_long_side))
bio = BytesIO()
pil_image.save(bio, format=file_format)
encoded_url = f"data:{download.content_type};base64,{base64.b64encode(bio.getvalue()).decode('utf-8')}"
parent = (
chat_messages[-1]
if chat_messages
and chat_messages[-1]["role"]
== (
"assistant"
if message.sender == self.bot.matrix_client.user_id
else "user"
)
else None
)
if not parent:
chat_messages.append(
{
"role": (
"assistant"
if message.sender == self.matrix_client.user_id
else "user"
),
"content": [],
}
)
parent = chat_messages[-1]
parent["content"].append(
{"type": "image_url", "image_url": {"url": encoded_url}}
)
except Exception as e:
if isinstance(e, DownloadException):
self.bot.send_message(
event.room_id,
f"Could not process image due to download error: {e.args[0]}",
True,
)
self.logger.log(f"Error generating image from file: {e}", "error")
message_body = (
message.body
if not self.supports_chat_images()
else [{"type": "text", "text": message.body}]
)
chat_messages.append({"role": "system", "content": message_body})
elif self.supports_chat_videos() and (
isinstance(message, RoomMessageVideo)
or (
isinstance(message, RoomMessageFile)
and message.body.endswith(".mp4")
)
):
try:
video_url = message.url
download = await self.bot.download_file(video_url, raise_error=True)
if download:
video = BytesIO(download.body)
video_url = f"data:{download.content_type};base64,{base64.b64encode(video.getvalue()).decode('utf-8')}"
parent = (
chat_messages[-1]
if chat_messages
and chat_messages[-1]["role"]
== (
"assistant"
if message.sender == self.bot.matrix_client.user_id
else "user"
)
else None
)
if not parent:
chat_messages.append(
{
"role": (
"assistant"
if message.sender == self.matrix_client.user_id
else "user"
),
"content": [],
}
)
parent = chat_messages[-1]
parent["content"].append(
{"type": "image_url", "image_url": {"url": video_url}}
)
except Exception as e:
if isinstance(e, DownloadException):
self.bot.send_message(
event.room_id,
f"Could not process video due to download error: {e.args[0]}",
True,
)
self.logger.log(f"Error generating video from file: {e}", "error")
message_body = (
message.body
if not self.supports_chat_images()
else [{"type": "text", "text": message.body}]
)
chat_messages.append({"role": "system", "content": message_body})
# Truncate messages to fit within the token limit
self._truncate(
messages=chat_messages,
max_tokens=self.chat_api.max_tokens - 1,
system_message=system_message,
)
return chat_messages
def _truncate(
self,
messages: List[Any],
max_tokens: Optional[int] = None,
model: Optional[str] = None,
system_message: Optional[str] = None,
) -> List[Any]:
"""Truncate messages to fit within the token limit.
Args:
messages (List[Any]): The messages to truncate.
max_tokens (Optional[int], optional): The maximum number of tokens to use. Defaults to None, which uses the default token limit.
model (Optional[str], optional): The model to use. Defaults to None, which uses the default chat model.
system_message (Optional[str], optional): The system message to use. Defaults to None, which uses the default system message.
Returns:
List[Any]: The truncated messages.
"""
max_tokens = max_tokens or self.max_tokens
model = model or self.chat_model
system_message = (
self.bot.default_system_message
if system_message is None
else system_message
)
try:
encoding = tiktoken.encoding_for_model(model)
except Exception:
# TODO: Handle this more gracefully
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
total_tokens = 0
system_message_tokens = (
0 if not system_message else (len(encoding.encode(system_message)) + 1)
)
if system_message_tokens > max_tokens:
self.logger.log(
f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed",
"error",
)
return []
total_tokens += system_message_tokens
total_tokens = len(system_message) + 1
truncated_messages = []
for message in [messages[0]] + list(reversed(messages[1:])):
content = (
message["content"]
if isinstance(message["content"], str)
else (
message["content"][0]["text"]
if isinstance(message["content"][0].get("text"), str)
else ""
)
)
tokens = len(encoding.encode(content)) + 1
if total_tokens + tokens > max_tokens:
break
total_tokens += tokens
truncated_messages.append(message)
return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
async def generate_chat_response( async def generate_chat_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],