Adds retry logic for failed openai requests
This commit is contained in:
parent
6c97c0f61d
commit
b41a9ecd14
1 changed files with 53 additions and 17 deletions
|
@ -1,11 +1,13 @@
|
|||
import openai
|
||||
import requests
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
from .logging import Logger
|
||||
|
||||
from typing import Dict, List, Tuple, Generator, Optional
|
||||
from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any
|
||||
|
||||
class OpenAI:
|
||||
api_key: str
|
||||
|
@ -28,6 +30,32 @@ class OpenAI:
|
|||
self.chat_model = chat_model or self.chat_model
|
||||
self.logger = logger or Logger()
|
||||
|
||||
async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]:
|
||||
"""Retry a request a set number of times if it fails.
|
||||
|
||||
Args:
|
||||
request (partial): The request to make with retries.
|
||||
attempts (int, optional): The number of attempts to make. Defaults to 5.
|
||||
retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request.
|
||||
"""
|
||||
# call the request function and return the response if it succeeds, else retry
|
||||
current_attempt = 1
|
||||
while current_attempt <= attempts:
|
||||
try:
|
||||
response = await request()
|
||||
return response
|
||||
except Exception as e:
|
||||
self.logger.log(f"Request failed: {e}", "error")
|
||||
self.logger.log(f"Retrying in {retry_interval} seconds...")
|
||||
await asyncio.sleep(retry_interval)
|
||||
current_attempt += 1
|
||||
|
||||
# if all attempts failed, raise an exception
|
||||
raise Exception("Request failed after all attempts.")
|
||||
|
||||
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
|
||||
"""Generate a response to a chat message.
|
||||
|
||||
|
@ -39,12 +67,16 @@ class OpenAI:
|
|||
"""
|
||||
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
|
||||
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
|
||||
chat_partial = partial(
|
||||
openai.ChatCompletion.acreate,
|
||||
model=self.chat_model,
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
user=user
|
||||
)
|
||||
response = await self._request_with_retries(chat_partial)
|
||||
|
||||
|
||||
result_text = response.choices[0].message['content']
|
||||
tokens_used = response.usage["total_tokens"]
|
||||
|
@ -78,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
|||
|
||||
self.logger.log(f"Classifying message '{query}'...")
|
||||
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
chat_partial = partial(
|
||||
openai.ChatCompletion.acreate,
|
||||
model=self.chat_model,
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
user=user
|
||||
)
|
||||
response = await self._request_with_retries(chat_partial)
|
||||
|
||||
try:
|
||||
result = json.loads(response.choices[0].message['content'])
|
||||
|
@ -107,13 +141,15 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
|||
"""
|
||||
self.logger.log(f"Generating image from prompt '{prompt}'...")
|
||||
|
||||
response = await openai.Image.acreate(
|
||||
image_partial = partial(
|
||||
openai.Image.acreate,
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
api_key=self.api_key,
|
||||
size="1024x1024",
|
||||
user=user
|
||||
)
|
||||
response = await self._request_with_retries(image_partial)
|
||||
|
||||
images = []
|
||||
|
||||
|
|
Loading…
Reference in a new issue