diff --git a/classes/bot.py b/classes/bot.py index c728633..0e3cbe4 100644 --- a/classes/bot.py +++ b/classes/bot.py @@ -40,6 +40,7 @@ from pathlib import Path import uuid import traceback +import json from .logging import Logger from migrations import migrate @@ -54,7 +55,8 @@ from .trackingmore import TrackingMore class GPTBot: # Default values database: Optional[duckdb.DuckDBPyConnection] = None - display_name = default_room_name = "GPTBot" # Default name of rooms created by the bot + # Default name of rooms created by the bot + display_name = default_room_name = "GPTBot" default_system_message: str = "You are a helpful assistant." # Force default system message to be included even if a custom room message is set force_system_message: bool = False @@ -72,6 +74,7 @@ class GPTBot: debug: bool = False logo: Optional[Image.Image] = None logo_uri: Optional[str] = None + allowed_users: List[str] = [] @classmethod def from_config(cls, config: ConfigParser): @@ -102,14 +105,19 @@ class GPTBot: "ForceSystemMessage", bot.force_system_message) bot.debug = config["GPTBot"].getboolean("Debug", bot.debug) - logo_path = config["GPTBot"].get("Logo", str(Path(__file__).parent.parent / "assets/logo.png")) + logo_path = config["GPTBot"].get("Logo", str( + Path(__file__).parent.parent / "assets/logo.png")) bot.logger.log(f"Loading logo from {logo_path}") if Path(logo_path).exists() and Path(logo_path).is_file(): bot.logo = Image.open(logo_path) - bot.display_name = config["GPTBot"].get("DisplayName", bot.display_name) + bot.display_name = config["GPTBot"].get( + "DisplayName", bot.display_name) + + if "AllowedUsers" in config["GPTBot"]: + bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"]) bot.chat_api = bot.image_api = bot.classification_api = OpenAI( config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger) @@ -258,6 +266,14 @@ class GPTBot: return device_id async def process_command(self, room: MatrixRoom, event: RoomMessageText): + """Process a command. Called from the event_callback() method. + Delegates to the appropriate command handler. + + Args: + room (MatrixRoom): The room the command was sent in. + event (RoomMessageText): The event containing the command. + """ + self.logger.log( f"Received command {event.body} from {event.sender} in room {room.room_id}") command = event.body.split()[1] if event.body.split()[1:] else None @@ -295,7 +311,45 @@ class GPTBot: if self.debug: await self.send_message(room, f"Error: {e}\n\n```\n{traceback.format_exc()}\n```", True) + def user_is_allowed(self, user_id: str) -> bool: + """Check if a user is allowed to use the bot. + + Args: + user_id (str): The user ID to check. + + Returns: + bool: Whether the user is allowed to use the bot. + """ + + return ( + user_id in self.allowed_users or + f"*:{user_id.split(':')[1]}" in self.allowed_users or + f"@*:{user_id.split(':')[1]}" in self.allowed_users + ) if self.allowed_users else True + async def event_callback(self, room: MatrixRoom, event: Event): + """Callback for events. + + Args: + room (MatrixRoom): The room the event was sent in. + event (Event): The event. + """ + + if event.sender == self.matrix_client.user_id: + return + + if not self.user_is_allowed(event.sender): + if len(room.users) == 2: + await self.matrix_client.room_send( + room.room_id, + "m.room.message", + { + "msgtype": "m.notice", + "body": f"You are not allowed to use this bot. Please contact {self.operator} for more information." + } + ) + return + task = asyncio.create_task(self._event_callback(room, event)) def room_uses_timing(self, room: MatrixRoom): @@ -589,7 +643,7 @@ class GPTBot: self.logo_uri = uri asyncio.create_task(self.matrix_client.set_avatar(uri)) - + for room in self.matrix_client.rooms.keys(): self.logger.log(f"Setting avatar for {room}...", "debug") asyncio.create_task(self.matrix_client.room_put_state(room, "m.room.avatar", { diff --git a/config.dist.ini b/config.dist.ini index e2f92e6..e3b79bc 100644 --- a/config.dist.ini +++ b/config.dist.ini @@ -98,6 +98,12 @@ Operator = Contact details not set # # DisplayName = GPTBot +# A list of allowed users +# If not defined, everyone is allowed to use the bot +# Use the "*:homeserver.matrix" syntax to allow everyone on a given homeserver +# +# AllowedUsers = ["*:matrix.local"] + [Database] # Settings for the DuckDB database.