feat: Enable calling tools in chat completion

This commit adds functionality to call tools within the chat completion model. By introducing the `call_tool()` method in the `GPTBot` class, tools can now be invoked with the appropriate tool call. The commit also includes the necessary changes in the `OpenAI` class to handle tool calls during response generation. Additionally, new tool classes for geocoding and dice rolling have been implemented. This enhancement aims to expand the capabilities of the bot by allowing users to leverage various tools directly within the chat conversation.
This commit is contained in:
Kumi 2023-11-28 18:15:21 +01:00
parent 155ea68e7a
commit 54dd80ed50
Signed by: kumi
GPG key ID: ECBCC9082395383F
11 changed files with 359 additions and 9 deletions

View file

@ -104,6 +104,14 @@ APIKey = sk-yoursecretkey
#
# BaseURL = https://openai.local/v1
# Whether to force the use of tools in the chat completion model
#
# Currently, only gpt-3.5-turbo supports tools. If you set this to 1, the bot
# will use that model for tools even if you have a different model set as the
# default. It will only generate the final result using the default model.
#
# ForceTools = 0
###############################################################################
[WolframAlpha]
@ -180,3 +188,12 @@ CryptoStore = store.db
# APIKey = __________________________
###############################################################################
[OpenWeatherMap]
# API key for OpenWeatherMap
# If not defined, the bot will be unable to provide weather information
#
# APIKey = __________________________
###############################################################################

View file

@ -7,7 +7,7 @@ allow-direct-references = true
[project]
name = "matrix-gptbot"
version = "0.2.1"
version = "0.2.2"
authors = [
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
@ -52,6 +52,8 @@ trackingmore = [
all = [
"matrix-gptbot[openai,wolframalpha,trackingmore]",
"geopy",
"beautifulsoup4",
]
dev = [

View file

@ -61,6 +61,7 @@ from .logging import Logger
from ..migrations import migrate
from ..callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
from ..commands import COMMANDS
from ..tools import TOOLS
from .openai import OpenAI
from .wolframalpha import WolframAlpha
from .trackingmore import TrackingMore
@ -92,6 +93,9 @@ class GPTBot:
logo: Optional[Image.Image] = None
logo_uri: Optional[str] = None
allowed_users: List[str] = []
config: ConfigParser = ConfigParser()
USER_AGENT = "matrix-gptbot/dev (+https://kumig.it/kumitterer/matrix-gptbot)"
@classmethod
def from_config(cls, config: ConfigParser):
@ -188,6 +192,7 @@ class GPTBot:
bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
# Return the new GPTBot instance
bot.config = config
return bot
async def _get_user_id(self) -> str:
@ -342,6 +347,31 @@ class GPTBot:
return device_id
async def call_tool(self, tool_call: dict):
"""Call a tool.
Args:
tool_call (dict): The tool call to make.
"""
tool = tool_call.function.name
args = json.loads(tool_call.function.arguments)
self.logger.log(f"Calling tool {tool} with args {args}", "debug")
try:
tool_class = TOOLS[tool]
result = await tool_class(**args, bot=self).run()
return result
except KeyError:
self.logger.log(f"Tool {tool} not found", "error")
return "Error: Tool not found"
except Exception as e:
self.logger.log(f"Error calling tool {tool}: {e}", "error")
return f"Error: Something went wrong calling tool {tool}"
async def process_command(self, room: MatrixRoom, event: RoomMessageText):
"""Process a command. Called from the event_callback() method.
Delegates to the appropriate command handler.

View file

@ -1,5 +1,6 @@
import openai
import requests
import tiktoken
import asyncio
import json
@ -12,6 +13,7 @@ from io import BytesIO
from pydub import AudioSegment
from .logging import Logger
from ..tools import TOOLS
ASSISTANT_CODE_INTERPRETER = [
{
@ -199,35 +201,101 @@ class OpenAI:
return result is not None
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None, room: Optional[str] = None) -> Tuple[str, int]:
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None, room: Optional[str] = None, allow_override: bool = True) -> Tuple[str, int]:
"""Generate a response to a chat message.
Args:
messages (List[Dict[str, str]]): A list of messages to use as context.
user (Optional[str], optional): The user to use the assistant for. Defaults to None.
room (Optional[str], optional): The room to use the assistant for. Defaults to None.
allow_override (bool, optional): Whether to allow the chat model to be overridden. Defaults to True.
Returns:
Tuple[str, int]: The response text and the number of tokens used.
"""
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
self.logger.log(f"Generating response to {len(messages)} messages...")
if await self.room_uses_assistant(room):
return await self.generate_assistant_response(messages, room, user)
tools = [
{
"type": "function",
"function": {
"name": tool_name,
"description": tool_class.DESCRIPTION,
"parameters": tool_class.PARAMETERS
}
}
for tool_name, tool_class in TOOLS.items()]
chat_model = self.chat_model
if allow_override and not "gpt-3.5-turbo" in self.chat_model:
if self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False):
self.logger.log(f"Overriding chat model to use tools")
chat_model = "gpt-3.5-turbo-1106"
self.logger.log(f"Generating response with model {chat_model}...")
kwargs = {
"model": chat_model,
"messages": messages,
"user": user,
}
if "gpt-3.5-turbo" in chat_model:
kwargs["tools"] = tools
if "gpt-4" in chat_model:
kwargs["max_tokens"] = self.bot.config.getint("OpenAI", "MaxTokens", fallback=4000)
chat_partial = partial(
self.openai_api.chat.completions.create,
model=self.chat_model,
messages=messages,
user=user,
max_tokens=4096
**kwargs
)
response = await self._request_with_retries(chat_partial)
result_text = response.choices[0].message.content
choice = response.choices[0]
result_text = choice.message.content
additional_tokens = 0
if (not result_text) and choice.message.tool_calls:
tool_responses = []
for tool_call in choice.message.tool_calls:
tool_response = await self.bot.call_tool(tool_call)
tool_responses.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": str(tool_response)
})
messages = messages + [choice.message] + tool_responses
result_text, additional_tokens = await self.generate_chat_response(messages, user, room)
elif not self.chat_model == chat_model:
new_messages = []
for message in messages:
new_message = message
if isinstance(message, dict):
if message["role"] == "tool":
new_message["role"] = "system"
del(new_message["tool_call_id"])
else:
continue
new_messages.append(new_message)
result_text, additional_tokens = await self.generate_chat_response(new_messages, user, room, False)
tokens_used = response.usage.total_tokens
self.logger.log(f"Generated response with {tokens_used} tokens.")
return result_text, tokens_used
return result_text, tokens_used + additional_tokens
async def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
system_message = """You are a classifier for different types of messages. You decide whether an incoming message is meant to be a prompt for an AI chat model, or meant for a different API. You respond with a JSON object like this:

View file

@ -0,0 +1,14 @@
from importlib import import_module
TOOLS = {}
for tool in [
"weather",
"geocode",
"dice",
"websearch",
"webrequest",
]:
tool_class = getattr(import_module(
"." + tool, "gptbot.tools"), tool.capitalize())
TOOLS[tool] = tool_class

10
src/gptbot/tools/base.py Normal file
View file

@ -0,0 +1,10 @@
class BaseTool:
DESCRIPTION: str
PARAMETERS: list
def __init__(self, **kwargs):
self.kwargs = kwargs
self.bot = kwargs["bot"]
async def run(self):
raise NotImplementedError()

26
src/gptbot/tools/dice.py Normal file
View file

@ -0,0 +1,26 @@
from .base import BaseTool
from random import SystemRandom
class Dice(BaseTool):
DESCRIPTION = "Roll dice."
PARAMETERS = {
"type": "object",
"properties": {
"dice": {
"type": "string",
"description": "The number of sides on the dice.",
"default": "6",
},
},
"required": [],
}
async def run(self):
"""Roll dice."""
dice = int(self.kwargs.get("dice", 6))
return f"""**Dice roll**
Used dice: {dice}
Result: {SystemRandom().randint(1, dice)}
"""

View file

@ -0,0 +1,34 @@
from geopy.geocoders import Nominatim
from .base import BaseTool
class Geocode(BaseTool):
DESCRIPTION = "Get location information (latitude, longitude) for a given location name."
PARAMETERS = {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location name.",
},
},
"required": ["location"],
}
async def run(self):
"""Get location information for a given location."""
if not (location := self.kwargs.get("location")):
raise Exception('No location provided.')
geolocator = Nominatim(user_agent=self.bot.USER_AGENT)
location = geolocator.geocode(location)
if location:
return f"""**Location information for {location.address}**
Latitude: {location.latitude}
Longitude: {location.longitude}
"""
raise Exception('Could not find location data for that location.')

View file

@ -0,0 +1,53 @@
import aiohttp
from datetime import datetime
from .base import BaseTool
class Weather(BaseTool):
DESCRIPTION = "Get weather information for a given location."
PARAMETERS = {
"type": "object",
"properties": {
"latitude": {
"type": "string",
"description": "The latitude of the location.",
},
"longitude": {
"type": "string",
"description": "The longitude of the location.",
},
},
"required": ["latitude", "longitude"],
}
async def run(self):
"""Get weather information for a given location."""
if not (latitude := self.kwargs.get("latitude")) or not (longitude := self.kwargs.get("longitude")):
raise Exception('No location provided.')
weather_api_key = self.bot.config.get("OpenWeatherMap", "APIKey")
if not weather_api_key:
raise Exception('Weather API key not found.')
url = f'https://api.openweathermap.org/data/3.0/onecall?lat={latitude}&lon={longitude}&appid={weather_api_key}&units=metric'
print(url)
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
data = await response.json()
return f"""**Weather report**
Current: {data['current']['temp']}°C, {data['current']['weather'][0]['description']}
Feels like: {data['current']['feels_like']}°C
Humidity: {data['current']['humidity']}%
Wind: {data['current']['wind_speed']}m/s
Sunrise: {datetime.fromtimestamp(data['current']['sunrise']).strftime('%H:%M')}
Sunset: {datetime.fromtimestamp(data['current']['sunset']).strftime('%H:%M')}
Today: {data['daily'][0]['temp']['day']}°C, {data['daily'][0]['weather'][0]['description']}, {data['daily'][0]['summary']}
Tomorrow: {data['daily'][1]['temp']['day']}°C, {data['daily'][1]['weather'][0]['description']}, {data['daily'][1]['summary']}
"""
else:
raise Exception(f'Could not connect to weather API: {response.status} {response.reason}')

View file

@ -0,0 +1,59 @@
from .base import BaseTool
import aiohttp
from bs4 import BeautifulSoup
import re
class Webrequest(BaseTool):
DESCRIPTION = "Browse an external website by URL."
PARAMETERS = {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to request.",
},
},
"required": ["url"],
}
async def html_to_text(self, html):
# Parse the HTML content of the response
soup = BeautifulSoup(html, 'html.parser')
# Format the links within the text
for link in soup.find_all('a'):
link_text = link.get_text()
link_href = link.get('href')
new_link_text = f"{link_text} ({link_href})"
link.replace_with(new_link_text)
# Extract the plain text content of the website
plain_text_content = soup.get_text()
# Remove extra whitespace
plain_text_content = re.sub('\s+', ' ', plain_text_content).strip()
# Return the formatted text content of the website
return plain_text_content
async def run(self):
"""Make a web request to a given URL."""
if not (url := self.kwargs.get("url")):
raise Exception('No URL provided.')
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
data = await response.text()
output = await self.html_to_text(data)
return f"""**Web request**
URL: {url}
Status: {response.status} {response.reason}
{output}
"""

View file

@ -0,0 +1,37 @@
from .base import BaseTool
import aiohttp
from urllib.parse import quote_plus
class Websearch(BaseTool):
DESCRIPTION = "Search the web for a given query."
PARAMETERS = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for.",
},
},
"required": ["query"],
}
async def run(self):
"""Search the web for a given query."""
if not (query := self.kwargs.get("query")):
raise Exception('No query provided.')
query = quote_plus(query)
url = f'https://librey.private.coffee/api.php?q={query}'
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
data = await response.json()
response_text = "**Search results for {query}**"
for result in data:
response_text += f"\n{result['title']}\n{result['url']}\n{result['description']}\n"
return response_text