Adds retry logic for failed openai requests

This commit is contained in:
Justin 2023-05-19 15:37:04 -05:00
parent 6c97c0f61d
commit b41a9ecd14

View file

@ -1,11 +1,13 @@
import openai import openai
import requests import requests
import asyncio
import json import json
from functools import partial
from .logging import Logger from .logging import Logger
from typing import Dict, List, Tuple, Generator, Optional from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any
class OpenAI: class OpenAI:
api_key: str api_key: str
@ -28,6 +30,32 @@ class OpenAI:
self.chat_model = chat_model or self.chat_model self.chat_model = chat_model or self.chat_model
self.logger = logger or Logger() self.logger = logger or Logger()
async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]:
"""Retry a request a set number of times if it fails.
Args:
request (partial): The request to make with retries.
attempts (int, optional): The number of attempts to make. Defaults to 5.
retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
Returns:
AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request.
"""
# call the request function and return the response if it succeeds, else retry
current_attempt = 1
while current_attempt <= attempts:
try:
response = await request()
return response
except Exception as e:
self.logger.log(f"Request failed: {e}", "error")
self.logger.log(f"Retrying in {retry_interval} seconds...")
await asyncio.sleep(retry_interval)
current_attempt += 1
# if all attempts failed, raise an exception
raise Exception("Request failed after all attempts.")
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]: async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
"""Generate a response to a chat message. """Generate a response to a chat message.
@ -39,12 +67,16 @@ class OpenAI:
""" """
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...") self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
response = await openai.ChatCompletion.acreate(
model=self.chat_model, chat_partial = partial(
messages=messages, openai.ChatCompletion.acreate,
api_key=self.api_key, model=self.chat_model,
user = user messages=messages,
api_key=self.api_key,
user=user
) )
response = await self._request_with_retries(chat_partial)
result_text = response.choices[0].message['content'] result_text = response.choices[0].message['content']
tokens_used = response.usage["total_tokens"] tokens_used = response.usage["total_tokens"]
@ -78,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
self.logger.log(f"Classifying message '{query}'...") self.logger.log(f"Classifying message '{query}'...")
response = await openai.ChatCompletion.acreate( chat_partial = partial(
model=self.chat_model, openai.ChatCompletion.acreate,
messages=messages, model=self.chat_model,
api_key=self.api_key, messages=messages,
user = user api_key=self.api_key,
user=user
) )
response = await self._request_with_retries(chat_partial)
try: try:
result = json.loads(response.choices[0].message['content']) result = json.loads(response.choices[0].message['content'])
@ -107,13 +141,15 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
""" """
self.logger.log(f"Generating image from prompt '{prompt}'...") self.logger.log(f"Generating image from prompt '{prompt}'...")
response = await openai.Image.acreate( image_partial = partial(
prompt=prompt, openai.Image.acreate,
n=1, prompt=prompt,
api_key=self.api_key, n=1,
size="1024x1024", api_key=self.api_key,
user = user size="1024x1024",
user=user
) )
response = await self._request_with_retries(image_partial)
images = [] images = []