Adds retry logic for failed openai requests
This commit is contained in:
parent
6c97c0f61d
commit
b41a9ecd14
1 changed files with 53 additions and 17 deletions
|
@ -1,11 +1,13 @@
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from .logging import Logger
|
from .logging import Logger
|
||||||
|
|
||||||
from typing import Dict, List, Tuple, Generator, Optional
|
from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any
|
||||||
|
|
||||||
class OpenAI:
|
class OpenAI:
|
||||||
api_key: str
|
api_key: str
|
||||||
|
@ -28,6 +30,32 @@ class OpenAI:
|
||||||
self.chat_model = chat_model or self.chat_model
|
self.chat_model = chat_model or self.chat_model
|
||||||
self.logger = logger or Logger()
|
self.logger = logger or Logger()
|
||||||
|
|
||||||
|
async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]:
|
||||||
|
"""Retry a request a set number of times if it fails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (partial): The request to make with retries.
|
||||||
|
attempts (int, optional): The number of attempts to make. Defaults to 5.
|
||||||
|
retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request.
|
||||||
|
"""
|
||||||
|
# call the request function and return the response if it succeeds, else retry
|
||||||
|
current_attempt = 1
|
||||||
|
while current_attempt <= attempts:
|
||||||
|
try:
|
||||||
|
response = await request()
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log(f"Request failed: {e}", "error")
|
||||||
|
self.logger.log(f"Retrying in {retry_interval} seconds...")
|
||||||
|
await asyncio.sleep(retry_interval)
|
||||||
|
current_attempt += 1
|
||||||
|
|
||||||
|
# if all attempts failed, raise an exception
|
||||||
|
raise Exception("Request failed after all attempts.")
|
||||||
|
|
||||||
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
|
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
|
||||||
"""Generate a response to a chat message.
|
"""Generate a response to a chat message.
|
||||||
|
|
||||||
|
@ -39,12 +67,16 @@ class OpenAI:
|
||||||
"""
|
"""
|
||||||
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
|
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
|
||||||
|
|
||||||
response = await openai.ChatCompletion.acreate(
|
|
||||||
model=self.chat_model,
|
chat_partial = partial(
|
||||||
messages=messages,
|
openai.ChatCompletion.acreate,
|
||||||
api_key=self.api_key,
|
model=self.chat_model,
|
||||||
user = user
|
messages=messages,
|
||||||
|
api_key=self.api_key,
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
|
response = await self._request_with_retries(chat_partial)
|
||||||
|
|
||||||
|
|
||||||
result_text = response.choices[0].message['content']
|
result_text = response.choices[0].message['content']
|
||||||
tokens_used = response.usage["total_tokens"]
|
tokens_used = response.usage["total_tokens"]
|
||||||
|
@ -78,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
|
|
||||||
self.logger.log(f"Classifying message '{query}'...")
|
self.logger.log(f"Classifying message '{query}'...")
|
||||||
|
|
||||||
response = await openai.ChatCompletion.acreate(
|
chat_partial = partial(
|
||||||
model=self.chat_model,
|
openai.ChatCompletion.acreate,
|
||||||
messages=messages,
|
model=self.chat_model,
|
||||||
api_key=self.api_key,
|
messages=messages,
|
||||||
user = user
|
api_key=self.api_key,
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
|
response = await self._request_with_retries(chat_partial)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = json.loads(response.choices[0].message['content'])
|
result = json.loads(response.choices[0].message['content'])
|
||||||
|
@ -107,13 +141,15 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
"""
|
"""
|
||||||
self.logger.log(f"Generating image from prompt '{prompt}'...")
|
self.logger.log(f"Generating image from prompt '{prompt}'...")
|
||||||
|
|
||||||
response = await openai.Image.acreate(
|
image_partial = partial(
|
||||||
prompt=prompt,
|
openai.Image.acreate,
|
||||||
n=1,
|
prompt=prompt,
|
||||||
api_key=self.api_key,
|
n=1,
|
||||||
size="1024x1024",
|
api_key=self.api_key,
|
||||||
user = user
|
size="1024x1024",
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
|
response = await self._request_with_retries(image_partial)
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue