From 588e3850fe5cfa45f8b12bc6ed3526940824ed35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Guillaume?= <62661249+leoguillaume@users.noreply.github.com> Date: Thu, 9 Jan 2025 18:52:57 +0100 Subject: [PATCH] feat: httpx timeout error handling (#139) Co-authored-by: leoguillaume --- app/clients/_modelclients.py | 4 +- app/endpoints/audio.py | 37 +++--- app/endpoints/chat.py | 69 ++++------ app/endpoints/completions.py | 19 +-- app/endpoints/embeddings.py | 20 +-- app/helpers/__init__.py | 11 +- app/helpers/_metricsmiddleware.py | 19 +++ .../_streamingresponsewithstatuscode.py | 62 +++++++++ app/main.py | 8 ++ app/tests/test_audio.py | 11 ++ app/tests/test_chat.py | 66 ++++++---- app/utils/logging.py | 32 ++++- app/utils/route.py | 120 ++++++++++++++++++ app/utils/variables.py | 2 +- 14 files changed, 360 insertions(+), 120 deletions(-) create mode 100644 app/helpers/_streamingresponsewithstatuscode.py create mode 100644 app/utils/route.py diff --git a/app/clients/_modelclients.py b/app/clients/_modelclients.py index b1587254..5336f5d3 100644 --- a/app/clients/_modelclients.py +++ b/app/clients/_modelclients.py @@ -114,8 +114,6 @@ def create_embeddings(self, *args, **kwargs): class ModelClient(OpenAI): - DEFAULT_TIMEOUT = 120 - def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE, RERANK_MODEL_TYPE], *args, **kwargs) -> None: """ ModelClient class extends AsyncOpenAI class to support custom methods. @@ -162,7 +160,7 @@ def create(self, prompt: str, input: list[str], model: str) -> List[Rerank]: return data - self.rerank = RerankClient(model=self.id, base_url=self.base_url, api_key=self.api_key, timeout=self.DEFAULT_TIMEOUT) + self.rerank = RerankClient(model=self.id, base_url=self.base_url, api_key=self.api_key, timeout=DEFAULT_TIMEOUT) class ModelClients(dict): diff --git a/app/endpoints/audio.py b/app/endpoints/audio.py index 1762d169..2eb9c73f 100644 --- a/app/endpoints/audio.py +++ b/app/endpoints/audio.py @@ -1,16 +1,15 @@ -import json from typing import List, Literal -from fastapi import APIRouter, File, Form, HTTPException, Request, Security, UploadFile +from fastapi import APIRouter, File, Form, Request, Security, UploadFile from fastapi.responses import PlainTextResponse -import httpx from app.schemas.audio import AudioTranscription -from app.utils.exceptions import ModelNotFoundException +from app.utils.exceptions import WrongModelTypeException from app.utils.lifespan import clients, limiter +from app.utils.route import forward_request from app.utils.security import User, check_api_key, check_rate_limit from app.utils.settings import settings -from app.utils.variables import DEFAULT_TIMEOUT, AUDIO_MODEL_TYPE +from app.utils.variables import AUDIO_MODEL_TYPE, DEFAULT_TIMEOUT router = APIRouter() @@ -152,7 +151,7 @@ async def audio_transcriptions( client = clients.models[model] if client.type != AUDIO_MODEL_TYPE: - raise ModelNotFoundException() + raise WrongModelTypeException() # @TODO: Implement prompt # @TODO: Implement timestamp_granularities @@ -163,20 +162,16 @@ async def audio_transcriptions( url = f"{client.base_url}audio/transcriptions" headers = {"Authorization": f"Bearer {client.api_key}"} - try: - async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client: - response = await async_client.post( - url=url, - headers=headers, - files={"file": (file.filename, file_content, file.content_type)}, - data={"language": language, "response_format": response_format, "temperature": temperature}, - ) - response.raise_for_status() - if response_format == "text": - return PlainTextResponse(content=response.text) + response = await forward_request( + url=url, + method="POST", + headers=headers, + timeout=DEFAULT_TIMEOUT, + files={"file": (file.filename, file_content, file.content_type)}, + data={"language": language, "response_format": response_format, "temperature": temperature}, + ) - data = response.json() - return AudioTranscription(**data) + if response_format == "text": + return PlainTextResponse(content=response.text) - except Exception as e: - raise HTTPException(status_code=e.response.status_code, detail=json.loads(s=e.response.text)["message"]) + return AudioTranscription(**response.json()) diff --git a/app/endpoints/chat.py b/app/endpoints/chat.py index 6db34958..efa596e4 100644 --- a/app/endpoints/chat.py +++ b/app/endpoints/chat.py @@ -1,18 +1,16 @@ -import json from typing import List, Tuple, Union -from fastapi import APIRouter, HTTPException, Request, Security +from fastapi import APIRouter, Request, Security from fastapi.concurrency import run_in_threadpool -from fastapi.responses import StreamingResponse -import httpx -from app.helpers import ClientsManager, InternetManager, SearchManager +from app.helpers import ClientsManager, InternetManager, SearchManager, StreamingResponseWithStatusCode from app.schemas.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionRequest from app.schemas.search import Search from app.schemas.security import User from app.schemas.settings import Settings from app.utils.exceptions import WrongModelTypeException from app.utils.lifespan import clients, limiter +from app.utils.route import forward_request, forward_stream from app.utils.security import check_api_key, check_rate_limit from app.utils.settings import settings from app.utils.variables import DEFAULT_TIMEOUT, LANGUAGE_MODEL_TYPE @@ -32,10 +30,12 @@ async def chat_completions( client = clients.models[body.model] if client.type != LANGUAGE_MODEL_TYPE: raise WrongModelTypeException() + body.model = client.id # replace alias by model id url = f"{client.base_url}chat/completions" headers = {"Authorization": f"Bearer {client.api_key}"} + # retrieval augmentation generation def retrieval_augmentation_generation( body: ChatCompletionRequest, clients: ClientsManager, settings: Settings ) -> Tuple[ChatCompletionRequest, List[Search]]: @@ -75,42 +75,27 @@ def retrieval_augmentation_generation( # not stream case if not body["stream"]: - try: - async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client: - response = await async_client.request(method="POST", url=url, headers=headers, json=body) - response.raise_for_status() - data = response.json() - data["search_results"] = searches - - return ChatCompletion(**data) - except Exception as e: - raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"]) + response = await forward_request( + url=url, + method="POST", + headers=headers, + json=body, + timeout=DEFAULT_TIMEOUT, + additional_data_value=searches, + additional_data_key="search_results", + ) + return ChatCompletion(**response.json()) # stream case - async def forward_stream(url: str, headers: dict, request: dict): - try: - error = None - async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client: - async with async_client.stream(method="POST", url=url, headers=headers, json=request) as response: - if response.status_code >= 400: - error = await response.aread().decode() - response.raise_for_status() - - i = 0 - async for chunk in response.aiter_raw(): - if i == 0: - chunks = chunk.decode(encoding="utf-8").split(sep="\n\n") - chunk = json.loads(chunks[0].lstrip("data: ")) - chunk["search_results"] = searches - chunks[0] = f"data: {json.dumps(chunk)}" - chunk = "\n\n".join(chunks).encode(encoding="utf-8") - i = 1 - yield chunk - - # @TODO: raise the error instead of forwarding it (raise model ) - except Exception as e: - error = error if error else {"error": {"type": e.__class__.__name__, "message": str(e), "code": 500}} - yield f"data: {json.dumps(error)}\n\n".encode(encoding="utf-8") - yield b"data: [DONE]\n\n" - - return StreamingResponse(content=forward_stream(url=url, headers=headers, request=body), media_type="text/event-stream") + return StreamingResponseWithStatusCode( + content=forward_stream( + url=url, + method="POST", + headers=headers, + json=body, + timeout=DEFAULT_TIMEOUT, + additional_data_value=searches, + additional_data_key="search_results", + ), + media_type="text/event-stream", + ) diff --git a/app/endpoints/completions.py b/app/endpoints/completions.py index 8ae3e12d..2203334a 100644 --- a/app/endpoints/completions.py +++ b/app/endpoints/completions.py @@ -1,12 +1,10 @@ -import json - -from fastapi import APIRouter, HTTPException, Request, Security -import httpx +from fastapi import APIRouter, Request, Security from app.schemas.completions import CompletionRequest, Completions from app.schemas.security import User from app.utils.exceptions import WrongModelTypeException from app.utils.lifespan import clients, limiter +from app.utils.route import forward_request from app.utils.security import check_api_key, check_rate_limit from app.utils.settings import settings from app.utils.variables import DEFAULT_TIMEOUT, LANGUAGE_MODEL_TYPE @@ -24,17 +22,10 @@ async def completions(request: Request, body: CompletionRequest, user: User = Se client = clients.models[body.model] if client.type != LANGUAGE_MODEL_TYPE: raise WrongModelTypeException() + body.model = client.id # replace alias by model id url = f"{client.base_url}completions" headers = {"Authorization": f"Bearer {client.api_key}"} - try: - async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client: - response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump()) - response.raise_for_status() - - data = response.json() - return Completions(**data) - - except Exception as e: - raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"]) + response = await forward_request(url=url, method="POST", headers=headers, json=body.model_dump(), timeout=DEFAULT_TIMEOUT) + return Completions(**response.json()) diff --git a/app/endpoints/embeddings.py b/app/endpoints/embeddings.py index 9fca1267..a92268b9 100644 --- a/app/endpoints/embeddings.py +++ b/app/endpoints/embeddings.py @@ -1,14 +1,13 @@ -from fastapi import APIRouter, Request, Security, HTTPException -import httpx -import json +from fastapi import APIRouter, Request, Security from app.schemas.embeddings import Embeddings, EmbeddingsRequest from app.schemas.security import User -from app.utils.settings import settings from app.utils.exceptions import WrongModelTypeException from app.utils.lifespan import clients, limiter +from app.utils.route import forward_request from app.utils.security import check_api_key, check_rate_limit -from app.utils.variables import EMBEDDINGS_MODEL_TYPE, DEFAULT_TIMEOUT +from app.utils.settings import settings +from app.utils.variables import DEFAULT_TIMEOUT, EMBEDDINGS_MODEL_TYPE router = APIRouter() @@ -24,15 +23,10 @@ async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Sec client = clients.models[body.model] if client.type != EMBEDDINGS_MODEL_TYPE: raise WrongModelTypeException() + body.model = client.id # replace alias by model id url = f"{client.base_url}embeddings" headers = {"Authorization": f"Bearer {client.api_key}"} - try: - async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client: - response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump()) - response.raise_for_status() - data = response.json() - return Embeddings(**data) - except Exception as e: - raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"]) + response = await forward_request(url=url, method="POST", headers=headers, json=body.model_dump(), timeout=DEFAULT_TIMEOUT) + return Embeddings(**response.json()) diff --git a/app/helpers/__init__.py b/app/helpers/__init__.py index 1adfe8c6..538b4114 100644 --- a/app/helpers/__init__.py +++ b/app/helpers/__init__.py @@ -4,5 +4,14 @@ from ._languagemodelreranker import LanguageModelReranker from ._metricsmiddleware import MetricsMiddleware from ._searchmanager import SearchManager +from ._streamingresponsewithstatuscode import StreamingResponseWithStatusCode -__all__ = ["ClientsManager", "FileUploader", "LanguageModelReranker", "InternetManager", "MetricsMiddleware", "SearchManager"] +__all__ = [ + "ClientsManager", + "FileUploader", + "LanguageModelReranker", + "InternetManager", + "MetricsMiddleware", + "SearchManager", + "StreamingResponseWithStatusCode", +] diff --git a/app/helpers/_metricsmiddleware.py b/app/helpers/_metricsmiddleware.py index 68ecdb0c..cd00a94c 100644 --- a/app/helpers/_metricsmiddleware.py +++ b/app/helpers/_metricsmiddleware.py @@ -5,6 +5,8 @@ from starlette.middleware.base import BaseHTTPMiddleware from app.clients import AuthenticationClient +from app.utils.logging import logger +from app.utils.logging import client_ip class MetricsMiddleware(BaseHTTPMiddleware): @@ -16,10 +18,27 @@ class MetricsMiddleware(BaseHTTPMiddleware): labelnames=["user", "endpoint", "model"], ) + async def __call__(self, scope, receive, send): + try: + await super().__call__(scope, receive, send) + except RuntimeError as exc: + # ignore the error when the request is disconnected by the client + if str(exc) == "No response returned.": + logger.info( + f'"{list(scope["route"].methods)[0]} {scope["route"].path} HTTP/{scope["http_version"]}" request disconnected by the client' + ) + request = Request(scope, receive=receive) + if await request.is_disconnected(): + return + raise + async def dispatch(self, request: Request, call_next) -> Response: endpoint = request.url.path content_type = request.headers.get("Content-Type", "") + client_addr = request.client.host + client_ip.set(client_addr) + if endpoint.startswith("/v1"): authorization = request.headers.get("Authorization") model = None diff --git a/app/helpers/_streamingresponsewithstatuscode.py b/app/helpers/_streamingresponsewithstatuscode.py new file mode 100644 index 00000000..a2526e96 --- /dev/null +++ b/app/helpers/_streamingresponsewithstatuscode.py @@ -0,0 +1,62 @@ +import json +from typing import AsyncIterator + +from fastapi.responses import StreamingResponse +from starlette.types import Send + + +class StreamingResponseWithStatusCode(StreamingResponse): + """ + Variation of StreamingResponse that can dynamically decide the HTTP status code, + based on the return value of the content iterator (parameter `content`). + Expects the content to yield either just str content as per the original `StreamingResponse` + or else tuples of (`content`: `str`, `status_code`: `int`). + """ + + body_iterator: AsyncIterator[str | bytes] + response_started: bool = False + + async def stream_response(self, send: Send) -> None: + more_body = True + try: + first_chunk = await self.body_iterator.__anext__() + if isinstance(first_chunk, tuple): + first_chunk_content, self.status_code = first_chunk + else: + first_chunk_content, self.status_code = first_chunk, 200 + + if isinstance(first_chunk_content, str): + first_chunk_content = first_chunk_content.encode(self.charset) + + await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) + + self.response_started = True + await send({"type": "http.response.body", "body": first_chunk_content, "more_body": more_body}) + + async for chunk in self.body_iterator: + if isinstance(chunk, tuple): + content, status_code = chunk + if status_code // 100 != 2: + # an error occurred mid-stream + if not isinstance(content, bytes): + content = content.encode(self.charset) + more_body = False + await send({"type": "http.response.body", "body": content, "more_body": more_body}) + return + else: + content = chunk + + if isinstance(content, str): + content = content.encode(self.charset) + more_body = True + await send({"type": "http.response.body", "body": content, "more_body": more_body}) + + except Exception: + more_body = False + error_resp = {"error": {"message": "Internal Server Error"}} + error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset) + if not self.response_started: + await send({"type": "http.response.start", "status": 500, "headers": self.raw_headers}) + await send({"type": "http.response.body", "body": error_event, "more_body": more_body}) + if more_body: + await send({"type": "http.response.body", "body": b"", "more_body": False}) diff --git a/app/main.py b/app/main.py index dcd2ed55..1bd6458b 100644 --- a/app/main.py +++ b/app/main.py @@ -6,6 +6,7 @@ from app.helpers import MetricsMiddleware from app.schemas.security import User from app.utils.lifespan import lifespan +from app.utils.logging import logger from app.utils.security import check_admin_api_key, check_api_key from app.utils.settings import settings @@ -20,6 +21,13 @@ redoc_url="/documentation", ) + +@app.get("/") +async def root(): + logger.info("Accès à la route principale") + return {"message": "Hello World"} + + # Prometheus metrics # @TODO: env_var_name="ENABLE_METRICS" app.instrumentator = Instrumentator().instrument(app=app) diff --git a/app/tests/test_audio.py b/app/tests/test_audio.py index 6d84713a..db59d3dd 100644 --- a/app/tests/test_audio.py +++ b/app/tests/test_audio.py @@ -45,6 +45,17 @@ def test_audio_transcriptions_mp3(self, args, session_user, setup): transcription = AudioTranscription(**response_json) assert isinstance(transcription, AudioTranscription) + def test_audio_transcriptions_text_output(self, args, session_user, setup): + """Test the POST /audio/transcriptions with text output""" + MODEL_ID = setup + + with open("app/tests/assets/audio.mp3", "rb") as f: + files = {"file": ("test.mp3", f, "audio/mpeg")} + data = {"model": MODEL_ID, "language": "fr", "response_format": "text"} + response = session_user.post(f"{args['base_url']}/audio/transcriptions", files=files, data=data) + assert response.status_code == 200, f"error: audio transcription failed ({response.status_code})" + assert isinstance(response.text, str), f"error: expected text output ({response.text})" + def test_audio_transcriptions_wav(self, args, session_user, setup): """Test the POST /audio/transcriptions endpoint with WAV file""" MODEL_ID = setup diff --git a/app/tests/test_chat.py b/app/tests/test_chat.py index 11dad851..e5c3c954 100644 --- a/app/tests/test_chat.py +++ b/app/tests/test_chat.py @@ -7,7 +7,6 @@ from app.schemas.chat import ChatCompletion, ChatCompletionChunk from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE -from app.utils.settings import settings @pytest.fixture(scope="module") @@ -43,7 +42,7 @@ def setup(args, session_user): @pytest.mark.usefixtures("args", "session_user", "setup") class TestChat: def test_chat_completions_unstreamed_response(self, args, session_user, setup): - """Test the GET /chat/completions response status code.""" + """Test the POST /chat/completions unstreamed response.""" MODEL_ID, _, _, _ = setup params = { "model": MODEL_ID, @@ -60,7 +59,7 @@ def test_chat_completions_unstreamed_response(self, args, session_user, setup): assert isinstance(chat_completion, ChatCompletion) def test_chat_completions_streamed_response(self, args, session_user, setup): - """Test the GET /chat/completions response status code.""" + """Test the POST /chat/completions streamed response.""" MODEL_ID, _, _, _ = setup params = { "model": MODEL_ID, @@ -82,7 +81,7 @@ def test_chat_completions_streamed_response(self, args, session_user, setup): assert isinstance(chat_completion_chunk, ChatCompletionChunk), f"error: retrieve chat completions chunk {chunk}" def test_chat_completions_unknown_params(self, args, session_user, setup): - """Test the GET /chat/completions unknown params.""" + """Test the POST /chat/completions unknown params.""" MODEL_ID, _, _, _ = setup params = { "model": MODEL_ID, @@ -95,6 +94,34 @@ def test_chat_completions_unknown_params(self, args, session_user, setup): response = session_user.post(f"{args["base_url"]}/chat/completions", json=params) assert response.status_code == 200, f"error: retrieve chat completions ({response.status_code})" + def test_chat_completions_invalid_params(self, args, session_user, setup): + """Test the POST /chat/completions response with unknown params.""" + MODEL_ID, _, _, _ = setup + params = { + "model": MODEL_ID, + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "stream": False, + "n": 1, + "max_tokens": 10, + "test": "test", # invalid param + } + response = session_user.post(f"{args["base_url"]}/chat/completions", json=params) + assert response.status_code == 400, f"error: retrieve chat completions ({response.status_code})" + + def test_chat_completions_streamed_invalid_params(self, args, session_user, setup): + """Test the POST /chat/completions streamed response with unknown params.""" + MODEL_ID, _, _, _ = setup + params = { + "model": MODEL_ID, + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "stream": True, + "n": 1, + "max_tokens": 10, + "test": "test", # invalid param + } + response = session_user.post(f"{args["base_url"]}/chat/completions", json=params) + assert response.status_code == 400, f"error: retrieve chat completions ({response.status_code})" + def test_chat_completions_context_too_large(self, args, session_user, setup): MODEL_ID, MAX_CONTEXT_LENGTH, _, _ = setup @@ -109,6 +136,19 @@ def test_chat_completions_context_too_large(self, args, session_user, setup): response = session_user.post(f"{args["base_url"]}/chat/completions", json=params) assert response.status_code == 400, f"error: retrieve chat completions ({response.status_code})" + def test_chat_completions_streamed_context_too_large(self, args, session_user, setup): + MODEL_ID, MAX_CONTEXT_LENGTH, _, _ = setup + prompt = "test" * (MAX_CONTEXT_LENGTH + 100) + params = { + "model": MODEL_ID, + "messages": [{"role": "user", "content": prompt}], + "stream": True, + "n": 1, + "max_tokens": 10, + } + response = session_user.post(f"{args["base_url"]}/chat/completions", json=params) + assert response.status_code == 400, f"error: retrieve chat completions ({response.status_code})" + def test_chat_completions_search_unstreamed_response(self, args, session_user, setup): """Test the GET /chat/completions search unstreamed response.""" MODEL_ID, _, DOCUMENT_IDS, COLLECTION_ID = setup @@ -261,21 +301,3 @@ def test_chat_completions_search_wrong_collection(self, args, session_user, setu } response = session_user.post(f"{args["base_url"]}/chat/completions", json=params) assert response.status_code == 404, f"error: retrieve chat completions ({response.status_code})" - - def test_chat_completions_model_alias(self, args, session_user, setup): - """Test the GET /chat/completions model alias.""" - MODEL_ID, _, _, _ = setup - - model_id = list(settings.models.aliases.keys())[0] - aliases = settings.models.aliases[model_id] - - params = { - "model": aliases[0], - "messages": [{"role": "user", "content": "Hello, how are you?"}], - "stream": False, - "n": 1, - "max_tokens": 10, - } - - response = session_user.post(f"{args["base_url"]}/chat/completions", json=params) - assert response.status_code == 200, f"error: retrieve chat completions ({response.status_code}" diff --git a/app/utils/logging.py b/app/utils/logging.py index 4540adc9..83e611d5 100644 --- a/app/utils/logging.py +++ b/app/utils/logging.py @@ -1,6 +1,32 @@ +from contextvars import ContextVar import logging +from logging import Logger +import sys +from typing import Optional + from app.utils.settings import settings -logging.basicConfig(format="%(levelname)s:%(asctime)s:%(name)s: %(message)s", level=logging.INFO) -logger = logging.getLogger(__name__) -logger.setLevel(settings.log_level) +client_ip: ContextVar[Optional[str]] = ContextVar("client_ip", default=None) + + +class ClientIPFilter(logging.Filter): + def filter(self, record): + client_addr = client_ip.get() + record.client_ip = client_addr if client_addr else "." + return True + + +def setup_logger() -> Logger: + logger = logging.getLogger(name="app") + logger.setLevel(level=settings.log_level) + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter("[%(asctime)s][%(process)d:%(threadName)s][%(levelname)s] %(client_ip)s - %(message)s") + handler.setFormatter(formatter) + + logger.addFilter(ClientIPFilter()) + logger.addHandler(handler) + + return logger + + +logger = setup_logger() diff --git a/app/utils/route.py b/app/utils/route.py new file mode 100644 index 00000000..78530c5f --- /dev/null +++ b/app/utils/route.py @@ -0,0 +1,120 @@ +import ast +from json import dumps, loads +from typing import Optional + +from fastapi import HTTPException +import httpx + + +async def forward_request( + url: str, + method: str, + headers: Optional[dict] = None, + json: Optional[dict] = None, + files: Optional[dict] = None, + data: Optional[dict] = None, + timeout: Optional[int] = None, + additional_data_value: Optional[list] = None, + additional_data_key: Optional[str] = None, +) -> httpx.Response: + """ + Forward a request to an API and add additional data to the response if provided. + + Args: + url(str): The URL to forward the request to. + method(str): The method to use for the request. + headers(dict): The headers to use for the request. + json(dict): The JSON body to use for the request. + files(dict): The files to use for the request. + data(dict): The data to use for the request. + timeout(int): The timeout to use for the request. + additional_data_value(list): The value to add to the response. + additional_data_key(str): The key to add the value to. + + Returns: + httpx.Response: The response from the API. + """ + async with httpx.AsyncClient(timeout=timeout) as async_client: + try: + response = await async_client.request(method=method, url=url, headers=headers, json=json, files=files, data=data, timeout=timeout) + except httpx.TimeoutException or httpx.ReadTimeout or httpx.ConnectTimeout or httpx.WriteTimeout or httpx.PoolTimeout as e: + raise HTTPException(status_code=504, detail="Request timed out, model is not available.") + except Exception as e: + raise HTTPException(status_code=500, detail=type(e).__name__) + try: + response.raise_for_status() + except httpx.HTTPStatusError: + message = loads(response.text) + + # format error message + if "message" in message: + try: + message = ast.literal_eval(message["message"]) + except Exception: + message = message["message"] + raise HTTPException(status_code=response.status_code, detail=message) + + # add additional data to the response + if additional_data_value and additional_data_key: + data = response.json() + data[additional_data_key] = additional_data_value + response = httpx.Response(status_code=response.status_code, content=dumps(data)) + + return response + + +async def forward_stream( + url: str, + method: str, + headers: Optional[dict] = None, + json: Optional[dict] = None, + files: Optional[dict] = None, + data: Optional[dict] = None, + timeout: Optional[int] = None, + additional_data_value: Optional[list] = None, + additional_data_key: Optional[str] = None, +): + """ + Streams the response from the API and adds additional data to the response if provided. + + Args: + url(str): The URL to forward the request to. + method(str): The method to use for the request. + headers(dict): The headers to use for the request. + json(dict): The JSON body to use for the request. + files(dict): The files to use for the request. + data(dict): The data to use for the request. + timeout(int): The timeout to use for the request. + additional_data_value(list): The value to add to the response (only on the first chunk). + additional_data_key(str): The key to add the value to (only on the first chunk). + """ + async with httpx.AsyncClient(timeout=timeout) as async_client: + try: + async with async_client.stream(method=method, url=url, headers=headers, json=json, files=files, data=data) as response: + first_chunk = True + async for chunk in response.aiter_raw(): + # format error message + if response.status_code // 100 != 2: + chunks = loads(chunk.decode(encoding="utf-8")) + if "message" in chunks: + try: + chunks["message"] = ast.literal_eval(chunks["message"]) + except Exception: + pass + chunk = dumps(chunks).encode(encoding="utf-8") + + # add additional data to the first chunk + elif first_chunk and additional_data_value and additional_data_key: + chunks = chunk.decode(encoding="utf-8").split(sep="\n\n") + chunk = loads(chunks[0].lstrip("data: ")) + chunk[additional_data_key] = additional_data_value + chunks[0] = f"data: {dumps(chunk)}" + chunk = "\n\n".join(chunks).encode(encoding="utf-8") + + first_chunk = False + yield chunk, response.status_code + + except httpx.TimeoutException or httpx.ReadTimeout or httpx.ConnectTimeout or httpx.WriteTimeout or httpx.PoolTimeout as e: + yield dumps({"detail": "Request timed out, model is not available."}).encode(), 504 + except Exception as e: + yield dumps({"detail": type(e).__name__}).encode(), 500 diff --git a/app/utils/variables.py b/app/utils/variables.py index 3a417be8..da33e54a 100644 --- a/app/utils/variables.py +++ b/app/utils/variables.py @@ -1,4 +1,4 @@ -DEFAULT_TIMEOUT = 120 +DEFAULT_TIMEOUT = 3 INTERNET_COLLECTION_DISPLAY_ID = "internet"