feat: add video file support and integrate Google AI
Introduced the capability to handle video files as input for AI models that support it, enhancing the bot's versatility in processing media. This update includes a new configuration option to enable or disable video input, catering to different model capabilities. Additionally, integrated Google's Generative AI through the addition of a Google dependency and a corresponding AI class implementation. This move broadens the AI options available, providing users with more flexibility in choosing their desired AI backend. The update involves refactoring and simplifying message preparation and handling, ensuring compatibility and extending functionality to include the new video input feature and Google AI support. - Added `ForceVideoInput` configuration option to toggle video file processing. - Integrated Google Generative AI as an optional dependency and included it in the bot's AI choices. - Implemented a unified method for preparing messages for AI processing, streamlining how the bot handles various message types. - Removed obsolete code related to message truncation and specialized handling for images, files, and audio, reflecting a shift towards a more flexible and generalized message processing approach.
This commit is contained in:
parent
05ba26d540
commit
c06da55d5d
5 changed files with 114 additions and 228 deletions
|
@ -149,6 +149,15 @@ APIKey = sk-yoursecretkey
|
|||
# MaxImageLongSide = 2000
|
||||
# MaxImageShortSide = 768
|
||||
|
||||
# Whether the used model supports video files as input
|
||||
#
|
||||
# If you are using a model that supports video files as input, set this to 1.
|
||||
# This will make the bot send video files to the model as well as images.
|
||||
# This may be possible with some self-hosted models, but is not supported by
|
||||
# the OpenAI API at this time.
|
||||
#
|
||||
# ForceVideoInput = 0
|
||||
|
||||
# Advanced settings for the OpenAI API
|
||||
#
|
||||
# These settings are not required for normal operation, but can be used to
|
||||
|
|
|
@ -7,7 +7,7 @@ allow-direct-references = true
|
|||
|
||||
[project]
|
||||
name = "matrix-gptbot"
|
||||
version = "0.3.14"
|
||||
version = "0.3.15"
|
||||
|
||||
authors = [
|
||||
{ name = "Kumi Mitterer", email = "gptbot@kumi.email" },
|
||||
|
@ -39,12 +39,14 @@ dependencies = [
|
|||
[project.optional-dependencies]
|
||||
openai = ["openai>=1.2", "pydub"]
|
||||
|
||||
google = ["google-generativeai"]
|
||||
|
||||
wolframalpha = ["wolframalpha"]
|
||||
|
||||
trackingmore = ["trackingmore-api-tool"]
|
||||
|
||||
all = [
|
||||
"matrix-gptbot[openai,wolframalpha,trackingmore]",
|
||||
"matrix-gptbot[openai,wolframalpha,trackingmore,google]",
|
||||
"geopy",
|
||||
"beautifulsoup4",
|
||||
]
|
||||
|
|
|
@ -4,6 +4,8 @@ import asyncio
|
|||
from functools import partial
|
||||
from typing import Any, AsyncGenerator, Dict, Optional, Mapping
|
||||
|
||||
from nio import Event
|
||||
|
||||
|
||||
class AttributeDictionary(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -24,6 +26,29 @@ class BaseAI:
|
|||
def chat_api(self) -> str:
|
||||
return self.chat_model
|
||||
|
||||
async def prepare_messages(
|
||||
self, event: Event, messages: list[Any], system_message: Optional[str] = None
|
||||
) -> list[Any]:
|
||||
"""A helper method to prepare messages for the AI.
|
||||
|
||||
This converts a list of Matrix messages into whatever format the AI requires.
|
||||
|
||||
Args:
|
||||
event (Event): The event that triggered the message generation. Generally a text message from a user.
|
||||
messages (list[Dict[str, str]]): The messages to prepare. Generally of type RoomMessage*.
|
||||
system_message (Optional[str], optional): A system message to include. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[Any]: The prepared messages in the format the AI requires.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented in the subclass.
|
||||
"""
|
||||
|
||||
raise NotImplementedError(
|
||||
"Implementations of BaseAI must implement prepare_messages."
|
||||
)
|
||||
|
||||
async def _request_with_retries(
|
||||
self, request: partial, attempts: int = 5, retry_interval: int = 2
|
||||
) -> AsyncGenerator[Any | list | Dict, None]:
|
||||
|
|
73
src/gptbot/classes/ai/google.py
Normal file
73
src/gptbot/classes/ai/google.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
from .base import BaseAI
|
||||
from ..logging import Logger
|
||||
|
||||
from typing import Optional, Mapping, List, Dict, Tuple
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
|
||||
class GeminiAI(BaseAI):
|
||||
api_code: str = "google"
|
||||
|
||||
@property
|
||||
def chat_api(self) -> str:
|
||||
return self.chat_model
|
||||
|
||||
google_api: genai.GenerativeModel
|
||||
|
||||
operator: str = "Google (https://ai.google)"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot,
|
||||
config: Mapping,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
super().__init__(bot, config, logger)
|
||||
genai.configure(api_key=self.api_key)
|
||||
self.gemini_api = genai.GenerativeModel(self.chat_model)
|
||||
|
||||
@property
|
||||
def api_key(self):
|
||||
return self._config["APIKey"]
|
||||
|
||||
@property
|
||||
def chat_model(self):
|
||||
return self._config.get("Model", fallback="gemini-pro")
|
||||
|
||||
def prepare_messages(event, messages: List[Dict[str, str]], ) -> List[str]:
|
||||
return [message["content"] for message in messages]
|
||||
|
||||
async def generate_chat_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
user: Optional[str] = None,
|
||||
room: Optional[str] = None,
|
||||
use_tools: bool = True,
|
||||
model: Optional[str] = None,
|
||||
) -> Tuple[str, int]:
|
||||
"""Generate a response to a chat message.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, str]]): A list of messages to use as context.
|
||||
user (Optional[str], optional): The user to use the assistant for. Defaults to None.
|
||||
room (Optional[str], optional): The room to use the assistant for. Defaults to None.
|
||||
use_tools (bool, optional): Whether to use tools. Defaults to True.
|
||||
model (Optional[str], optional): The model to use. Defaults to None, which uses the default chat model.
|
||||
|
||||
Returns:
|
||||
Tuple[str, int]: The response text and the number of tokens used.
|
||||
"""
|
||||
self.logger.log(
|
||||
f"Generating response to {len(messages)} messages for user {user} in room {room}..."
|
||||
)
|
||||
|
||||
messages = self.prepare_messages(messages)
|
||||
|
||||
return self.gemini_api.generate_content(
|
||||
messages=messages,
|
||||
user=user,
|
||||
room=room,
|
||||
use_tools=use_tools,
|
||||
model=model,
|
||||
)
|
|
@ -24,9 +24,6 @@ from nio import (
|
|||
RoomVisibility,
|
||||
RoomCreateError,
|
||||
RoomMessageMedia,
|
||||
RoomMessageImage,
|
||||
RoomMessageFile,
|
||||
RoomMessageAudio,
|
||||
DownloadError,
|
||||
RoomGetStateError,
|
||||
DiskDownloadResponse,
|
||||
|
@ -43,7 +40,6 @@ from io import BytesIO
|
|||
from pathlib import Path
|
||||
from contextlib import closing
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
import traceback
|
||||
import json
|
||||
|
@ -362,61 +358,6 @@ class GPTBot:
|
|||
# Reverse the list so that messages are in chronological order
|
||||
return messages[::-1]
|
||||
|
||||
def _truncate(
|
||||
self,
|
||||
messages: list,
|
||||
max_tokens: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
system_message: Optional[str] = None,
|
||||
):
|
||||
max_tokens = max_tokens or self.chat_api.max_tokens
|
||||
model = model or self.chat_api.chat_model
|
||||
system_message = (
|
||||
self.default_system_message if system_message is None else system_message
|
||||
)
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except Exception:
|
||||
# TODO: Handle this more gracefully
|
||||
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
system_message_tokens = (
|
||||
0 if not system_message else (len(encoding.encode(system_message)) + 1)
|
||||
)
|
||||
|
||||
if system_message_tokens > max_tokens:
|
||||
self.logger.log(
|
||||
f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed",
|
||||
"error",
|
||||
)
|
||||
return []
|
||||
|
||||
total_tokens += system_message_tokens
|
||||
|
||||
total_tokens = len(system_message) + 1
|
||||
truncated_messages = []
|
||||
|
||||
for message in [messages[0]] + list(reversed(messages[1:])):
|
||||
content = (
|
||||
message["content"]
|
||||
if isinstance(message["content"], str)
|
||||
else (
|
||||
message["content"][0]["text"]
|
||||
if isinstance(message["content"][0].get("text"), str)
|
||||
else ""
|
||||
)
|
||||
)
|
||||
tokens = len(encoding.encode(content)) + 1
|
||||
if total_tokens + tokens > max_tokens:
|
||||
break
|
||||
total_tokens += tokens
|
||||
truncated_messages.append(message)
|
||||
|
||||
return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
|
||||
|
||||
async def _get_device_id(self) -> str:
|
||||
"""Guess the device ID of the bot.
|
||||
Requires an access token to be set up.
|
||||
|
@ -1171,172 +1112,8 @@ class GPTBot:
|
|||
|
||||
system_message = self.get_system_message(room)
|
||||
|
||||
chat_messages = [{"role": "system", "content": system_message}]
|
||||
|
||||
last_messages = last_messages + [event]
|
||||
|
||||
for message in last_messages:
|
||||
if isinstance(message, (RoomMessageNotice, RoomMessageText)):
|
||||
role = (
|
||||
"assistant"
|
||||
if message.sender == self.matrix_client.user_id
|
||||
else "user"
|
||||
)
|
||||
if message == event or (not message.event_id == event.event_id):
|
||||
message_body = (
|
||||
message.body
|
||||
if not self.chat_api.supports_chat_images()
|
||||
else [{"type": "text", "text": message.body}]
|
||||
)
|
||||
chat_messages.append({"role": role, "content": message_body})
|
||||
|
||||
elif isinstance(message, RoomMessageAudio) or (
|
||||
isinstance(message, RoomMessageFile) and message.body.endswith(".mp3")
|
||||
):
|
||||
role = (
|
||||
"assistant"
|
||||
if message.sender == self.matrix_client.user_id
|
||||
else "user"
|
||||
)
|
||||
if message == event or (not message.event_id == event.event_id):
|
||||
if self.room_uses_stt(room):
|
||||
try:
|
||||
download = await self.download_file(
|
||||
message.url, raise_error=True
|
||||
)
|
||||
message_text = await self.stt_api.speech_to_text(
|
||||
download.body
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.log(
|
||||
f"Error generating text from audio: {e}", "error"
|
||||
)
|
||||
message_text = message.body
|
||||
else:
|
||||
message_text = message.body
|
||||
|
||||
message_body = (
|
||||
message_text
|
||||
if not self.chat_api.supports_chat_images()
|
||||
else [{"type": "text", "text": message_text}]
|
||||
)
|
||||
chat_messages.append({"role": role, "content": message_body})
|
||||
|
||||
elif isinstance(message, RoomMessageFile):
|
||||
try:
|
||||
download = await self.download_file(message.url, raise_error=True)
|
||||
if download:
|
||||
try:
|
||||
text = download.body.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
text = None
|
||||
|
||||
if text:
|
||||
role = (
|
||||
"assistant"
|
||||
if message.sender == self.matrix_client.user_id
|
||||
else "user"
|
||||
)
|
||||
if message == event or (
|
||||
not message.event_id == event.event_id
|
||||
):
|
||||
message_body = (
|
||||
text
|
||||
if not self.chat_api.supports_chat_images()
|
||||
else [{"type": "text", "text": text}]
|
||||
)
|
||||
chat_messages.append(
|
||||
{"role": role, "content": message_body}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.log(f"Error generating text from file: {e}", "error")
|
||||
message_body = (
|
||||
message.body
|
||||
if not self.chat_api.supports_chat_images()
|
||||
else [{"type": "text", "text": message.body}]
|
||||
)
|
||||
chat_messages.append({"role": "system", "content": message_body})
|
||||
|
||||
elif self.chat_api.supports_chat_images() and isinstance(
|
||||
message, RoomMessageImage
|
||||
):
|
||||
try:
|
||||
image_url = message.url
|
||||
download = await self.download_file(image_url, raise_error=True)
|
||||
|
||||
if download:
|
||||
pil_image = Image.open(BytesIO(download.body))
|
||||
|
||||
file_format = pil_image.format or "PNG"
|
||||
|
||||
max_long_side = self.chat_api.max_image_long_side
|
||||
max_short_side = self.chat_api.max_image_short_side
|
||||
|
||||
if max_long_side and max_short_side:
|
||||
if pil_image.width > pil_image.height:
|
||||
if pil_image.width > max_long_side:
|
||||
pil_image.thumbnail((max_long_side, max_short_side))
|
||||
|
||||
else:
|
||||
if pil_image.height > max_long_side:
|
||||
pil_image.thumbnail((max_short_side, max_long_side))
|
||||
|
||||
bio = BytesIO()
|
||||
|
||||
pil_image.save(bio, format=file_format)
|
||||
|
||||
encoded_url = f"data:{download.content_type};base64,{base64.b64encode(bio.getvalue()).decode('utf-8')}"
|
||||
parent = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
and chat_messages[-1]["role"]
|
||||
== (
|
||||
"assistant"
|
||||
if message.sender == self.matrix_client.user_id
|
||||
else "user"
|
||||
)
|
||||
else None
|
||||
)
|
||||
|
||||
if not parent:
|
||||
chat_messages.append(
|
||||
{
|
||||
"role": (
|
||||
"assistant"
|
||||
if message.sender == self.matrix_client.user_id
|
||||
else "user"
|
||||
),
|
||||
"content": [],
|
||||
}
|
||||
)
|
||||
parent = chat_messages[-1]
|
||||
|
||||
parent["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": encoded_url}}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, DownloadException):
|
||||
self.send_message(
|
||||
room,
|
||||
f"Could not process image due to download error: {e.args[0]}",
|
||||
True,
|
||||
)
|
||||
|
||||
self.logger.log(f"Error generating image from file: {e}", "error")
|
||||
message_body = (
|
||||
message.body
|
||||
if not self.chat_api.supports_chat_images()
|
||||
else [{"type": "text", "text": message.body}]
|
||||
)
|
||||
chat_messages.append({"role": "system", "content": message_body})
|
||||
|
||||
# Truncate messages to fit within the token limit
|
||||
self._truncate(
|
||||
chat_messages[1:],
|
||||
self.chat_api.max_tokens - 1,
|
||||
system_message=system_message,
|
||||
chat_messages = await self.chat_api.prepare_messages(
|
||||
last_messages, system_message
|
||||
)
|
||||
|
||||
# Check for a model override
|
||||
|
@ -1382,7 +1159,7 @@ class GPTBot:
|
|||
room, "Something went wrong generating audio file.", True
|
||||
)
|
||||
|
||||
message = await self.send_message(room, response)
|
||||
await self.send_message(room, response)
|
||||
|
||||
await self.matrix_client.room_typing(room.room_id, False)
|
||||
|
||||
|
|
Loading…
Reference in a new issue