Add image input on models that support it, fix some bugs, bump required OpenAI version

This commit is contained in:
Kumi 2023-11-11 12:27:19 +01:00
parent c238da9b99
commit 4113a02232
Signed by: kumi
GPG key ID: ECBCC9082395383F
5 changed files with 84 additions and 26 deletions

View file

@ -7,7 +7,7 @@ allow-direct-references = true
[project] [project]
name = "matrix-gptbot" name = "matrix-gptbot"
version = "0.1.1" version = "0.2.0"
authors = [ authors = [
{ name="Kumi Mitterer", email="gptbot@kumi.email" }, { name="Kumi Mitterer", email="gptbot@kumi.email" },
@ -38,7 +38,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
openai = [ openai = [
"openai", "openai>=1.2",
] ]
wolframalpha = [ wolframalpha = [

View file

@ -1,4 +1,4 @@
openai openai>=1.2
matrix-nio[e2e] matrix-nio[e2e]
markdown2[all] markdown2[all]
tiktoken tiktoken

View file

@ -27,6 +27,12 @@ from nio import (
RoomSendError, RoomSendError,
RoomVisibility, RoomVisibility,
RoomCreateError, RoomCreateError,
RoomMessageMedia,
RoomMessageImage,
RoomMessageFile,
RoomMessageAudio,
DownloadError,
DownloadResponse,
) )
from nio.crypto import Olm from nio.crypto import Olm
from nio.store import SqliteStore from nio.store import SqliteStore
@ -38,6 +44,7 @@ from io import BytesIO
from pathlib import Path from pathlib import Path
from contextlib import closing from contextlib import closing
import base64
import uuid import uuid
import traceback import traceback
import json import json
@ -139,7 +146,7 @@ class GPTBot:
bot.chat_api = bot.image_api = bot.classification_api = OpenAI( bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
config["OpenAI"].get("ImageModel"), bot.logger config["OpenAI"].get("ImageModel"), config["OpenAI"].get("BaseURL"), bot.logger
) )
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens) bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages) bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
@ -220,6 +227,7 @@ class GPTBot:
for event in response.chunk: for event in response.chunk:
if len(messages) >= n: if len(messages) >= n:
break break
if isinstance(event, MegolmEvent): if isinstance(event, MegolmEvent):
try: try:
event = await self.matrix_client.decrypt_event(event) event = await self.matrix_client.decrypt_event(event)
@ -229,14 +237,22 @@ class GPTBot:
"error", "error",
) )
continue continue
if isinstance(event, (RoomMessageText, RoomMessageNotice)):
if isinstance(event, RoomMessageText):
if event.body.startswith("!gptbot ignoreolder"): if event.body.startswith("!gptbot ignoreolder"):
break break
if (not event.body.startswith("!")) or ( if (not event.body.startswith("!")) or (not ignore_bot_commands):
event.body.startswith("!gptbot") and not ignore_bot_commands
):
messages.append(event) messages.append(event)
if isinstance(event, RoomMessageNotice):
if not ignore_bot_commands:
messages.append(event)
if isinstance(event, RoomMessageMedia):
if event.sender != self.matrix_client.user_id:
if len(messages) < 2 or isinstance(messages[-1], RoomMessageMedia):
messages.append(event)
self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug") self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
# Reverse the list so that messages are in chronological order # Reverse the list so that messages are in chronological order
@ -275,7 +291,7 @@ class GPTBot:
truncated_messages = [] truncated_messages = []
for message in [messages[0]] + list(reversed(messages[1:])): for message in [messages[0]] + list(reversed(messages[1:])):
content = message["content"] content = message["content"] if isinstance(message["content"], str) else message["content"][0]["text"] if isinstance(message["content"][0].get("text"), str) else ""
tokens = len(encoding.encode(content)) + 1 tokens = len(encoding.encode(content)) + 1
if total_tokens + tokens > max_tokens: if total_tokens + tokens > max_tokens:
break break
@ -906,14 +922,39 @@ class GPTBot:
chat_messages = [{"role": "system", "content": system_message}] chat_messages = [{"role": "system", "content": system_message}]
for message in last_messages: text_messages = list(filter(lambda x: not isinstance(x, RoomMessageMedia), last_messages))
for message in text_messages:
role = ( role = (
"assistant" if message.sender == self.matrix_client.user_id else "user" "assistant" if message.sender == self.matrix_client.user_id else "user"
) )
if not message.event_id == event.event_id: if not message.event_id == event.event_id:
chat_messages.append({"role": role, "content": message.body}) chat_messages.append({"role": role, "content": message.body})
chat_messages.append({"role": "user", "content": event.body}) if not self.chat_api.supports_chat_images():
event_body = event.body
else:
event_body = [
{
"type": "text",
"text": event.body
}
]
for m in list(filter(lambda x: isinstance(x, RoomMessageMedia), last_messages)):
image_url = m.url
download = await self.download_file(image_url)
if download:
encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}"
event_body.append({
"type": "image_url",
"image_url": {
"url": encoded_url
}
})
chat_messages.append({"role": "user", "content": event_body})
# Truncate messages to fit within the token limit # Truncate messages to fit within the token limit
truncated_messages = self._truncate( truncated_messages = self._truncate(
@ -926,6 +967,7 @@ class GPTBot:
) )
except Exception as e: except Exception as e:
self.logger.log(f"Error generating response: {e}", "error") self.logger.log(f"Error generating response: {e}", "error")
await self.send_message( await self.send_message(
room, "Something went wrong. Please try again.", True room, "Something went wrong. Please try again.", True
) )
@ -954,6 +996,24 @@ class GPTBot:
await self.matrix_client.room_typing(room.room_id, False) await self.matrix_client.room_typing(room.room_id, False)
def download_file(self, mxc) -> Optional[bytes]:
"""Download a file from the homeserver.
Args:
mxc (str): The MXC URI of the file to download.
Returns:
Optional[bytes]: The downloaded file, or None if there was an error.
"""
download = self.matrix_client.download(mxc)
if isinstance(download, DownloadError):
self.logger.log(f"Error downloading file: {download.message}", "error")
return
return download
def get_system_message(self, room: MatrixRoom | str) -> str: def get_system_message(self, room: MatrixRoom | str) -> str:
"""Get the system message for a room. """Get the system message for a room.

View file

@ -25,12 +25,13 @@ class OpenAI:
operator: str = "OpenAI ([https://openai.com](https://openai.com))" operator: str = "OpenAI ([https://openai.com](https://openai.com))"
def __init__(self, api_key, chat_model=None, image_model=None, logger=None): def __init__(self, api_key, chat_model=None, image_model=None, base_url=None, logger=None):
self.api_key = api_key self.api_key = api_key
self.chat_model = chat_model or self.chat_model self.chat_model = chat_model or self.chat_model
self.image_model = image_model or self.image_model self.image_model = image_model or self.image_model
self.logger = logger or Logger() self.logger = logger or Logger()
self.base_url = openai.api_base self.base_url = base_url or openai.base_url
self.openai_api = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
def supports_chat_images(self): def supports_chat_images(self):
return "vision" in self.chat_model return "vision" in self.chat_model
@ -74,18 +75,20 @@ class OpenAI:
chat_partial = partial( chat_partial = partial(
openai.ChatCompletion.acreate, self.openai_api.chat.completions.create,
model=self.chat_model, model=self.chat_model,
messages=messages, messages=messages,
api_key=self.api_key,
user=user, user=user,
api_base=self.base_url, max_tokens=4096
) )
response = await self._request_with_retries(chat_partial) response = await self._request_with_retries(chat_partial)
self.logger.log(response, "error")
self.logger.log(response.choices, "error")
self.logger.log(response.choices[0].message, "error")
result_text = response.choices[0].message['content'] result_text = response.choices[0].message.content
tokens_used = response.usage["total_tokens"] tokens_used = response.usage.total_tokens
self.logger.log(f"Generated response with {tokens_used} tokens.") self.logger.log(f"Generated response with {tokens_used} tokens.")
return result_text, tokens_used return result_text, tokens_used
@ -117,13 +120,10 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
self.logger.log(f"Classifying message '{query}'...") self.logger.log(f"Classifying message '{query}'...")
chat_partial = partial( chat_partial = partial(
openai.ChatCompletion.acreate, self.openai_api.chat.completions.create,
model=self.chat_model, model=self.chat_model,
messages=messages, messages=messages,
api_key=self.api_key,
user=user, user=user,
api_base=self.base_url,
quality=("hd" if model == "dall-e-3" else "normal")
) )
response = await self._request_with_retries(chat_partial) response = await self._request_with_retries(chat_partial)
@ -150,14 +150,12 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
self.logger.log(f"Generating image from prompt '{prompt}'...") self.logger.log(f"Generating image from prompt '{prompt}'...")
image_partial = partial( image_partial = partial(
openai.Image.acreate, self.openai_api.images.generate,
model=self.image_model, model=self.image_model,
prompt=prompt, prompt=prompt,
n=1, n=1,
api_key=self.api_key,
size="1024x1024", size="1024x1024",
user=user, user=user,
api_base=self.base_url,
) )
response = await self._request_with_retries(image_partial) response = await self._request_with_retries(image_partial)

View file

@ -19,7 +19,7 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
bot.logger.log(f"Sending image...") bot.logger.log(f"Sending image...")
await bot.send_image(room, image) await bot.send_image(room, image)
bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_api}", tokens_used) bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_model}", tokens_used)
return return