feat: enhance AI integration & update models
Refactored the handling of AI providers to support multiple AI services efficiently, introducing a `BaseAI` class from which all AI providers now inherit. This change modernizes our approach to AI integration, providing a more flexible and maintainable architecture for future expansions and enhancements. - Adopted `gpt-4o` and `dall-e-3` as the default models for chat and image generation, respectively, aligning with the latest advancements in AI capabilities. - Integrated `ruff` as a development dependency to enforce coding standards and improve code quality through consistent linting. - Removed unused API keys and sections from `config.dist.ini` to streamline configuration management and clarify setup processes for new users. - Updated the command line tool for improved usability and fixed previous issues preventing its effective operation. - Enhanced OpenAI integration with advanced settings for temperature, top_p, frequency_penalty, and presence_penalty, enabling finer control over AI-generated outputs. This comprehensive update not only enhances the bot's performance and usability but also lays the groundwork for incorporating additional AI providers, ensuring the project remains at the forefront of AI-driven chatbot technologies. Resolves #13
This commit is contained in:
parent
02887b9336
commit
8e0cffe02a
8 changed files with 213 additions and 158 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -7,4 +7,5 @@ venv/
|
|||
__pycache__/
|
||||
*.bak
|
||||
dist/
|
||||
pantalaimon.conf
|
||||
pantalaimon.conf
|
||||
.ruff_cache/
|
12
CHANGELOG.md
12
CHANGELOG.md
|
@ -1,5 +1,17 @@
|
|||
# Changelog
|
||||
|
||||
### 0.3.11 (2024-05-17)
|
||||
|
||||
- Refactoring of AI provider handling in preparation for multiple AI providers: Introduced a `BaseAI` class that all AI providers must inherit from
|
||||
- Added support for temperature, top_p, frequency_penalty, and presence_penalty in `AllowedUsers`
|
||||
- Introduced ruff as a development dependency for linting and applied some linting fixes
|
||||
- Fixed `gptbot` command line tool
|
||||
- Changed default chat model to `gpt-4o`
|
||||
- Changed default image generation model to `dall-e-3`
|
||||
- Removed currently unused sections from `config.dist.ini`
|
||||
- Changed provided Pantalaimon config file to not use a key ring by default
|
||||
- Prevent bot from crashing when an unneeded dependency is missing
|
||||
|
||||
### 0.3.10 (2024-05-16)
|
||||
|
||||
- Add support for specifying room IDs in `AllowedUsers`
|
||||
|
|
|
@ -63,14 +63,11 @@ LogLevel = info
|
|||
|
||||
# The Chat Completion model you want to use.
|
||||
#
|
||||
# Unless you are in the GPT-4 beta (if you don't know - you aren't),
|
||||
# leave this as the default value (gpt-3.5-turbo)
|
||||
#
|
||||
# Model = gpt-3.5-turbo
|
||||
# Model = gpt-4o
|
||||
|
||||
# The Image Generation model you want to use.
|
||||
#
|
||||
# ImageModel = dall-e-2
|
||||
# ImageModel = dall-e-3
|
||||
|
||||
# Your OpenAI API key
|
||||
#
|
||||
|
@ -106,7 +103,7 @@ APIKey = sk-yoursecretkey
|
|||
# Setting this allows you to use a self-hosted AI model for chat completions
|
||||
# using something like llama-cpp-python or ollama
|
||||
#
|
||||
# BaseURL = https://openai.local/v1
|
||||
# BaseURL = https://api.openai.com/v1/
|
||||
|
||||
# Whether to force the use of tools in the chat completion model
|
||||
#
|
||||
|
@ -124,6 +121,22 @@ APIKey = sk-yoursecretkey
|
|||
#
|
||||
# EmulateTools = 0
|
||||
|
||||
# Advanced settings for the OpenAI API
|
||||
#
|
||||
# These settings are not required for normal operation, but can be used to
|
||||
# tweak the behavior of the bot.
|
||||
#
|
||||
# Note: These settings are not validated by the bot, so make sure they are
|
||||
# correct before setting them, or the bot may not work as expected.
|
||||
#
|
||||
# For more information, see the OpenAI documentation:
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
#
|
||||
# Temperature = 1
|
||||
# TopP = 1
|
||||
# FrequencyPenalty = 0
|
||||
# PresencePenalty = 0
|
||||
|
||||
###############################################################################
|
||||
|
||||
[WolframAlpha]
|
||||
|
@ -185,26 +198,6 @@ Path = database.db
|
|||
|
||||
###############################################################################
|
||||
|
||||
[Replicate]
|
||||
|
||||
# API key for replicate.com
|
||||
# Can be used to run lots of different AI models
|
||||
# If not defined, the features that depend on it are not available
|
||||
#
|
||||
# APIKey = r8_alotoflettersandnumbershere
|
||||
|
||||
###############################################################################
|
||||
|
||||
[HuggingFace]
|
||||
|
||||
# API key for Hugging Face
|
||||
# Can be used to run lots of different AI models
|
||||
# If not defined, the features that depend on it are not available
|
||||
#
|
||||
# APIKey = __________________________
|
||||
|
||||
###############################################################################
|
||||
|
||||
[OpenWeatherMap]
|
||||
|
||||
# API key for OpenWeatherMap
|
||||
|
|
|
@ -7,7 +7,7 @@ allow-direct-references = true
|
|||
|
||||
[project]
|
||||
name = "matrix-gptbot"
|
||||
version = "0.3.10"
|
||||
version = "0.3.11"
|
||||
|
||||
authors = [
|
||||
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
|
||||
|
@ -63,6 +63,7 @@ dev = [
|
|||
"hatchling",
|
||||
"twine",
|
||||
"build",
|
||||
"ruff",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
@ -71,7 +72,7 @@ dev = [
|
|||
"Source Code" = "https://git.private.coffee/privatecoffee/matrix-gptbot"
|
||||
|
||||
[project.scripts]
|
||||
gptbot = "gptbot.__main__:main"
|
||||
gptbot = "gptbot.__main__:main_sync"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/gptbot"]
|
0
src/gptbot/classes/ai/__init__.py
Normal file
0
src/gptbot/classes/ai/__init__.py
Normal file
52
src/gptbot/classes/ai/base.py
Normal file
52
src/gptbot/classes/ai/base.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
from ...classes.logging import Logger
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Any, AsyncGenerator, Dict, Optional, Mapping
|
||||
|
||||
|
||||
class AttributeDictionary(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttributeDictionary, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
class BaseAI:
|
||||
bot: Any
|
||||
logger: Logger
|
||||
|
||||
def __init__(self, bot, config: Mapping, logger: Optional[Logger] = None):
|
||||
self.bot = bot
|
||||
self.logger = logger or bot.logger or Logger()
|
||||
self._config = config
|
||||
|
||||
@property
|
||||
def chat_api(self) -> str:
|
||||
return self.chat_model
|
||||
|
||||
async def _request_with_retries(
|
||||
self, request: partial, attempts: int = 5, retry_interval: int = 2
|
||||
) -> AsyncGenerator[Any | list | Dict, None]:
|
||||
"""Retry a request a set number of times if it fails.
|
||||
|
||||
Args:
|
||||
request (partial): The request to make with retries.
|
||||
attempts (int, optional): The number of attempts to make. Defaults to 5.
|
||||
retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Any | list | Dict, None]: The response for the request.
|
||||
"""
|
||||
current_attempt = 1
|
||||
while current_attempt <= attempts:
|
||||
try:
|
||||
response = await request()
|
||||
return response
|
||||
except Exception as e:
|
||||
self.logger.log(f"Request failed: {e}", "error")
|
||||
self.logger.log(f"Retrying in {retry_interval} seconds...")
|
||||
await asyncio.sleep(retry_interval)
|
||||
current_attempt += 1
|
||||
|
||||
raise Exception("Request failed after all attempts.")
|
|
@ -1,21 +1,18 @@
|
|||
import openai
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import inspect
|
||||
|
||||
from functools import partial
|
||||
from contextlib import closing
|
||||
from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any
|
||||
from typing import Dict, List, Tuple, Generator, Optional, Mapping
|
||||
from io import BytesIO
|
||||
|
||||
from pydub import AudioSegment
|
||||
|
||||
from .logging import Logger
|
||||
from ..tools import TOOLS, Handover, StopProcessing
|
||||
from ..logging import Logger
|
||||
from ...tools import TOOLS, Handover, StopProcessing
|
||||
from .base import BaseAI, AttributeDictionary
|
||||
|
||||
ASSISTANT_CODE_INTERPRETER = [
|
||||
{
|
||||
|
@ -24,58 +21,81 @@ ASSISTANT_CODE_INTERPRETER = [
|
|||
]
|
||||
|
||||
|
||||
class AttributeDictionary(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttributeDictionary, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
class OpenAI:
|
||||
api_key: str
|
||||
chat_model: str = "gpt-3.5-turbo"
|
||||
logger: Logger
|
||||
|
||||
class OpenAI(BaseAI):
|
||||
api_code: str = "openai"
|
||||
|
||||
@property
|
||||
def chat_api(self) -> str:
|
||||
return self.chat_model
|
||||
|
||||
classification_api = chat_api
|
||||
image_model: str = "dall-e-2"
|
||||
tts_model: str = "tts-1-hd"
|
||||
tts_voice: str = "alloy"
|
||||
stt_model: str = "whisper-1"
|
||||
openai_api: openai.AsyncOpenAI
|
||||
|
||||
operator: str = "OpenAI ([https://openai.com](https://openai.com))"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot,
|
||||
api_key,
|
||||
chat_model=None,
|
||||
image_model=None,
|
||||
tts_model=None,
|
||||
tts_voice=None,
|
||||
stt_model=None,
|
||||
base_url=None,
|
||||
logger=None,
|
||||
config: Mapping,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
self.bot = bot
|
||||
self.api_key = api_key
|
||||
self.chat_model = chat_model or self.chat_model
|
||||
self.image_model = image_model or self.image_model
|
||||
self.logger = logger or bot.logger or Logger()
|
||||
self.base_url = base_url or openai.base_url
|
||||
super().__init__(bot, config, logger)
|
||||
self.openai_api = openai.AsyncOpenAI(
|
||||
api_key=self.api_key, base_url=self.base_url
|
||||
)
|
||||
self.tts_model = tts_model or self.tts_model
|
||||
self.tts_voice = tts_voice or self.tts_voice
|
||||
self.stt_model = stt_model or self.stt_model
|
||||
|
||||
# TODO: Add descriptions for these properties
|
||||
|
||||
@property
|
||||
def api_key(self):
|
||||
return self._config["APIKey"]
|
||||
|
||||
@property
|
||||
def chat_model(self):
|
||||
return self._config.get("Model", fallback="gpt-4o")
|
||||
|
||||
@property
|
||||
def image_model(self):
|
||||
return self._config.get("ImageModel", fallback="dall-e-3")
|
||||
|
||||
@property
|
||||
def tts_model(self):
|
||||
return self._config.get("TTSModel", fallback="tts-1-hd")
|
||||
|
||||
@property
|
||||
def tts_voice(self):
|
||||
return self._config.get("TTSVoice", fallback="alloy")
|
||||
|
||||
@property
|
||||
def stt_model(self):
|
||||
return self._config.get("STTModel", fallback="whisper-1")
|
||||
|
||||
@property
|
||||
def base_url(self):
|
||||
return self._config.get("BaseURL", fallback="https://api.openai.com/v1/")
|
||||
|
||||
@property
|
||||
def temperature(self):
|
||||
return self._config.getfloat("Temperature", fallback=1.0)
|
||||
|
||||
@property
|
||||
def top_p(self):
|
||||
return self._config.getfloat("TopP", fallback=1.0)
|
||||
|
||||
@property
|
||||
def frequency_penalty(self):
|
||||
return self._config.getfloat("FrequencyPenalty", fallback=0.0)
|
||||
|
||||
@property
|
||||
def presence_penalty(self):
|
||||
return self._config.getfloat("PresencePenalty", fallback=0.0)
|
||||
|
||||
@property
|
||||
def max_tokens(self):
|
||||
# TODO: This should be model-specific
|
||||
return self._config.getint("MaxTokens", fallback=4000)
|
||||
|
||||
def supports_chat_images(self):
|
||||
return "vision" in self.chat_model
|
||||
return "vision" in self.chat_model or self.chat_model in ("gpt-4o",)
|
||||
|
||||
def json_decode(self, data):
|
||||
if data.startswith("```json\n"):
|
||||
|
@ -88,37 +108,9 @@ class OpenAI:
|
|||
|
||||
try:
|
||||
return json.loads(data)
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _request_with_retries(
|
||||
self, request: partial, attempts: int = 5, retry_interval: int = 2
|
||||
) -> AsyncGenerator[Any | list | Dict, None]:
|
||||
"""Retry a request a set number of times if it fails.
|
||||
|
||||
Args:
|
||||
request (partial): The request to make with retries.
|
||||
attempts (int, optional): The number of attempts to make. Defaults to 5.
|
||||
retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request.
|
||||
"""
|
||||
# call the request function and return the response if it succeeds, else retry
|
||||
current_attempt = 1
|
||||
while current_attempt <= attempts:
|
||||
try:
|
||||
response = await request()
|
||||
return response
|
||||
except Exception as e:
|
||||
self.logger.log(f"Request failed: {e}", "error")
|
||||
self.logger.log(f"Retrying in {retry_interval} seconds...")
|
||||
await asyncio.sleep(retry_interval)
|
||||
current_attempt += 1
|
||||
|
||||
# if all attempts failed, raise an exception
|
||||
raise Exception("Request failed after all attempts.")
|
||||
|
||||
async def generate_chat_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
|
@ -162,7 +154,7 @@ class OpenAI:
|
|||
)
|
||||
|
||||
if count > 5:
|
||||
self.logger.log(f"Recursion depth exceeded, aborting.")
|
||||
self.logger.log("Recursion depth exceeded, aborting.")
|
||||
return await self.generate_chat_response(
|
||||
messages=messages,
|
||||
user=user,
|
||||
|
@ -186,9 +178,10 @@ class OpenAI:
|
|||
|
||||
original_messages = messages
|
||||
|
||||
if allow_override and not "gpt-3.5-turbo" in original_model:
|
||||
# TODO: I believe more models support tools now, so this could be adapted
|
||||
if allow_override and "gpt-3.5-turbo" not in original_model:
|
||||
if self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False):
|
||||
self.logger.log(f"Overriding chat model to use tools")
|
||||
self.logger.log("Overriding chat model to use tools")
|
||||
chat_model = "gpt-3.5-turbo"
|
||||
|
||||
out_messages = []
|
||||
|
@ -216,7 +209,7 @@ class OpenAI:
|
|||
use_tools
|
||||
and self.bot.config.getboolean("OpenAI", "EmulateTools", fallback=False)
|
||||
and not self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False)
|
||||
and not "gpt-3.5-turbo" in chat_model
|
||||
and "gpt-3.5-turbo" not in chat_model
|
||||
):
|
||||
self.bot.logger.log("Using tool emulation mode.", "debug")
|
||||
|
||||
|
@ -257,15 +250,17 @@ class OpenAI:
|
|||
"model": chat_model,
|
||||
"messages": messages,
|
||||
"user": room,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
}
|
||||
|
||||
if "gpt-3.5-turbo" in chat_model and use_tools:
|
||||
kwargs["tools"] = tools
|
||||
|
||||
if "gpt-4" in chat_model:
|
||||
kwargs["max_tokens"] = self.bot.config.getint(
|
||||
"OpenAI", "MaxTokens", fallback=4000
|
||||
)
|
||||
kwargs["max_tokens"] = self.max_tokens
|
||||
|
||||
api_url = self.base_url
|
||||
|
||||
|
@ -295,7 +290,7 @@ class OpenAI:
|
|||
tool_response = await self.bot.call_tool(
|
||||
tool_call, room=room, user=user
|
||||
)
|
||||
if tool_response != False:
|
||||
if tool_response is not False:
|
||||
tool_responses.append(
|
||||
{
|
||||
"role": "tool",
|
||||
|
@ -316,7 +311,7 @@ class OpenAI:
|
|||
)
|
||||
|
||||
if not tool_responses:
|
||||
self.logger.log(f"No more responses received, aborting.")
|
||||
self.logger.log("No more responses received, aborting.")
|
||||
result_text = False
|
||||
else:
|
||||
try:
|
||||
|
@ -332,7 +327,7 @@ class OpenAI:
|
|||
except openai.APIError as e:
|
||||
if e.code == "max_tokens":
|
||||
self.logger.log(
|
||||
f"Max tokens exceeded, falling back to no-tools response."
|
||||
"Max tokens exceeded, falling back to no-tools response."
|
||||
)
|
||||
try:
|
||||
new_messages = []
|
||||
|
@ -381,7 +376,6 @@ class OpenAI:
|
|||
elif isinstance((tool_object := self.json_decode(result_text)), dict):
|
||||
if "tool" in tool_object:
|
||||
tool_name = tool_object["tool"]
|
||||
tool_class = TOOLS[tool_name]
|
||||
tool_parameters = (
|
||||
tool_object["parameters"] if "parameters" in tool_object else {}
|
||||
)
|
||||
|
@ -405,7 +399,7 @@ class OpenAI:
|
|||
tool_response = await self.bot.call_tool(
|
||||
tool_call, room=room, user=user
|
||||
)
|
||||
if tool_response != False:
|
||||
if tool_response is not False:
|
||||
tool_responses = [
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -425,7 +419,7 @@ class OpenAI:
|
|||
)
|
||||
|
||||
if not tool_responses:
|
||||
self.logger.log(f"No response received, aborting.")
|
||||
self.logger.log("No response received, aborting.")
|
||||
result_text = False
|
||||
else:
|
||||
try:
|
||||
|
@ -494,11 +488,13 @@ class OpenAI:
|
|||
)
|
||||
|
||||
if not result_text:
|
||||
self.logger.log(f"Received an empty response from the OpenAI endpoint.", "debug")
|
||||
self.logger.log(
|
||||
"Received an empty response from the OpenAI endpoint.", "debug"
|
||||
)
|
||||
|
||||
try:
|
||||
tokens_used = response.usage.total_tokens
|
||||
except:
|
||||
except Exception:
|
||||
tokens_used = 0
|
||||
|
||||
self.logger.log(f"Generated response with {tokens_used} tokens.")
|
||||
|
@ -580,7 +576,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
|||
Returns:
|
||||
Tuple[str, int]: The text and the number of tokens used.
|
||||
"""
|
||||
self.logger.log(f"Generating text from speech...")
|
||||
self.logger.log("Generating text from speech...")
|
||||
|
||||
audio_file = BytesIO()
|
||||
AudioSegment.from_file(BytesIO(audio)).export(audio_file, format="mp3")
|
||||
|
@ -667,18 +663,20 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
|||
Returns:
|
||||
Tuple[str, int]: The description and the number of tokens used.
|
||||
"""
|
||||
self.logger.log(f"Generating description for images in conversation...")
|
||||
self.logger.log("Generating description for images in conversation...")
|
||||
|
||||
system_message = "You are an image description generator. You generate descriptions for all images in the current conversation, one after another."
|
||||
|
||||
messages = [{"role": "system", "content": system_message}] + messages[1:]
|
||||
|
||||
if not "vision" in (chat_model := self.chat_model):
|
||||
chat_model = self.chat_model + "gpt-4-vision-preview"
|
||||
if "vision" not in (chat_model := self.chat_model) and chat_model not in (
|
||||
"gpt-4o",
|
||||
):
|
||||
chat_model = "gpt-4o"
|
||||
|
||||
chat_partial = partial(
|
||||
self.openai_api.chat.completions.create,
|
||||
model=self.chat_model,
|
||||
model=chat_model,
|
||||
messages=messages,
|
||||
user=str(user),
|
||||
)
|
|
@ -1,7 +1,6 @@
|
|||
import markdown2
|
||||
import tiktoken
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -15,8 +14,6 @@ from nio import (
|
|||
MatrixRoom,
|
||||
Api,
|
||||
RoomMessagesError,
|
||||
GroupEncryptionError,
|
||||
EncryptionError,
|
||||
RoomMessageText,
|
||||
RoomSendResponse,
|
||||
SyncResponse,
|
||||
|
@ -31,15 +28,12 @@ from nio import (
|
|||
RoomMessageFile,
|
||||
RoomMessageAudio,
|
||||
DownloadError,
|
||||
DownloadResponse,
|
||||
ToDeviceEvent,
|
||||
ToDeviceError,
|
||||
RoomGetStateError,
|
||||
)
|
||||
from nio.store import SqliteStore
|
||||
|
||||
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Any
|
||||
from configparser import ConfigParser
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
|
@ -50,21 +44,14 @@ import base64
|
|||
import uuid
|
||||
import traceback
|
||||
import json
|
||||
import importlib.util
|
||||
import sys
|
||||
import sqlite3
|
||||
import traceback
|
||||
|
||||
from .logging import Logger
|
||||
from ..migrations import migrate
|
||||
from ..callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
|
||||
from ..commands import COMMANDS
|
||||
from ..tools import TOOLS, Handover, StopProcessing
|
||||
|
||||
# TODO: Make these optional based on config
|
||||
from .openai import OpenAI
|
||||
from .wolframalpha import WolframAlpha
|
||||
from .trackingmore import TrackingMore
|
||||
from .ai.base import BaseAI
|
||||
|
||||
|
||||
class GPTBot:
|
||||
|
@ -74,12 +61,13 @@ class GPTBot:
|
|||
matrix_client: Optional[AsyncClient] = None
|
||||
sync_token: Optional[str] = None
|
||||
logger: Optional[Logger] = Logger()
|
||||
chat_api: Optional[OpenAI] = None
|
||||
image_api: Optional[OpenAI] = None
|
||||
classification_api: Optional[OpenAI] = None
|
||||
tts_api: Optional[OpenAI] = None
|
||||
stt_api: Optional[OpenAI] = None
|
||||
parcel_api: Optional[TrackingMore] = None
|
||||
chat_api: Optional[BaseAI] = None
|
||||
image_api: Optional[BaseAI] = None
|
||||
classification_api: Optional[BaseAI] = None
|
||||
tts_api: Optional[BaseAI] = None
|
||||
stt_api: Optional[BaseAI] = None
|
||||
parcel_api: Optional[Any] = None
|
||||
calculation_api: Optional[Any] = None
|
||||
room_ignore_list: List[str] = [] # List of rooms to ignore invites from
|
||||
logo: Optional[Image.Image] = None
|
||||
logo_uri: Optional[str] = None
|
||||
|
@ -96,7 +84,7 @@ class GPTBot:
|
|||
"""
|
||||
try:
|
||||
return json.loads(self.config["GPTBot"]["AllowedUsers"])
|
||||
except:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
@property
|
||||
|
@ -232,34 +220,41 @@ class GPTBot:
|
|||
if Path(bot.logo_path).exists() and Path(bot.logo_path).is_file():
|
||||
bot.logo = Image.open(bot.logo_path)
|
||||
|
||||
bot.chat_api = bot.image_api = bot.classification_api = bot.tts_api = (
|
||||
bot.stt_api
|
||||
) = OpenAI(
|
||||
bot=bot,
|
||||
api_key=config["OpenAI"]["APIKey"],
|
||||
chat_model=config["OpenAI"].get("Model"),
|
||||
image_model=config["OpenAI"].get("ImageModel"),
|
||||
tts_model=config["OpenAI"].get("TTSModel"),
|
||||
stt_model=config["OpenAI"].get("STTModel"),
|
||||
base_url=config["OpenAI"].get("BaseURL"),
|
||||
)
|
||||
# Set up OpenAI
|
||||
assert (
|
||||
"OpenAI" in config
|
||||
), "OpenAI config not found" # TODO: Update this to support other providers
|
||||
|
||||
if "BaseURL" in config["OpenAI"]:
|
||||
bot.chat_api.base_url = config["OpenAI"]["BaseURL"]
|
||||
bot.image_api = None
|
||||
from .ai.openai import OpenAI
|
||||
|
||||
openai_api = OpenAI(bot=bot, config=config["OpenAI"])
|
||||
|
||||
if "Model" in config["OpenAI"]:
|
||||
bot.chat_api = openai_api
|
||||
bot.classification_api = openai_api
|
||||
|
||||
if "ImageModel" in config["OpenAI"]:
|
||||
bot.image_api = openai_api
|
||||
|
||||
if "TTSModel" in config["OpenAI"]:
|
||||
bot.tts_api = openai_api
|
||||
|
||||
if "STTModel" in config["OpenAI"]:
|
||||
bot.stt_api = openai_api
|
||||
|
||||
# Set up WolframAlpha
|
||||
if "WolframAlpha" in config:
|
||||
from .wolframalpha import WolframAlpha
|
||||
bot.calculation_api = WolframAlpha(
|
||||
config["WolframAlpha"]["APIKey"], bot.logger
|
||||
)
|
||||
|
||||
# Set up TrackingMore
|
||||
if "TrackingMore" in config:
|
||||
from .trackingmore import TrackingMore
|
||||
bot.parcel_api = TrackingMore(config["TrackingMore"]["APIKey"], bot.logger)
|
||||
|
||||
# Set up the Matrix client
|
||||
|
||||
assert "Matrix" in config, "Matrix config not found"
|
||||
|
||||
homeserver = config["Matrix"]["Homeserver"]
|
||||
|
@ -339,7 +334,10 @@ class GPTBot:
|
|||
try:
|
||||
event_type = event.source["content"]["msgtype"]
|
||||
except KeyError:
|
||||
self.logger.log(f"Could not process event: {event}", "debug")
|
||||
if event.__class__.__name__ in ("RoomMemberEvent", ):
|
||||
self.logger.log(f"Ignoring event of type {event.__class__.__name__}", "debug")
|
||||
continue
|
||||
self.logger.log(f"Could not process event: {event}", "warning")
|
||||
continue # This is most likely not a message event
|
||||
|
||||
if event_type.startswith("gptbot"):
|
||||
|
@ -475,7 +473,7 @@ class GPTBot:
|
|||
except (Handover, StopProcessing):
|
||||
raise
|
||||
|
||||
except KeyError as e:
|
||||
except KeyError:
|
||||
self.logger.log(f"Tool {tool} not found", "error")
|
||||
return "Error: Tool not found"
|
||||
|
||||
|
|
Loading…
Reference in a new issue