diff --git a/oa/api/utils.py b/oa/api/utils.py index f7aea11..04d6d61 100644 --- a/oa/api/utils.py +++ b/oa/api/utils.py @@ -1,10 +1,14 @@ +import base64 +import logging from typing import Any -from django.urls import reverse from openai import AsyncAssistantEventHandler from openai.types.beta.threads import Text, TextDelta, ImageFile +logger = logging.getLogger(__name__) + + class APIError(Exception): def __init__(self, message, status=500): self.message = message @@ -29,10 +33,11 @@ def serialize_to_dict(obj: Any) -> Any: class EventHandler(AsyncAssistantEventHandler): - def __init__(self, shared_data): + def __init__(self, request, shared_data): super().__init__() + self.request = request self.current_message = "" - self.shared_data = shared_data # Use shared_data instead of response_data + self.shared_data = shared_data self.stream_done = False self.current_annotations = [] @@ -76,13 +81,18 @@ async def on_message_done(self, message): }) async def on_image_file_done(self, image_file: ImageFile) -> None: - image_url = reverse('api-1.0.0:serve_image_file', kwargs={ - 'file_id': image_file.file_id - }) - - self.current_message += f'
' - - # No annotations for images + try: + content_response = await self.request.auth['client'].files.content(image_file.file_id) + image_binary = content_response.read() + image_base64 = base64.b64encode(image_binary).decode('utf-8') + image_data = f"data:image/png;base64,{image_base64}" + + self.current_message += f'' + except APIError as e: + logger.warning(f"Error fetching image file with id {image_file.file_id}: {e}") + self.current_message += f"(Error fetching image file)
" + + # Send an SSE event that includes the updated current_message self.shared_data.append({ "type": "image_file", "text": self.current_message, diff --git a/oa/api/views.py b/oa/api/views.py index c13e308..64e1551 100644 --- a/oa/api/views.py +++ b/oa/api/views.py @@ -695,7 +695,7 @@ async def cancel_run(request, thread_id, run_id): async def stream_responses(request, assistant_id: str, thread_id: str): async def event_stream(): shared_data = [] - event_handler = EventHandler(shared_data=shared_data) + event_handler = EventHandler(request, shared_data=shared_data) try: async with request.auth['client'].beta.threads.runs.stream( thread_id=thread_id, @@ -782,19 +782,6 @@ async def event_stream(): return response -@api.get("files/image/{file_id}", auth=BearerAuth()) -async def serve_image_file(request, file_id: str): - try: - content_response = await request.auth['client'].files.content(file_id) - image_binary = content_response.read() - return HttpResponse(image_binary, content_type='image/png') - except APIError as e: - return JsonResponse({"error": e.message}, status=e.status) - except OpenAIError as e: - logger.error(f"Error fetching image file: {e}") - return HttpResponseNotFound('Image not found') - - @api.get("/thread/{thread_id}/messages", auth=BearerAuth()) async def get_thread_messages(request, thread_id): try: @@ -830,12 +817,15 @@ async def format_message(message): # Fetch file path if file_path := getattr(annotation, 'file_path', None): file_path_file_id = getattr(file_path, 'file_id', None) - download_link = reverse('api-1.0.0:download_file', kwargs={ - 'file_id': file_path_file_id - }) - # Replace the annotation text with the download link - text_content = text_content.replace(annotation.text, download_link) + download_link = reverse( + 'api-1.0.0:download_file_trigger', + kwargs={'file_id': file_path_file_id} + ) + + html_snippet = f'' + + text_content = text_content.replace(annotation.text, html_snippet) content += f"{text_content}
" @@ -919,6 +909,11 @@ async def fetch_file(file_id): return JsonResponse({'success': True, 'files': files}) +@api.get("/download-trigger/{file_id}", auth=BearerAuth()) +async def download_file_trigger(request, file_id: str): + return await download_file(request, file_id) + + @api.get("/download/{file_id}", auth=BearerAuth()) async def download_file(request, file_id: str): try: @@ -927,7 +922,7 @@ async def download_file(request, file_id: str): # Retrieve file content file_content_response = await request.auth['client'].files.content(file_id) - file_content = await file_content_response.read() # Read the content as bytes + file_content = file_content_response.read() # Read the content as bytes # Extract the filename from the full path filename = os.path.basename(file_info.filename) diff --git a/oa/templates/chat/chat_js.html b/oa/templates/chat/chat_js.html index 13a6645..280d3da 100644 --- a/oa/templates/chat/chat_js.html +++ b/oa/templates/chat/chat_js.html @@ -20,6 +20,47 @@ return new URLSearchParams(window.location.search).get(name); } +async function downloadFile(fileUrl) { + try { + const response = await fetch(fileUrl, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${API_KEY}` + } + }); + if (!response.ok) { + showToast("Failed!", "Failed to download!", "danger"); + console.error('Failed to download!', response); + } + + // Convert response to blob for download + const blob = await response.blob(); + + // Extract filename from Content-Disposition header if present + let filename = "downloaded_file"; + const disposition = response.headers.get("Content-Disposition"); + if (disposition && disposition.includes("filename=")) { + const match = /filename="([^"]*)"/.exec(disposition); + if (match && match[1]) { + filename = match[1]; + } + } + + // Create a temporary to download the blob + const downloadUrl = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = downloadUrl; + a.download = filename; + document.body.appendChild(a); + a.click(); + a.remove(); + window.URL.revokeObjectURL(downloadUrl); + } catch (error) { + showToast("Failed!", "Download error!", "danger"); + console.error("Download error:", error); + } +} + document.addEventListener("DOMContentLoaded", async function () { // Call initializePage and wait for it to complete await initializePage(); @@ -723,10 +764,13 @@ content = content.replace(annotation.text, replacementText); } else if (annotation.file_path) { const fileId = annotation.file_path.file_id; - const downloadUrlTemplate = "{% url 'api-1.0.0:download_file' file_id='FILE_ID' %}"; - const downloadUrl = downloadUrlTemplate.replace('ASST_ID_PLACEHOLDER', fileId); + + const downloadUrlTemplate = "{% url 'api-1.0.0:download_file_trigger' file_id='FILE_ID' %}"; + const triggeredDownloadUrl = downloadUrlTemplate.replace('FILE_ID', fileId); + + const linkHTML = ``; - content = content.replace(annotation.text, downloadUrl); + content = content.replace(annotation.text, linkHTML); } } return content; diff --git a/oa/templates/chat/shared_chat_js.html b/oa/templates/chat/shared_chat_js.html index 04c027a..303805d 100644 --- a/oa/templates/chat/shared_chat_js.html +++ b/oa/templates/chat/shared_chat_js.html @@ -21,6 +21,48 @@ return new URLSearchParams(window.location.search).get(name); } +async function downloadFile(fileUrl) { + try { + const response = await fetch(fileUrl, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${API_KEY}`, + 'X-Token': shared_token + } + }); + if (!response.ok) { + showToast("Failed!", "Failed to download!", "danger"); + console.error('Failed to download!', response); + } + + // Convert response to blob for download + const blob = await response.blob(); + + // Extract filename from Content-Disposition header if present + let filename = "downloaded_file"; + const disposition = response.headers.get("Content-Disposition"); + if (disposition && disposition.includes("filename=")) { + const match = /filename="([^"]*)"/.exec(disposition); + if (match && match[1]) { + filename = match[1]; + } + } + + // Create a temporary to download the blob + const downloadUrl = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = downloadUrl; + a.download = filename; + document.body.appendChild(a); + a.click(); + a.remove(); + window.URL.revokeObjectURL(downloadUrl); + } catch (error) { + showToast("Failed!", "Download error!", "danger"); + console.error("Download error:", error); + } +} + document.addEventListener("DOMContentLoaded", async function () { // Call initializePage and wait for it to complete await initializePage(); @@ -735,15 +777,13 @@ content = content.replace(annotation.text, replacementText); } else if (annotation.file_path) { const fileId = annotation.file_path.file_id; - const downloadUrlTemplate = "{% url 'api-1.0.0:download_file' file_id='FILE_ID' %}"; - const downloadUrlTemplateToken = downloadUrlTemplate + `?token=${shared_token}` - const downloadUrl = downloadUrlTemplateToken.replace('ASST_ID_PLACEHOLDER', fileId); - console.log(downloadUrlTemplate); - console.log(downloadUrlTemplateToken); - console.log(downloadUrl); + const downloadUrlTemplate = "{% url 'api-1.0.0:download_file_trigger' file_id='FILE_ID' %}"; + const triggeredDownloadUrl = downloadUrlTemplate.replace('FILE_ID', fileId); + + const linkHTML = ``; - content = content.replace(annotation.text, downloadUrl); + content = content.replace(annotation.text, linkHTML); } } return content;