fix(truncation): correct message handling and token calc
Updated message truncation logic to correctly return a system message dictionary and adjust token calculations. Improved model encoding fallback strategy to utilize "gpt-4o" instead of "gpt-3.5-turbo" for greater compatibility. This addresses message mishandling and ensures more robust operation. Resolves the need for better error handling with encoding defaults.
This commit is contained in:
parent
571031002c
commit
5fef1ab59c
1 changed files with 16 additions and 4 deletions
|
@ -402,7 +402,7 @@ class OpenAI(BaseAI):
|
||||||
self.logger.log(f"Prepared messages: {chat_messages}", "debug")
|
self.logger.log(f"Prepared messages: {chat_messages}", "debug")
|
||||||
|
|
||||||
# Truncate messages to fit within the token limit
|
# Truncate messages to fit within the token limit
|
||||||
self._truncate(
|
chat_messages = self._truncate(
|
||||||
messages=chat_messages,
|
messages=chat_messages,
|
||||||
max_tokens=self.max_tokens - 1,
|
max_tokens=self.max_tokens - 1,
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
|
@ -441,7 +441,7 @@ class OpenAI(BaseAI):
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
except Exception:
|
except Exception:
|
||||||
# TODO: Handle this more gracefully
|
# TODO: Handle this more gracefully
|
||||||
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
encoding = tiktoken.encoding_for_model("gpt-4o")
|
||||||
|
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
|
@ -458,7 +458,6 @@ class OpenAI(BaseAI):
|
||||||
|
|
||||||
total_tokens += system_message_tokens
|
total_tokens += system_message_tokens
|
||||||
|
|
||||||
total_tokens = len(system_message) + 1
|
|
||||||
truncated_messages = []
|
truncated_messages = []
|
||||||
|
|
||||||
self.logger.log(f"Messages: {messages}", "debug")
|
self.logger.log(f"Messages: {messages}", "debug")
|
||||||
|
@ -479,7 +478,20 @@ class OpenAI(BaseAI):
|
||||||
total_tokens += tokens
|
total_tokens += tokens
|
||||||
truncated_messages.append(message)
|
truncated_messages.append(message)
|
||||||
|
|
||||||
return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
|
system_message_dict = {
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
system_message
|
||||||
|
if isinstance(messages[0]["content"], str)
|
||||||
|
else [{"type": "text", "text": system_message}]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
system_message_dict
|
||||||
|
+ [truncated_messages[0]]
|
||||||
|
+ list(reversed(truncated_messages[1:]))
|
||||||
|
)
|
||||||
|
|
||||||
async def generate_chat_response(
|
async def generate_chat_response(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in a new issue