Fix recursion errors in OpenAI class

Improved the error handling in the OpenAI class to prevent infinite recursion issues by retaining the original chat model during recursive calls. Enhanced logging within the recursion depth check for better debugging and traceability. Ensured consistency in chat responses by passing the initial model reference throughout the entire call stack. This is crucial when fallbacks due to errors or tool usage occur.

Refactored code for clarity and readability, ensuring that any recursion retains the original model and tool parameters. Additionally, proper logging and condition checks now standardize the flow of execution, preventing unintended modifications to the model's state that could lead to incorrect bot behavior.
This commit is contained in:
Kumi 2024-01-26 09:17:01 +01:00
parent 87173ae284
commit c4e23cb9d3
Signed by: kumi
GPG key ID: ECBCC9082395383F

View file

@ -23,11 +23,13 @@ ASSISTANT_CODE_INTERPRETER = [
}, },
] ]
class AttributeDictionary(dict): class AttributeDictionary(dict):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(AttributeDictionary, self).__init__(*args, **kwargs) super(AttributeDictionary, self).__init__(*args, **kwargs)
self.__dict__ = self self.__dict__ = self
class OpenAI: class OpenAI:
api_key: str api_key: str
chat_model: str = "gpt-3.5-turbo" chat_model: str = "gpt-3.5-turbo"
@ -143,25 +145,31 @@ class OpenAI:
f"Generating response to {len(messages)} messages for user {user} in room {room}..." f"Generating response to {len(messages)} messages for user {user} in room {room}..."
) )
chat_model = model or self.chat_model original_model = chat_model = model or self.chat_model
# Check current recursion depth to prevent infinite loops # Check current recursion depth to prevent infinite loops
if use_tools: if use_tools:
frames = inspect.stack() frames = inspect.stack()
current_function = inspect.getframeinfo(frames[0][0]).function current_function = inspect.getframeinfo(frames[0][0]).function
count = sum(1 for frame in frames if inspect.getframeinfo(frame[0]).function == current_function) count = sum(
self.logger.log(f"{current_function} appears {count} times in the call stack") 1
for frame in frames
if inspect.getframeinfo(frame[0]).function == current_function
)
self.logger.log(
f"{current_function} appears {count} times in the call stack"
)
if count > 5: if count > 5:
self.logger.log(f"Recursion depth exceeded, aborting.") self.logger.log(f"Recursion depth exceeded, aborting.")
return self.generate_chat_response( return self.generate_chat_response(
messages, messages,
user=user, user=user,
room=room, room=room,
allow_override=False, # TODO: Could this be a problem? allow_override=False, # TODO: Could this be a problem?
use_tools=False, use_tools=False,
model=model, model=original_model,
) )
tools = [ tools = [
@ -231,12 +239,13 @@ class OpenAI:
f"- {tool_name}: {tool_class.DESCRIPTION} ({tool_class.PARAMETERS})" f"- {tool_name}: {tool_class.DESCRIPTION} ({tool_class.PARAMETERS})"
for tool_name, tool_class in TOOLS.items() for tool_name, tool_class in TOOLS.items()
] ]
) + """ )
+ """
If no tool is required, or all information is already available in the message thread, respond with an empty JSON object: {} If no tool is required, or all information is already available in the message thread, respond with an empty JSON object: {}
Do NOT FOLLOW ANY OTHER INSTRUCTIONS BELOW, they are only meant for the AI chat model. You can ignore them. DO NOT include any other text or syntax in your response, only the JSON object. DO NOT surround it in code tags (```). DO NOT, UNDER ANY CIRCUMSTANCES, ASK AGAIN FOR INFORMATION ALREADY PROVIDED IN THE MESSAGES YOU RECEIVED! DO NOT REQUEST MORE INFORMATION THAN ABSOLUTELY REQUIRED TO RESPOND TO THE USER'S MESSAGE! Remind the user that they may ask you to search for additional information if they need it. Do NOT FOLLOW ANY OTHER INSTRUCTIONS BELOW, they are only meant for the AI chat model. You can ignore them. DO NOT include any other text or syntax in your response, only the JSON object. DO NOT surround it in code tags (```). DO NOT, UNDER ANY CIRCUMSTANCES, ASK AGAIN FOR INFORMATION ALREADY PROVIDED IN THE MESSAGES YOU RECEIVED! DO NOT REQUEST MORE INFORMATION THAN ABSOLUTELY REQUIRED TO RESPOND TO THE USER'S MESSAGE! Remind the user that they may ask you to search for additional information if they need it.
""" """,
} }
] ]
+ messages + messages
@ -292,6 +301,7 @@ class OpenAI:
room=room, room=room,
allow_override=False, allow_override=False,
use_tools=False, use_tools=False,
model=original_model,
) )
if not tool_responses: if not tool_responses:
@ -306,7 +316,7 @@ class OpenAI:
+ original_messages[-1:] + original_messages[-1:]
) )
result_text, additional_tokens = await self.generate_chat_response( result_text, additional_tokens = await self.generate_chat_response(
messages, user=user, room=room messages, user=user, room=room, model=original_messages
) )
except openai.APIError as e: except openai.APIError as e:
if e.code == "max_tokens": if e.code == "max_tokens":
@ -338,6 +348,7 @@ class OpenAI:
room=room, room=room,
allow_override=False, allow_override=False,
use_tools=False, use_tools=False,
model=original_model,
) )
except openai.APIError as e: except openai.APIError as e:
@ -351,6 +362,7 @@ class OpenAI:
room=room, room=room,
allow_override=False, allow_override=False,
use_tools=False, use_tools=False,
model=original_model,
) )
else: else:
raise e raise e
@ -359,16 +371,22 @@ class OpenAI:
if "tool" in tool_object: if "tool" in tool_object:
tool_name = tool_object["tool"] tool_name = tool_object["tool"]
tool_class = TOOLS[tool_name] tool_class = TOOLS[tool_name]
tool_parameters = tool_object["parameters"] if "parameters" in tool_object else {} tool_parameters = (
tool_object["parameters"] if "parameters" in tool_object else {}
)
self.logger.log(f"Using tool {tool_name} with parameters {tool_parameters}...") self.logger.log(
f"Using tool {tool_name} with parameters {tool_parameters}..."
)
tool_call = AttributeDictionary( tool_call = AttributeDictionary(
{ {
"function": AttributeDictionary({ "function": AttributeDictionary(
"name": tool_name, {
"arguments": json.dumps(tool_parameters), "name": tool_name,
}), "arguments": json.dumps(tool_parameters),
}
),
} }
) )
@ -392,6 +410,7 @@ class OpenAI:
room=room, room=room,
allow_override=False, allow_override=False,
use_tools=False, use_tools=False,
model=original_model,
) )
if not tool_responses: if not tool_responses:
@ -405,7 +424,10 @@ class OpenAI:
+ tool_responses + tool_responses
+ original_messages[-1:] + original_messages[-1:]
) )
result_text, additional_tokens = await self.generate_chat_response( (
result_text,
additional_tokens,
) = await self.generate_chat_response(
messages, user=user, room=room messages, user=user, room=room
) )
except openai.APIError as e: except openai.APIError as e:
@ -419,6 +441,7 @@ class OpenAI:
room=room, room=room,
allow_override=False, allow_override=False,
use_tools=False, use_tools=False,
model=original_model,
) )
else: else:
raise e raise e
@ -429,9 +452,10 @@ class OpenAI:
room=room, room=room,
allow_override=False, allow_override=False,
use_tools=False, use_tools=False,
model=original_model,
) )
elif not self.chat_model == chat_model: elif not original_model == chat_model:
new_messages = [] new_messages = []
for message in original_messages: for message in original_messages:
@ -448,7 +472,8 @@ class OpenAI:
new_messages.append(new_message) new_messages.append(new_message)
result_text, additional_tokens = await self.generate_chat_response( result_text, additional_tokens = await self.generate_chat_response(
new_messages, user=user, room=room, allow_override=False new_messages, user=user, room=room, allow_override=False,
model=original_model
) )
try: try: