Skip to content

Commit

Permalink
feat: added headers to the DIAL exception class (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Nov 26, 2024
1 parent 45681f3 commit 7632cda
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 56 deletions.
1 change: 1 addition & 0 deletions aidial_sdk/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def fastapi_exception_handler(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse(
status_code=exc.status_code,
content=exc.detail,
headers=exc.headers,
)


Expand Down
8 changes: 7 additions & 1 deletion aidial_sdk/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import warnings
from http import HTTPStatus
from typing import Optional
from typing import Dict, Optional

from fastapi import HTTPException as FastAPIException
from fastapi.responses import JSONResponse
Expand All @@ -18,6 +18,7 @@ def __init__(
param: Optional[str] = None,
code: Optional[str] = None,
display_message: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
) -> None:
status_code = int(status_code)

Expand All @@ -27,8 +28,11 @@ def __init__(
self.param = param
self.code = code or str(status_code)
self.display_message = display_message
self.headers = headers

def __repr__(self):
# headers field is omitted deliberately
# since it may contain sensitive information
return (
"%s(message=%r, status_code=%r, type=%r, param=%r, code=%r, display_message=%r)"
% (
Expand Down Expand Up @@ -59,12 +63,14 @@ def to_fastapi_response(self) -> JSONResponse:
return JSONResponse(
status_code=self.status_code,
content=self.json_error(),
headers=self.headers,
)

def to_fastapi_exception(self) -> FastAPIException:
return FastAPIException(
status_code=self.status_code,
detail=self.json_error(),
headers=self.headers,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aidial_sdk.chat_completion import ChatCompletion, Request, Response


def raise_exception(exception_type: str):
def _raise_exception(exception_type: str):
if exception_type == "sdk_exception":
raise DIALException("Test error", 503)
elif exception_type == "fastapi_exception":
Expand All @@ -15,16 +15,38 @@ def raise_exception(exception_type: str):
return 1 / 0
elif exception_type == "sdk_exception_with_display_message":
raise DIALException("Test error", 503, display_message="I'm broken")
elif exception_type == "sdk_exception_with_headers":
raise DIALException(
"Too many requests", 429, headers={"Retry-After": "42"}
)
else:
raise DIALException("Unexpected error")


class BrokenApplication(ChatCompletion):
class ImmediatelyBrokenApplication(ChatCompletion):
"""
Application which breaks immediately after receiving a request.
"""

async def chat_completion(
self, request: Request, response: Response
) -> None:
raise_exception(request.messages[0].text())
_raise_exception(request.messages[0].text())


class RuntimeBrokenApplication(ChatCompletion):
"""
Application which breaks after producing some output.
"""

async def chat_completion(
self, request: Request, response: Response
) -> None:
response.set_response_id("test_id")
response.set_created(0)

with response.create_single_choice() as choice:
choice.append_content("Test content")
await response.aflush()

_raise_exception(request.messages[0].text())
20 changes: 0 additions & 20 deletions tests/applications/broken_in_runtime.py

This file was deleted.

86 changes: 54 additions & 32 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import dataclasses
import json
from typing import Any, Dict, List

import pytest
from starlette.testclient import TestClient

from aidial_sdk import DIALApp
from tests.applications.broken_immediately import BrokenApplication
from tests.applications.broken_in_runtime import RuntimeBrokenApplication
from tests.applications.broken import (
ImmediatelyBrokenApplication,
RuntimeBrokenApplication,
)
from tests.applications.noop import NoopApplication

DEFAULT_RUNTIME_ERROR = {
Expand All @@ -24,11 +28,20 @@
}
}

error_testdata = [
("fastapi_exception", 500, DEFAULT_RUNTIME_ERROR),
("value_error_exception", 500, DEFAULT_RUNTIME_ERROR),
("zero_division_exception", 500, DEFAULT_RUNTIME_ERROR),
(

@dataclasses.dataclass
class ErrorTestCase:
content: Any
response_code: int
response_error: dict
response_headers: Dict[str, str] = dataclasses.field(default_factory=dict)


error_testcases: List[ErrorTestCase] = [
ErrorTestCase("fastapi_exception", 500, DEFAULT_RUNTIME_ERROR),
ErrorTestCase("value_error_exception", 500, DEFAULT_RUNTIME_ERROR),
ErrorTestCase("zero_division_exception", 500, DEFAULT_RUNTIME_ERROR),
ErrorTestCase(
"sdk_exception",
503,
{
Expand All @@ -39,7 +52,7 @@
}
},
),
(
ErrorTestCase(
"sdk_exception_with_display_message",
503,
{
Expand All @@ -51,7 +64,7 @@
}
},
),
(
ErrorTestCase(
None,
400,
{
Expand All @@ -62,7 +75,7 @@
}
},
),
(
ErrorTestCase(
[{"type": "text", "text": "hello"}],
400,
{
Expand All @@ -73,57 +86,66 @@
}
},
),
ErrorTestCase(
"sdk_exception_with_headers",
429,
{
"error": {
"message": "Too many requests",
"type": "runtime_error",
"code": "429",
}
},
{"Retry-after": "42"},
),
]


@pytest.mark.parametrize(
"type, response_status_code, response_content", error_testdata
)
def test_error(type, response_status_code, response_content):
@pytest.mark.parametrize("test_case", error_testcases)
def test_error(test_case: ErrorTestCase):
dial_app = DIALApp()
dial_app.add_chat_completion("test_app", BrokenApplication())
dial_app.add_chat_completion("test_app", ImmediatelyBrokenApplication())

test_app = TestClient(dial_app)

response = test_app.post(
"/openai/deployments/test_app/chat/completions",
json={
"messages": [{"role": "user", "content": type}],
"messages": [{"role": "user", "content": test_case.content}],
"stream": False,
},
headers={"Api-Key": "TEST_API_KEY"},
)

assert response.status_code == response_status_code
assert response.json() == response_content
assert response.status_code == test_case.response_code
assert response.json() == test_case.response_error

for k, v in test_case.response_headers.items():
assert response.headers.get(k) == v

@pytest.mark.parametrize(
"type, response_status_code, response_content", error_testdata
)
def test_streaming_error(type, response_status_code, response_content):

@pytest.mark.parametrize("test_case", error_testcases)
def test_streaming_error(test_case: ErrorTestCase):
dial_app = DIALApp()
dial_app.add_chat_completion("test_app", BrokenApplication())
dial_app.add_chat_completion("test_app", ImmediatelyBrokenApplication())

test_app = TestClient(dial_app)

response = test_app.post(
"/openai/deployments/test_app/chat/completions",
json={
"messages": [{"role": "user", "content": type}],
"messages": [{"role": "user", "content": test_case.content}],
"stream": True,
},
headers={"Api-Key": "TEST_API_KEY"},
)

assert response.status_code == response_status_code
assert response.json() == response_content
assert response.status_code == test_case.response_code
assert response.json() == test_case.response_error


@pytest.mark.parametrize(
"type, response_status_code, response_content", error_testdata
)
def test_runtime_streaming_error(type, response_status_code, response_content):
@pytest.mark.parametrize("test_case", error_testcases)
def test_runtime_streaming_error(test_case: ErrorTestCase):
dial_app = DIALApp()
dial_app.add_chat_completion("test_app", RuntimeBrokenApplication())

Expand All @@ -132,7 +154,7 @@ def test_runtime_streaming_error(type, response_status_code, response_content):
response = test_app.post(
"/openai/deployments/test_app/chat/completions",
json={
"messages": [{"role": "user", "content": type}],
"messages": [{"role": "user", "content": test_case.content}],
"stream": True,
},
headers={"Api-Key": "TEST_API_KEY"},
Expand Down Expand Up @@ -183,7 +205,7 @@ def test_runtime_streaming_error(type, response_status_code, response_content):
"object": "chat.completion.chunk",
}
elif index == 6:
assert json.loads(data) == response_content
assert json.loads(data) == test_case.response_error
elif index == 8:
assert data == "[DONE]"

Expand Down

0 comments on commit 7632cda

Please sign in to comment.