fix(truncation): correct message handling and token calc

Updated message truncation logic to correctly return a system message dictionary and adjust token calculations. Improved model encoding fallback strategy to utilize "gpt-4o" instead of "gpt-3.5-turbo" for greater compatibility. This addresses message mishandling and ensures more robust operation.

Resolves the need for better error handling with encoding defaults.
This commit is contained in:
Kumi 2024-11-06 16:18:30 +01:00
parent 571031002c
commit 5fef1ab59c
Signed by: kumi
GPG key ID: ECBCC9082395383F

View file

@ -402,7 +402,7 @@ class OpenAI(BaseAI):
self.logger.log(f"Prepared messages: {chat_messages}", "debug") self.logger.log(f"Prepared messages: {chat_messages}", "debug")
# Truncate messages to fit within the token limit # Truncate messages to fit within the token limit
self._truncate( chat_messages = self._truncate(
messages=chat_messages, messages=chat_messages,
max_tokens=self.max_tokens - 1, max_tokens=self.max_tokens - 1,
system_message=system_message, system_message=system_message,
@ -441,7 +441,7 @@ class OpenAI(BaseAI):
encoding = tiktoken.encoding_for_model(model) encoding = tiktoken.encoding_for_model(model)
except Exception: except Exception:
# TODO: Handle this more gracefully # TODO: Handle this more gracefully
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") encoding = tiktoken.encoding_for_model("gpt-4o")
total_tokens = 0 total_tokens = 0
@ -458,7 +458,6 @@ class OpenAI(BaseAI):
total_tokens += system_message_tokens total_tokens += system_message_tokens
total_tokens = len(system_message) + 1
truncated_messages = [] truncated_messages = []
self.logger.log(f"Messages: {messages}", "debug") self.logger.log(f"Messages: {messages}", "debug")
@ -479,7 +478,20 @@ class OpenAI(BaseAI):
total_tokens += tokens total_tokens += tokens
truncated_messages.append(message) truncated_messages.append(message)
return [truncated_messages[0]] + list(reversed(truncated_messages[1:])) system_message_dict = {
"role": "system",
"content": (
system_message
if isinstance(messages[0]["content"], str)
else [{"type": "text", "text": system_message}]
),
}
return (
system_message_dict
+ [truncated_messages[0]]
+ list(reversed(truncated_messages[1:]))
)
async def generate_chat_response( async def generate_chat_response(
self, self,