Add image input on models that support it, fix some bugs, bump required OpenAI version
This commit is contained in:
parent
c238da9b99
commit
4113a02232
5 changed files with 84 additions and 26 deletions
|
@ -7,7 +7,7 @@ allow-direct-references = true
|
|||
|
||||
[project]
|
||||
name = "matrix-gptbot"
|
||||
version = "0.1.1"
|
||||
version = "0.2.0"
|
||||
|
||||
authors = [
|
||||
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
|
||||
|
@ -38,7 +38,7 @@ dependencies = [
|
|||
|
||||
[project.optional-dependencies]
|
||||
openai = [
|
||||
"openai",
|
||||
"openai>=1.2",
|
||||
]
|
||||
|
||||
wolframalpha = [
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
openai
|
||||
openai>=1.2
|
||||
matrix-nio[e2e]
|
||||
markdown2[all]
|
||||
tiktoken
|
||||
|
|
|
@ -27,6 +27,12 @@ from nio import (
|
|||
RoomSendError,
|
||||
RoomVisibility,
|
||||
RoomCreateError,
|
||||
RoomMessageMedia,
|
||||
RoomMessageImage,
|
||||
RoomMessageFile,
|
||||
RoomMessageAudio,
|
||||
DownloadError,
|
||||
DownloadResponse,
|
||||
)
|
||||
from nio.crypto import Olm
|
||||
from nio.store import SqliteStore
|
||||
|
@ -38,6 +44,7 @@ from io import BytesIO
|
|||
from pathlib import Path
|
||||
from contextlib import closing
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
import traceback
|
||||
import json
|
||||
|
@ -139,7 +146,7 @@ class GPTBot:
|
|||
|
||||
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
|
||||
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
|
||||
config["OpenAI"].get("ImageModel"), bot.logger
|
||||
config["OpenAI"].get("ImageModel"), config["OpenAI"].get("BaseURL"), bot.logger
|
||||
)
|
||||
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
||||
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
|
||||
|
@ -220,6 +227,7 @@ class GPTBot:
|
|||
for event in response.chunk:
|
||||
if len(messages) >= n:
|
||||
break
|
||||
|
||||
if isinstance(event, MegolmEvent):
|
||||
try:
|
||||
event = await self.matrix_client.decrypt_event(event)
|
||||
|
@ -229,14 +237,22 @@ class GPTBot:
|
|||
"error",
|
||||
)
|
||||
continue
|
||||
if isinstance(event, (RoomMessageText, RoomMessageNotice)):
|
||||
|
||||
if isinstance(event, RoomMessageText):
|
||||
if event.body.startswith("!gptbot ignoreolder"):
|
||||
break
|
||||
if (not event.body.startswith("!")) or (
|
||||
event.body.startswith("!gptbot") and not ignore_bot_commands
|
||||
):
|
||||
if (not event.body.startswith("!")) or (not ignore_bot_commands):
|
||||
messages.append(event)
|
||||
|
||||
if isinstance(event, RoomMessageNotice):
|
||||
if not ignore_bot_commands:
|
||||
messages.append(event)
|
||||
|
||||
if isinstance(event, RoomMessageMedia):
|
||||
if event.sender != self.matrix_client.user_id:
|
||||
if len(messages) < 2 or isinstance(messages[-1], RoomMessageMedia):
|
||||
messages.append(event)
|
||||
|
||||
self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
|
||||
|
||||
# Reverse the list so that messages are in chronological order
|
||||
|
@ -275,7 +291,7 @@ class GPTBot:
|
|||
truncated_messages = []
|
||||
|
||||
for message in [messages[0]] + list(reversed(messages[1:])):
|
||||
content = message["content"]
|
||||
content = message["content"] if isinstance(message["content"], str) else message["content"][0]["text"] if isinstance(message["content"][0].get("text"), str) else ""
|
||||
tokens = len(encoding.encode(content)) + 1
|
||||
if total_tokens + tokens > max_tokens:
|
||||
break
|
||||
|
@ -906,14 +922,39 @@ class GPTBot:
|
|||
|
||||
chat_messages = [{"role": "system", "content": system_message}]
|
||||
|
||||
for message in last_messages:
|
||||
text_messages = list(filter(lambda x: not isinstance(x, RoomMessageMedia), last_messages))
|
||||
|
||||
for message in text_messages:
|
||||
role = (
|
||||
"assistant" if message.sender == self.matrix_client.user_id else "user"
|
||||
)
|
||||
if not message.event_id == event.event_id:
|
||||
chat_messages.append({"role": role, "content": message.body})
|
||||
|
||||
chat_messages.append({"role": "user", "content": event.body})
|
||||
if not self.chat_api.supports_chat_images():
|
||||
event_body = event.body
|
||||
else:
|
||||
event_body = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": event.body
|
||||
}
|
||||
]
|
||||
|
||||
for m in list(filter(lambda x: isinstance(x, RoomMessageMedia), last_messages)):
|
||||
image_url = m.url
|
||||
download = await self.download_file(image_url)
|
||||
|
||||
if download:
|
||||
encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}"
|
||||
event_body.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": encoded_url
|
||||
}
|
||||
})
|
||||
|
||||
chat_messages.append({"role": "user", "content": event_body})
|
||||
|
||||
# Truncate messages to fit within the token limit
|
||||
truncated_messages = self._truncate(
|
||||
|
@ -926,6 +967,7 @@ class GPTBot:
|
|||
)
|
||||
except Exception as e:
|
||||
self.logger.log(f"Error generating response: {e}", "error")
|
||||
|
||||
await self.send_message(
|
||||
room, "Something went wrong. Please try again.", True
|
||||
)
|
||||
|
@ -954,6 +996,24 @@ class GPTBot:
|
|||
|
||||
await self.matrix_client.room_typing(room.room_id, False)
|
||||
|
||||
def download_file(self, mxc) -> Optional[bytes]:
|
||||
"""Download a file from the homeserver.
|
||||
|
||||
Args:
|
||||
mxc (str): The MXC URI of the file to download.
|
||||
|
||||
Returns:
|
||||
Optional[bytes]: The downloaded file, or None if there was an error.
|
||||
"""
|
||||
|
||||
download = self.matrix_client.download(mxc)
|
||||
|
||||
if isinstance(download, DownloadError):
|
||||
self.logger.log(f"Error downloading file: {download.message}", "error")
|
||||
return
|
||||
|
||||
return download
|
||||
|
||||
def get_system_message(self, room: MatrixRoom | str) -> str:
|
||||
"""Get the system message for a room.
|
||||
|
||||
|
|
|
@ -25,12 +25,13 @@ class OpenAI:
|
|||
|
||||
operator: str = "OpenAI ([https://openai.com](https://openai.com))"
|
||||
|
||||
def __init__(self, api_key, chat_model=None, image_model=None, logger=None):
|
||||
def __init__(self, api_key, chat_model=None, image_model=None, base_url=None, logger=None):
|
||||
self.api_key = api_key
|
||||
self.chat_model = chat_model or self.chat_model
|
||||
self.image_model = image_model or self.image_model
|
||||
self.logger = logger or Logger()
|
||||
self.base_url = openai.api_base
|
||||
self.base_url = base_url or openai.base_url
|
||||
self.openai_api = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
def supports_chat_images(self):
|
||||
return "vision" in self.chat_model
|
||||
|
@ -74,18 +75,20 @@ class OpenAI:
|
|||
|
||||
|
||||
chat_partial = partial(
|
||||
openai.ChatCompletion.acreate,
|
||||
self.openai_api.chat.completions.create,
|
||||
model=self.chat_model,
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
user=user,
|
||||
api_base=self.base_url,
|
||||
max_tokens=4096
|
||||
)
|
||||
response = await self._request_with_retries(chat_partial)
|
||||
|
||||
self.logger.log(response, "error")
|
||||
self.logger.log(response.choices, "error")
|
||||
self.logger.log(response.choices[0].message, "error")
|
||||
|
||||
result_text = response.choices[0].message['content']
|
||||
tokens_used = response.usage["total_tokens"]
|
||||
result_text = response.choices[0].message.content
|
||||
tokens_used = response.usage.total_tokens
|
||||
self.logger.log(f"Generated response with {tokens_used} tokens.")
|
||||
return result_text, tokens_used
|
||||
|
||||
|
@ -117,13 +120,10 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
|||
self.logger.log(f"Classifying message '{query}'...")
|
||||
|
||||
chat_partial = partial(
|
||||
openai.ChatCompletion.acreate,
|
||||
self.openai_api.chat.completions.create,
|
||||
model=self.chat_model,
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
user=user,
|
||||
api_base=self.base_url,
|
||||
quality=("hd" if model == "dall-e-3" else "normal")
|
||||
)
|
||||
response = await self._request_with_retries(chat_partial)
|
||||
|
||||
|
@ -150,14 +150,12 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
|||
self.logger.log(f"Generating image from prompt '{prompt}'...")
|
||||
|
||||
image_partial = partial(
|
||||
openai.Image.acreate,
|
||||
self.openai_api.images.generate,
|
||||
model=self.image_model,
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
api_key=self.api_key,
|
||||
size="1024x1024",
|
||||
user=user,
|
||||
api_base=self.base_url,
|
||||
)
|
||||
response = await self._request_with_retries(image_partial)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
|
|||
bot.logger.log(f"Sending image...")
|
||||
await bot.send_image(room, image)
|
||||
|
||||
bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_api}", tokens_used)
|
||||
bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_model}", tokens_used)
|
||||
|
||||
return
|
||||
|
||||
|
|
Loading…
Reference in a new issue