Add image input on models that support it, fix some bugs, bump required OpenAI version

This commit is contained in:
Kumi 2023-11-11 12:27:19 +01:00
parent c238da9b99
commit 4113a02232
Signed by: kumi
GPG key ID: ECBCC9082395383F
5 changed files with 84 additions and 26 deletions

View file

@ -7,7 +7,7 @@ allow-direct-references = true
[project]
name = "matrix-gptbot"
version = "0.1.1"
version = "0.2.0"
authors = [
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
@ -38,7 +38,7 @@ dependencies = [
[project.optional-dependencies]
openai = [
"openai",
"openai>=1.2",
]
wolframalpha = [

View file

@ -1,4 +1,4 @@
openai
openai>=1.2
matrix-nio[e2e]
markdown2[all]
tiktoken

View file

@ -27,6 +27,12 @@ from nio import (
RoomSendError,
RoomVisibility,
RoomCreateError,
RoomMessageMedia,
RoomMessageImage,
RoomMessageFile,
RoomMessageAudio,
DownloadError,
DownloadResponse,
)
from nio.crypto import Olm
from nio.store import SqliteStore
@ -38,6 +44,7 @@ from io import BytesIO
from pathlib import Path
from contextlib import closing
import base64
import uuid
import traceback
import json
@ -139,7 +146,7 @@ class GPTBot:
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
config["OpenAI"].get("ImageModel"), bot.logger
config["OpenAI"].get("ImageModel"), config["OpenAI"].get("BaseURL"), bot.logger
)
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
@ -220,6 +227,7 @@ class GPTBot:
for event in response.chunk:
if len(messages) >= n:
break
if isinstance(event, MegolmEvent):
try:
event = await self.matrix_client.decrypt_event(event)
@ -229,12 +237,20 @@ class GPTBot:
"error",
)
continue
if isinstance(event, (RoomMessageText, RoomMessageNotice)):
if isinstance(event, RoomMessageText):
if event.body.startswith("!gptbot ignoreolder"):
break
if (not event.body.startswith("!")) or (
event.body.startswith("!gptbot") and not ignore_bot_commands
):
if (not event.body.startswith("!")) or (not ignore_bot_commands):
messages.append(event)
if isinstance(event, RoomMessageNotice):
if not ignore_bot_commands:
messages.append(event)
if isinstance(event, RoomMessageMedia):
if event.sender != self.matrix_client.user_id:
if len(messages) < 2 or isinstance(messages[-1], RoomMessageMedia):
messages.append(event)
self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
@ -275,7 +291,7 @@ class GPTBot:
truncated_messages = []
for message in [messages[0]] + list(reversed(messages[1:])):
content = message["content"]
content = message["content"] if isinstance(message["content"], str) else message["content"][0]["text"] if isinstance(message["content"][0].get("text"), str) else ""
tokens = len(encoding.encode(content)) + 1
if total_tokens + tokens > max_tokens:
break
@ -906,14 +922,39 @@ class GPTBot:
chat_messages = [{"role": "system", "content": system_message}]
for message in last_messages:
text_messages = list(filter(lambda x: not isinstance(x, RoomMessageMedia), last_messages))
for message in text_messages:
role = (
"assistant" if message.sender == self.matrix_client.user_id else "user"
)
if not message.event_id == event.event_id:
chat_messages.append({"role": role, "content": message.body})
chat_messages.append({"role": "user", "content": event.body})
if not self.chat_api.supports_chat_images():
event_body = event.body
else:
event_body = [
{
"type": "text",
"text": event.body
}
]
for m in list(filter(lambda x: isinstance(x, RoomMessageMedia), last_messages)):
image_url = m.url
download = await self.download_file(image_url)
if download:
encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}"
event_body.append({
"type": "image_url",
"image_url": {
"url": encoded_url
}
})
chat_messages.append({"role": "user", "content": event_body})
# Truncate messages to fit within the token limit
truncated_messages = self._truncate(
@ -926,6 +967,7 @@ class GPTBot:
)
except Exception as e:
self.logger.log(f"Error generating response: {e}", "error")
await self.send_message(
room, "Something went wrong. Please try again.", True
)
@ -954,6 +996,24 @@ class GPTBot:
await self.matrix_client.room_typing(room.room_id, False)
def download_file(self, mxc) -> Optional[bytes]:
"""Download a file from the homeserver.
Args:
mxc (str): The MXC URI of the file to download.
Returns:
Optional[bytes]: The downloaded file, or None if there was an error.
"""
download = self.matrix_client.download(mxc)
if isinstance(download, DownloadError):
self.logger.log(f"Error downloading file: {download.message}", "error")
return
return download
def get_system_message(self, room: MatrixRoom | str) -> str:
"""Get the system message for a room.

View file

@ -25,12 +25,13 @@ class OpenAI:
operator: str = "OpenAI ([https://openai.com](https://openai.com))"
def __init__(self, api_key, chat_model=None, image_model=None, logger=None):
def __init__(self, api_key, chat_model=None, image_model=None, base_url=None, logger=None):
self.api_key = api_key
self.chat_model = chat_model or self.chat_model
self.image_model = image_model or self.image_model
self.logger = logger or Logger()
self.base_url = openai.api_base
self.base_url = base_url or openai.base_url
self.openai_api = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
def supports_chat_images(self):
return "vision" in self.chat_model
@ -74,18 +75,20 @@ class OpenAI:
chat_partial = partial(
openai.ChatCompletion.acreate,
self.openai_api.chat.completions.create,
model=self.chat_model,
messages=messages,
api_key=self.api_key,
user=user,
api_base=self.base_url,
max_tokens=4096
)
response = await self._request_with_retries(chat_partial)
self.logger.log(response, "error")
self.logger.log(response.choices, "error")
self.logger.log(response.choices[0].message, "error")
result_text = response.choices[0].message['content']
tokens_used = response.usage["total_tokens"]
result_text = response.choices[0].message.content
tokens_used = response.usage.total_tokens
self.logger.log(f"Generated response with {tokens_used} tokens.")
return result_text, tokens_used
@ -117,13 +120,10 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
self.logger.log(f"Classifying message '{query}'...")
chat_partial = partial(
openai.ChatCompletion.acreate,
self.openai_api.chat.completions.create,
model=self.chat_model,
messages=messages,
api_key=self.api_key,
user=user,
api_base=self.base_url,
quality=("hd" if model == "dall-e-3" else "normal")
)
response = await self._request_with_retries(chat_partial)
@ -150,14 +150,12 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
self.logger.log(f"Generating image from prompt '{prompt}'...")
image_partial = partial(
openai.Image.acreate,
self.openai_api.images.generate,
model=self.image_model,
prompt=prompt,
n=1,
api_key=self.api_key,
size="1024x1024",
user=user,
api_base=self.base_url,
)
response = await self._request_with_retries(image_partial)

View file

@ -19,7 +19,7 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
bot.logger.log(f"Sending image...")
await bot.send_image(room, image)
bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_api}", tokens_used)
bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_model}", tokens_used)
return