feat(openai.py): expand message handling capabilities
Enhanced the OpenAI class to better support diverse message types in chat interactions, including image and video processing. This update introduces several key improvements: - Added handling for image and video messages, converting them to a format compatible with the OpenAI API. - Implemented a new method to prepare messages for OpenAI, allowing for richer interaction by including media content directly within the chat. - Incorporated message truncation to adhere to token limits, ensuring efficient usage of OpenAI's API without sacrificing message content. - Extended support for additional message types, such as audio and file messages, with specialized processing for each category. This change aims to enhance user experience by allowing more dynamic and multimedia-rich interactions, aligning with modern chat functionalities. It also addresses potential issues with token limit surpassing and ensures smoother integration of different message formats into the chat flow.
This commit is contained in:
parent
75e637546a
commit
05ba26d540
1 changed files with 320 additions and 1 deletions
|
@ -1,17 +1,29 @@
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Tuple, Generator, Optional, Mapping
|
from typing import Dict, List, Tuple, Generator, Optional, Mapping, Any
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
from PIL import Image
|
||||||
|
from nio import (
|
||||||
|
RoomMessageNotice,
|
||||||
|
RoomMessageText,
|
||||||
|
RoomMessageAudio,
|
||||||
|
RoomMessageFile,
|
||||||
|
RoomMessageImage,
|
||||||
|
RoomMessageVideo,
|
||||||
|
)
|
||||||
|
|
||||||
from ..logging import Logger
|
from ..logging import Logger
|
||||||
from ...tools import TOOLS, Handover, StopProcessing
|
from ...tools import TOOLS, Handover, StopProcessing
|
||||||
|
from ..exceptions import DownloadException
|
||||||
from .base import BaseAI, AttributeDictionary
|
from .base import BaseAI, AttributeDictionary
|
||||||
|
|
||||||
ASSISTANT_CODE_INTERPRETER = [
|
ASSISTANT_CODE_INTERPRETER = [
|
||||||
|
@ -93,6 +105,10 @@ class OpenAI(BaseAI):
|
||||||
def force_vision(self):
|
def force_vision(self):
|
||||||
return self._config.getboolean("ForceVision", fallback=False)
|
return self._config.getboolean("ForceVision", fallback=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def force_video_input(self):
|
||||||
|
return self._config.getboolean("ForceVideoInput", fallback=False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def force_tools(self):
|
def force_tools(self):
|
||||||
return self._config.getboolean("ForceTools", fallback=False)
|
return self._config.getboolean("ForceTools", fallback=False)
|
||||||
|
@ -135,6 +151,9 @@ class OpenAI(BaseAI):
|
||||||
def supports_chat_images(self):
|
def supports_chat_images(self):
|
||||||
return self._is_vision_model(self.chat_model) or self.force_vision
|
return self._is_vision_model(self.chat_model) or self.force_vision
|
||||||
|
|
||||||
|
def supports_chat_videos(self):
|
||||||
|
return self.force_video_input
|
||||||
|
|
||||||
def json_decode(self, data):
|
def json_decode(self, data):
|
||||||
if data.startswith("```json\n"):
|
if data.startswith("```json\n"):
|
||||||
data = data[8:]
|
data = data[8:]
|
||||||
|
@ -149,6 +168,306 @@ class OpenAI(BaseAI):
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def prepare_messages(
|
||||||
|
self, event, messages: List[Dict[str, str]], system_message=None
|
||||||
|
) -> List[Any]:
|
||||||
|
chat_messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, (RoomMessageNotice, RoomMessageText)):
|
||||||
|
role = (
|
||||||
|
"assistant"
|
||||||
|
if message.sender == self.bot.matrix_client.user_id
|
||||||
|
else "user"
|
||||||
|
)
|
||||||
|
if message == event or (not message.event_id == event.event_id):
|
||||||
|
message_body = (
|
||||||
|
message.body
|
||||||
|
if not self.supports_chat_images()
|
||||||
|
else [{"type": "text", "text": message.body}]
|
||||||
|
)
|
||||||
|
chat_messages.append({"role": role, "content": message_body})
|
||||||
|
|
||||||
|
elif isinstance(message, RoomMessageAudio) or (
|
||||||
|
isinstance(message, RoomMessageFile) and message.body.endswith(".mp3")
|
||||||
|
):
|
||||||
|
role = (
|
||||||
|
"assistant"
|
||||||
|
if message.sender == self.bot.matrix_client.user_id
|
||||||
|
else "user"
|
||||||
|
)
|
||||||
|
if message == event or (not message.event_id == event.event_id):
|
||||||
|
if self.room_uses_stt(event.room_id):
|
||||||
|
try:
|
||||||
|
download = await self.bot.download_file(
|
||||||
|
message.url, raise_error=True
|
||||||
|
)
|
||||||
|
message_text = await self.bot.stt_api.speech_to_text(
|
||||||
|
download.body
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log(
|
||||||
|
f"Error generating text from audio: {e}", "error"
|
||||||
|
)
|
||||||
|
message_text = message.body
|
||||||
|
else:
|
||||||
|
message_text = message.body
|
||||||
|
|
||||||
|
message_body = (
|
||||||
|
message_text
|
||||||
|
if not self.supports_chat_images()
|
||||||
|
else [{"type": "text", "text": message_text}]
|
||||||
|
)
|
||||||
|
chat_messages.append({"role": role, "content": message_body})
|
||||||
|
|
||||||
|
elif isinstance(message, RoomMessageFile):
|
||||||
|
try:
|
||||||
|
download = await self.bot.download_file(
|
||||||
|
message.url, raise_error=True
|
||||||
|
)
|
||||||
|
if download:
|
||||||
|
try:
|
||||||
|
text = download.body.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
text = None
|
||||||
|
|
||||||
|
if text:
|
||||||
|
role = (
|
||||||
|
"assistant"
|
||||||
|
if message.sender == self.bot.matrix_client.user_id
|
||||||
|
else "user"
|
||||||
|
)
|
||||||
|
if message == event or (
|
||||||
|
not message.event_id == event.event_id
|
||||||
|
):
|
||||||
|
message_body = (
|
||||||
|
text
|
||||||
|
if not self.supports_chat_images()
|
||||||
|
else [{"type": "text", "text": text}]
|
||||||
|
)
|
||||||
|
chat_messages.append(
|
||||||
|
{"role": role, "content": message_body}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log(f"Error generating text from file: {e}", "error")
|
||||||
|
message_body = (
|
||||||
|
message.body
|
||||||
|
if not self.supports_chat_images()
|
||||||
|
else [{"type": "text", "text": message.body}]
|
||||||
|
)
|
||||||
|
chat_messages.append({"role": "system", "content": message_body})
|
||||||
|
|
||||||
|
elif self.supports_chat_images() and isinstance(message, RoomMessageImage):
|
||||||
|
try:
|
||||||
|
image_url = message.url
|
||||||
|
download = await self.bot.download_file(image_url, raise_error=True)
|
||||||
|
|
||||||
|
if download:
|
||||||
|
pil_image = Image.open(BytesIO(download.body))
|
||||||
|
|
||||||
|
file_format = pil_image.format or "PNG"
|
||||||
|
|
||||||
|
max_long_side = self.max_image_long_side
|
||||||
|
max_short_side = self.max_image_short_side
|
||||||
|
|
||||||
|
if max_long_side and max_short_side:
|
||||||
|
if pil_image.width > pil_image.height:
|
||||||
|
if pil_image.width > max_long_side:
|
||||||
|
pil_image.thumbnail((max_long_side, max_short_side))
|
||||||
|
|
||||||
|
else:
|
||||||
|
if pil_image.height > max_long_side:
|
||||||
|
pil_image.thumbnail((max_short_side, max_long_side))
|
||||||
|
|
||||||
|
bio = BytesIO()
|
||||||
|
|
||||||
|
pil_image.save(bio, format=file_format)
|
||||||
|
|
||||||
|
encoded_url = f"data:{download.content_type};base64,{base64.b64encode(bio.getvalue()).decode('utf-8')}"
|
||||||
|
parent = (
|
||||||
|
chat_messages[-1]
|
||||||
|
if chat_messages
|
||||||
|
and chat_messages[-1]["role"]
|
||||||
|
== (
|
||||||
|
"assistant"
|
||||||
|
if message.sender == self.bot.matrix_client.user_id
|
||||||
|
else "user"
|
||||||
|
)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not parent:
|
||||||
|
chat_messages.append(
|
||||||
|
{
|
||||||
|
"role": (
|
||||||
|
"assistant"
|
||||||
|
if message.sender == self.matrix_client.user_id
|
||||||
|
else "user"
|
||||||
|
),
|
||||||
|
"content": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
parent = chat_messages[-1]
|
||||||
|
|
||||||
|
parent["content"].append(
|
||||||
|
{"type": "image_url", "image_url": {"url": encoded_url}}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, DownloadException):
|
||||||
|
self.bot.send_message(
|
||||||
|
event.room_id,
|
||||||
|
f"Could not process image due to download error: {e.args[0]}",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.log(f"Error generating image from file: {e}", "error")
|
||||||
|
message_body = (
|
||||||
|
message.body
|
||||||
|
if not self.supports_chat_images()
|
||||||
|
else [{"type": "text", "text": message.body}]
|
||||||
|
)
|
||||||
|
chat_messages.append({"role": "system", "content": message_body})
|
||||||
|
|
||||||
|
elif self.supports_chat_videos() and (
|
||||||
|
isinstance(message, RoomMessageVideo)
|
||||||
|
or (
|
||||||
|
isinstance(message, RoomMessageFile)
|
||||||
|
and message.body.endswith(".mp4")
|
||||||
|
)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
video_url = message.url
|
||||||
|
download = await self.bot.download_file(video_url, raise_error=True)
|
||||||
|
|
||||||
|
if download:
|
||||||
|
video = BytesIO(download.body)
|
||||||
|
video_url = f"data:{download.content_type};base64,{base64.b64encode(video.getvalue()).decode('utf-8')}"
|
||||||
|
|
||||||
|
parent = (
|
||||||
|
chat_messages[-1]
|
||||||
|
if chat_messages
|
||||||
|
and chat_messages[-1]["role"]
|
||||||
|
== (
|
||||||
|
"assistant"
|
||||||
|
if message.sender == self.bot.matrix_client.user_id
|
||||||
|
else "user"
|
||||||
|
)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not parent:
|
||||||
|
chat_messages.append(
|
||||||
|
{
|
||||||
|
"role": (
|
||||||
|
"assistant"
|
||||||
|
if message.sender == self.matrix_client.user_id
|
||||||
|
else "user"
|
||||||
|
),
|
||||||
|
"content": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
parent = chat_messages[-1]
|
||||||
|
|
||||||
|
parent["content"].append(
|
||||||
|
{"type": "image_url", "image_url": {"url": video_url}}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, DownloadException):
|
||||||
|
self.bot.send_message(
|
||||||
|
event.room_id,
|
||||||
|
f"Could not process video due to download error: {e.args[0]}",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.log(f"Error generating video from file: {e}", "error")
|
||||||
|
message_body = (
|
||||||
|
message.body
|
||||||
|
if not self.supports_chat_images()
|
||||||
|
else [{"type": "text", "text": message.body}]
|
||||||
|
)
|
||||||
|
chat_messages.append({"role": "system", "content": message_body})
|
||||||
|
|
||||||
|
# Truncate messages to fit within the token limit
|
||||||
|
self._truncate(
|
||||||
|
messages=chat_messages,
|
||||||
|
max_tokens=self.chat_api.max_tokens - 1,
|
||||||
|
system_message=system_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
return chat_messages
|
||||||
|
|
||||||
|
def _truncate(
|
||||||
|
self,
|
||||||
|
messages: List[Any],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
system_message: Optional[str] = None,
|
||||||
|
) -> List[Any]:
|
||||||
|
"""Truncate messages to fit within the token limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (List[Any]): The messages to truncate.
|
||||||
|
max_tokens (Optional[int], optional): The maximum number of tokens to use. Defaults to None, which uses the default token limit.
|
||||||
|
model (Optional[str], optional): The model to use. Defaults to None, which uses the default chat model.
|
||||||
|
system_message (Optional[str], optional): The system message to use. Defaults to None, which uses the default system message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Any]: The truncated messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_tokens = max_tokens or self.max_tokens
|
||||||
|
model = model or self.chat_model
|
||||||
|
system_message = (
|
||||||
|
self.bot.default_system_message
|
||||||
|
if system_message is None
|
||||||
|
else system_message
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
except Exception:
|
||||||
|
# TODO: Handle this more gracefully
|
||||||
|
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
system_message_tokens = (
|
||||||
|
0 if not system_message else (len(encoding.encode(system_message)) + 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if system_message_tokens > max_tokens:
|
||||||
|
self.logger.log(
|
||||||
|
f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed",
|
||||||
|
"error",
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
total_tokens += system_message_tokens
|
||||||
|
|
||||||
|
total_tokens = len(system_message) + 1
|
||||||
|
truncated_messages = []
|
||||||
|
|
||||||
|
for message in [messages[0]] + list(reversed(messages[1:])):
|
||||||
|
content = (
|
||||||
|
message["content"]
|
||||||
|
if isinstance(message["content"], str)
|
||||||
|
else (
|
||||||
|
message["content"][0]["text"]
|
||||||
|
if isinstance(message["content"][0].get("text"), str)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tokens = len(encoding.encode(content)) + 1
|
||||||
|
if total_tokens + tokens > max_tokens:
|
||||||
|
break
|
||||||
|
total_tokens += tokens
|
||||||
|
truncated_messages.append(message)
|
||||||
|
|
||||||
|
return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
|
||||||
|
|
||||||
async def generate_chat_response(
|
async def generate_chat_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
|
|
Loading…
Reference in a new issue