feat: enhance error handling for file downloads

Introduce `DownloadException` to improve error reporting and handling when file downloads fail. Modified `download_file` method to accept a `raise_error` flag, which, when set, raises `DownloadException` upon a download error instead of just logging it. This enables the bot to respond with a specific error message to the room if a download fails during processing of speech-to-text, file messages, and image messages, enhancing user feedback on download failures.
This commit is contained in:
Kumi 2024-05-20 10:41:09 +02:00
parent 3f084ffdd3
commit e1695f0cce
Signed by: kumi
GPG key ID: ECBCC9082395383F
2 changed files with 18 additions and 4 deletions

View file

@ -54,6 +54,7 @@ from ..callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
from ..commands import COMMANDS from ..commands import COMMANDS
from ..tools import TOOLS, Handover, StopProcessing from ..tools import TOOLS, Handover, StopProcessing
from .ai.base import BaseAI from .ai.base import BaseAI
from .exceptions import DownloadException
class GPTBot: class GPTBot:
@ -1192,7 +1193,9 @@ class GPTBot:
if message == event or (not message.event_id == event.event_id): if message == event or (not message.event_id == event.event_id):
if self.room_uses_stt(room): if self.room_uses_stt(room):
try: try:
download = await self.download_file(message.url) download = await self.download_file(
message.url, raise_error=True
)
message_text = await self.stt_api.speech_to_text( message_text = await self.stt_api.speech_to_text(
download.body download.body
) )
@ -1213,7 +1216,7 @@ class GPTBot:
elif isinstance(message, RoomMessageFile): elif isinstance(message, RoomMessageFile):
try: try:
download = await self.download_file(message.url) download = await self.download_file(message.url, raise_error=True)
if download: if download:
try: try:
text = download.body.decode("utf-8") text = download.body.decode("utf-8")
@ -1252,7 +1255,7 @@ class GPTBot:
): ):
try: try:
image_url = message.url image_url = message.url
download = await self.download_file(image_url) download = await self.download_file(image_url, raise_error=True)
if download: if download:
pil_image = Image.open(BytesIO(download.body)) pil_image = Image.open(BytesIO(download.body))
@ -1306,6 +1309,13 @@ class GPTBot:
) )
except Exception as e: except Exception as e:
if isinstance(e, DownloadException):
self.send_message(
room,
f"Could not process image due to download error: {e.args[0]}",
True,
)
self.logger.log(f"Error generating image from file: {e}", "error") self.logger.log(f"Error generating image from file: {e}", "error")
message_body = ( message_body = (
message.body message.body
@ -1369,7 +1379,7 @@ class GPTBot:
await self.matrix_client.room_typing(room.room_id, False) await self.matrix_client.room_typing(room.room_id, False)
async def download_file( async def download_file(
self, mxc self, mxc: str, raise_error: bool = False
) -> Union[DiskDownloadResponse, MemoryDownloadResponse]: ) -> Union[DiskDownloadResponse, MemoryDownloadResponse]:
"""Download a file from the homeserver. """Download a file from the homeserver.
@ -1384,6 +1394,8 @@ class GPTBot:
if isinstance(download, DownloadError): if isinstance(download, DownloadError):
self.logger.log(f"Error downloading file: {download.message}", "error") self.logger.log(f"Error downloading file: {download.message}", "error")
if raise_error:
raise DownloadException(download.message)
return return
return download return download

View file

@ -0,0 +1,2 @@
class DownloadException(Exception):
pass