feat: add video file support and integrate Google AI
All checks were successful
Docker CI/CD / Docker Build and Push to Docker Hub (push) Successful in 8m44s
Python Package CI/CD / Setup and Test (push) Successful in 1m22s
Python Package CI/CD / Publish to PyPI (push) Successful in 40s

Introduced the capability to handle video files as input for AI models that support it, enhancing the bot's versatility in processing media. This update includes a new configuration option to enable or disable video input, catering to different model capabilities. Additionally, integrated Google's Generative AI through the addition of a Google dependency and a corresponding AI class implementation. This move broadens the AI options available, providing users with more flexibility in choosing their desired AI backend. The update involves refactoring and simplifying message preparation and handling, ensuring compatibility and extending functionality to include the new video input feature and Google AI support.

- Added `ForceVideoInput` configuration option to toggle video file processing.
- Integrated Google Generative AI as an optional dependency and included it in the bot's AI choices.
- Implemented a unified method for preparing messages for AI processing, streamlining how the bot handles various message types.
- Removed obsolete code related to message truncation and specialized handling for images, files, and audio, reflecting a shift towards a more flexible and generalized message processing approach.
This commit is contained in:
Kumi 2024-05-25 17:35:22 +02:00
parent 05ba26d540
commit c06da55d5d
Signed by: kumi
GPG key ID: ECBCC9082395383F
5 changed files with 114 additions and 228 deletions

View file

@ -149,6 +149,15 @@ APIKey = sk-yoursecretkey
# MaxImageLongSide = 2000
# MaxImageShortSide = 768
# Whether the used model supports video files as input
#
# If you are using a model that supports video files as input, set this to 1.
# This will make the bot send video files to the model as well as images.
# This may be possible with some self-hosted models, but is not supported by
# the OpenAI API at this time.
#
# ForceVideoInput = 0
# Advanced settings for the OpenAI API
#
# These settings are not required for normal operation, but can be used to

View file

@ -7,7 +7,7 @@ allow-direct-references = true
[project]
name = "matrix-gptbot"
version = "0.3.14"
version = "0.3.15"
authors = [
{ name = "Kumi Mitterer", email = "gptbot@kumi.email" },
@ -39,12 +39,14 @@ dependencies = [
[project.optional-dependencies]
openai = ["openai>=1.2", "pydub"]
google = ["google-generativeai"]
wolframalpha = ["wolframalpha"]
trackingmore = ["trackingmore-api-tool"]
all = [
"matrix-gptbot[openai,wolframalpha,trackingmore]",
"matrix-gptbot[openai,wolframalpha,trackingmore,google]",
"geopy",
"beautifulsoup4",
]

View file

@ -4,6 +4,8 @@ import asyncio
from functools import partial
from typing import Any, AsyncGenerator, Dict, Optional, Mapping
from nio import Event
class AttributeDictionary(dict):
def __init__(self, *args, **kwargs):
@ -24,6 +26,29 @@ class BaseAI:
def chat_api(self) -> str:
return self.chat_model
async def prepare_messages(
self, event: Event, messages: list[Any], system_message: Optional[str] = None
) -> list[Any]:
"""A helper method to prepare messages for the AI.
This converts a list of Matrix messages into whatever format the AI requires.
Args:
event (Event): The event that triggered the message generation. Generally a text message from a user.
messages (list[Dict[str, str]]): The messages to prepare. Generally of type RoomMessage*.
system_message (Optional[str], optional): A system message to include. Defaults to None.
Returns:
list[Any]: The prepared messages in the format the AI requires.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
raise NotImplementedError(
"Implementations of BaseAI must implement prepare_messages."
)
async def _request_with_retries(
self, request: partial, attempts: int = 5, retry_interval: int = 2
) -> AsyncGenerator[Any | list | Dict, None]:

View file

@ -0,0 +1,73 @@
from .base import BaseAI
from ..logging import Logger
from typing import Optional, Mapping, List, Dict, Tuple
import google.generativeai as genai
class GeminiAI(BaseAI):
api_code: str = "google"
@property
def chat_api(self) -> str:
return self.chat_model
google_api: genai.GenerativeModel
operator: str = "Google (https://ai.google)"
def __init__(
self,
bot,
config: Mapping,
logger: Optional[Logger] = None,
):
super().__init__(bot, config, logger)
genai.configure(api_key=self.api_key)
self.gemini_api = genai.GenerativeModel(self.chat_model)
@property
def api_key(self):
return self._config["APIKey"]
@property
def chat_model(self):
return self._config.get("Model", fallback="gemini-pro")
def prepare_messages(event, messages: List[Dict[str, str]], ) -> List[str]:
return [message["content"] for message in messages]
async def generate_chat_response(
self,
messages: List[Dict[str, str]],
user: Optional[str] = None,
room: Optional[str] = None,
use_tools: bool = True,
model: Optional[str] = None,
) -> Tuple[str, int]:
"""Generate a response to a chat message.
Args:
messages (List[Dict[str, str]]): A list of messages to use as context.
user (Optional[str], optional): The user to use the assistant for. Defaults to None.
room (Optional[str], optional): The room to use the assistant for. Defaults to None.
use_tools (bool, optional): Whether to use tools. Defaults to True.
model (Optional[str], optional): The model to use. Defaults to None, which uses the default chat model.
Returns:
Tuple[str, int]: The response text and the number of tokens used.
"""
self.logger.log(
f"Generating response to {len(messages)} messages for user {user} in room {room}..."
)
messages = self.prepare_messages(messages)
return self.gemini_api.generate_content(
messages=messages,
user=user,
room=room,
use_tools=use_tools,
model=model,
)

View file

@ -24,9 +24,6 @@ from nio import (
RoomVisibility,
RoomCreateError,
RoomMessageMedia,
RoomMessageImage,
RoomMessageFile,
RoomMessageAudio,
DownloadError,
RoomGetStateError,
DiskDownloadResponse,
@ -43,7 +40,6 @@ from io import BytesIO
from pathlib import Path
from contextlib import closing
import base64
import uuid
import traceback
import json
@ -362,61 +358,6 @@ class GPTBot:
# Reverse the list so that messages are in chronological order
return messages[::-1]
def _truncate(
self,
messages: list,
max_tokens: Optional[int] = None,
model: Optional[str] = None,
system_message: Optional[str] = None,
):
max_tokens = max_tokens or self.chat_api.max_tokens
model = model or self.chat_api.chat_model
system_message = (
self.default_system_message if system_message is None else system_message
)
try:
encoding = tiktoken.encoding_for_model(model)
except Exception:
# TODO: Handle this more gracefully
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
total_tokens = 0
system_message_tokens = (
0 if not system_message else (len(encoding.encode(system_message)) + 1)
)
if system_message_tokens > max_tokens:
self.logger.log(
f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed",
"error",
)
return []
total_tokens += system_message_tokens
total_tokens = len(system_message) + 1
truncated_messages = []
for message in [messages[0]] + list(reversed(messages[1:])):
content = (
message["content"]
if isinstance(message["content"], str)
else (
message["content"][0]["text"]
if isinstance(message["content"][0].get("text"), str)
else ""
)
)
tokens = len(encoding.encode(content)) + 1
if total_tokens + tokens > max_tokens:
break
total_tokens += tokens
truncated_messages.append(message)
return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
async def _get_device_id(self) -> str:
"""Guess the device ID of the bot.
Requires an access token to be set up.
@ -1171,172 +1112,8 @@ class GPTBot:
system_message = self.get_system_message(room)
chat_messages = [{"role": "system", "content": system_message}]
last_messages = last_messages + [event]
for message in last_messages:
if isinstance(message, (RoomMessageNotice, RoomMessageText)):
role = (
"assistant"
if message.sender == self.matrix_client.user_id
else "user"
)
if message == event or (not message.event_id == event.event_id):
message_body = (
message.body
if not self.chat_api.supports_chat_images()
else [{"type": "text", "text": message.body}]
)
chat_messages.append({"role": role, "content": message_body})
elif isinstance(message, RoomMessageAudio) or (
isinstance(message, RoomMessageFile) and message.body.endswith(".mp3")
):
role = (
"assistant"
if message.sender == self.matrix_client.user_id
else "user"
)
if message == event or (not message.event_id == event.event_id):
if self.room_uses_stt(room):
try:
download = await self.download_file(
message.url, raise_error=True
)
message_text = await self.stt_api.speech_to_text(
download.body
)
except Exception as e:
self.logger.log(
f"Error generating text from audio: {e}", "error"
)
message_text = message.body
else:
message_text = message.body
message_body = (
message_text
if not self.chat_api.supports_chat_images()
else [{"type": "text", "text": message_text}]
)
chat_messages.append({"role": role, "content": message_body})
elif isinstance(message, RoomMessageFile):
try:
download = await self.download_file(message.url, raise_error=True)
if download:
try:
text = download.body.decode("utf-8")
except UnicodeDecodeError:
text = None
if text:
role = (
"assistant"
if message.sender == self.matrix_client.user_id
else "user"
)
if message == event or (
not message.event_id == event.event_id
):
message_body = (
text
if not self.chat_api.supports_chat_images()
else [{"type": "text", "text": text}]
)
chat_messages.append(
{"role": role, "content": message_body}
)
except Exception as e:
self.logger.log(f"Error generating text from file: {e}", "error")
message_body = (
message.body
if not self.chat_api.supports_chat_images()
else [{"type": "text", "text": message.body}]
)
chat_messages.append({"role": "system", "content": message_body})
elif self.chat_api.supports_chat_images() and isinstance(
message, RoomMessageImage
):
try:
image_url = message.url
download = await self.download_file(image_url, raise_error=True)
if download:
pil_image = Image.open(BytesIO(download.body))
file_format = pil_image.format or "PNG"
max_long_side = self.chat_api.max_image_long_side
max_short_side = self.chat_api.max_image_short_side
if max_long_side and max_short_side:
if pil_image.width > pil_image.height:
if pil_image.width > max_long_side:
pil_image.thumbnail((max_long_side, max_short_side))
else:
if pil_image.height > max_long_side:
pil_image.thumbnail((max_short_side, max_long_side))
bio = BytesIO()
pil_image.save(bio, format=file_format)
encoded_url = f"data:{download.content_type};base64,{base64.b64encode(bio.getvalue()).decode('utf-8')}"
parent = (
chat_messages[-1]
if chat_messages
and chat_messages[-1]["role"]
== (
"assistant"
if message.sender == self.matrix_client.user_id
else "user"
)
else None
)
if not parent:
chat_messages.append(
{
"role": (
"assistant"
if message.sender == self.matrix_client.user_id
else "user"
),
"content": [],
}
)
parent = chat_messages[-1]
parent["content"].append(
{"type": "image_url", "image_url": {"url": encoded_url}}
)
except Exception as e:
if isinstance(e, DownloadException):
self.send_message(
room,
f"Could not process image due to download error: {e.args[0]}",
True,
)
self.logger.log(f"Error generating image from file: {e}", "error")
message_body = (
message.body
if not self.chat_api.supports_chat_images()
else [{"type": "text", "text": message.body}]
)
chat_messages.append({"role": "system", "content": message_body})
# Truncate messages to fit within the token limit
self._truncate(
chat_messages[1:],
self.chat_api.max_tokens - 1,
system_message=system_message,
chat_messages = await self.chat_api.prepare_messages(
last_messages, system_message
)
# Check for a model override
@ -1382,7 +1159,7 @@ class GPTBot:
room, "Something went wrong generating audio file.", True
)
message = await self.send_message(room, response)
await self.send_message(room, response)
await self.matrix_client.room_typing(room.room_id, False)