feat: Enable calling tools in chat completion
This commit adds functionality to call tools within the chat completion model. By introducing the `call_tool()` method in the `GPTBot` class, tools can now be invoked with the appropriate tool call. The commit also includes the necessary changes in the `OpenAI` class to handle tool calls during response generation. Additionally, new tool classes for geocoding and dice rolling have been implemented. This enhancement aims to expand the capabilities of the bot by allowing users to leverage various tools directly within the chat conversation.
This commit is contained in:
parent
155ea68e7a
commit
54dd80ed50
11 changed files with 359 additions and 9 deletions
|
@ -104,6 +104,14 @@ APIKey = sk-yoursecretkey
|
|||
#
|
||||
# BaseURL = https://openai.local/v1
|
||||
|
||||
# Whether to force the use of tools in the chat completion model
|
||||
#
|
||||
# Currently, only gpt-3.5-turbo supports tools. If you set this to 1, the bot
|
||||
# will use that model for tools even if you have a different model set as the
|
||||
# default. It will only generate the final result using the default model.
|
||||
#
|
||||
# ForceTools = 0
|
||||
|
||||
###############################################################################
|
||||
|
||||
[WolframAlpha]
|
||||
|
@ -180,3 +188,12 @@ CryptoStore = store.db
|
|||
# APIKey = __________________________
|
||||
|
||||
###############################################################################
|
||||
|
||||
[OpenWeatherMap]
|
||||
|
||||
# API key for OpenWeatherMap
|
||||
# If not defined, the bot will be unable to provide weather information
|
||||
#
|
||||
# APIKey = __________________________
|
||||
|
||||
###############################################################################
|
|
@ -7,7 +7,7 @@ allow-direct-references = true
|
|||
|
||||
[project]
|
||||
name = "matrix-gptbot"
|
||||
version = "0.2.1"
|
||||
version = "0.2.2"
|
||||
|
||||
authors = [
|
||||
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
|
||||
|
@ -52,6 +52,8 @@ trackingmore = [
|
|||
|
||||
all = [
|
||||
"matrix-gptbot[openai,wolframalpha,trackingmore]",
|
||||
"geopy",
|
||||
"beautifulsoup4",
|
||||
]
|
||||
|
||||
dev = [
|
||||
|
|
|
@ -61,6 +61,7 @@ from .logging import Logger
|
|||
from ..migrations import migrate
|
||||
from ..callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
|
||||
from ..commands import COMMANDS
|
||||
from ..tools import TOOLS
|
||||
from .openai import OpenAI
|
||||
from .wolframalpha import WolframAlpha
|
||||
from .trackingmore import TrackingMore
|
||||
|
@ -92,6 +93,9 @@ class GPTBot:
|
|||
logo: Optional[Image.Image] = None
|
||||
logo_uri: Optional[str] = None
|
||||
allowed_users: List[str] = []
|
||||
config: ConfigParser = ConfigParser()
|
||||
|
||||
USER_AGENT = "matrix-gptbot/dev (+https://kumig.it/kumitterer/matrix-gptbot)"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigParser):
|
||||
|
@ -188,6 +192,7 @@ class GPTBot:
|
|||
bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
|
||||
|
||||
# Return the new GPTBot instance
|
||||
bot.config = config
|
||||
return bot
|
||||
|
||||
async def _get_user_id(self) -> str:
|
||||
|
@ -342,6 +347,31 @@ class GPTBot:
|
|||
|
||||
return device_id
|
||||
|
||||
async def call_tool(self, tool_call: dict):
|
||||
"""Call a tool.
|
||||
|
||||
Args:
|
||||
tool_call (dict): The tool call to make.
|
||||
"""
|
||||
|
||||
tool = tool_call.function.name
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
|
||||
self.logger.log(f"Calling tool {tool} with args {args}", "debug")
|
||||
|
||||
try:
|
||||
tool_class = TOOLS[tool]
|
||||
result = await tool_class(**args, bot=self).run()
|
||||
return result
|
||||
|
||||
except KeyError:
|
||||
self.logger.log(f"Tool {tool} not found", "error")
|
||||
return "Error: Tool not found"
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"Error calling tool {tool}: {e}", "error")
|
||||
return f"Error: Something went wrong calling tool {tool}"
|
||||
|
||||
async def process_command(self, room: MatrixRoom, event: RoomMessageText):
|
||||
"""Process a command. Called from the event_callback() method.
|
||||
Delegates to the appropriate command handler.
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import openai
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
@ -12,6 +13,7 @@ from io import BytesIO
|
|||
from pydub import AudioSegment
|
||||
|
||||
from .logging import Logger
|
||||
from ..tools import TOOLS
|
||||
|
||||
ASSISTANT_CODE_INTERPRETER = [
|
||||
{
|
||||
|
@ -199,35 +201,101 @@ class OpenAI:
|
|||
|
||||
return result is not None
|
||||
|
||||
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None, room: Optional[str] = None) -> Tuple[str, int]:
|
||||
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None, room: Optional[str] = None, allow_override: bool = True) -> Tuple[str, int]:
|
||||
"""Generate a response to a chat message.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, str]]): A list of messages to use as context.
|
||||
user (Optional[str], optional): The user to use the assistant for. Defaults to None.
|
||||
room (Optional[str], optional): The room to use the assistant for. Defaults to None.
|
||||
allow_override (bool, optional): Whether to allow the chat model to be overridden. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tuple[str, int]: The response text and the number of tokens used.
|
||||
"""
|
||||
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
|
||||
self.logger.log(f"Generating response to {len(messages)} messages...")
|
||||
|
||||
if await self.room_uses_assistant(room):
|
||||
return await self.generate_assistant_response(messages, room, user)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": tool_class.DESCRIPTION,
|
||||
"parameters": tool_class.PARAMETERS
|
||||
}
|
||||
}
|
||||
for tool_name, tool_class in TOOLS.items()]
|
||||
|
||||
chat_model = self.chat_model
|
||||
|
||||
if allow_override and not "gpt-3.5-turbo" in self.chat_model:
|
||||
if self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False):
|
||||
self.logger.log(f"Overriding chat model to use tools")
|
||||
chat_model = "gpt-3.5-turbo-1106"
|
||||
|
||||
self.logger.log(f"Generating response with model {chat_model}...")
|
||||
|
||||
kwargs = {
|
||||
"model": chat_model,
|
||||
"messages": messages,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
if "gpt-3.5-turbo" in chat_model:
|
||||
kwargs["tools"] = tools
|
||||
|
||||
if "gpt-4" in chat_model:
|
||||
kwargs["max_tokens"] = self.bot.config.getint("OpenAI", "MaxTokens", fallback=4000)
|
||||
|
||||
chat_partial = partial(
|
||||
self.openai_api.chat.completions.create,
|
||||
model=self.chat_model,
|
||||
messages=messages,
|
||||
user=user,
|
||||
max_tokens=4096
|
||||
**kwargs
|
||||
)
|
||||
response = await self._request_with_retries(chat_partial)
|
||||
|
||||
result_text = response.choices[0].message.content
|
||||
choice = response.choices[0]
|
||||
result_text = choice.message.content
|
||||
|
||||
additional_tokens = 0
|
||||
|
||||
if (not result_text) and choice.message.tool_calls:
|
||||
tool_responses = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_response = await self.bot.call_tool(tool_call)
|
||||
tool_responses.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": str(tool_response)
|
||||
})
|
||||
|
||||
messages = messages + [choice.message] + tool_responses
|
||||
|
||||
result_text, additional_tokens = await self.generate_chat_response(messages, user, room)
|
||||
|
||||
elif not self.chat_model == chat_model:
|
||||
new_messages = []
|
||||
|
||||
for message in messages:
|
||||
new_message = message
|
||||
|
||||
if isinstance(message, dict):
|
||||
if message["role"] == "tool":
|
||||
new_message["role"] = "system"
|
||||
del(new_message["tool_call_id"])
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
new_messages.append(new_message)
|
||||
|
||||
result_text, additional_tokens = await self.generate_chat_response(new_messages, user, room, False)
|
||||
|
||||
tokens_used = response.usage.total_tokens
|
||||
self.logger.log(f"Generated response with {tokens_used} tokens.")
|
||||
return result_text, tokens_used
|
||||
return result_text, tokens_used + additional_tokens
|
||||
|
||||
async def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
|
||||
system_message = """You are a classifier for different types of messages. You decide whether an incoming message is meant to be a prompt for an AI chat model, or meant for a different API. You respond with a JSON object like this:
|
||||
|
|
14
src/gptbot/tools/__init__.py
Normal file
14
src/gptbot/tools/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from importlib import import_module
|
||||
|
||||
TOOLS = {}
|
||||
|
||||
for tool in [
|
||||
"weather",
|
||||
"geocode",
|
||||
"dice",
|
||||
"websearch",
|
||||
"webrequest",
|
||||
]:
|
||||
tool_class = getattr(import_module(
|
||||
"." + tool, "gptbot.tools"), tool.capitalize())
|
||||
TOOLS[tool] = tool_class
|
10
src/gptbot/tools/base.py
Normal file
10
src/gptbot/tools/base.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
class BaseTool:
|
||||
DESCRIPTION: str
|
||||
PARAMETERS: list
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.bot = kwargs["bot"]
|
||||
|
||||
async def run(self):
|
||||
raise NotImplementedError()
|
26
src/gptbot/tools/dice.py
Normal file
26
src/gptbot/tools/dice.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from .base import BaseTool
|
||||
|
||||
from random import SystemRandom
|
||||
|
||||
class Dice(BaseTool):
|
||||
DESCRIPTION = "Roll dice."
|
||||
PARAMETERS = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dice": {
|
||||
"type": "string",
|
||||
"description": "The number of sides on the dice.",
|
||||
"default": "6",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
async def run(self):
|
||||
"""Roll dice."""
|
||||
dice = int(self.kwargs.get("dice", 6))
|
||||
|
||||
return f"""**Dice roll**
|
||||
Used dice: {dice}
|
||||
Result: {SystemRandom().randint(1, dice)}
|
||||
"""
|
34
src/gptbot/tools/geocode.py
Normal file
34
src/gptbot/tools/geocode.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
from geopy.geocoders import Nominatim
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
class Geocode(BaseTool):
|
||||
DESCRIPTION = "Get location information (latitude, longitude) for a given location name."
|
||||
PARAMETERS = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The location name.",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
}
|
||||
|
||||
|
||||
async def run(self):
|
||||
"""Get location information for a given location."""
|
||||
if not (location := self.kwargs.get("location")):
|
||||
raise Exception('No location provided.')
|
||||
|
||||
geolocator = Nominatim(user_agent=self.bot.USER_AGENT)
|
||||
|
||||
location = geolocator.geocode(location)
|
||||
|
||||
if location:
|
||||
return f"""**Location information for {location.address}**
|
||||
Latitude: {location.latitude}
|
||||
Longitude: {location.longitude}
|
||||
"""
|
||||
|
||||
raise Exception('Could not find location data for that location.')
|
53
src/gptbot/tools/weather.py
Normal file
53
src/gptbot/tools/weather.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import aiohttp
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
class Weather(BaseTool):
|
||||
DESCRIPTION = "Get weather information for a given location."
|
||||
PARAMETERS = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {
|
||||
"type": "string",
|
||||
"description": "The latitude of the location.",
|
||||
},
|
||||
"longitude": {
|
||||
"type": "string",
|
||||
"description": "The longitude of the location.",
|
||||
},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
}
|
||||
|
||||
async def run(self):
|
||||
"""Get weather information for a given location."""
|
||||
if not (latitude := self.kwargs.get("latitude")) or not (longitude := self.kwargs.get("longitude")):
|
||||
raise Exception('No location provided.')
|
||||
|
||||
weather_api_key = self.bot.config.get("OpenWeatherMap", "APIKey")
|
||||
|
||||
if not weather_api_key:
|
||||
raise Exception('Weather API key not found.')
|
||||
|
||||
url = f'https://api.openweathermap.org/data/3.0/onecall?lat={latitude}&lon={longitude}&appid={weather_api_key}&units=metric'
|
||||
print(url)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return f"""**Weather report**
|
||||
Current: {data['current']['temp']}°C, {data['current']['weather'][0]['description']}
|
||||
Feels like: {data['current']['feels_like']}°C
|
||||
Humidity: {data['current']['humidity']}%
|
||||
Wind: {data['current']['wind_speed']}m/s
|
||||
Sunrise: {datetime.fromtimestamp(data['current']['sunrise']).strftime('%H:%M')}
|
||||
Sunset: {datetime.fromtimestamp(data['current']['sunset']).strftime('%H:%M')}
|
||||
|
||||
Today: {data['daily'][0]['temp']['day']}°C, {data['daily'][0]['weather'][0]['description']}, {data['daily'][0]['summary']}
|
||||
Tomorrow: {data['daily'][1]['temp']['day']}°C, {data['daily'][1]['weather'][0]['description']}, {data['daily'][1]['summary']}
|
||||
"""
|
||||
else:
|
||||
raise Exception(f'Could not connect to weather API: {response.status} {response.reason}')
|
59
src/gptbot/tools/webrequest.py
Normal file
59
src/gptbot/tools/webrequest.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from .base import BaseTool
|
||||
|
||||
import aiohttp
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
import re
|
||||
|
||||
class Webrequest(BaseTool):
|
||||
DESCRIPTION = "Browse an external website by URL."
|
||||
PARAMETERS = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to request.",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
}
|
||||
|
||||
async def html_to_text(self, html):
|
||||
# Parse the HTML content of the response
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
# Format the links within the text
|
||||
for link in soup.find_all('a'):
|
||||
link_text = link.get_text()
|
||||
link_href = link.get('href')
|
||||
new_link_text = f"{link_text} ({link_href})"
|
||||
link.replace_with(new_link_text)
|
||||
|
||||
# Extract the plain text content of the website
|
||||
plain_text_content = soup.get_text()
|
||||
|
||||
# Remove extra whitespace
|
||||
plain_text_content = re.sub('\s+', ' ', plain_text_content).strip()
|
||||
|
||||
# Return the formatted text content of the website
|
||||
return plain_text_content
|
||||
|
||||
async def run(self):
|
||||
"""Make a web request to a given URL."""
|
||||
if not (url := self.kwargs.get("url")):
|
||||
raise Exception('No URL provided.')
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
data = await response.text()
|
||||
|
||||
output = await self.html_to_text(data)
|
||||
|
||||
return f"""**Web request**
|
||||
URL: {url}
|
||||
Status: {response.status} {response.reason}
|
||||
|
||||
{output}
|
||||
"""
|
37
src/gptbot/tools/websearch.py
Normal file
37
src/gptbot/tools/websearch.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from .base import BaseTool
|
||||
|
||||
import aiohttp
|
||||
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
class Websearch(BaseTool):
|
||||
DESCRIPTION = "Search the web for a given query."
|
||||
PARAMETERS = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def run(self):
|
||||
"""Search the web for a given query."""
|
||||
if not (query := self.kwargs.get("query")):
|
||||
raise Exception('No query provided.')
|
||||
|
||||
query = quote_plus(query)
|
||||
|
||||
url = f'https://librey.private.coffee/api.php?q={query}'
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
response_text = "**Search results for {query}**"
|
||||
for result in data:
|
||||
response_text += f"\n{result['title']}\n{result['url']}\n{result['description']}\n"
|
||||
|
||||
return response_text
|
Loading…
Reference in a new issue