Merge branch 'justin-russell-bugfixes'

This commit is contained in:
Kumi 2023-05-20 19:11:54 +00:00
commit 3a1d1ea86a
Signed by: kumi
GPG key ID: ECBCC9082395383F
5 changed files with 89 additions and 41 deletions

View file

@ -1,8 +1,8 @@
import markdown2
import duckdb
import tiktoken
import magic
import asyncio
import functools
from PIL import Image
@ -27,12 +27,11 @@ from nio import (
RoomLeaveError,
RoomSendError,
RoomVisibility,
RoomCreateResponse,
RoomCreateError,
)
from nio.crypto import Olm
from typing import Optional, List, Dict, Tuple
from typing import Optional, List
from configparser import ConfigParser
from datetime import datetime
from io import BytesIO
@ -174,7 +173,7 @@ class GPTBot:
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
messages = []
n = n or bot.max_messages
n = n or self.max_messages
room_id = room.room_id if isinstance(room, MatrixRoom) else room
self.logger.log(
@ -585,7 +584,7 @@ class GPTBot:
self.logger.log(
"No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
IN_MEMORY = True
self.database = DuckDBPyConnection(":memory:")
self.database = duckdb.DuckDBPyConnection(":memory:")
self.logger.log("Running migrations...")
before, after = migrate(self.database)
@ -747,8 +746,14 @@ class GPTBot:
await self.matrix_client.room_read_markers(room.room_id, event.event_id)
if (not from_chat_command) and self.room_uses_classification(room):
classification, tokens = self.classification_api.classify_message(
event.body, room.room_id)
try:
classification, tokens = await self.classification_api.classify_message(
event.body, room.room_id)
except Exception as e:
self.logger.log(f"Error classifying message: {e}", "error")
await self.send_message(
room, "Something went wrong. Please try again.", True)
return
self.log_api_usage(
event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens)
@ -782,7 +787,7 @@ class GPTBot:
chat_messages, self.max_tokens - 1, system_message=system_message)
try:
response, tokens_used = self.chat_api.generate_chat_response(
response, tokens_used = await self.chat_api.generate_chat_response(
chat_messages, user=room.room_id)
except Exception as e:
self.logger.log(f"Error generating response: {e}", "error")
@ -803,7 +808,7 @@ class GPTBot:
else:
# Send a notice to the room if there was an error
self.logger.log("Didn't get a response from GPT API", "error")
await send_message(
await self.send_message(
room, "Something went wrong. Please try again.", True)
await self.matrix_client.room_typing(room.room_id, False)

View file

@ -1,11 +1,13 @@
import openai
import requests
import asyncio
import json
from functools import partial
from .logging import Logger
from typing import Dict, List, Tuple, Generator, Optional
from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any
class OpenAI:
api_key: str
@ -28,7 +30,33 @@ class OpenAI:
self.chat_model = chat_model or self.chat_model
self.logger = logger or Logger()
def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]:
"""Retry a request a set number of times if it fails.
Args:
request (partial): The request to make with retries.
attempts (int, optional): The number of attempts to make. Defaults to 5.
retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
Returns:
AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request.
"""
# call the request function and return the response if it succeeds, else retry
current_attempt = 1
while current_attempt <= attempts:
try:
response = await request()
return response
except Exception as e:
self.logger.log(f"Request failed: {e}", "error")
self.logger.log(f"Retrying in {retry_interval} seconds...")
await asyncio.sleep(retry_interval)
current_attempt += 1
# if all attempts failed, raise an exception
raise Exception("Request failed after all attempts.")
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
"""Generate a response to a chat message.
Args:
@ -37,22 +65,25 @@ class OpenAI:
Returns:
Tuple[str, int]: The response text and the number of tokens used.
"""
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
response = openai.ChatCompletion.create(
model=self.chat_model,
messages=messages,
api_key=self.api_key,
user = user
chat_partial = partial(
openai.ChatCompletion.acreate,
model=self.chat_model,
messages=messages,
api_key=self.api_key,
user=user
)
response = await self._request_with_retries(chat_partial)
result_text = response.choices[0].message['content']
tokens_used = response.usage["total_tokens"]
self.logger.log(f"Generated response with {tokens_used} tokens.")
return result_text, tokens_used
def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
async def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
system_message = """You are a classifier for different types of messages. You decide whether an incoming message is meant to be a prompt for an AI chat model, or meant for a different API. You respond with a JSON object like this:
{ "type": event_type, "prompt": prompt }
@ -66,7 +97,6 @@ class OpenAI:
- If for any reason you are unable to classify the message (for example, if it infringes on your terms of service), the event_type is "error", and the prompt is a message explaining why you are unable to process the message.
Only the event_types mentioned above are allowed, you must not respond in any other way."""
messages = [
{
"role": "system",
@ -80,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
self.logger.log(f"Classifying message '{query}'...")
response = openai.ChatCompletion.create(
model=self.chat_model,
messages=messages,
api_key=self.api_key,
user = user
chat_partial = partial(
openai.ChatCompletion.acreate,
model=self.chat_model,
messages=messages,
api_key=self.api_key,
user=user
)
response = await self._request_with_retries(chat_partial)
try:
result = json.loads(response.choices[0].message['content'])
@ -98,7 +130,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
return result, tokens_used
def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]:
async def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]:
"""Generate an image from a prompt.
Args:
@ -107,16 +139,17 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
Yields:
bytes: The image data.
"""
self.logger.log(f"Generating image from prompt '{prompt}'...")
response = openai.Image.create(
prompt=prompt,
n=1,
api_key=self.api_key,
size="1024x1024",
user = user
image_partial = partial(
openai.Image.acreate,
prompt=prompt,
n=1,
api_key=self.api_key,
size="1024x1024",
user=user
)
response = await self._request_with_retries(image_partial)
images = []

View file

@ -8,7 +8,12 @@ async def command_classify(room: MatrixRoom, event: RoomMessageText, bot):
if prompt:
bot.logger.log("Classifying message...")
response, tokens_used = bot.classification_api.classify_message(prompt, user=room.room_id)
try:
response, tokens_used = await bot.classification_api.classify_message(prompt, user=room.room_id)
except Exception as e:
bot.logger.log(f"Error classifying message: {e}", "error")
await bot.send_message(room, "Sorry, I couldn't classify the message. Please try again later.", True)
return
message = f"The message you provided seems to be of type: {response['type']}."

View file

@ -8,7 +8,12 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
if prompt:
bot.logger.log("Generating image...")
images, tokens_used = bot.image_api.generate_image(prompt, user=room.room_id)
try:
images, tokens_used = await bot.image_api.generate_image(prompt, user=room.room_id)
except Exception as e:
bot.logger.log(f"Error generating image: {e}", "error")
await bot.send_message(room, "Sorry, I couldn't generate an image. Please try again later.", True)
return
for image in images:
bot.logger.log(f"Sending image...")

View file

@ -45,7 +45,7 @@ def migrate(db: DuckDBPyConnection, from_version: Optional[int] = None, to_versi
raise ValueError("Cannot migrate from a higher version to a lower version.")
for version in range(from_version, to_version):
if version in MIGRATIONS:
if version + 1 in MIGRATIONS:
MIGRATIONS[version + 1](db)
return from_version, to_version