Merge branch 'justin-russell-bugfixes'

This commit is contained in:
Kumi 2023-05-20 19:11:54 +00:00
commit 3a1d1ea86a
Signed by: kumi
GPG key ID: ECBCC9082395383F
5 changed files with 89 additions and 41 deletions

View file

@ -1,8 +1,8 @@
import markdown2 import markdown2
import duckdb import duckdb
import tiktoken import tiktoken
import magic
import asyncio import asyncio
import functools
from PIL import Image from PIL import Image
@ -27,12 +27,11 @@ from nio import (
RoomLeaveError, RoomLeaveError,
RoomSendError, RoomSendError,
RoomVisibility, RoomVisibility,
RoomCreateResponse,
RoomCreateError, RoomCreateError,
) )
from nio.crypto import Olm from nio.crypto import Olm
from typing import Optional, List, Dict, Tuple from typing import Optional, List
from configparser import ConfigParser from configparser import ConfigParser
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
@ -174,7 +173,7 @@ class GPTBot:
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]): async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
messages = [] messages = []
n = n or bot.max_messages n = n or self.max_messages
room_id = room.room_id if isinstance(room, MatrixRoom) else room room_id = room.room_id if isinstance(room, MatrixRoom) else room
self.logger.log( self.logger.log(
@ -585,7 +584,7 @@ class GPTBot:
self.logger.log( self.logger.log(
"No database connection set up, using in-memory database. Data will be lost on bot shutdown.") "No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
IN_MEMORY = True IN_MEMORY = True
self.database = DuckDBPyConnection(":memory:") self.database = duckdb.DuckDBPyConnection(":memory:")
self.logger.log("Running migrations...") self.logger.log("Running migrations...")
before, after = migrate(self.database) before, after = migrate(self.database)
@ -747,8 +746,14 @@ class GPTBot:
await self.matrix_client.room_read_markers(room.room_id, event.event_id) await self.matrix_client.room_read_markers(room.room_id, event.event_id)
if (not from_chat_command) and self.room_uses_classification(room): if (not from_chat_command) and self.room_uses_classification(room):
classification, tokens = self.classification_api.classify_message( try:
event.body, room.room_id) classification, tokens = await self.classification_api.classify_message(
event.body, room.room_id)
except Exception as e:
self.logger.log(f"Error classifying message: {e}", "error")
await self.send_message(
room, "Something went wrong. Please try again.", True)
return
self.log_api_usage( self.log_api_usage(
event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens) event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens)
@ -782,7 +787,7 @@ class GPTBot:
chat_messages, self.max_tokens - 1, system_message=system_message) chat_messages, self.max_tokens - 1, system_message=system_message)
try: try:
response, tokens_used = self.chat_api.generate_chat_response( response, tokens_used = await self.chat_api.generate_chat_response(
chat_messages, user=room.room_id) chat_messages, user=room.room_id)
except Exception as e: except Exception as e:
self.logger.log(f"Error generating response: {e}", "error") self.logger.log(f"Error generating response: {e}", "error")
@ -803,7 +808,7 @@ class GPTBot:
else: else:
# Send a notice to the room if there was an error # Send a notice to the room if there was an error
self.logger.log("Didn't get a response from GPT API", "error") self.logger.log("Didn't get a response from GPT API", "error")
await send_message( await self.send_message(
room, "Something went wrong. Please try again.", True) room, "Something went wrong. Please try again.", True)
await self.matrix_client.room_typing(room.room_id, False) await self.matrix_client.room_typing(room.room_id, False)

View file

@ -1,11 +1,13 @@
import openai import openai
import requests import requests
import asyncio
import json import json
from functools import partial
from .logging import Logger from .logging import Logger
from typing import Dict, List, Tuple, Generator, Optional from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any
class OpenAI: class OpenAI:
api_key: str api_key: str
@ -17,7 +19,7 @@ class OpenAI:
@property @property
def chat_api(self) -> str: def chat_api(self) -> str:
return self.chat_model return self.chat_model
classification_api = chat_api classification_api = chat_api
image_api: str = "dalle" image_api: str = "dalle"
@ -28,7 +30,33 @@ class OpenAI:
self.chat_model = chat_model or self.chat_model self.chat_model = chat_model or self.chat_model
self.logger = logger or Logger() self.logger = logger or Logger()
def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]: async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]:
"""Retry a request a set number of times if it fails.
Args:
request (partial): The request to make with retries.
attempts (int, optional): The number of attempts to make. Defaults to 5.
retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
Returns:
AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request.
"""
# call the request function and return the response if it succeeds, else retry
current_attempt = 1
while current_attempt <= attempts:
try:
response = await request()
return response
except Exception as e:
self.logger.log(f"Request failed: {e}", "error")
self.logger.log(f"Retrying in {retry_interval} seconds...")
await asyncio.sleep(retry_interval)
current_attempt += 1
# if all attempts failed, raise an exception
raise Exception("Request failed after all attempts.")
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
"""Generate a response to a chat message. """Generate a response to a chat message.
Args: Args:
@ -37,22 +65,25 @@ class OpenAI:
Returns: Returns:
Tuple[str, int]: The response text and the number of tokens used. Tuple[str, int]: The response text and the number of tokens used.
""" """
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...") self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
response = openai.ChatCompletion.create(
model=self.chat_model, chat_partial = partial(
messages=messages, openai.ChatCompletion.acreate,
api_key=self.api_key, model=self.chat_model,
user = user messages=messages,
api_key=self.api_key,
user=user
) )
response = await self._request_with_retries(chat_partial)
result_text = response.choices[0].message['content'] result_text = response.choices[0].message['content']
tokens_used = response.usage["total_tokens"] tokens_used = response.usage["total_tokens"]
self.logger.log(f"Generated response with {tokens_used} tokens.") self.logger.log(f"Generated response with {tokens_used} tokens.")
return result_text, tokens_used return result_text, tokens_used
def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]: async def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
system_message = """You are a classifier for different types of messages. You decide whether an incoming message is meant to be a prompt for an AI chat model, or meant for a different API. You respond with a JSON object like this: system_message = """You are a classifier for different types of messages. You decide whether an incoming message is meant to be a prompt for an AI chat model, or meant for a different API. You respond with a JSON object like this:
{ "type": event_type, "prompt": prompt } { "type": event_type, "prompt": prompt }
@ -66,10 +97,9 @@ class OpenAI:
- If for any reason you are unable to classify the message (for example, if it infringes on your terms of service), the event_type is "error", and the prompt is a message explaining why you are unable to process the message. - If for any reason you are unable to classify the message (for example, if it infringes on your terms of service), the event_type is "error", and the prompt is a message explaining why you are unable to process the message.
Only the event_types mentioned above are allowed, you must not respond in any other way.""" Only the event_types mentioned above are allowed, you must not respond in any other way."""
messages = [ messages = [
{ {
"role": "system", "role": "system",
"content": system_message "content": system_message
}, },
{ {
@ -80,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
self.logger.log(f"Classifying message '{query}'...") self.logger.log(f"Classifying message '{query}'...")
response = openai.ChatCompletion.create( chat_partial = partial(
model=self.chat_model, openai.ChatCompletion.acreate,
messages=messages, model=self.chat_model,
api_key=self.api_key, messages=messages,
user = user api_key=self.api_key,
user=user
) )
response = await self._request_with_retries(chat_partial)
try: try:
result = json.loads(response.choices[0].message['content']) result = json.loads(response.choices[0].message['content'])
@ -98,7 +130,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
return result, tokens_used return result, tokens_used
def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]: async def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]:
"""Generate an image from a prompt. """Generate an image from a prompt.
Args: Args:
@ -107,16 +139,17 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
Yields: Yields:
bytes: The image data. bytes: The image data.
""" """
self.logger.log(f"Generating image from prompt '{prompt}'...") self.logger.log(f"Generating image from prompt '{prompt}'...")
response = openai.Image.create( image_partial = partial(
prompt=prompt, openai.Image.acreate,
n=1, prompt=prompt,
api_key=self.api_key, n=1,
size="1024x1024", api_key=self.api_key,
user = user size="1024x1024",
user=user
) )
response = await self._request_with_retries(image_partial)
images = [] images = []
@ -124,4 +157,4 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
image = requests.get(image.url).content image = requests.get(image.url).content
images.append(image) images.append(image)
return images, len(images) return images, len(images)

View file

@ -8,7 +8,12 @@ async def command_classify(room: MatrixRoom, event: RoomMessageText, bot):
if prompt: if prompt:
bot.logger.log("Classifying message...") bot.logger.log("Classifying message...")
response, tokens_used = bot.classification_api.classify_message(prompt, user=room.room_id) try:
response, tokens_used = await bot.classification_api.classify_message(prompt, user=room.room_id)
except Exception as e:
bot.logger.log(f"Error classifying message: {e}", "error")
await bot.send_message(room, "Sorry, I couldn't classify the message. Please try again later.", True)
return
message = f"The message you provided seems to be of type: {response['type']}." message = f"The message you provided seems to be of type: {response['type']}."
@ -21,4 +26,4 @@ async def command_classify(room: MatrixRoom, event: RoomMessageText, bot):
return return
await bot.send_message(room, "You need to provide a prompt.", True) await bot.send_message(room, "You need to provide a prompt.", True)

View file

@ -8,7 +8,12 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
if prompt: if prompt:
bot.logger.log("Generating image...") bot.logger.log("Generating image...")
images, tokens_used = bot.image_api.generate_image(prompt, user=room.room_id) try:
images, tokens_used = await bot.image_api.generate_image(prompt, user=room.room_id)
except Exception as e:
bot.logger.log(f"Error generating image: {e}", "error")
await bot.send_message(room, "Sorry, I couldn't generate an image. Please try again later.", True)
return
for image in images: for image in images:
bot.logger.log(f"Sending image...") bot.logger.log(f"Sending image...")
@ -18,4 +23,4 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
return return
await bot.send_message(room, "You need to provide a prompt.", True) await bot.send_message(room, "You need to provide a prompt.", True)

View file

@ -45,7 +45,7 @@ def migrate(db: DuckDBPyConnection, from_version: Optional[int] = None, to_versi
raise ValueError("Cannot migrate from a higher version to a lower version.") raise ValueError("Cannot migrate from a higher version to a lower version.")
for version in range(from_version, to_version): for version in range(from_version, to_version):
if version in MIGRATIONS: if version + 1 in MIGRATIONS:
MIGRATIONS[version + 1](db) MIGRATIONS[version + 1](db)
return from_version, to_version return from_version, to_version