feat: enhance error handling for file downloads

Introduce `DownloadException` to improve error reporting and handling when file downloads fail. Modified `download_file` method to accept a `raise_error` flag, which, when set, raises `DownloadException` upon a download error instead of just logging it. This enables the bot to respond with a specific error message to the room if a download fails during processing of speech-to-text, file messages, and image messages, enhancing user feedback on download failures.
This commit is contained in:
Kumi 2024-05-20 10:41:09 +02:00
parent 3f084ffdd3
commit e1695f0cce
Signed by: kumi
GPG key ID: ECBCC9082395383F
2 changed files with 18 additions and 4 deletions

View file

@ -54,6 +54,7 @@ from ..callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
from ..commands import COMMANDS
from ..tools import TOOLS, Handover, StopProcessing
from .ai.base import BaseAI
from .exceptions import DownloadException
class GPTBot:
@ -1192,7 +1193,9 @@ class GPTBot:
if message == event or (not message.event_id == event.event_id):
if self.room_uses_stt(room):
try:
download = await self.download_file(message.url)
download = await self.download_file(
message.url, raise_error=True
)
message_text = await self.stt_api.speech_to_text(
download.body
)
@ -1213,7 +1216,7 @@ class GPTBot:
elif isinstance(message, RoomMessageFile):
try:
download = await self.download_file(message.url)
download = await self.download_file(message.url, raise_error=True)
if download:
try:
text = download.body.decode("utf-8")
@ -1252,7 +1255,7 @@ class GPTBot:
):
try:
image_url = message.url
download = await self.download_file(image_url)
download = await self.download_file(image_url, raise_error=True)
if download:
pil_image = Image.open(BytesIO(download.body))
@ -1306,6 +1309,13 @@ class GPTBot:
)
except Exception as e:
if isinstance(e, DownloadException):
self.send_message(
room,
f"Could not process image due to download error: {e.args[0]}",
True,
)
self.logger.log(f"Error generating image from file: {e}", "error")
message_body = (
message.body
@ -1369,7 +1379,7 @@ class GPTBot:
await self.matrix_client.room_typing(room.room_id, False)
async def download_file(
self, mxc
self, mxc: str, raise_error: bool = False
) -> Union[DiskDownloadResponse, MemoryDownloadResponse]:
"""Download a file from the homeserver.
@ -1384,6 +1394,8 @@ class GPTBot:
if isinstance(download, DownloadError):
self.logger.log(f"Error downloading file: {download.message}", "error")
if raise_error:
raise DownloadException(download.message)
return
return download

View file

@ -0,0 +1,2 @@
class DownloadException(Exception):
pass