Enable per-room model overrides and clean up code

Introduced the ability to specify and retrieve different OpenAI models on a per-room basis, thereby allowing enhanced customization of the bot's response behavior according to the preferences for each room. Cleaned up code formatting across the bot implementation files for improved readability and maintainability. Additional logic now checks for model overrides when generating responses, ensuring the correct model is used as configured.

Refactors include streamlined database and API initializations and a refined method for processing message formatting to accommodate images, texts, and system messages consistently. This change differentiates default behavior from room-specific configurations, catering to diverse user needs without compromising on default settings.
This commit is contained in:
Kumi 2024-01-26 09:11:39 +01:00
parent ad0d694222
commit 87173ae284
Signed by: kumi
GPG key ID: ECBCC9082395383F
3 changed files with 184 additions and 47 deletions

View file

@ -184,6 +184,15 @@ class GPTBot:
"Logo", str(Path(__file__).parent.parent / "assets/logo.png") "Logo", str(Path(__file__).parent.parent / "assets/logo.png")
) )
@property
def allow_model_override(self) -> bool:
"""Whether to allow per-room model overrides.
Returns:
bool: Whether to allow per-room model overrides. Defaults to False.
"""
return self.config["GPTBot"].getboolean("AllowModelOverride", False)
# User agent to use for HTTP requests # User agent to use for HTTP requests
USER_AGENT = "matrix-gptbot/dev (+https://kumig.it/kumitterer/matrix-gptbot)" USER_AGENT = "matrix-gptbot/dev (+https://kumig.it/kumitterer/matrix-gptbot)"
@ -208,11 +217,7 @@ class GPTBot:
if "Database" in config and "Path" in config["Database"] if "Database" in config and "Path" in config["Database"]
else None else None
) )
bot.database = ( bot.database = sqlite3.connect(bot.database_path) if bot.database_path else None
sqlite3.connect(bot.database_path)
if bot.database_path
else None
)
# Override default values # Override default values
if "GPTBot" in config: if "GPTBot" in config:
@ -224,14 +229,16 @@ class GPTBot:
if Path(bot.logo_path).exists() and Path(bot.logo_path).is_file(): if Path(bot.logo_path).exists() and Path(bot.logo_path).is_file():
bot.logo = Image.open(bot.logo_path) bot.logo = Image.open(bot.logo_path)
bot.chat_api = bot.image_api = bot.classification_api = bot.tts_api = bot.stt_api = OpenAI( bot.chat_api = (
bot.image_api
) = bot.classification_api = bot.tts_api = bot.stt_api = OpenAI(
bot=bot, bot=bot,
api_key=config["OpenAI"]["APIKey"], api_key=config["OpenAI"]["APIKey"],
chat_model=config["OpenAI"].get("Model"), chat_model=config["OpenAI"].get("Model"),
image_model=config["OpenAI"].get("ImageModel"), image_model=config["OpenAI"].get("ImageModel"),
tts_model=config["OpenAI"].get("TTSModel"), tts_model=config["OpenAI"].get("TTSModel"),
stt_model=config["OpenAI"].get("STTModel"), stt_model=config["OpenAI"].get("STTModel"),
base_url=config["OpenAI"].get("BaseURL") base_url=config["OpenAI"].get("BaseURL"),
) )
if "BaseURL" in config["OpenAI"]: if "BaseURL" in config["OpenAI"]:
@ -285,7 +292,12 @@ class GPTBot:
return user_id return user_id
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int], ignore_bot_commands: bool = False): async def _last_n_messages(
self,
room: str | MatrixRoom,
n: Optional[int],
ignore_bot_commands: bool = False,
):
messages = [] messages = []
n = n or self.max_messages n = n or self.max_messages
room_id = room.room_id if isinstance(room, MatrixRoom) else room room_id = room.room_id if isinstance(room, MatrixRoom) else room
@ -362,7 +374,13 @@ class GPTBot:
truncated_messages = [] truncated_messages = []
for message in [messages[0]] + list(reversed(messages[1:])): for message in [messages[0]] + list(reversed(messages[1:])):
content = message["content"] if isinstance(message["content"], str) else message["content"][0]["text"] if isinstance(message["content"][0].get("text"), str) else "" content = (
message["content"]
if isinstance(message["content"], str)
else message["content"][0]["text"]
if isinstance(message["content"][0].get("text"), str)
else ""
)
tokens = len(encoding.encode(content)) + 1 tokens = len(encoding.encode(content)) + 1
if total_tokens + tokens > max_tokens: if total_tokens + tokens > max_tokens:
break break
@ -658,9 +676,7 @@ class GPTBot:
"url": content_uri, "url": content_uri,
} }
status = await self.matrix_client.room_send( status = await self.matrix_client.room_send(room, "m.room.message", content)
room, "m.room.message", content
)
self.logger.log("Sent image", "debug") self.logger.log("Sent image", "debug")
@ -694,9 +710,7 @@ class GPTBot:
"url": content_uri, "url": content_uri,
} }
status = await self.matrix_client.room_send( status = await self.matrix_client.room_send(room, "m.room.message", content)
room, "m.room.message", content
)
self.logger.log("Sent file", "debug") self.logger.log("Sent file", "debug")
@ -789,7 +803,9 @@ class GPTBot:
self.matrix_client.device_id = await self._get_device_id() self.matrix_client.device_id = await self._get_device_id()
if not self.database: if not self.database:
self.database = sqlite3.connect(Path(__file__).parent.parent / "database.db") self.database = sqlite3.connect(
Path(__file__).parent.parent / "database.db"
)
self.logger.log("Running migrations...") self.logger.log("Running migrations...")
@ -987,6 +1003,28 @@ class GPTBot:
return True if not result else bool(int(result[0])) return True if not result else bool(int(result[0]))
async def get_room_model(self, room: MatrixRoom | str) -> str:
"""Get the model used for a room.
Args:
room (MatrixRoom | str): The room to check.
Returns:
str: The model used for the room.
"""
if isinstance(room, MatrixRoom):
room = room.room_id
with closing(self.database.cursor()) as cursor:
cursor.execute(
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
(room, "model"),
)
result = cursor.fetchone()
return result[0] if result else self.chat_api.chat_model
async def process_query( async def process_query(
self, room: MatrixRoom, event: RoomMessageText, from_chat_command: bool = False self, room: MatrixRoom, event: RoomMessageText, from_chat_command: bool = False
): ):
@ -1053,28 +1091,46 @@ class GPTBot:
for message in last_messages: for message in last_messages:
if isinstance(message, (RoomMessageNotice, RoomMessageText)): if isinstance(message, (RoomMessageNotice, RoomMessageText)):
role = ( role = (
"assistant" if message.sender == self.matrix_client.user_id else "user" "assistant"
if message.sender == self.matrix_client.user_id
else "user"
) )
if message == event or (not message.event_id == event.event_id): if message == event or (not message.event_id == event.event_id):
message_body = message.body if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message.body}] message_body = (
message.body
if not self.chat_api.supports_chat_images()
else [{"type": "text", "text": message.body}]
)
chat_messages.append({"role": role, "content": message_body}) chat_messages.append({"role": role, "content": message_body})
elif isinstance(message, RoomMessageAudio) or (isinstance(message, RoomMessageFile) and message.body.endswith(".mp3")): elif isinstance(message, RoomMessageAudio) or (
isinstance(message, RoomMessageFile) and message.body.endswith(".mp3")
):
role = ( role = (
"assistant" if message.sender == self.matrix_client.user_id else "user" "assistant"
if message.sender == self.matrix_client.user_id
else "user"
) )
if message == event or (not message.event_id == event.event_id): if message == event or (not message.event_id == event.event_id):
if self.room_uses_stt(room): if self.room_uses_stt(room):
try: try:
download = await self.download_file(message.url) download = await self.download_file(message.url)
message_text = await self.stt_api.speech_to_text(download.body) message_text = await self.stt_api.speech_to_text(
download.body
)
except Exception as e: except Exception as e:
self.logger.log(f"Error generating text from audio: {e}", "error") self.logger.log(
f"Error generating text from audio: {e}", "error"
)
message_text = message.body message_text = message.body
else: else:
message_text = message.body message_text = message.body
message_body = message_text if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message_text}] message_body = (
message_text
if not self.chat_api.supports_chat_images()
else [{"type": "text", "text": message_text}]
)
chat_messages.append({"role": role, "content": message_body}) chat_messages.append({"role": role, "content": message_body})
elif isinstance(message, RoomMessageFile): elif isinstance(message, RoomMessageFile):
@ -1092,38 +1148,72 @@ class GPTBot:
if message.sender == self.matrix_client.user_id if message.sender == self.matrix_client.user_id
else "user" else "user"
) )
if message == event or (not message.event_id == event.event_id): if message == event or (
message_body = text if not self.chat_api.supports_chat_images() else [{"type": "text", "text": text}] not message.event_id == event.event_id
chat_messages.append({"role": role, "content": message_body}) ):
message_body = (
text
if not self.chat_api.supports_chat_images()
else [{"type": "text", "text": text}]
)
chat_messages.append(
{"role": role, "content": message_body}
)
except Exception as e: except Exception as e:
self.logger.log(f"Error generating text from file: {e}", "error") self.logger.log(f"Error generating text from file: {e}", "error")
message_body = message.body if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message.body}] message_body = (
message.body
if not self.chat_api.supports_chat_images()
else [{"type": "text", "text": message.body}]
)
chat_messages.append({"role": "system", "content": message_body}) chat_messages.append({"role": "system", "content": message_body})
elif self.chat_api.supports_chat_images() and isinstance(message, RoomMessageImage): elif self.chat_api.supports_chat_images() and isinstance(
message, RoomMessageImage
):
try: try:
image_url = message.url image_url = message.url
download = await self.download_file(image_url) download = await self.download_file(image_url)
if download: if download:
encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}" encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}"
parent = chat_messages[-1] if chat_messages and chat_messages[-1]["role"] == ("assistant" if message.sender == self.matrix_client.user_id else "user") else None parent = (
chat_messages[-1]
if chat_messages
and chat_messages[-1]["role"]
== (
"assistant"
if message.sender == self.matrix_client.user_id
else "user"
)
else None
)
if not parent: if not parent:
chat_messages.append({"role": ("assistant" if message.sender == self.matrix_client.user_id else "user"), "content": []}) chat_messages.append(
{
"role": (
"assistant"
if message.sender == self.matrix_client.user_id
else "user"
),
"content": [],
}
)
parent = chat_messages[-1] parent = chat_messages[-1]
parent["content"].append({ parent["content"].append(
"type": "image_url", {"type": "image_url", "image_url": {"url": encoded_url}}
"image_url": { )
"url": encoded_url
}
})
except Exception as e: except Exception as e:
self.logger.log(f"Error generating image from file: {e}", "error") self.logger.log(f"Error generating image from file: {e}", "error")
message_body = message.body if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message.body}] message_body = (
message.body
if not self.chat_api.supports_chat_images()
else [{"type": "text", "text": message.body}]
)
chat_messages.append({"role": "system", "content": message_body}) chat_messages.append({"role": "system", "content": message_body})
# Truncate messages to fit within the token limit # Truncate messages to fit within the token limit
@ -1131,9 +1221,15 @@ class GPTBot:
chat_messages[1:], self.max_tokens - 1, system_message=system_message chat_messages[1:], self.max_tokens - 1, system_message=system_message
) )
# Check for a model override
if self.allow_model_override:
model = await self.get_room_model(room)
else:
model = self.chat_api.chat_model
try: try:
response, tokens_used = await self.chat_api.generate_chat_response( response, tokens_used = await self.chat_api.generate_chat_response(
chat_messages, user=event.sender, room=room.room_id chat_messages, user=event.sender, room=room.room_id, model=model
) )
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())

View file

@ -124,6 +124,7 @@ class OpenAI:
room: Optional[str] = None, room: Optional[str] = None,
allow_override: bool = True, allow_override: bool = True,
use_tools: bool = True, use_tools: bool = True,
model: Optional[str] = None,
) -> Tuple[str, int]: ) -> Tuple[str, int]:
"""Generate a response to a chat message. """Generate a response to a chat message.
@ -133,6 +134,7 @@ class OpenAI:
room (Optional[str], optional): The room to use the assistant for. Defaults to None. room (Optional[str], optional): The room to use the assistant for. Defaults to None.
allow_override (bool, optional): Whether to allow the chat model to be overridden. Defaults to True. allow_override (bool, optional): Whether to allow the chat model to be overridden. Defaults to True.
use_tools (bool, optional): Whether to use tools. Defaults to True. use_tools (bool, optional): Whether to use tools. Defaults to True.
model (Optional[str], optional): The model to use. Defaults to None, which uses the default chat model.
Returns: Returns:
Tuple[str, int]: The response text and the number of tokens used. Tuple[str, int]: The response text and the number of tokens used.
@ -141,6 +143,8 @@ class OpenAI:
f"Generating response to {len(messages)} messages for user {user} in room {room}..." f"Generating response to {len(messages)} messages for user {user} in room {room}..."
) )
chat_model = model or self.chat_model
# Check current recursion depth to prevent infinite loops # Check current recursion depth to prevent infinite loops
if use_tools: if use_tools:
@ -157,6 +161,7 @@ class OpenAI:
room=room, room=room,
allow_override=False, # TODO: Could this be a problem? allow_override=False, # TODO: Could this be a problem?
use_tools=False, use_tools=False,
model=model,
) )
tools = [ tools = [
@ -171,10 +176,9 @@ class OpenAI:
for tool_name, tool_class in TOOLS.items() for tool_name, tool_class in TOOLS.items()
] ]
chat_model = self.chat_model
original_messages = messages original_messages = messages
if allow_override and not "gpt-3.5-turbo" in self.chat_model: if allow_override and not "gpt-3.5-turbo" in model:
if self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False): if self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False):
self.logger.log(f"Overriding chat model to use tools") self.logger.log(f"Overriding chat model to use tools")
chat_model = "gpt-3.5-turbo-1106" chat_model = "gpt-3.5-turbo-1106"
@ -204,7 +208,7 @@ class OpenAI:
use_tools use_tools
and self.bot.config.getboolean("OpenAI", "EmulateTools", fallback=False) and self.bot.config.getboolean("OpenAI", "EmulateTools", fallback=False)
and not self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False) and not self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False)
and not "gpt-3.5-turbo" in self.chat_model and not "gpt-3.5-turbo" in chat_model
): ):
self.bot.logger.log("Using tool emulation mode.", "debug") self.bot.logger.log("Using tool emulation mode.", "debug")

View file

@ -80,6 +80,40 @@ async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
await bot.send_message(room, f"The current {setting} status is: '{value}'.", True) await bot.send_message(room, f"The current {setting} status is: '{value}'.", True)
return return
if bot.allow_model_override and setting == "model":
if value:
bot.logger.log(f"Setting chat model for {room.room_id} to {value}...")
with closing(bot.database.cursor()) as cur:
cur.execute(
"""INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)
ON CONFLICT (room_id, setting) DO UPDATE SET value = ?;""",
(room.room_id, "model", value, value)
)
bot.database.commit()
await bot.send_message(room, f"Alright, I've set the chat model to: '{value}'.", True)
return
bot.logger.log(f"Retrieving chat model for {room.room_id}...")
with closing(bot.database.cursor()) as cur:
cur.execute(
"""SELECT value FROM room_settings WHERE room_id = ? AND setting = ?;""",
(room.room_id, "model")
)
value = cur.fetchone()[0]
if not value:
value = bot.chat_api.chat_model
else:
value = str(value)
await bot.send_message(room, f"The current chat model is: '{value}'.", True)
return
message = f"""The following settings are available: message = f"""The following settings are available:
- system_message [message]: Get or set the system message to be sent to the chat model - system_message [message]: Get or set the system message to be sent to the chat model
@ -90,4 +124,7 @@ async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
- timing [true/false]: Get or set whether the bot should return information about the time it took to generate a response - timing [true/false]: Get or set whether the bot should return information about the time it took to generate a response
""" """
if bot.allow_model_override:
message += "- model [model]: Get or set the chat model to be used for this room"
await bot.send_message(room, message, True) await bot.send_message(room, message, True)