From 5fef1ab59ce0a09c194ab74b34d0d52e8793df19 Mon Sep 17 00:00:00 2001 From: Kumi Date: Wed, 6 Nov 2024 16:18:30 +0100 Subject: [PATCH] fix(truncation): correct message handling and token calc Updated message truncation logic to correctly return a system message dictionary and adjust token calculations. Improved model encoding fallback strategy to utilize "gpt-4o" instead of "gpt-3.5-turbo" for greater compatibility. This addresses message mishandling and ensures more robust operation. Resolves the need for better error handling with encoding defaults. --- src/gptbot/classes/ai/openai.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/gptbot/classes/ai/openai.py b/src/gptbot/classes/ai/openai.py index 2aabbb0..777d399 100644 --- a/src/gptbot/classes/ai/openai.py +++ b/src/gptbot/classes/ai/openai.py @@ -402,7 +402,7 @@ class OpenAI(BaseAI): self.logger.log(f"Prepared messages: {chat_messages}", "debug") # Truncate messages to fit within the token limit - self._truncate( + chat_messages = self._truncate( messages=chat_messages, max_tokens=self.max_tokens - 1, system_message=system_message, @@ -441,7 +441,7 @@ class OpenAI(BaseAI): encoding = tiktoken.encoding_for_model(model) except Exception: # TODO: Handle this more gracefully - encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + encoding = tiktoken.encoding_for_model("gpt-4o") total_tokens = 0 @@ -458,7 +458,6 @@ class OpenAI(BaseAI): total_tokens += system_message_tokens - total_tokens = len(system_message) + 1 truncated_messages = [] self.logger.log(f"Messages: {messages}", "debug") @@ -479,7 +478,20 @@ class OpenAI(BaseAI): total_tokens += tokens truncated_messages.append(message) - return [truncated_messages[0]] + list(reversed(truncated_messages[1:])) + system_message_dict = { + "role": "system", + "content": ( + system_message + if isinstance(messages[0]["content"], str) + else [{"type": "text", "text": system_message}] + ), + } + + return ( + system_message_dict + + [truncated_messages[0]] + + list(reversed(truncated_messages[1:])) + ) async def generate_chat_response( self,