Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/stage' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
kemalbsoylu committed Jan 14, 2025
2 parents 2c5e392 + a51a2fa commit 2b89df6
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 40 deletions.
30 changes: 20 additions & 10 deletions oa/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import base64
import logging
from typing import Any
from django.urls import reverse

from openai import AsyncAssistantEventHandler
from openai.types.beta.threads import Text, TextDelta, ImageFile


logger = logging.getLogger(__name__)


class APIError(Exception):
def __init__(self, message, status=500):
self.message = message
Expand All @@ -29,10 +33,11 @@ def serialize_to_dict(obj: Any) -> Any:


class EventHandler(AsyncAssistantEventHandler):
def __init__(self, shared_data):
def __init__(self, request, shared_data):
super().__init__()
self.request = request
self.current_message = ""
self.shared_data = shared_data # Use shared_data instead of response_data
self.shared_data = shared_data
self.stream_done = False
self.current_annotations = []

Expand Down Expand Up @@ -76,13 +81,18 @@ async def on_message_done(self, message):
})

async def on_image_file_done(self, image_file: ImageFile) -> None:
image_url = reverse('api-1.0.0:serve_image_file', kwargs={
'file_id': image_file.file_id
})

self.current_message += f'<p><img src="{image_url}" style="max-width: 100%;"></p>'

# No annotations for images
try:
content_response = await self.request.auth['client'].files.content(image_file.file_id)
image_binary = content_response.read()
image_base64 = base64.b64encode(image_binary).decode('utf-8')
image_data = f"data:image/png;base64,{image_base64}"

self.current_message += f'<p><img src="{image_data}" style="max-width: 100%;"></p>'
except APIError as e:
logger.warning(f"Error fetching image file with id {image_file.file_id}: {e}")
self.current_message += f"<p>(Error fetching image file)</p>"

# Send an SSE event that includes the updated current_message
self.shared_data.append({
"type": "image_file",
"text": self.current_message,
Expand Down
35 changes: 15 additions & 20 deletions oa/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ async def cancel_run(request, thread_id, run_id):
async def stream_responses(request, assistant_id: str, thread_id: str):
async def event_stream():
shared_data = []
event_handler = EventHandler(shared_data=shared_data)
event_handler = EventHandler(request, shared_data=shared_data)
try:
async with request.auth['client'].beta.threads.runs.stream(
thread_id=thread_id,
Expand Down Expand Up @@ -782,19 +782,6 @@ async def event_stream():
return response


@api.get("files/image/{file_id}", auth=BearerAuth())
async def serve_image_file(request, file_id: str):
try:
content_response = await request.auth['client'].files.content(file_id)
image_binary = content_response.read()
return HttpResponse(image_binary, content_type='image/png')
except APIError as e:
return JsonResponse({"error": e.message}, status=e.status)
except OpenAIError as e:
logger.error(f"Error fetching image file: {e}")
return HttpResponseNotFound('Image not found')


@api.get("/thread/{thread_id}/messages", auth=BearerAuth())
async def get_thread_messages(request, thread_id):
try:
Expand Down Expand Up @@ -830,12 +817,15 @@ async def format_message(message):
# Fetch file path
if file_path := getattr(annotation, 'file_path', None):
file_path_file_id = getattr(file_path, 'file_id', None)
download_link = reverse('api-1.0.0:download_file', kwargs={
'file_id': file_path_file_id
})

# Replace the annotation text with the download link
text_content = text_content.replace(annotation.text, download_link)
download_link = reverse(
'api-1.0.0:download_file_trigger',
kwargs={'file_id': file_path_file_id}
)

html_snippet = f'<a href="#" onclick="downloadFile(\'{download_link}\')"><i class="bi bi-cloud-download"></i></a>'

text_content = text_content.replace(annotation.text, html_snippet)

content += f"<p>{text_content}</p>"

Expand Down Expand Up @@ -919,6 +909,11 @@ async def fetch_file(file_id):
return JsonResponse({'success': True, 'files': files})


@api.get("/download-trigger/{file_id}", auth=BearerAuth())
async def download_file_trigger(request, file_id: str):
return await download_file(request, file_id)


@api.get("/download/{file_id}", auth=BearerAuth())
async def download_file(request, file_id: str):
try:
Expand All @@ -927,7 +922,7 @@ async def download_file(request, file_id: str):

# Retrieve file content
file_content_response = await request.auth['client'].files.content(file_id)
file_content = await file_content_response.read() # Read the content as bytes
file_content = file_content_response.read() # Read the content as bytes

# Extract the filename from the full path
filename = os.path.basename(file_info.filename)
Expand Down
50 changes: 47 additions & 3 deletions oa/templates/chat/chat_js.html
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,47 @@
return new URLSearchParams(window.location.search).get(name);
}

async function downloadFile(fileUrl) {
try {
const response = await fetch(fileUrl, {
method: 'GET',
headers: {
'Authorization': `Bearer ${API_KEY}`
}
});
if (!response.ok) {
showToast("Failed!", "Failed to download!", "danger");
console.error('Failed to download!', response);
}

// Convert response to blob for download
const blob = await response.blob();

// Extract filename from Content-Disposition header if present
let filename = "downloaded_file";
const disposition = response.headers.get("Content-Disposition");
if (disposition && disposition.includes("filename=")) {
const match = /filename="([^"]*)"/.exec(disposition);
if (match && match[1]) {
filename = match[1];
}
}

// Create a temporary <a> to download the blob
const downloadUrl = window.URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = downloadUrl;
a.download = filename;
document.body.appendChild(a);
a.click();
a.remove();
window.URL.revokeObjectURL(downloadUrl);
} catch (error) {
showToast("Failed!", "Download error!", "danger");
console.error("Download error:", error);
}
}

document.addEventListener("DOMContentLoaded", async function () {
// Call initializePage and wait for it to complete
await initializePage();
Expand Down Expand Up @@ -723,10 +764,13 @@
content = content.replace(annotation.text, replacementText);
} else if (annotation.file_path) {
const fileId = annotation.file_path.file_id;
const downloadUrlTemplate = "{% url 'api-1.0.0:download_file' file_id='FILE_ID' %}";
const downloadUrl = downloadUrlTemplate.replace('ASST_ID_PLACEHOLDER', fileId);

const downloadUrlTemplate = "{% url 'api-1.0.0:download_file_trigger' file_id='FILE_ID' %}";
const triggeredDownloadUrl = downloadUrlTemplate.replace('FILE_ID', fileId);

const linkHTML = `<a href="#" onclick="downloadFile('${triggeredDownloadUrl}')"><i class="bi bi-cloud-download"></i></a>`;

content = content.replace(annotation.text, downloadUrl);
content = content.replace(annotation.text, linkHTML);
}
}
return content;
Expand Down
54 changes: 47 additions & 7 deletions oa/templates/chat/shared_chat_js.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,48 @@
return new URLSearchParams(window.location.search).get(name);
}

async function downloadFile(fileUrl) {
try {
const response = await fetch(fileUrl, {
method: 'GET',
headers: {
'Authorization': `Bearer ${API_KEY}`,
'X-Token': shared_token
}
});
if (!response.ok) {
showToast("Failed!", "Failed to download!", "danger");
console.error('Failed to download!', response);
}

// Convert response to blob for download
const blob = await response.blob();

// Extract filename from Content-Disposition header if present
let filename = "downloaded_file";
const disposition = response.headers.get("Content-Disposition");
if (disposition && disposition.includes("filename=")) {
const match = /filename="([^"]*)"/.exec(disposition);
if (match && match[1]) {
filename = match[1];
}
}

// Create a temporary <a> to download the blob
const downloadUrl = window.URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = downloadUrl;
a.download = filename;
document.body.appendChild(a);
a.click();
a.remove();
window.URL.revokeObjectURL(downloadUrl);
} catch (error) {
showToast("Failed!", "Download error!", "danger");
console.error("Download error:", error);
}
}

document.addEventListener("DOMContentLoaded", async function () {
// Call initializePage and wait for it to complete
await initializePage();
Expand Down Expand Up @@ -735,15 +777,13 @@
content = content.replace(annotation.text, replacementText);
} else if (annotation.file_path) {
const fileId = annotation.file_path.file_id;
const downloadUrlTemplate = "{% url 'api-1.0.0:download_file' file_id='FILE_ID' %}";
const downloadUrlTemplateToken = downloadUrlTemplate + `?token=${shared_token}`
const downloadUrl = downloadUrlTemplateToken.replace('ASST_ID_PLACEHOLDER', fileId);

console.log(downloadUrlTemplate);
console.log(downloadUrlTemplateToken);
console.log(downloadUrl);
const downloadUrlTemplate = "{% url 'api-1.0.0:download_file_trigger' file_id='FILE_ID' %}";
const triggeredDownloadUrl = downloadUrlTemplate.replace('FILE_ID', fileId);

const linkHTML = `<a href="#" onclick="downloadFile('${triggeredDownloadUrl}')"><i class="bi bi-cloud-download"></i></a>`;

content = content.replace(annotation.text, downloadUrl);
content = content.replace(annotation.text, linkHTML);
}
}
return content;
Expand Down

0 comments on commit 2b89df6

Please sign in to comment.