Merge branch 'justin-russell-bugfixes'
This commit is contained in:
commit
3a1d1ea86a
5 changed files with 89 additions and 41 deletions
|
@ -1,8 +1,8 @@
|
||||||
import markdown2
|
import markdown2
|
||||||
import duckdb
|
import duckdb
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import magic
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import functools
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -27,12 +27,11 @@ from nio import (
|
||||||
RoomLeaveError,
|
RoomLeaveError,
|
||||||
RoomSendError,
|
RoomSendError,
|
||||||
RoomVisibility,
|
RoomVisibility,
|
||||||
RoomCreateResponse,
|
|
||||||
RoomCreateError,
|
RoomCreateError,
|
||||||
)
|
)
|
||||||
from nio.crypto import Olm
|
from nio.crypto import Olm
|
||||||
|
|
||||||
from typing import Optional, List, Dict, Tuple
|
from typing import Optional, List
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
@ -174,7 +173,7 @@ class GPTBot:
|
||||||
|
|
||||||
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
|
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
|
||||||
messages = []
|
messages = []
|
||||||
n = n or bot.max_messages
|
n = n or self.max_messages
|
||||||
room_id = room.room_id if isinstance(room, MatrixRoom) else room
|
room_id = room.room_id if isinstance(room, MatrixRoom) else room
|
||||||
|
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
|
@ -585,7 +584,7 @@ class GPTBot:
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
"No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
|
"No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
|
||||||
IN_MEMORY = True
|
IN_MEMORY = True
|
||||||
self.database = DuckDBPyConnection(":memory:")
|
self.database = duckdb.DuckDBPyConnection(":memory:")
|
||||||
|
|
||||||
self.logger.log("Running migrations...")
|
self.logger.log("Running migrations...")
|
||||||
before, after = migrate(self.database)
|
before, after = migrate(self.database)
|
||||||
|
@ -747,8 +746,14 @@ class GPTBot:
|
||||||
await self.matrix_client.room_read_markers(room.room_id, event.event_id)
|
await self.matrix_client.room_read_markers(room.room_id, event.event_id)
|
||||||
|
|
||||||
if (not from_chat_command) and self.room_uses_classification(room):
|
if (not from_chat_command) and self.room_uses_classification(room):
|
||||||
classification, tokens = self.classification_api.classify_message(
|
try:
|
||||||
event.body, room.room_id)
|
classification, tokens = await self.classification_api.classify_message(
|
||||||
|
event.body, room.room_id)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log(f"Error classifying message: {e}", "error")
|
||||||
|
await self.send_message(
|
||||||
|
room, "Something went wrong. Please try again.", True)
|
||||||
|
return
|
||||||
|
|
||||||
self.log_api_usage(
|
self.log_api_usage(
|
||||||
event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens)
|
event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens)
|
||||||
|
@ -782,7 +787,7 @@ class GPTBot:
|
||||||
chat_messages, self.max_tokens - 1, system_message=system_message)
|
chat_messages, self.max_tokens - 1, system_message=system_message)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, tokens_used = self.chat_api.generate_chat_response(
|
response, tokens_used = await self.chat_api.generate_chat_response(
|
||||||
chat_messages, user=room.room_id)
|
chat_messages, user=room.room_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.log(f"Error generating response: {e}", "error")
|
self.logger.log(f"Error generating response: {e}", "error")
|
||||||
|
@ -803,7 +808,7 @@ class GPTBot:
|
||||||
else:
|
else:
|
||||||
# Send a notice to the room if there was an error
|
# Send a notice to the room if there was an error
|
||||||
self.logger.log("Didn't get a response from GPT API", "error")
|
self.logger.log("Didn't get a response from GPT API", "error")
|
||||||
await send_message(
|
await self.send_message(
|
||||||
room, "Something went wrong. Please try again.", True)
|
room, "Something went wrong. Please try again.", True)
|
||||||
|
|
||||||
await self.matrix_client.room_typing(room.room_id, False)
|
await self.matrix_client.room_typing(room.room_id, False)
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from .logging import Logger
|
from .logging import Logger
|
||||||
|
|
||||||
from typing import Dict, List, Tuple, Generator, Optional
|
from typing import Dict, List, Tuple, Generator, AsyncGenerator, Optional, Any
|
||||||
|
|
||||||
class OpenAI:
|
class OpenAI:
|
||||||
api_key: str
|
api_key: str
|
||||||
|
@ -28,7 +30,33 @@ class OpenAI:
|
||||||
self.chat_model = chat_model or self.chat_model
|
self.chat_model = chat_model or self.chat_model
|
||||||
self.logger = logger or Logger()
|
self.logger = logger or Logger()
|
||||||
|
|
||||||
def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
|
async def _request_with_retries(self, request: partial, attempts: int = 5, retry_interval: int = 2) -> AsyncGenerator[Any | list | Dict, None]:
|
||||||
|
"""Retry a request a set number of times if it fails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (partial): The request to make with retries.
|
||||||
|
attempts (int, optional): The number of attempts to make. Defaults to 5.
|
||||||
|
retry_interval (int, optional): The interval in seconds between attempts. Defaults to 2 seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncGenerator[Any | list | Dict, None]: The OpenAI response for the request.
|
||||||
|
"""
|
||||||
|
# call the request function and return the response if it succeeds, else retry
|
||||||
|
current_attempt = 1
|
||||||
|
while current_attempt <= attempts:
|
||||||
|
try:
|
||||||
|
response = await request()
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.log(f"Request failed: {e}", "error")
|
||||||
|
self.logger.log(f"Retrying in {retry_interval} seconds...")
|
||||||
|
await asyncio.sleep(retry_interval)
|
||||||
|
current_attempt += 1
|
||||||
|
|
||||||
|
# if all attempts failed, raise an exception
|
||||||
|
raise Exception("Request failed after all attempts.")
|
||||||
|
|
||||||
|
async def generate_chat_response(self, messages: List[Dict[str, str]], user: Optional[str] = None) -> Tuple[str, int]:
|
||||||
"""Generate a response to a chat message.
|
"""Generate a response to a chat message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -37,22 +65,25 @@ class OpenAI:
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, int]: The response text and the number of tokens used.
|
Tuple[str, int]: The response text and the number of tokens used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
|
self.logger.log(f"Generating response to {len(messages)} messages using {self.chat_model}...")
|
||||||
|
|
||||||
response = openai.ChatCompletion.create(
|
|
||||||
model=self.chat_model,
|
chat_partial = partial(
|
||||||
messages=messages,
|
openai.ChatCompletion.acreate,
|
||||||
api_key=self.api_key,
|
model=self.chat_model,
|
||||||
user = user
|
messages=messages,
|
||||||
|
api_key=self.api_key,
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
|
response = await self._request_with_retries(chat_partial)
|
||||||
|
|
||||||
|
|
||||||
result_text = response.choices[0].message['content']
|
result_text = response.choices[0].message['content']
|
||||||
tokens_used = response.usage["total_tokens"]
|
tokens_used = response.usage["total_tokens"]
|
||||||
self.logger.log(f"Generated response with {tokens_used} tokens.")
|
self.logger.log(f"Generated response with {tokens_used} tokens.")
|
||||||
return result_text, tokens_used
|
return result_text, tokens_used
|
||||||
|
|
||||||
def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
|
async def classify_message(self, query: str, user: Optional[str] = None) -> Tuple[Dict[str, str], int]:
|
||||||
system_message = """You are a classifier for different types of messages. You decide whether an incoming message is meant to be a prompt for an AI chat model, or meant for a different API. You respond with a JSON object like this:
|
system_message = """You are a classifier for different types of messages. You decide whether an incoming message is meant to be a prompt for an AI chat model, or meant for a different API. You respond with a JSON object like this:
|
||||||
|
|
||||||
{ "type": event_type, "prompt": prompt }
|
{ "type": event_type, "prompt": prompt }
|
||||||
|
@ -66,7 +97,6 @@ class OpenAI:
|
||||||
- If for any reason you are unable to classify the message (for example, if it infringes on your terms of service), the event_type is "error", and the prompt is a message explaining why you are unable to process the message.
|
- If for any reason you are unable to classify the message (for example, if it infringes on your terms of service), the event_type is "error", and the prompt is a message explaining why you are unable to process the message.
|
||||||
|
|
||||||
Only the event_types mentioned above are allowed, you must not respond in any other way."""
|
Only the event_types mentioned above are allowed, you must not respond in any other way."""
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -80,12 +110,14 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
|
|
||||||
self.logger.log(f"Classifying message '{query}'...")
|
self.logger.log(f"Classifying message '{query}'...")
|
||||||
|
|
||||||
response = openai.ChatCompletion.create(
|
chat_partial = partial(
|
||||||
model=self.chat_model,
|
openai.ChatCompletion.acreate,
|
||||||
messages=messages,
|
model=self.chat_model,
|
||||||
api_key=self.api_key,
|
messages=messages,
|
||||||
user = user
|
api_key=self.api_key,
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
|
response = await self._request_with_retries(chat_partial)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = json.loads(response.choices[0].message['content'])
|
result = json.loads(response.choices[0].message['content'])
|
||||||
|
@ -98,7 +130,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
|
|
||||||
return result, tokens_used
|
return result, tokens_used
|
||||||
|
|
||||||
def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]:
|
async def generate_image(self, prompt: str, user: Optional[str] = None) -> Generator[bytes, None, None]:
|
||||||
"""Generate an image from a prompt.
|
"""Generate an image from a prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -107,16 +139,17 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
Yields:
|
Yields:
|
||||||
bytes: The image data.
|
bytes: The image data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.logger.log(f"Generating image from prompt '{prompt}'...")
|
self.logger.log(f"Generating image from prompt '{prompt}'...")
|
||||||
|
|
||||||
response = openai.Image.create(
|
image_partial = partial(
|
||||||
prompt=prompt,
|
openai.Image.acreate,
|
||||||
n=1,
|
prompt=prompt,
|
||||||
api_key=self.api_key,
|
n=1,
|
||||||
size="1024x1024",
|
api_key=self.api_key,
|
||||||
user = user
|
size="1024x1024",
|
||||||
|
user=user
|
||||||
)
|
)
|
||||||
|
response = await self._request_with_retries(image_partial)
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,12 @@ async def command_classify(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
if prompt:
|
if prompt:
|
||||||
bot.logger.log("Classifying message...")
|
bot.logger.log("Classifying message...")
|
||||||
|
|
||||||
response, tokens_used = bot.classification_api.classify_message(prompt, user=room.room_id)
|
try:
|
||||||
|
response, tokens_used = await bot.classification_api.classify_message(prompt, user=room.room_id)
|
||||||
|
except Exception as e:
|
||||||
|
bot.logger.log(f"Error classifying message: {e}", "error")
|
||||||
|
await bot.send_message(room, "Sorry, I couldn't classify the message. Please try again later.", True)
|
||||||
|
return
|
||||||
|
|
||||||
message = f"The message you provided seems to be of type: {response['type']}."
|
message = f"The message you provided seems to be of type: {response['type']}."
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,12 @@ async def command_imagine(room: MatrixRoom, event: RoomMessageText, bot):
|
||||||
if prompt:
|
if prompt:
|
||||||
bot.logger.log("Generating image...")
|
bot.logger.log("Generating image...")
|
||||||
|
|
||||||
images, tokens_used = bot.image_api.generate_image(prompt, user=room.room_id)
|
try:
|
||||||
|
images, tokens_used = await bot.image_api.generate_image(prompt, user=room.room_id)
|
||||||
|
except Exception as e:
|
||||||
|
bot.logger.log(f"Error generating image: {e}", "error")
|
||||||
|
await bot.send_message(room, "Sorry, I couldn't generate an image. Please try again later.", True)
|
||||||
|
return
|
||||||
|
|
||||||
for image in images:
|
for image in images:
|
||||||
bot.logger.log(f"Sending image...")
|
bot.logger.log(f"Sending image...")
|
||||||
|
|
|
@ -45,7 +45,7 @@ def migrate(db: DuckDBPyConnection, from_version: Optional[int] = None, to_versi
|
||||||
raise ValueError("Cannot migrate from a higher version to a lower version.")
|
raise ValueError("Cannot migrate from a higher version to a lower version.")
|
||||||
|
|
||||||
for version in range(from_version, to_version):
|
for version in range(from_version, to_version):
|
||||||
if version in MIGRATIONS:
|
if version + 1 in MIGRATIONS:
|
||||||
MIGRATIONS[version + 1](db)
|
MIGRATIONS[version + 1](db)
|
||||||
|
|
||||||
return from_version, to_version
|
return from_version, to_version
|
Loading…
Reference in a new issue