diff --git a/classes/bot.py b/classes/bot.py index a8681e4..c1679b5 100644 --- a/classes/bot.py +++ b/classes/bot.py @@ -1,8 +1,8 @@ import markdown2 import duckdb import tiktoken -import magic import asyncio +import functools from PIL import Image @@ -27,12 +27,11 @@ from nio import ( RoomLeaveError, RoomSendError, RoomVisibility, - RoomCreateResponse, RoomCreateError, ) from nio.crypto import Olm -from typing import Optional, List, Dict, Tuple +from typing import Optional, List from configparser import ConfigParser from datetime import datetime from io import BytesIO @@ -174,7 +173,7 @@ class GPTBot: async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]): messages = [] - n = n or bot.max_messages + n = n or self.max_messages room_id = room.room_id if isinstance(room, MatrixRoom) else room self.logger.log( @@ -585,7 +584,7 @@ class GPTBot: self.logger.log( "No database connection set up, using in-memory database. Data will be lost on bot shutdown.") IN_MEMORY = True - self.database = DuckDBPyConnection(":memory:") + self.database = duckdb.DuckDBPyConnection(":memory:") self.logger.log("Running migrations...") before, after = migrate(self.database) @@ -747,8 +746,14 @@ class GPTBot: await self.matrix_client.room_read_markers(room.room_id, event.event_id) if (not from_chat_command) and self.room_uses_classification(room): - classification, tokens = self.classification_api.classify_message( - event.body, room.room_id) + try: + classification, tokens = await self.classification_api.classify_message( + event.body, room.room_id) + except Exception as e: + self.logger.log(f"Error classifying message: {e}", "error") + await self.send_message( + room, "Something went wrong. Please try again.", True) + return self.log_api_usage( event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens) @@ -782,7 +787,7 @@ class GPTBot: chat_messages, self.max_tokens - 1, system_message=system_message) try: - response, tokens_used = self.chat_api.generate_chat_response( + response, tokens_used = await self.chat_api.generate_chat_response( chat_messages, user=room.room_id) except Exception as e: self.logger.log(f"Error generating response: {e}", "error") @@ -803,7 +808,7 @@ class GPTBot: else: # Send a notice to the room if there was an error self.logger.log("Didn't get a response from GPT API", "error") - await send_message( + await self.send_message( room, "Something went wrong. Please try again.", True) await self.matrix_client.room_typing(room.room_id, False) diff --git a/classes/openai.py b/classes/openai.py index 249701c..3423ba2 100644 --- a/classes/openai.py +++ b/classes/openai.py @@ -1,11 +1,13 @@ import openai import requests +import asyncio import json +from functools import partial from .logging import Logger -from typing import Dict, List, Tuple, Generator, Optional +from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any class OpenAI: api_key: str @@ -17,7 +19,7 @@ class OpenAI: @property def chat_api(self) -> str: return self.chat_model - + classification_api = chat_api image_api: str = "dalle" @@ -28,7 +30,33 @@ class OpenAI: self.chat_model = chat_model or self.chat_model self.logger = logger or Logger() - def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]: + async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]: + """Retry a request a set number of times if it fails. + + Args: + request (partial): The request to make with retries. + attempts (int, optional): The number of attempts to make. Defaults to 5. + retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds. + + Returns: + AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request. + """ + # call the request function and return the response if it succeeds, else retry + current_attempt = 1 + while current_attempt <= attempts: + try: + response = await request() + return response + except Exception as e: + self.logger.log(f"Request failed: {e}", "error") + self.logger.log(f"Retrying in {retry_interval} seconds...") + await asyncio.sleep(retry_interval) + current_attempt += 1 + + # if all attempts failed, raise an exception + raise Exception("Request failed after all attempts.") + + async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]: """Generate a response to a chat message. Args: @@ -37,22 +65,25 @@ class OpenAI: Returns: Tuple[str, int]: The response text and the number of tokens used. """ - self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...") - response = openai.ChatCompletion.create( - model=self.chat_model, - messages=messages, - api_key=self.api_key, - user = user + + chat_partial = partial( + openai.ChatCompletion.acreate, + model=self.chat_model, + messages=messages, + api_key=self.api_key, + user=user ) + response = await self._request_with_retries(chat_partial) + result_text = response.choices[0].message['content'] tokens_used = response.usage["total_tokens"] self.logger.log(f"Generated response with {tokens_used} tokens.") return result_text, tokens_used - def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]: + async def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]: system_message = """You are a classifier for different types of messages. You decide whether an incoming message is meant to be a prompt for an AI chat model, or meant for a different API. You respond with a JSON object like this: { "type": event_type, "prompt": prompt } @@ -66,10 +97,9 @@ class OpenAI: - If for any reason you are unable to classify the message (for example, if it infringes on your terms of service), the event_type is "error", and the prompt is a message explaining why you are unable to process the message. Only the event_types mentioned above are allowed, you must not respond in any other way.""" - messages = [ { - "role": "system", + "role": "system", "content": system_message }, { @@ -80,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot self.logger.log(f"Classifying message '{query}'...") - response = openai.ChatCompletion.create( - model=self.chat_model, - messages=messages, - api_key=self.api_key, - user = user + chat_partial = partial( + openai.ChatCompletion.acreate, + model=self.chat_model, + messages=messages, + api_key=self.api_key, + user=user ) + response = await self._request_with_retries(chat_partial) try: result = json.loads(response.choices[0].message['content']) @@ -98,7 +130,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot return result, tokens_used - def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]: + async def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]: """Generate an image from a prompt. Args: @@ -107,16 +139,17 @@ Only the event_types mentioned above are allowed, you must not respond in any ot Yields: bytes: The image data. """ - self.logger.log(f"Generating image from prompt '{prompt}'...") - response = openai.Image.create( - prompt=prompt, - n=1, - api_key=self.api_key, - size="1024x1024", - user = user + image_partial = partial( + openai.Image.acreate, + prompt=prompt, + n=1, + api_key=self.api_key, + size="1024x1024", + user=user ) + response = await self._request_with_retries(image_partial) images = [] @@ -124,4 +157,4 @@ Only the event_types mentioned above are allowed, you must not respond in any ot image = requests.get(image.url).content images.append(image) - return images, len(images) \ No newline at end of file + return images, len(images) diff --git a/commands/classify.py b/commands/classify.py index e5813ff..c07bbfc 100644 --- a/commands/classify.py +++ b/commands/classify.py @@ -8,7 +8,12 @@ async def command_classify(room: MatrixRoom, event: RoomMessageText, bot): if prompt: bot.logger.log("Classifying message...") - response, tokens_used = bot.classification_api.classify_message(prompt, user=room.room_id) + try: + response, tokens_used = await bot.classification_api.classify_message(prompt, user=room.room_id) + except Exception as e: + bot.logger.log(f"Error classifying message: {e}", "error") + await bot.send_message(room, "Sorry, I couldn't classify the message. Please try again later.", True) + return message = f"The message you provided seems to be of type: {response['type']}." @@ -21,4 +26,4 @@ async def command_classify(room: MatrixRoom, event: RoomMessageText, bot): return - await bot.send_message(room, "You need to provide a prompt.", True) \ No newline at end of file + await bot.send_message(room, "You need to provide a prompt.", True) diff --git a/commands/imagine.py b/commands/imagine.py index 54e6c71..462a771 100644 --- a/commands/imagine.py +++ b/commands/imagine.py @@ -8,7 +8,12 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot): if prompt: bot.logger.log("Generating image...") - images, tokens_used = bot.image_api.generate_image(prompt, user=room.room_id) + try: + images, tokens_used = await bot.image_api.generate_image(prompt, user=room.room_id) + except Exception as e: + bot.logger.log(f"Error generating image: {e}", "error") + await bot.send_message(room, "Sorry, I couldn't generate an image. Please try again later.", True) + return for image in images: bot.logger.log(f"Sending image...") @@ -18,4 +23,4 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot): return - await bot.send_message(room, "You need to provide a prompt.", True) \ No newline at end of file + await bot.send_message(room, "You need to provide a prompt.", True) diff --git a/migrations/__init__.py b/migrations/__init__.py index 1cf8d53..1085ee2 100644 --- a/migrations/__init__.py +++ b/migrations/__init__.py @@ -45,7 +45,7 @@ def migrate(db: DuckDBPyConnection, from_version: Optional[int] = None, to_versi raise ValueError("Cannot migrate from a higher version to a lower version.") for version in range(from_version, to_version): - if version in MIGRATIONS: + if version + 1 in MIGRATIONS: MIGRATIONS[version + 1](db) - return from_version, to_version \ No newline at end of file + return from_version, to_version