From 4113a02232a2cde5e18e021b311258a6dafe9875 Mon Sep 17 00:00:00 2001 From: Kumi Date: Sat, 11 Nov 2023 12:27:19 +0100 Subject: [PATCH] Add image input on models that support it, fix some bugs, bump required OpenAI version --- pyproject.toml | 4 +- requirements.txt | 2 +- src/gptbot/classes/bot.py | 76 ++++++++++++++++++++++++++++++---- src/gptbot/classes/openai.py | 26 ++++++------ src/gptbot/commands/imagine.py | 2 +- 5 files changed, 84 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 196b186..405526d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ allow-direct-references = true [project] name = "matrix-gptbot" -version = "0.1.1" +version = "0.2.0" authors = [ { name="Kumi Mitterer", email="gptbot@kumi.email" }, @@ -38,7 +38,7 @@ dependencies = [ [project.optional-dependencies] openai = [ - "openai", + "openai>=1.2", ] wolframalpha = [ diff --git a/requirements.txt b/requirements.txt index ae9ac08..28a5039 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -openai +openai>=1.2 matrix-nio[e2e] markdown2[all] tiktoken diff --git a/src/gptbot/classes/bot.py b/src/gptbot/classes/bot.py index 101d17d..a7b7d49 100644 --- a/src/gptbot/classes/bot.py +++ b/src/gptbot/classes/bot.py @@ -27,6 +27,12 @@ from nio import ( RoomSendError, RoomVisibility, RoomCreateError, + RoomMessageMedia, + RoomMessageImage, + RoomMessageFile, + RoomMessageAudio, + DownloadError, + DownloadResponse, ) from nio.crypto import Olm from nio.store import SqliteStore @@ -38,6 +44,7 @@ from io import BytesIO from pathlib import Path from contextlib import closing +import base64 import uuid import traceback import json @@ -139,7 +146,7 @@ class GPTBot: bot.chat_api = bot.image_api = bot.classification_api = OpenAI( config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), - config["OpenAI"].get("ImageModel"), bot.logger + config["OpenAI"].get("ImageModel"), config["OpenAI"].get("BaseURL"), bot.logger ) bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens) bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages) @@ -220,6 +227,7 @@ class GPTBot: for event in response.chunk: if len(messages) >= n: break + if isinstance(event, MegolmEvent): try: event = await self.matrix_client.decrypt_event(event) @@ -229,14 +237,22 @@ class GPTBot: "error", ) continue - if isinstance(event, (RoomMessageText, RoomMessageNotice)): + + if isinstance(event, RoomMessageText): if event.body.startswith("!gptbot ignoreolder"): break - if (not event.body.startswith("!")) or ( - event.body.startswith("!gptbot") and not ignore_bot_commands - ): + if (not event.body.startswith("!")) or (not ignore_bot_commands): messages.append(event) + if isinstance(event, RoomMessageNotice): + if not ignore_bot_commands: + messages.append(event) + + if isinstance(event, RoomMessageMedia): + if event.sender != self.matrix_client.user_id: + if len(messages) < 2 or isinstance(messages[-1], RoomMessageMedia): + messages.append(event) + self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug") # Reverse the list so that messages are in chronological order @@ -275,7 +291,7 @@ class GPTBot: truncated_messages = [] for message in [messages[0]] + list(reversed(messages[1:])): - content = message["content"] + content = message["content"] if isinstance(message["content"], str) else message["content"][0]["text"] if isinstance(message["content"][0].get("text"), str) else "" tokens = len(encoding.encode(content)) + 1 if total_tokens + tokens > max_tokens: break @@ -906,14 +922,39 @@ class GPTBot: chat_messages = [{"role": "system", "content": system_message}] - for message in last_messages: + text_messages = list(filter(lambda x: not isinstance(x, RoomMessageMedia), last_messages)) + + for message in text_messages: role = ( "assistant" if message.sender == self.matrix_client.user_id else "user" ) if not message.event_id == event.event_id: chat_messages.append({"role": role, "content": message.body}) - chat_messages.append({"role": "user", "content": event.body}) + if not self.chat_api.supports_chat_images(): + event_body = event.body + else: + event_body = [ + { + "type": "text", + "text": event.body + } + ] + + for m in list(filter(lambda x: isinstance(x, RoomMessageMedia), last_messages)): + image_url = m.url + download = await self.download_file(image_url) + + if download: + encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}" + event_body.append({ + "type": "image_url", + "image_url": { + "url": encoded_url + } + }) + + chat_messages.append({"role": "user", "content": event_body}) # Truncate messages to fit within the token limit truncated_messages = self._truncate( @@ -926,6 +967,7 @@ class GPTBot: ) except Exception as e: self.logger.log(f"Error generating response: {e}", "error") + await self.send_message( room, "Something went wrong. Please try again.", True ) @@ -954,6 +996,24 @@ class GPTBot: await self.matrix_client.room_typing(room.room_id, False) + def download_file(self, mxc) -> Optional[bytes]: + """Download a file from the homeserver. + + Args: + mxc (str): The MXC URI of the file to download. + + Returns: + Optional[bytes]: The downloaded file, or None if there was an error. + """ + + download = self.matrix_client.download(mxc) + + if isinstance(download, DownloadError): + self.logger.log(f"Error downloading file: {download.message}", "error") + return + + return download + def get_system_message(self, room: MatrixRoom | str) -> str: """Get the system message for a room. diff --git a/src/gptbot/classes/openai.py b/src/gptbot/classes/openai.py index 97741d3..2c34c0d 100644 --- a/src/gptbot/classes/openai.py +++ b/src/gptbot/classes/openai.py @@ -25,12 +25,13 @@ class OpenAI: operator: str = "OpenAI ([https://openai.com](https://openai.com))" - def __init__(self, api_key, chat_model=None, image_model=None, logger=None): + def __init__(self, api_key, chat_model=None, image_model=None, base_url=None, logger=None): self.api_key = api_key self.chat_model = chat_model or self.chat_model self.image_model = image_model or self.image_model self.logger = logger or Logger() - self.base_url = openai.api_base + self.base_url = base_url or openai.base_url + self.openai_api = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) def supports_chat_images(self): return "vision" in self.chat_model @@ -74,18 +75,20 @@ class OpenAI: chat_partial = partial( - openai.ChatCompletion.acreate, + self.openai_api.chat.completions.create, model=self.chat_model, messages=messages, - api_key=self.api_key, user=user, - api_base=self.base_url, + max_tokens=4096 ) response = await self._request_with_retries(chat_partial) + self.logger.log(response, "error") + self.logger.log(response.choices, "error") + self.logger.log(response.choices[0].message, "error") - result_text = response.choices[0].message['content'] - tokens_used = response.usage["total_tokens"] + result_text = response.choices[0].message.content + tokens_used = response.usage.total_tokens self.logger.log(f"Generated response with {tokens_used} tokens.") return result_text, tokens_used @@ -117,13 +120,10 @@ Only the event_types mentioned above are allowed, you must not respond in any ot self.logger.log(f"Classifying message '{query}'...") chat_partial = partial( - openai.ChatCompletion.acreate, + self.openai_api.chat.completions.create, model=self.chat_model, messages=messages, - api_key=self.api_key, user=user, - api_base=self.base_url, - quality=("hd" if model == "dall-e-3" else "normal") ) response = await self._request_with_retries(chat_partial) @@ -150,14 +150,12 @@ Only the event_types mentioned above are allowed, you must not respond in any ot self.logger.log(f"Generating image from prompt '{prompt}'...") image_partial = partial( - openai.Image.acreate, + self.openai_api.images.generate, model=self.image_model, prompt=prompt, n=1, - api_key=self.api_key, size="1024x1024", user=user, - api_base=self.base_url, ) response = await self._request_with_retries(image_partial) diff --git a/src/gptbot/commands/imagine.py b/src/gptbot/commands/imagine.py index 462a771..cc7d4e8 100644 --- a/src/gptbot/commands/imagine.py +++ b/src/gptbot/commands/imagine.py @@ -19,7 +19,7 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot): bot.logger.log(f"Sending image...") await bot.send_image(room, image) - bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_api}", tokens_used) + bot.log_api_usage(event, room, f"{bot.image_api.api_code}-{bot.image_api.image_model}", tokens_used) return