Enhanced tool handling and image gen
Refactored `call_tool` to pass `room` and `user` for improved context during tool execution. Introduced `Handover` and `StopProcessing` exceptions to better control the flow when calling tools involves managing exceptions and handovers between tools and text generation. Enabled flexibility with `room` param in sending images and files, now accepting both `MatrixRoom` and `str` types. Updated `generate_chat_response` in OpenAI class to incorporate tool usage flag and more pruned message handling for tool responses. Introduced `orientation` option for image generation to specify landscape or portrait. Implemented two new tool classes, `Imagine` and `Imagedescription`, to streamline image creation and description processes accordingly. This improved error handling and additional granularity in tool invocation ensure that the bot behaves more predictably and transparently, particularly when interacting with generative AI and handling dialogue. The flexibility in both response and image generation caters to a wider range of user inputs and scenarios, ultimately enhancing the bot's user experience.
This commit is contained in:
parent
54dd80ed50
commit
4d64593e89
6 changed files with 167 additions and 45 deletions
|
@ -56,12 +56,13 @@ import json
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import sys
|
import sys
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import traceback
|
||||||
|
|
||||||
from .logging import Logger
|
from .logging import Logger
|
||||||
from ..migrations import migrate
|
from ..migrations import migrate
|
||||||
from ..callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
|
from ..callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
|
||||||
from ..commands import COMMANDS
|
from ..commands import COMMANDS
|
||||||
from ..tools import TOOLS
|
from ..tools import TOOLS, Handover, StopProcessing
|
||||||
from .openai import OpenAI
|
from .openai import OpenAI
|
||||||
from .wolframalpha import WolframAlpha
|
from .wolframalpha import WolframAlpha
|
||||||
from .trackingmore import TrackingMore
|
from .trackingmore import TrackingMore
|
||||||
|
@ -347,11 +348,13 @@ class GPTBot:
|
||||||
|
|
||||||
return device_id
|
return device_id
|
||||||
|
|
||||||
async def call_tool(self, tool_call: dict):
|
async def call_tool(self, tool_call: dict, room: str, user: str, **kwargs):
|
||||||
"""Call a tool.
|
"""Call a tool.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_call (dict): The tool call to make.
|
tool_call (dict): The tool call to make.
|
||||||
|
room (str): The room to call the tool in.
|
||||||
|
user (str): The user to call the tool as.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tool = tool_call.function.name
|
tool = tool_call.function.name
|
||||||
|
@ -361,15 +364,19 @@ class GPTBot:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_class = TOOLS[tool]
|
tool_class = TOOLS[tool]
|
||||||
result = await tool_class(**args, bot=self).run()
|
result = await tool_class(**args, room=room, bot=self, user=user).run()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except KeyError:
|
except (Handover, StopProcessing):
|
||||||
|
raise
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
self.logger.log(f"Tool {tool} not found", "error")
|
self.logger.log(f"Tool {tool} not found", "error")
|
||||||
return "Error: Tool not found"
|
return "Error: Tool not found"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.log(f"Error calling tool {tool}: {e}", "error")
|
self.logger.log(f"Error calling tool {tool}: {e}", "error")
|
||||||
|
traceback.print_exc()
|
||||||
return f"Error: Something went wrong calling tool {tool}"
|
return f"Error: Something went wrong calling tool {tool}"
|
||||||
|
|
||||||
async def process_command(self, room: MatrixRoom, event: RoomMessageText):
|
async def process_command(self, room: MatrixRoom, event: RoomMessageText):
|
||||||
|
@ -568,13 +575,16 @@ class GPTBot:
|
||||||
"""Send an image to a room.
|
"""Send an image to a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom): The room to send the image to.
|
room (MatrixRoom|str): The room to send the image to.
|
||||||
image (bytes): The image to send.
|
image (bytes): The image to send.
|
||||||
message (str, optional): The message to send with the image. Defaults to None.
|
message (str, optional): The message to send with the image. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(room, MatrixRoom):
|
||||||
|
room = room.room_id
|
||||||
|
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
f"Sending image of size {len(image)} bytes to room {room.room_id}", "debug"
|
f"Sending image of size {len(image)} bytes to room {room}", "debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
bio = BytesIO(image)
|
bio = BytesIO(image)
|
||||||
|
@ -605,7 +615,7 @@ class GPTBot:
|
||||||
}
|
}
|
||||||
|
|
||||||
status = await self.matrix_client.room_send(
|
status = await self.matrix_client.room_send(
|
||||||
room.room_id, "m.room.message", content
|
room, "m.room.message", content
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.log("Sent image", "debug")
|
self.logger.log("Sent image", "debug")
|
||||||
|
@ -616,14 +626,17 @@ class GPTBot:
|
||||||
"""Send a file to a room.
|
"""Send a file to a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom): The room to send the file to.
|
room (MatrixRoom|str): The room to send the file to.
|
||||||
file (bytes): The file to send.
|
file (bytes): The file to send.
|
||||||
filename (str): The name of the file.
|
filename (str): The name of the file.
|
||||||
mime (str): The MIME type of the file.
|
mime (str): The MIME type of the file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(room, MatrixRoom):
|
||||||
|
room = room.room_id
|
||||||
|
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
f"Sending file of size {len(file)} bytes to room {room.room_id}", "debug"
|
f"Sending file of size {len(file)} bytes to room {room}", "debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
content_uri = await self.upload_file(file, filename, mime)
|
content_uri = await self.upload_file(file, filename, mime)
|
||||||
|
@ -638,7 +651,7 @@ class GPTBot:
|
||||||
}
|
}
|
||||||
|
|
||||||
status = await self.matrix_client.room_send(
|
status = await self.matrix_client.room_send(
|
||||||
room.room_id, "m.room.message", content
|
room, "m.room.message", content
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.log("Sent file", "debug")
|
self.logger.log("Sent file", "debug")
|
||||||
|
@ -1108,7 +1121,7 @@ class GPTBot:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, tokens_used = await self.chat_api.generate_chat_response(
|
response, tokens_used = await self.chat_api.generate_chat_response(
|
||||||
chat_messages, user=room.room_id
|
chat_messages, user=room.room_id, room=room.room_id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.log(f"Error generating response: {e}", "error")
|
self.logger.log(f"Error generating response: {e}", "error")
|
||||||
|
@ -1144,13 +1157,6 @@ class GPTBot:
|
||||||
|
|
||||||
message = await self.send_message(room, response)
|
message = await self.send_message(room, response)
|
||||||
|
|
||||||
else:
|
|
||||||
# Send a notice to the room if there was an error
|
|
||||||
self.logger.log("Didn't get a response from GPT API", "error")
|
|
||||||
await self.send_message(
|
|
||||||
room, "Something went wrong. Please try again.", True
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.matrix_client.room_typing(room.room_id, False)
|
await self.matrix_client.room_typing(room.room_id, False)
|
||||||
|
|
||||||
def download_file(self, mxc) -> Optional[bytes]:
|
def download_file(self, mxc) -> Optional[bytes]:
|
||||||
|
|
|
@ -13,7 +13,7 @@ from io import BytesIO
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
|
||||||
from .logging import Logger
|
from .logging import Logger
|
||||||
from ..tools import TOOLS
|
from ..tools import TOOLS, Handover, StopProcessing
|
||||||
|
|
||||||
ASSISTANT_CODE_INTERPRETER = [
|
ASSISTANT_CODE_INTERPRETER = [
|
||||||
{
|
{
|
||||||
|
@ -201,7 +201,7 @@ class OpenAI:
|
||||||
|
|
||||||
return result is not None
|
return result is not None
|
||||||
|
|
||||||
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None, room: Optional[str] = None, allow_override: bool = True) -> Tuple[str, int]:
|
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None, room: Optional[str] = None, allow_override: bool = True, use_tools: bool = True) -> Tuple[str, int]:
|
||||||
"""Generate a response to a chat message.
|
"""Generate a response to a chat message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -209,6 +209,7 @@ class OpenAI:
|
||||||
user (Optional[str], optional): The user to use the assistant for. Defaults to None.
|
user (Optional[str], optional): The user to use the assistant for. Defaults to None.
|
||||||
room (Optional[str], optional): The room to use the assistant for. Defaults to None.
|
room (Optional[str], optional): The room to use the assistant for. Defaults to None.
|
||||||
allow_override (bool, optional): Whether to allow the chat model to be overridden. Defaults to True.
|
allow_override (bool, optional): Whether to allow the chat model to be overridden. Defaults to True.
|
||||||
|
use_tools (bool, optional): Whether to use tools. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, int]: The response text and the number of tokens used.
|
Tuple[str, int]: The response text and the number of tokens used.
|
||||||
|
@ -230,12 +231,32 @@ class OpenAI:
|
||||||
for tool_name, tool_class in TOOLS.items()]
|
for tool_name, tool_class in TOOLS.items()]
|
||||||
|
|
||||||
chat_model = self.chat_model
|
chat_model = self.chat_model
|
||||||
|
original_messages = messages
|
||||||
|
|
||||||
if allow_override and not "gpt-3.5-turbo" in self.chat_model:
|
if allow_override and not "gpt-3.5-turbo" in self.chat_model:
|
||||||
if self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False):
|
if self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False):
|
||||||
self.logger.log(f"Overriding chat model to use tools")
|
self.logger.log(f"Overriding chat model to use tools")
|
||||||
chat_model = "gpt-3.5-turbo-1106"
|
chat_model = "gpt-3.5-turbo-1106"
|
||||||
|
|
||||||
|
out_messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, dict):
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
out_messages.append(message)
|
||||||
|
else:
|
||||||
|
message_content = []
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
message_content.append(content)
|
||||||
|
if message_content:
|
||||||
|
message["content"] = message_content
|
||||||
|
out_messages.append(message)
|
||||||
|
else:
|
||||||
|
out_messages.append(message)
|
||||||
|
|
||||||
|
messages = out_messages
|
||||||
|
|
||||||
self.logger.log(f"Generating response with model {chat_model}...")
|
self.logger.log(f"Generating response with model {chat_model}...")
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
@ -244,7 +265,7 @@ class OpenAI:
|
||||||
"user": user,
|
"user": user,
|
||||||
}
|
}
|
||||||
|
|
||||||
if "gpt-3.5-turbo" in chat_model:
|
if "gpt-3.5-turbo" in chat_model and use_tools:
|
||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
|
|
||||||
if "gpt-4" in chat_model:
|
if "gpt-4" in chat_model:
|
||||||
|
@ -264,21 +285,31 @@ class OpenAI:
|
||||||
if (not result_text) and choice.message.tool_calls:
|
if (not result_text) and choice.message.tool_calls:
|
||||||
tool_responses = []
|
tool_responses = []
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
tool_response = await self.bot.call_tool(tool_call)
|
try:
|
||||||
tool_responses.append({
|
tool_response = await self.bot.call_tool(tool_call, room=room, user=user)
|
||||||
"role": "tool",
|
if tool_response != False:
|
||||||
"tool_call_id": tool_call.id,
|
tool_responses.append({
|
||||||
"content": str(tool_response)
|
"role": "tool",
|
||||||
})
|
"tool_call_id": tool_call.id,
|
||||||
|
"content": str(tool_response)
|
||||||
|
})
|
||||||
|
except StopProcessing:
|
||||||
|
return False, 0
|
||||||
|
except Handover:
|
||||||
|
return await self.generate_chat_response(original_messages, user, room, allow_override=False, use_tools=False)
|
||||||
|
|
||||||
messages = messages + [choice.message] + tool_responses
|
if not tool_responses:
|
||||||
|
self.logger.log(f"No more responses received, aborting.")
|
||||||
|
result_text = False
|
||||||
|
else:
|
||||||
|
messages = original_messages + [choice.message] + tool_responses
|
||||||
|
|
||||||
result_text, additional_tokens = await self.generate_chat_response(messages, user, room)
|
result_text, additional_tokens = await self.generate_chat_response(messages, user, room)
|
||||||
|
|
||||||
elif not self.chat_model == chat_model:
|
elif not self.chat_model == chat_model:
|
||||||
new_messages = []
|
new_messages = []
|
||||||
|
|
||||||
for message in messages:
|
for message in original_messages:
|
||||||
new_message = message
|
new_message = message
|
||||||
|
|
||||||
if isinstance(message, dict):
|
if isinstance(message, dict):
|
||||||
|
@ -291,9 +322,13 @@ class OpenAI:
|
||||||
|
|
||||||
new_messages.append(new_message)
|
new_messages.append(new_message)
|
||||||
|
|
||||||
result_text, additional_tokens = await self.generate_chat_response(new_messages, user, room, False)
|
result_text, additional_tokens = await self.generate_chat_response(new_messages, user, room, allow_override=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokens_used = response.usage.total_tokens
|
||||||
|
except:
|
||||||
|
tokens_used = 0
|
||||||
|
|
||||||
tokens_used = response.usage.total_tokens
|
|
||||||
self.logger.log(f"Generated response with {tokens_used} tokens.")
|
self.logger.log(f"Generated response with {tokens_used} tokens.")
|
||||||
return result_text, tokens_used + additional_tokens
|
return result_text, tokens_used + additional_tokens
|
||||||
|
|
||||||
|
@ -384,11 +419,13 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
async def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]:
|
async def generate_image(self, prompt: str, user: Optional[str] = None, orientation: str = "square") -> Generator[bytes, None, None]:
|
||||||
"""Generate an image from a prompt.
|
"""Generate an image from a prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (str): The prompt to use.
|
prompt (str): The prompt to use.
|
||||||
|
user (Optional[str], optional): The user to use the assistant for. Defaults to None.
|
||||||
|
orientation (str, optional): The orientation of the image. Defaults to "square".
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
bytes: The image data.
|
bytes: The image data.
|
||||||
|
@ -396,27 +433,34 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
self.logger.log(f"Generating image from prompt '{prompt}'...")
|
self.logger.log(f"Generating image from prompt '{prompt}'...")
|
||||||
|
|
||||||
split_prompt = prompt.split()
|
split_prompt = prompt.split()
|
||||||
|
delete_first = False
|
||||||
|
|
||||||
size = "1024x1024"
|
size = "1024x1024"
|
||||||
|
|
||||||
if self.image_model == "dall-e-3":
|
if self.image_model == "dall-e-3":
|
||||||
if split_prompt[0] == "--portrait":
|
if orientation == "portrait" or (delete_first := split_prompt[0] == "--portrait"):
|
||||||
size = "1024x1792"
|
size = "1024x1792"
|
||||||
prompt = " ".join(split_prompt[1:])
|
elif orientation == "landscape" or (delete_first := split_prompt[0] == "--landscape"):
|
||||||
elif split_prompt[0] == "--landscape":
|
|
||||||
size = "1792x1024"
|
size = "1792x1024"
|
||||||
prompt = " ".join(split_prompt[1:])
|
|
||||||
|
if delete_first:
|
||||||
|
prompt = " ".join(split_prompt[1:])
|
||||||
|
|
||||||
self.logger.log(f"Generating image with size {size} using model {self.image_model}...")
|
self.logger.log(f"Generating image with size {size} using model {self.image_model}...")
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"model": self.image_model,
|
||||||
|
"quality": "standard" if self.image_model != "dall-e-3" else "hd",
|
||||||
|
"prompt": prompt,
|
||||||
|
"n": 1,
|
||||||
|
"size": size,
|
||||||
|
}
|
||||||
|
|
||||||
|
if user:
|
||||||
|
args["user"] = user
|
||||||
|
|
||||||
image_partial = partial(
|
image_partial = partial(
|
||||||
self.openai_api.images.generate,
|
self.openai_api.images.generate, **args
|
||||||
model=self.image_model,
|
|
||||||
quality="standard" if self.image_model != "dall-e-3" else "hd",
|
|
||||||
prompt=prompt,
|
|
||||||
n=1,
|
|
||||||
size=size,
|
|
||||||
user=user,
|
|
||||||
)
|
)
|
||||||
response = await self._request_with_retries(image_partial)
|
response = await self._request_with_retries(image_partial)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
|
from .base import BaseTool, StopProcessing, Handover
|
||||||
|
|
||||||
TOOLS = {}
|
TOOLS = {}
|
||||||
|
|
||||||
for tool in [
|
for tool in [
|
||||||
|
@ -8,6 +10,8 @@ for tool in [
|
||||||
"dice",
|
"dice",
|
||||||
"websearch",
|
"websearch",
|
||||||
"webrequest",
|
"webrequest",
|
||||||
|
"imagine",
|
||||||
|
"imagedescription",
|
||||||
]:
|
]:
|
||||||
tool_class = getattr(import_module(
|
tool_class = getattr(import_module(
|
||||||
"." + tool, "gptbot.tools"), tool.capitalize())
|
"." + tool, "gptbot.tools"), tool.capitalize())
|
||||||
|
|
|
@ -1,10 +1,20 @@
|
||||||
class BaseTool:
|
class BaseTool:
|
||||||
DESCRIPTION: str
|
DESCRIPTION: str
|
||||||
PARAMETERS: list
|
PARAMETERS: dict
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.bot = kwargs["bot"]
|
self.bot = kwargs["bot"]
|
||||||
|
self.room = kwargs["room"]
|
||||||
|
self.user = kwargs["user"]
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class StopProcessing(Exception):
|
||||||
|
"""Stop processing the message."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Handover(Exception):
|
||||||
|
"""Handover to the original model, if applicable. Stop using tools."""
|
||||||
|
pass
|
24
src/gptbot/tools/imagedescription.py
Normal file
24
src/gptbot/tools/imagedescription.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
from .base import BaseTool, Handover
|
||||||
|
|
||||||
|
class Imagedescription(BaseTool):
|
||||||
|
DESCRIPTION = "Describe the content of an image."
|
||||||
|
PARAMETERS = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"image": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The image to describe.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["image"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Describe an image.
|
||||||
|
|
||||||
|
This tool only hands over to the original model, if applicable.
|
||||||
|
It is intended to handle the case where GPT-3 thinks it is asked to
|
||||||
|
*generate* an image, but the user actually wants to *describe* an
|
||||||
|
image...
|
||||||
|
"""
|
||||||
|
raise Handover()
|
34
src/gptbot/tools/imagine.py
Normal file
34
src/gptbot/tools/imagine.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
from .base import BaseTool, StopProcessing
|
||||||
|
|
||||||
|
class Imagine(BaseTool):
|
||||||
|
DESCRIPTION = "Use generative AI to create images from text prompts."
|
||||||
|
PARAMETERS = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The prompt to use.",
|
||||||
|
},
|
||||||
|
"orientation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The orientation of the image.",
|
||||||
|
"enum": ["square", "landscape", "portrait"],
|
||||||
|
"default": "square",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["prompt"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Use generative AI to create images from text prompts."""
|
||||||
|
if not (prompt := self.kwargs.get("prompt")):
|
||||||
|
raise Exception('No prompt provided.')
|
||||||
|
|
||||||
|
api = self.bot.image_api
|
||||||
|
orientation = self.kwargs.get("orientation", "square")
|
||||||
|
images, tokens = await api.generate_image(prompt, self.room, orientation=orientation)
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
await self.bot.send_image(self.room, image, prompt)
|
||||||
|
|
||||||
|
raise StopProcessing()
|
Loading…
Reference in a new issue