Adds retry logic for failed openai requests

This commit is contained in:
Justin 2023-05-19 15:37:04 -05:00
parent 6c97c0f61d
commit b41a9ecd14

View file

@ -1,11 +1,13 @@
import openai
import requests
import asyncio
import json
from functools import partial
from .logging import Logger
from typing import Dict, List, Tuple, Generator, Optional
from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any
class OpenAI:
api_key: str
@ -28,6 +30,32 @@ class OpenAI:
self.chat_model = chat_model or self.chat_model
self.logger = logger or Logger()
async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]:
"""Retry a request a set number of times if it fails.
Args:
request (partial): The request to make with retries.
attempts (int, optional): The number of attempts to make. Defaults to 5.
retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
Returns:
AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request.
"""
# call the request function and return the response if it succeeds, else retry
current_attempt = 1
while current_attempt <= attempts:
try:
response = await request()
return response
except Exception as e:
self.logger.log(f"Request failed: {e}", "error")
self.logger.log(f"Retrying in {retry_interval} seconds...")
await asyncio.sleep(retry_interval)
current_attempt += 1
# if all attempts failed, raise an exception
raise Exception("Request failed after all attempts.")
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
"""Generate a response to a chat message.
@ -39,12 +67,16 @@ class OpenAI:
"""
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
response = await openai.ChatCompletion.acreate(
model=self.chat_model,
messages=messages,
api_key=self.api_key,
user = user
chat_partial = partial(
openai.ChatCompletion.acreate,
model=self.chat_model,
messages=messages,
api_key=self.api_key,
user=user
)
response = await self._request_with_retries(chat_partial)
result_text = response.choices[0].message['content']
tokens_used = response.usage["total_tokens"]
@ -78,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
self.logger.log(f"Classifying message '{query}'...")
response = await openai.ChatCompletion.acreate(
model=self.chat_model,
messages=messages,
api_key=self.api_key,
user = user
chat_partial = partial(
openai.ChatCompletion.acreate,
model=self.chat_model,
messages=messages,
api_key=self.api_key,
user=user
)
response = await self._request_with_retries(chat_partial)
try:
result = json.loads(response.choices[0].message['content'])
@ -107,13 +141,15 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
"""
self.logger.log(f"Generating image from prompt '{prompt}'...")
response = await openai.Image.acreate(
prompt=prompt,
n=1,
api_key=self.api_key,
size="1024x1024",
user = user
image_partial = partial(
openai.Image.acreate,
prompt=prompt,
n=1,
api_key=self.api_key,
size="1024x1024",
user=user
)
response = await self._request_with_retries(image_partial)
images = []