Compare commits
3 commits
main
...
v0.2.0-dev
Author | SHA1 | Date | |
---|---|---|---|
ef3118cbe3 | |||
9abea6e3f8 | |||
94b2457a39 |
11 changed files with 153 additions and 232 deletions
|
@ -67,6 +67,10 @@ LogLevel = info
|
||||||
#
|
#
|
||||||
# Model = gpt-3.5-turbo
|
# Model = gpt-3.5-turbo
|
||||||
|
|
||||||
|
# The Image Generation model you want to use.
|
||||||
|
#
|
||||||
|
# ImageModel = dall-e-2
|
||||||
|
|
||||||
# Your OpenAI API key
|
# Your OpenAI API key
|
||||||
#
|
#
|
||||||
# Find this in your OpenAI account:
|
# Find this in your OpenAI account:
|
||||||
|
|
|
@ -7,7 +7,7 @@ allow-direct-references = true
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "matrix-gptbot"
|
name = "matrix-gptbot"
|
||||||
version = "0.1.1"
|
version = "0.2.0-dev"
|
||||||
|
|
||||||
authors = [
|
authors = [
|
||||||
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
|
{ name="Kumi Mitterer", email="gptbot@kumi.email" },
|
||||||
|
@ -18,10 +18,6 @@ readme = "README.md"
|
||||||
license = { file="LICENSE" }
|
license = { file="LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
packages = [
|
|
||||||
"src/gptbot"
|
|
||||||
]
|
|
||||||
|
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
|
@ -29,13 +25,17 @@ classifiers = [
|
||||||
]
|
]
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"matrix-nio[e2e]",
|
"mautrix[all]",
|
||||||
"markdown2[all]",
|
"markdown2[all]",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"python-magic",
|
"python-magic",
|
||||||
"pillow",
|
"pillow",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
packages = [
|
||||||
|
{ include = "gptbot", where = "src" },
|
||||||
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
openai = [
|
openai = [
|
||||||
"openai",
|
"openai",
|
||||||
|
@ -54,7 +54,6 @@ all = [
|
||||||
]
|
]
|
||||||
|
|
||||||
dev = [
|
dev = [
|
||||||
"matrix-gptbot[all]",
|
|
||||||
"black",
|
"black",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -63,7 +62,10 @@ dev = [
|
||||||
"Bug Tracker" = "https://kumig.it/kumitterer/matrix-gptbot/issues"
|
"Bug Tracker" = "https://kumig.it/kumitterer/matrix-gptbot/issues"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
gptbot = "gptbot:main"
|
gptbot = "gptbot.__main___:main"
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["src/gptbot"]
|
only-include = ["src/gptbot"]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel.sources]
|
||||||
|
"src" = ""
|
|
@ -1,5 +1,5 @@
|
||||||
openai
|
openai
|
||||||
matrix-nio[e2e]
|
mautrix
|
||||||
markdown2[all]
|
markdown2[all]
|
||||||
tiktoken
|
tiktoken
|
||||||
duckdb
|
duckdb
|
||||||
|
|
|
@ -10,8 +10,7 @@ import asyncio
|
||||||
def sigterm_handler(_signo, _stack_frame):
|
def sigterm_handler(_signo, _stack_frame):
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
|
def main():
|
||||||
if __name__ == "__main__":
|
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -46,3 +45,7 @@ if __name__ == "__main__":
|
||||||
print("Received KeyboardInterrupt - exiting...")
|
print("Received KeyboardInterrupt - exiting...")
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
print("Received SIGTERM - exiting...")
|
print("Received SIGTERM - exiting...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -12,6 +12,8 @@ from nio import (
|
||||||
Response,
|
Response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from mautrix.types import Event, MessageEvent, StateEvent
|
||||||
|
|
||||||
from .test import test_callback
|
from .test import test_callback
|
||||||
from .sync import sync_callback
|
from .sync import sync_callback
|
||||||
from .invite import room_invite_callback
|
from .invite import room_invite_callback
|
||||||
|
@ -27,9 +29,5 @@ RESPONSE_CALLBACKS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
EVENT_CALLBACKS = {
|
EVENT_CALLBACKS = {
|
||||||
Event: test_callback,
|
MessageEvent: message_callback,
|
||||||
InviteEvent: room_invite_callback,
|
|
||||||
RoomMessageText: message_callback,
|
|
||||||
MegolmEvent: message_callback,
|
|
||||||
RoomMemberEvent: roommember_callback,
|
|
||||||
}
|
}
|
36
src/gptbot/callbacks/base.py
Normal file
36
src/gptbot/callbacks/base.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from ..classes.bot import GPTBot
|
||||||
|
|
||||||
|
from nio import Event
|
||||||
|
|
||||||
|
class BaseEventCallback:
|
||||||
|
EVENTS = [] # List of events that this callback should be called for
|
||||||
|
|
||||||
|
def __init__(self, bot: GPTBot):
|
||||||
|
"""Initialize the callback with the bot instance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot (GPTBot): GPTBot instance
|
||||||
|
"""
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
async def process(self, event: Event, *args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"BaseEventCallback.process() must be implemented by subclasses"
|
||||||
|
)
|
||||||
|
|
||||||
|
class BaseResponseCallback:
|
||||||
|
RESPONSES = [] # List of responses that this callback should be called for
|
||||||
|
|
||||||
|
def __init__(self, bot: GPTBot):
|
||||||
|
"""Initialize the callback with the bot instance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot (GPTBot): GPTBot instance
|
||||||
|
"""
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
async def process(self, response: Response, *args, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"BaseResponseCallback.process() must be implemented by subclasses"
|
||||||
|
)
|
||||||
|
|
0
src/gptbot/callbacks/invite_new/__init__.py
Normal file
0
src/gptbot/callbacks/invite_new/__init__.py
Normal file
|
@ -1,9 +1,9 @@
|
||||||
from nio import MatrixRoom, RoomMessageText, MegolmEvent, RoomKeyRequestError, RoomKeyRequestResponse
|
from mautrix.types import MessageEvent
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
async def message_callback(room: MatrixRoom | str, event: RoomMessageText | MegolmEvent, bot):
|
async def message_callback(event: MessageEvent, bot):
|
||||||
bot.logger.log(f"Received message from {event.sender} in room {room.room_id}")
|
bot.logger.log(f"Received message from {event.sender} in room {event.room_id}")
|
||||||
|
|
||||||
sent = datetime.fromtimestamp(event.server_timestamp / 1000)
|
sent = datetime.fromtimestamp(event.server_timestamp / 1000)
|
||||||
received = datetime.now()
|
received = datetime.now()
|
||||||
|
@ -34,18 +34,18 @@ async def message_callback(room: MatrixRoom | str, event: RoomMessageText | Mego
|
||||||
bot.logger.log("Message is from bot itself - ignoring")
|
bot.logger.log("Message is from bot itself - ignoring")
|
||||||
|
|
||||||
elif event.body.startswith("!gptbot"):
|
elif event.body.startswith("!gptbot"):
|
||||||
await bot.process_command(room, event)
|
await bot.process_command(event)
|
||||||
|
|
||||||
elif event.body.startswith("!"):
|
elif event.body.startswith("!"):
|
||||||
bot.logger.log(f"Received {event.body} - might be a command, but not for this bot - ignoring")
|
bot.logger.log(f"Received {event.body} - might be a command, but not for this bot - ignoring")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await bot.process_query(room, event)
|
await bot.process_query(event)
|
||||||
|
|
||||||
processed = datetime.now()
|
processed = datetime.now()
|
||||||
processing_time = processed - received
|
processing_time = processed - received
|
||||||
|
|
||||||
bot.logger.log(f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)")
|
bot.logger.log(f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)")
|
||||||
|
|
||||||
if bot.room_uses_timing(room):
|
if bot.room_uses_timing(event.room_id):
|
||||||
await bot.send_message(room, f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)", True)
|
await bot.send_message(event.room_id, f"Message processing took {processing_time.total_seconds()} seconds (latency: {latency.total_seconds()} seconds)", True)
|
|
@ -1,11 +1,10 @@
|
||||||
from nio import MatrixRoom, Event
|
from mautrix.types import Event
|
||||||
|
|
||||||
async def test_callback(room: MatrixRoom, event: Event, bot):
|
async def test_callback(event: Event, bot):
|
||||||
"""Test callback for debugging purposes.
|
"""Test callback for debugging purposes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom): The room the event was sent in.
|
|
||||||
event (Event): The event that was sent.
|
event (Event): The event that was sent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bot.logger.log(f"Test callback called: {room.room_id} {event.event_id} {event.sender} {event.__class__}")
|
bot.logger.log(f"Test callback called: {event.room_id} {event.event_id} {event.sender} {event.__class__}")
|
|
@ -5,31 +5,22 @@ import functools
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from nio import (
|
from mautrix.client import Client
|
||||||
AsyncClient,
|
from mautrix.types import (
|
||||||
AsyncClientConfig,
|
RoomID,
|
||||||
WhoamiResponse,
|
UserID,
|
||||||
DevicesResponse,
|
EventType,
|
||||||
Event,
|
MessageType,
|
||||||
Response,
|
MessageEvent,
|
||||||
MatrixRoom,
|
RoomDirectoryVisibility,
|
||||||
Api,
|
)
|
||||||
RoomMessagesError,
|
from mautrix.errors import (
|
||||||
MegolmEvent,
|
MForbidden,
|
||||||
GroupEncryptionError,
|
MNotFound,
|
||||||
EncryptionError,
|
MUnknownToken,
|
||||||
RoomMessageText,
|
MForbidden,
|
||||||
RoomSendResponse,
|
MatrixError,
|
||||||
SyncResponse,
|
|
||||||
RoomMessageNotice,
|
|
||||||
JoinError,
|
|
||||||
RoomLeaveError,
|
|
||||||
RoomSendError,
|
|
||||||
RoomVisibility,
|
|
||||||
RoomCreateError,
|
|
||||||
)
|
)
|
||||||
from nio.crypto import Olm
|
|
||||||
from nio.store import SqliteStore
|
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
|
@ -65,7 +56,7 @@ class GPTBot:
|
||||||
force_system_message: bool = False
|
force_system_message: bool = False
|
||||||
max_tokens: int = 3000 # Maximum number of input tokens
|
max_tokens: int = 3000 # Maximum number of input tokens
|
||||||
max_messages: int = 30 # Maximum number of messages to consider as input
|
max_messages: int = 30 # Maximum number of messages to consider as input
|
||||||
matrix_client: Optional[AsyncClient] = None
|
matrix_client: Optional[Client] = None
|
||||||
sync_token: Optional[str] = None
|
sync_token: Optional[str] = None
|
||||||
logger: Optional[Logger] = Logger()
|
logger: Optional[Logger] = Logger()
|
||||||
chat_api: Optional[OpenAI] = None
|
chat_api: Optional[OpenAI] = None
|
||||||
|
@ -138,7 +129,8 @@ class GPTBot:
|
||||||
bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"])
|
bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"])
|
||||||
|
|
||||||
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
|
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
|
||||||
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger
|
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
|
||||||
|
config["OpenAI"].get("ImageModel"), bot.logger
|
||||||
)
|
)
|
||||||
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
||||||
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
|
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
|
||||||
|
@ -161,13 +153,11 @@ class GPTBot:
|
||||||
|
|
||||||
assert "Matrix" in config, "Matrix config not found"
|
assert "Matrix" in config, "Matrix config not found"
|
||||||
|
|
||||||
homeserver = config["Matrix"]["Homeserver"]
|
bot.homeserver = config["Matrix"]["Homeserver"]
|
||||||
bot.matrix_client = AsyncClient(homeserver)
|
bot.access_token = config["Matrix"]["AccessToken"]
|
||||||
bot.matrix_client.access_token = config["Matrix"]["AccessToken"]
|
bot.user_id = config["Matrix"].get("UserID")
|
||||||
bot.matrix_client.user_id = config["Matrix"].get("UserID")
|
bot.device_id = config["Matrix"].get("DeviceID")
|
||||||
bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
|
|
||||||
|
|
||||||
# Return the new GPTBot instance
|
|
||||||
return bot
|
return bot
|
||||||
|
|
||||||
async def _get_user_id(self) -> str:
|
async def _get_user_id(self) -> str:
|
||||||
|
@ -178,68 +168,12 @@ class GPTBot:
|
||||||
str: The user ID of the bot.
|
str: The user ID of the bot.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.matrix_client, "Matrix client not set up"
|
pass
|
||||||
|
# TODO: Implement
|
||||||
|
|
||||||
user_id = self.matrix_client.user_id
|
async def _last_n_messages(self, room: str | RoomID, n: Optional[int]):
|
||||||
|
pass
|
||||||
if not user_id:
|
# TODO: Implement
|
||||||
assert self.matrix_client.access_token, "Access token not set up"
|
|
||||||
|
|
||||||
response = await self.matrix_client.whoami()
|
|
||||||
|
|
||||||
if isinstance(response, WhoamiResponse):
|
|
||||||
user_id = response.user_id
|
|
||||||
else:
|
|
||||||
raise Exception(f"Could not get user ID: {response}")
|
|
||||||
|
|
||||||
return user_id
|
|
||||||
|
|
||||||
async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
|
|
||||||
messages = []
|
|
||||||
n = n or self.max_messages
|
|
||||||
room_id = room.room_id if isinstance(room, MatrixRoom) else room
|
|
||||||
|
|
||||||
self.logger.log(
|
|
||||||
f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...",
|
|
||||||
"debug",
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self.matrix_client.room_messages(
|
|
||||||
room_id=room_id,
|
|
||||||
start=self.sync_token,
|
|
||||||
limit=2 * n,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(response, RoomMessagesError):
|
|
||||||
raise Exception(
|
|
||||||
f"Error fetching messages: {response.message} (status code {response.status_code})",
|
|
||||||
"error",
|
|
||||||
)
|
|
||||||
|
|
||||||
for event in response.chunk:
|
|
||||||
if len(messages) >= n:
|
|
||||||
break
|
|
||||||
if isinstance(event, MegolmEvent):
|
|
||||||
try:
|
|
||||||
event = await self.matrix_client.decrypt_event(event)
|
|
||||||
except (GroupEncryptionError, EncryptionError):
|
|
||||||
self.logger.log(
|
|
||||||
f"Could not decrypt message {event.event_id} in room {room_id}",
|
|
||||||
"error",
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
if isinstance(event, (RoomMessageText, RoomMessageNotice)):
|
|
||||||
if event.body.startswith("!gptbot ignoreolder"):
|
|
||||||
break
|
|
||||||
if (not event.body.startswith("!")) or (
|
|
||||||
event.body.startswith("!gptbot")
|
|
||||||
):
|
|
||||||
messages.append(event)
|
|
||||||
|
|
||||||
self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
|
|
||||||
|
|
||||||
# Reverse the list so that messages are in chronological order
|
|
||||||
return messages[::-1]
|
|
||||||
|
|
||||||
def _truncate(
|
def _truncate(
|
||||||
self,
|
self,
|
||||||
|
@ -298,20 +232,17 @@ class GPTBot:
|
||||||
if not device_id:
|
if not device_id:
|
||||||
assert self.matrix_client.access_token, "Access token not set up"
|
assert self.matrix_client.access_token, "Access token not set up"
|
||||||
|
|
||||||
devices = await self.matrix_client.devices()
|
# TODO: Implement
|
||||||
|
|
||||||
if isinstance(devices, DevicesResponse):
|
|
||||||
device_id = devices.devices[0].id
|
|
||||||
|
|
||||||
return device_id
|
return device_id
|
||||||
|
|
||||||
async def process_command(self, room: MatrixRoom, event: RoomMessageText):
|
async def process_command(self, room: RoomID, event: MessageEvent):
|
||||||
"""Process a command. Called from the event_callback() method.
|
"""Process a command. Called from the event_callback() method.
|
||||||
Delegates to the appropriate command handler.
|
Delegates to the appropriate command handler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom): The room the command was sent in.
|
room (RoomID): The room the command was sent in.
|
||||||
event (RoomMessageText): The event containing the command.
|
event (MessageEvent): The event containing the command.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
|
@ -322,11 +253,11 @@ class GPTBot:
|
||||||
|
|
||||||
await COMMANDS.get(command, COMMANDS[None])(room, event, self)
|
await COMMANDS.get(command, COMMANDS[None])(room, event, self)
|
||||||
|
|
||||||
def room_uses_classification(self, room: MatrixRoom | str) -> bool:
|
def room_uses_classification(self, room: RoomID | str) -> bool:
|
||||||
"""Check if a room uses classification.
|
"""Check if a room uses classification.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom | str): The room to check.
|
room (RoomID | str): The room to check.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether the room uses classification.
|
bool: Whether the room uses classification.
|
||||||
|
@ -342,7 +273,7 @@ class GPTBot:
|
||||||
|
|
||||||
return False if not result else bool(int(result[0]))
|
return False if not result else bool(int(result[0]))
|
||||||
|
|
||||||
async def _event_callback(self, room: MatrixRoom, event: Event):
|
async def _event_callback(self, room: RoomID, event: MessageEvent):
|
||||||
self.logger.log("Received event: " + str(event.event_id), "debug")
|
self.logger.log("Received event: " + str(event.event_id), "debug")
|
||||||
try:
|
try:
|
||||||
for eventtype, callback in EVENT_CALLBACKS.items():
|
for eventtype, callback in EVENT_CALLBACKS.items():
|
||||||
|
@ -378,12 +309,12 @@ class GPTBot:
|
||||||
else True
|
else True
|
||||||
)
|
)
|
||||||
|
|
||||||
async def event_callback(self, room: MatrixRoom, event: Event):
|
async def event_callback(self, room: RoomID, event: MessageEvent):
|
||||||
"""Callback for events.
|
"""Callback for events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom): The room the event was sent in.
|
room (RoomID): The room the event was sent in.
|
||||||
event (Event): The event.
|
event (MessageEvent): The event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if event.sender == self.matrix_client.user_id:
|
if event.sender == self.matrix_client.user_id:
|
||||||
|
@ -403,11 +334,11 @@ class GPTBot:
|
||||||
|
|
||||||
task = asyncio.create_task(self._event_callback(room, event))
|
task = asyncio.create_task(self._event_callback(room, event))
|
||||||
|
|
||||||
def room_uses_timing(self, room: MatrixRoom):
|
def room_uses_timing(self, room: RoomID):
|
||||||
"""Check if a room uses timing.
|
"""Check if a room uses timing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom): The room to check.
|
room (RoomID): The room to check.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether the room uses timing.
|
bool: Whether the room uses timing.
|
||||||
|
@ -423,14 +354,6 @@ class GPTBot:
|
||||||
|
|
||||||
return False if not result else bool(int(result[0]))
|
return False if not result else bool(int(result[0]))
|
||||||
|
|
||||||
async def _response_callback(self, response: Response):
|
|
||||||
for response_type, callback in RESPONSE_CALLBACKS.items():
|
|
||||||
if isinstance(response, response_type):
|
|
||||||
await callback(response, self)
|
|
||||||
|
|
||||||
async def response_callback(self, response: Response):
|
|
||||||
task = asyncio.create_task(self._response_callback(response))
|
|
||||||
|
|
||||||
async def accept_pending_invites(self):
|
async def accept_pending_invites(self):
|
||||||
"""Accept all pending invites."""
|
"""Accept all pending invites."""
|
||||||
|
|
||||||
|
@ -492,12 +415,12 @@ class GPTBot:
|
||||||
return response.content_uri
|
return response.content_uri
|
||||||
|
|
||||||
async def send_image(
|
async def send_image(
|
||||||
self, room: MatrixRoom, image: bytes, message: Optional[str] = None
|
self, room: RoomID, image: bytes, message: Optional[str] = None
|
||||||
):
|
):
|
||||||
"""Send an image to a room.
|
"""Send an image to a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom): The room to send the image to.
|
room (RoomID): The room to send the image to.
|
||||||
image (bytes): The image to send.
|
image (bytes): The image to send.
|
||||||
message (str, optional): The message to send with the image. Defaults to None.
|
message (str, optional): The message to send with the image. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
@ -539,13 +462,21 @@ class GPTBot:
|
||||||
|
|
||||||
self.logger.log("Sent image", "debug")
|
self.logger.log("Sent image", "debug")
|
||||||
|
|
||||||
|
async def handle_event(self, event):
|
||||||
|
"""Handle an event."""
|
||||||
|
|
||||||
|
for event_type, callback in EVENT_CALLBACKS.items():
|
||||||
|
if isinstance(event, event_type):
|
||||||
|
print(event_type, callback)
|
||||||
|
await callback(event, self)
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(
|
||||||
self, room: MatrixRoom | str, message: str, notice: bool = False
|
self, room: RoomID | str, message: str, notice: bool = False
|
||||||
):
|
):
|
||||||
"""Send a message to a room.
|
"""Send a message to a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom): The room to send the message to.
|
room (RoomID): The room to send the message to.
|
||||||
message (str): The message to send.
|
message (str): The message to send.
|
||||||
notice (bool): Whether to send the message as a notice. Defaults to False.
|
notice (bool): Whether to send the message as a notice. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
@ -618,13 +549,13 @@ class GPTBot:
|
||||||
return
|
return
|
||||||
|
|
||||||
def log_api_usage(
|
def log_api_usage(
|
||||||
self, message: Event | str, room: MatrixRoom | str, api: str, tokens: int
|
self, message: MessageEvent | str, room: RoomID | str, api: str, tokens: int
|
||||||
):
|
):
|
||||||
"""Log API usage to the database.
|
"""Log API usage to the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message (Event): The event that triggered the API usage.
|
message (MessageEvent): The event that triggered the API usage.
|
||||||
room (MatrixRoom | str): The room the event was sent in.
|
room (RoomID | str): The room the event was sent in.
|
||||||
api (str): The API that was used.
|
api (str): The API that was used.
|
||||||
tokens (int): The number of tokens used.
|
tokens (int): The number of tokens used.
|
||||||
"""
|
"""
|
||||||
|
@ -648,14 +579,11 @@ class GPTBot:
|
||||||
|
|
||||||
# Set up the Matrix client
|
# Set up the Matrix client
|
||||||
|
|
||||||
assert self.matrix_client, "Matrix client not set up"
|
self.matrix_client: Client = self.matrix_client or Client(base_url=self.homeserver, token=self.access_token)
|
||||||
assert self.matrix_client.access_token, "Access token not set up"
|
|
||||||
|
|
||||||
if not self.matrix_client.user_id:
|
iam = await self.matrix_client.whoami()
|
||||||
self.matrix_client.user_id = await self._get_user_id()
|
|
||||||
|
|
||||||
if not self.matrix_client.device_id:
|
self.logger.log(f"Logged in as {iam.user_id} (device ID: {iam.device_id})", "info")
|
||||||
self.matrix_client.device_id = await self._get_device_id()
|
|
||||||
|
|
||||||
# Set up database
|
# Set up database
|
||||||
|
|
||||||
|
@ -686,84 +614,33 @@ class GPTBot:
|
||||||
else:
|
else:
|
||||||
self.logger.log(f"Already at latest version {after}.")
|
self.logger.log(f"Already at latest version {after}.")
|
||||||
|
|
||||||
if IN_MEMORY:
|
# Set up event handlers
|
||||||
client_config = AsyncClientConfig(
|
self.matrix_client.add_event_handler(EventType.ALL, self.handle_event)
|
||||||
store_sync_tokens=True, encryption_enabled=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
matrix_store = SqliteStore
|
|
||||||
client_config = AsyncClientConfig(
|
|
||||||
store_sync_tokens=True, encryption_enabled=True, store=matrix_store
|
|
||||||
)
|
|
||||||
self.matrix_client.config = client_config
|
|
||||||
self.matrix_client.store = matrix_store(
|
|
||||||
self.matrix_client.user_id,
|
|
||||||
self.matrix_client.device_id,
|
|
||||||
'.', #store path
|
|
||||||
database_name=self.crypto_store_path or "",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.matrix_client.olm = Olm(
|
|
||||||
self.matrix_client.user_id,
|
|
||||||
self.matrix_client.device_id,
|
|
||||||
self.matrix_client.store,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.matrix_client.encrypted_rooms = (
|
|
||||||
self.matrix_client.store.load_encrypted_rooms()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run initial sync (now includes joining rooms)
|
# Run initial sync (now includes joining rooms)
|
||||||
sync = await self.matrix_client.sync(timeout=30000)
|
sync = await self.matrix_client.sync(timeout=30000)
|
||||||
if isinstance(sync, SyncResponse):
|
|
||||||
await self.response_callback(sync)
|
|
||||||
else:
|
|
||||||
self.logger.log(f"Initial sync failed, aborting: {sync}", "critical")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
# Set up callbacks
|
|
||||||
|
|
||||||
self.matrix_client.add_event_callback(self.event_callback, Event)
|
|
||||||
self.matrix_client.add_response_callback(self.response_callback, Response)
|
|
||||||
|
|
||||||
# Set custom name / logo
|
# Set custom name / logo
|
||||||
|
|
||||||
if self.display_name:
|
# TODO: Implement
|
||||||
self.logger.log(f"Setting display name to {self.display_name}", "debug")
|
|
||||||
asyncio.create_task(self.matrix_client.set_displayname(self.display_name))
|
|
||||||
if self.logo:
|
|
||||||
self.logger.log("Setting avatar...")
|
|
||||||
logo_bio = BytesIO()
|
|
||||||
self.logo.save(logo_bio, format=self.logo.format)
|
|
||||||
uri = await self.upload_file(
|
|
||||||
logo_bio.getvalue(), "logo", Image.MIME[self.logo.format]
|
|
||||||
)
|
|
||||||
self.logo_uri = uri
|
|
||||||
|
|
||||||
asyncio.create_task(self.matrix_client.set_avatar(uri))
|
|
||||||
|
|
||||||
for room in self.matrix_client.rooms.keys():
|
|
||||||
self.logger.log(f"Setting avatar for {room}...", "debug")
|
|
||||||
asyncio.create_task(
|
|
||||||
self.matrix_client.room_put_state(
|
|
||||||
room, "m.room.avatar", {"url": uri}, ""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start syncing events
|
# Start syncing events
|
||||||
self.logger.log("Starting sync loop...", "warning")
|
self.logger.log("Starting sync loop...", "warning")
|
||||||
try:
|
try:
|
||||||
await self.matrix_client.sync_forever(timeout=30000)
|
await self.matrix_client.start(None)
|
||||||
finally:
|
finally:
|
||||||
self.logger.log("Syncing one last time...", "warning")
|
self.logger.log("Syncing one last time...", "warning")
|
||||||
await self.matrix_client.sync(timeout=30000)
|
await self.matrix_client.sync(timeout=30000)
|
||||||
|
|
||||||
async def create_space(self, name, visibility=RoomVisibility.private) -> str:
|
async def create_space(
|
||||||
|
self, name, visibility=RoomDirectoryVisibility.PRIVATE
|
||||||
|
) -> str:
|
||||||
"""Create a space.
|
"""Create a space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): The name of the space.
|
name (str): The name of the space.
|
||||||
visibility (RoomVisibility, optional): The visibility of the space. Defaults to RoomVisibility.private.
|
visibility (RoomDirectoryVisibility, optional): The visibility of the space. Defaults to RoomVisibility.private.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MatrixRoom: The created space.
|
MatrixRoom: The created space.
|
||||||
|
@ -780,13 +657,13 @@ class GPTBot:
|
||||||
return response.room_id
|
return response.room_id
|
||||||
|
|
||||||
async def add_rooms_to_space(
|
async def add_rooms_to_space(
|
||||||
self, space: MatrixRoom | str, rooms: List[MatrixRoom | str]
|
self, space: RoomID | str, rooms: List[RoomID | str]
|
||||||
):
|
):
|
||||||
"""Add rooms to a space.
|
"""Add rooms to a space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
space (MatrixRoom | str): The space to add the rooms to.
|
space (RoomID | str): The space to add the rooms to.
|
||||||
rooms (List[MatrixRoom | str]): The rooms to add to the space.
|
rooms (List[RoomID | str]): The rooms to add to the space.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(space, MatrixRoom):
|
if isinstance(space, MatrixRoom):
|
||||||
|
@ -818,17 +695,17 @@ class GPTBot:
|
||||||
space,
|
space,
|
||||||
)
|
)
|
||||||
|
|
||||||
def respond_to_room_messages(self, room: MatrixRoom | str) -> bool:
|
def respond_to_room_messages(self, room: RoomID | str) -> bool:
|
||||||
"""Check whether the bot should respond to all messages sent in a room.
|
"""Check whether the bot should respond to all messages sent in a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom | str): The room to check.
|
room (RoomID | str): The room to check.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether the bot should respond to all messages sent in the room.
|
bool: Whether the bot should respond to all messages sent in the room.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(room, MatrixRoom):
|
if isinstance(room, RoomID):
|
||||||
room = room.room_id
|
room = room.room_id
|
||||||
|
|
||||||
with closing(self.database.cursor()) as cursor:
|
with closing(self.database.cursor()) as cursor:
|
||||||
|
@ -841,26 +718,26 @@ class GPTBot:
|
||||||
return True if not result else bool(int(result[0]))
|
return True if not result else bool(int(result[0]))
|
||||||
|
|
||||||
async def process_query(
|
async def process_query(
|
||||||
self, room: MatrixRoom, event: RoomMessageText, from_chat_command: bool = False
|
self, room: RoomID, event: MessageEvent, from_chat_command: bool = False
|
||||||
):
|
):
|
||||||
"""Process a query message. Generates a response and sends it to the room.
|
"""Process a query message. Generates a response and sends it to the room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom): The room the message was sent in.
|
room (RoomID): The room the message was sent in.
|
||||||
event (RoomMessageText): The event that triggered the query.
|
event (MessageEvent): The event that triggered the query.
|
||||||
from_chat_command (bool, optional): Whether the query was sent via the `!gptbot chat` command. Defaults to False.
|
from_chat_command (bool, optional): Whether the query was sent via the `!gptbot chat` command. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
from_chat_command
|
from_chat_command
|
||||||
or self.respond_to_room_messages(room)
|
or self.respond_to_room_messages(room)
|
||||||
or self.matrix_client.user_id in event.body
|
or self.matrix_client.whoami().user_id in event.body
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.matrix_client.room_typing(room.room_id, True)
|
# TODO: await self.matrix_client.room_typing(room.room_id, True)
|
||||||
|
|
||||||
await self.matrix_client.room_read_markers(room.room_id, event.event_id)
|
# TODO: await self.matrix_client.room_read_markers(room.room_id, event.event_id)
|
||||||
|
|
||||||
if (not from_chat_command) and self.room_uses_classification(room):
|
if (not from_chat_command) and self.room_uses_classification(room):
|
||||||
try:
|
try:
|
||||||
|
@ -949,11 +826,11 @@ class GPTBot:
|
||||||
|
|
||||||
await self.matrix_client.room_typing(room.room_id, False)
|
await self.matrix_client.room_typing(room.room_id, False)
|
||||||
|
|
||||||
def get_system_message(self, room: MatrixRoom | str) -> str:
|
def get_system_message(self, room: RoomID | str) -> str:
|
||||||
"""Get the system message for a room.
|
"""Get the system message for a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room (MatrixRoom | str): The room to get the system message for.
|
room (RoomID | str): The room to get the system message for.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The system message.
|
str: The system message.
|
||||||
|
|
|
@ -21,13 +21,14 @@ class OpenAI:
|
||||||
return self.chat_model
|
return self.chat_model
|
||||||
|
|
||||||
classification_api = chat_api
|
classification_api = chat_api
|
||||||
image_api: str = "dalle"
|
image_model: str = "dall-e-2"
|
||||||
|
|
||||||
operator: str = "OpenAI ([https://openai.com](https://openai.com))"
|
operator: str = "OpenAI ([https://openai.com](https://openai.com))"
|
||||||
|
|
||||||
def __init__(self, api_key, chat_model=None, logger=None):
|
def __init__(self, api_key, chat_model=None, image_model=None, logger=None):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.chat_model = chat_model or self.chat_model
|
self.chat_model = chat_model or self.chat_model
|
||||||
|
self.image_model = image_model or self.image_model
|
||||||
self.logger = logger or Logger()
|
self.logger = logger or Logger()
|
||||||
self.base_url = openai.api_base
|
self.base_url = openai.api_base
|
||||||
|
|
||||||
|
@ -146,6 +147,7 @@ Only the event_types mentioned above are allowed, you must not respond in any ot
|
||||||
|
|
||||||
image_partial = partial(
|
image_partial = partial(
|
||||||
openai.Image.acreate,
|
openai.Image.acreate,
|
||||||
|
model=self.image_model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
n=1,
|
n=1,
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
|
|
Loading…
Reference in a new issue