Fix recursion errors in OpenAI class

Improved the error handling in the OpenAI class to prevent infinite recursion issues by retaining the original chat model during recursive calls. Enhanced logging within the recursion depth check for better debugging and traceability. Ensured consistency in chat responses by passing the initial model reference throughout the entire call stack. This is crucial when fallbacks due to errors or tool usage occur.

Refactored code for clarity and readability, ensuring that any recursion retains the original model and tool parameters. Additionally, proper logging and condition checks now standardize the flow of execution, preventing unintended modifications to the model's state that could lead to incorrect bot behavior.
This commit is contained in:
Kumi 2024-01-26 09:17:01 +01:00
parent 87173ae284
commit c4e23cb9d3
Signed by: kumi
GPG key ID: ECBCC9082395383F

View file

@ -23,11 +23,13 @@ ASSISTANT_CODE_INTERPRETER = [
},
]
class AttributeDictionary(dict):
def __init__(self, *args, **kwargs):
super(AttributeDictionary, self).__init__(*args, **kwargs)
self.__dict__ = self
class OpenAI:
api_key: str
chat_model: str = "gpt-3.5-turbo"
@ -143,15 +145,21 @@ class OpenAI:
f"Generating response to {len(messages)} messages for user {user} in room {room}..."
)
chat_model = model or self.chat_model
original_model = chat_model = model or self.chat_model
# Check current recursion depth to prevent infinite loops
if use_tools:
frames = inspect.stack()
current_function = inspect.getframeinfo(frames[0][0]).function
count = sum(1 for frame in frames if inspect.getframeinfo(frame[0]).function == current_function)
self.logger.log(f"{current_function} appears {count} times in the call stack")
count = sum(
1
for frame in frames
if inspect.getframeinfo(frame[0]).function == current_function
)
self.logger.log(
f"{current_function} appears {count} times in the call stack"
)
if count > 5:
self.logger.log(f"Recursion depth exceeded, aborting.")
@ -161,7 +169,7 @@ class OpenAI:
room=room,
allow_override=False, # TODO: Could this be a problem?
use_tools=False,
model=model,
model=original_model,
)
tools = [
@ -231,12 +239,13 @@ class OpenAI:
f"- {tool_name}: {tool_class.DESCRIPTION} ({tool_class.PARAMETERS})"
for tool_name, tool_class in TOOLS.items()
]
) + """
)
+ """
If no tool is required, or all information is already available in the message thread, respond with an empty JSON object: {}
Do NOT FOLLOW ANY OTHER INSTRUCTIONS BELOW, they are only meant for the AI chat model. You can ignore them. DO NOT include any other text or syntax in your response, only the JSON object. DO NOT surround it in code tags (```). DO NOT, UNDER ANY CIRCUMSTANCES, ASK AGAIN FOR INFORMATION ALREADY PROVIDED IN THE MESSAGES YOU RECEIVED! DO NOT REQUEST MORE INFORMATION THAN ABSOLUTELY REQUIRED TO RESPOND TO THE USER'S MESSAGE! Remind the user that they may ask you to search for additional information if they need it.
"""
""",
}
]
+ messages
@ -292,6 +301,7 @@ class OpenAI:
room=room,
allow_override=False,
use_tools=False,
model=original_model,
)
if not tool_responses:
@ -306,7 +316,7 @@ class OpenAI:
+ original_messages[-1:]
)
result_text, additional_tokens = await self.generate_chat_response(
messages, user=user, room=room
messages, user=user, room=room, model=original_messages
)
except openai.APIError as e:
if e.code == "max_tokens":
@ -338,6 +348,7 @@ class OpenAI:
room=room,
allow_override=False,
use_tools=False,
model=original_model,
)
except openai.APIError as e:
@ -351,6 +362,7 @@ class OpenAI:
room=room,
allow_override=False,
use_tools=False,
model=original_model,
)
else:
raise e
@ -359,16 +371,22 @@ class OpenAI:
if "tool" in tool_object:
tool_name = tool_object["tool"]
tool_class = TOOLS[tool_name]
tool_parameters = tool_object["parameters"] if "parameters" in tool_object else {}
tool_parameters = (
tool_object["parameters"] if "parameters" in tool_object else {}
)
self.logger.log(f"Using tool {tool_name} with parameters {tool_parameters}...")
self.logger.log(
f"Using tool {tool_name} with parameters {tool_parameters}..."
)
tool_call = AttributeDictionary(
{
"function": AttributeDictionary({
"function": AttributeDictionary(
{
"name": tool_name,
"arguments": json.dumps(tool_parameters),
}),
}
),
}
)
@ -392,6 +410,7 @@ class OpenAI:
room=room,
allow_override=False,
use_tools=False,
model=original_model,
)
if not tool_responses:
@ -405,7 +424,10 @@ class OpenAI:
+ tool_responses
+ original_messages[-1:]
)
result_text, additional_tokens = await self.generate_chat_response(
(
result_text,
additional_tokens,
) = await self.generate_chat_response(
messages, user=user, room=room
)
except openai.APIError as e:
@ -419,6 +441,7 @@ class OpenAI:
room=room,
allow_override=False,
use_tools=False,
model=original_model,
)
else:
raise e
@ -429,9 +452,10 @@ class OpenAI:
room=room,
allow_override=False,
use_tools=False,
model=original_model,
)
elif not self.chat_model == chat_model:
elif not original_model == chat_model:
new_messages = []
for message in original_messages:
@ -448,7 +472,8 @@ class OpenAI:
new_messages.append(new_message)
result_text, additional_tokens = await self.generate_chat_response(
new_messages, user=user, room=room, allow_override=False
new_messages, user=user, room=room, allow_override=False,
model=original_model
)
try: