Add image input on models that support it, fix some bugs, bump required OpenAI version
This commit is contained in:
parent
c238da9b99
commit
4113a02232
5 changed files with 84 additions and 26 deletions
|
@ -7,7 +7,7 @@ allow-direct-references = true
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "matrix-gptbot"
|
name = "matrix-gptbot"
|
||||||
version = "0.1.1"
|
version = "0.2.0"
|
||||||
|
|
||||||
authors = [
|
authors = [
|
||||||
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
|
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
|
||||||
|
@ -38,7 +38,7 @@ dependencies = [
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
openai = [
|
openai = [
|
||||||
"openai",
|
"openai>=1.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
wolframalpha = [
|
wolframalpha = [
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
openai
|
openai>=1.2
|
||||||
matrix-nio[e2e]
|
matrix-nio[e2e]
|
||||||
markdown2[all]
|
markdown2[all]
|
||||||
tiktoken
|
tiktoken
|
||||||
|
|
|
@ -27,6 +27,12 @@ from nio import (
|
||||||
RoomSendError,
|
RoomSendError,
|
||||||
RoomVisibility,
|
RoomVisibility,
|
||||||
RoomCreateError,
|
RoomCreateError,
|
||||||
|
RoomMessageMedia,
|
||||||
|
RoomMessageImage,
|
||||||
|
RoomMessageFile,
|
||||||
|
RoomMessageAudio,
|
||||||
|
DownloadError,
|
||||||
|
DownloadResponse,
|
||||||
)
|
)
|
||||||
from nio.crypto import Olm
|
from nio.crypto import Olm
|
||||||
from nio.store import SqliteStore
|
from nio.store import SqliteStore
|
||||||
|
@ -38,6 +44,7 @@ from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
|
||||||
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
import traceback
|
import traceback
|
||||||
import json
|
import json
|
||||||
|
@ -139,7 +146,7 @@ class GPTBot:
|
||||||
|
|
||||||
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
|
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
|
||||||
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
|
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
|
||||||
config["OpenAI"].get("ImageModel"), bot.logger
|
config["OpenAI"].get("ImageModel"), config["OpenAI"].get("BaseURL"), bot.logger
|
||||||
)
|
)
|
||||||
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
||||||
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
|
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
|
||||||
|
@ -220,6 +227,7 @@ class GPTBot:
|
||||||
for event in response.chunk:
|
for event in response.chunk:
|
||||||
if len(messages) >= n:
|
if len(messages) >= n:
|
||||||
break
|
break
|
||||||
|
|
||||||
if isinstance(event, MegolmEvent):
|
if isinstance(event, MegolmEvent):
|
||||||
try:
|
try:
|
||||||
event = await self.matrix_client.decrypt_event(event)
|
event = await self.matrix_client.decrypt_event(event)
|
||||||
|
@ -229,14 +237,22 @@ class GPTBot:
|
||||||
"error",
|
"error",
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
if isinstance(event, (RoomMessageText, RoomMessageNotice)):
|
|
||||||
|
if isinstance(event, RoomMessageText):
|
||||||
if event.body.startswith("!gptbot ignoreolder"):
|
if event.body.startswith("!gptbot ignoreolder"):
|
||||||
break
|
break
|
||||||
if (not event.body.startswith("!")) or (
|
if (not event.body.startswith("!")) or (not ignore_bot_commands):
|
||||||
event.body.startswith("!gptbot") and not ignore_bot_commands
|
|
||||||
):
|
|
||||||
messages.append(event)
|
messages.append(event)
|
||||||
|
|
||||||
|
if isinstance(event, RoomMessageNotice):
|
||||||
|
if not ignore_bot_commands:
|
||||||
|
messages.append(event)
|
||||||
|
|
||||||
|
if isinstance(event, RoomMessageMedia):
|
||||||
|
if event.sender != self.matrix_client.user_id:
|
||||||
|
if len(messages) < 2 or isinstance(messages[-1], RoomMessageMedia):
|
||||||
|
messages.append(event)
|
||||||
|
|
||||||
self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
|
self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
|
||||||
|
|
||||||
# Reverse the list so that messages are in chronological order
|
# Reverse the list so that messages are in chronological order
|
||||||
|
@ -275,7 +291,7 @@ class GPTBot:
|
||||||
truncated_messages = []
|
truncated_messages = []
|
||||||
|
|
||||||
for message in [messages[0]] + list(reversed(messages[1:])):
|
for message in [messages[0]] + list(reversed(messages[1:])):
|
||||||
content = message["content"]
|
content = message["content"] if isinstance(message["content"], str) else message["content"][0]["text"] if isinstance(message["content"][0].get("text"), str) else ""
|
||||||
tokens = len(encoding.encode(content)) + 1
|
tokens = len(encoding.encode(content)) + 1
|
||||||
if total_tokens + tokens > max_tokens:
|
if total_tokens + tokens > max_tokens:
|
||||||
break
|
break
|
||||||
|
@ -906,14 +922,39 @@ class GPTBot:
|
||||||
|
|
||||||
chat_messages = [{"role": "system", "content": system_message}]
|
chat_messages = [{"role": "system", "content": system_message}]
|
||||||
|
|
||||||
for message in last_messages:
|
text_messages = list(filter(lambda x: not isinstance(x, RoomMessageMedia), last_messages))
|
||||||
|
|
||||||
|
for message in text_messages:
|
||||||
role = (
|
role = (
|
||||||
"assistant" if message.sender == self.matrix_client.user_id else "user"
|
"assistant" if message.sender == self.matrix_client.user_id else "user"
|
||||||
)
|
)
|
||||||
if not message.event_id == event.event_id:
|
if not message.event_id == event.event_id:
|
||||||
chat_messages.append({"role": role, "content": message.body})
|
chat_messages.append({"role": role, "content": message.body})
|
||||||
|
|
||||||
chat_messages.append({"role": "user", "content": event.body})
|
if not self.chat_api.supports_chat_images():
|
||||||
|
event_body = event.body
|
||||||
|
else:
|
||||||
|
event_body = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": event.body
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for m in list(filter(lambda x: isinstance(x, RoomMessageMedia), last_messages)):
|
||||||
|
image_url = m.url
|
||||||
|
download = await self.download_file(image_url)
|
||||||
|
|
||||||
|
if download:
|
||||||
|
encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}"
|
||||||
|
event_body.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": encoded_url
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
chat_messages.append({"role": "user", "content": event_body})
|
||||||
|
|
||||||
# Truncate messages to fit within the token limit
|
# Truncate messages to fit within the token limit
|
||||||
truncated_messages = self._truncate(
|
truncated_messages = self._truncate(
|
||||||
|
@ -926,6 +967,7 @@ class GPTBot:
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.log(f"Error generating response: {e}", "error")
|
self.logger.log(f"Error generating response: {e}", "error")
|
||||||
|
|
||||||
await self.send_message(
|
await self.send_message(
|
||||||
room, "Something went wrong. Please try again.", True
|
room, "Something went wrong. Please try again.", True
|
||||||
)
|
)
|
||||||
|
@ -954,6 +996,24 @@ class GPTBot:
|
||||||
|
|
||||||
await self.matrix_client.room_typing(room.room_id, False)
|
await self.matrix_client.room_typing(room.room_id, False)
|
||||||
|
|
||||||
|
def download_file(self, mxc) -> Optional[bytes]:
|
||||||
|
"""Download a file from the homeserver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mxc (str): The MXC URI of the file to download.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[bytes]: The downloaded file, or None if there was an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
download = self.matrix_client.download(mxc)
|
||||||
|
|
||||||
|
if isinstance(download, DownloadError):
|
||||||
|
self.logger.log(f"Error downloading file: {download.message}", "error")
|
||||||
|
return
|
||||||
|
|
||||||
|
return download
|
||||||
|
|
||||||
def get_system_message(self, room: MatrixRoom | str) -> str:
|
def get_system_message(self, room: MatrixRoom | str) -> str:
|
||||||
"""Get the system message for a room.
|
"""Get the system message for a room.
|
||||||
|
|
||||||
|
|
|
@ -25,12 +25,13 @@ class OpenAI:
|
||||||
|
|
||||||
operator: str = "OpenAI ([https://openai.com](https://openai.com))"
|
operator: str = "OpenAI ([https://openai.com](https://openai.com))"
|
||||||
|
|
||||||
def __init__(self, api_key, chat_model=None, image_model=None, logger=None):
|
def __init__(self, api_key, chat_model=None, image_model=None, base_url=None, logger=None):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.chat_model = chat_model or self.chat_model
|
self.chat_model = chat_model or self.chat_model
|
||||||
self.image_model = image_model or self.image_model
|
self.image_model = image_model or self.image_model
|
||||||
self.logger = logger or Logger()
|
self.logger = logger or Logger()
|
||||||
self.base_url = openai.api_base
|
self.base_url = base_url or openai.base_url
|
||||||
|
self.openai_api = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
def supports_chat_images(self):
|
def supports_chat_images(self):
|
||||||
return "vision" in self.chat_model
|
return "vision" in self.chat_model
|
||||||
|
@ -74,18 +75,20 @@ class OpenAI:
|
||||||
|
|
||||||
|
|
||||||
chat_partial = partial(
|
chat_partial = partial(
|
||||||
openai.ChatCompletion.acreate,
|
self.openai_api.chat.completions.create,
|
||||||
model=self.chat_model,
|
model=self.chat_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_key=self.api_key,
|
|
||||||
user=user,
|
user=user,
|
||||||
api_base=self.base_url,
|
max_tokens=4096
|
||||||
)
|
)
|
||||||
response = await self._request_with_retries(chat_partial)
|
response = await self._request_with_retries(chat_partial)
|
||||||
|
|
||||||
|
self.logger.log(response, "error")
|
||||||
|
self.logger.log(response.choices, "error")
|
||||||
|
self.logger.log(response.choices[0].message, "error")
|
||||||
|
|
||||||
result_text = response.choices[0].message['content']
|
result_text = response.choices[0].message.content
|
||||||
tokens_used = response.usage["total_tokens"]
|
tokens_used = response.usage.total_tokens
|
||||||
self.logger.log(f"Generated response with {tokens_used} tokens.")
|
self.logger.log(f"Generated response with {tokens_used} tokens.")
|
||||||
return result_text, tokens_used
|
return result_text, tokens_used
|
||||||
|
|
||||||
|
@ -117,13 +120,10 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
self.logger.log(f"Classifying message '{query}'...")
|
self.logger.log(f"Classifying message '{query}'...")
|
||||||
|
|
||||||
chat_partial = partial(
|
chat_partial = partial(
|
||||||
openai.ChatCompletion.acreate,
|
self.openai_api.chat.completions.create,
|
||||||
model=self.chat_model,
|
model=self.chat_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_key=self.api_key,
|
|
||||||
user=user,
|
user=user,
|
||||||
api_base=self.base_url,
|
|
||||||
quality=("hd" if model == "dall-e-3" else "normal")
|
|
||||||
)
|
)
|
||||||
response = await self._request_with_retries(chat_partial)
|
response = await self._request_with_retries(chat_partial)
|
||||||
|
|
||||||
|
@ -150,14 +150,12 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
self.logger.log(f"Generating image from prompt '{prompt}'...")
|
self.logger.log(f"Generating image from prompt '{prompt}'...")
|
||||||
|
|
||||||
image_partial = partial(
|
image_partial = partial(
|
||||||
openai.Image.acreate,
|
self.openai_api.images.generate,
|
||||||
model=self.image_model,
|
model=self.image_model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
n=1,
|
n=1,
|
||||||
api_key=self.api_key,
|
|
||||||
size="1024x1024",
|
size="1024x1024",
|
||||||
user=user,
|
user=user,
|
||||||
api_base=self.base_url,
|
|
||||||
)
|
)
|
||||||
response = await self._request_with_retries(image_partial)
|
response = await self._request_with_retries(image_partial)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
bot.logger.log(f"Sending image...")
|
bot.logger.log(f"Sending image...")
|
||||||
await bot.send_image(room, image)
|
await bot.send_image(room, image)
|
||||||
|
|
||||||
bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_api}", tokens_used)
|
bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_model}", tokens_used)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue