diff --git a/classes/openai.py b/classes/openai.py index 059dab9..3423ba2 100644 --- a/classes/openai.py +++ b/classes/openai.py @@ -1,11 +1,13 @@ import openai import requests +import asyncio import json +from functools import partial from .logging import Logger -from typing import Dict, List, Tuple, Generator, Optional +from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any class OpenAI: api_key: str @@ -28,6 +30,32 @@ class OpenAI: self.chat_model = chat_model or self.chat_model self.logger = logger or Logger() + async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]: + """Retry a request a set number of times if it fails. + + Args: + request (partial): The request to make with retries. + attempts (int, optional): The number of attempts to make. Defaults to 5. + retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds. + + Returns: + AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request. + """ + # call the request function and return the response if it succeeds, else retry + current_attempt = 1 + while current_attempt <= attempts: + try: + response = await request() + return response + except Exception as e: + self.logger.log(f"Request failed: {e}", "error") + self.logger.log(f"Retrying in {retry_interval} seconds...") + await asyncio.sleep(retry_interval) + current_attempt += 1 + + # if all attempts failed, raise an exception + raise Exception("Request failed after all attempts.") + async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]: """Generate a response to a chat message. @@ -39,12 +67,16 @@ class OpenAI: """ self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...") - response = await openai.ChatCompletion.acreate( - model=self.chat_model, - messages=messages, - api_key=self.api_key, - user = user + + chat_partial = partial( + openai.ChatCompletion.acreate, + model=self.chat_model, + messages=messages, + api_key=self.api_key, + user=user ) + response = await self._request_with_retries(chat_partial) + result_text = response.choices[0].message['content'] tokens_used = response.usage["total_tokens"] @@ -78,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot self.logger.log(f"Classifying message '{query}'...") - response = await openai.ChatCompletion.acreate( - model=self.chat_model, - messages=messages, - api_key=self.api_key, - user = user + chat_partial = partial( + openai.ChatCompletion.acreate, + model=self.chat_model, + messages=messages, + api_key=self.api_key, + user=user ) + response = await self._request_with_retries(chat_partial) try: result = json.loads(response.choices[0].message['content']) @@ -107,13 +141,15 @@ Only the event_types mentioned above are allowed, you must not respond in any ot """ self.logger.log(f"Generating image from prompt '{prompt}'...") - response = await openai.Image.acreate( - prompt=prompt, - n=1, - api_key=self.api_key, - size="1024x1024", - user = user + image_partial = partial( + openai.Image.acreate, + prompt=prompt, + n=1, + api_key=self.api_key, + size="1024x1024", + user=user ) + response = await self._request_with_retries(image_partial) images = []