Skip to content

Commit

Permalink
ollama exceptions now with status_codes
Browse files Browse the repository at this point in the history
  • Loading branch information
igorbenav committed Oct 28, 2024
1 parent 3e63682 commit 2f6c3cd
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
8 changes: 2 additions & 6 deletions clientai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,12 @@ class TimeoutError(ClientAIError):
"""Raised when a request to the AI provider times out."""


def map_status_code_to_exception(
status_code: int, message: str, original_error: Optional[Exception] = None
) -> Type[ClientAIError]:
def map_status_code_to_exception(status_code: int) -> Type[ClientAIError]:
"""
Maps an HTTP status code to the appropriate ClientAI exception class.
Args:
status_code (int): The HTTP status code.
message (str): The error message.
original_error (Exception, optional): The original exception caught.
Returns:
Type[ClientAIError]: The appropriate ClientAI exception class.
Expand Down Expand Up @@ -99,6 +95,6 @@ def raise_clientai_error(
ClientAIError: The appropriate ClientAI exception.
"""
exception_class = map_status_code_to_exception(
status_code, message, original_error
status_code,
)
raise exception_class(message, status_code, original_error)
20 changes: 13 additions & 7 deletions clientai/ollama/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,26 @@ def _map_exception_to_clientai_error(self, e: Exception) -> ClientAIError:

if isinstance(e, ollama.RequestError):
if "authentication" in message.lower():
return AuthenticationError(message, original_error=e)
return AuthenticationError(
message, status_code=401, original_error=e
)
elif "rate limit" in message.lower():
return RateLimitError(message, original_error=e)
return RateLimitError(
message, status_code=429, original_error=e
)
elif "not found" in message.lower():
return ModelError(message, original_error=e)
return ModelError(message, status_code=404, original_error=e)
else:
return InvalidRequestError(message, original_error=e)
return InvalidRequestError(
message, status_code=400, original_error=e
)
elif isinstance(e, ollama.ResponseError):
if "timeout" in message.lower() or "timed out" in message.lower():
return TimeoutError(message, original_error=e)
return TimeoutError(message, status_code=408, original_error=e)
else:
return APIError(message, original_error=e)
return APIError(message, status_code=500, original_error=e)
else:
return ClientAIError(message, original_error=e)
return ClientAIError(message, status_code=500, original_error=e)

def generate_text(
self,
Expand Down
24 changes: 16 additions & 8 deletions tests/ollama/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def test_generate_text_authentication_error(mock_ollama, provider):
with pytest.raises(AuthenticationError) as exc_info:
provider.generate_text(prompt="Test prompt", model="test-model")

assert str(exc_info.value) == "Authentication failed"
assert str(exc_info.value) == "[401] Authentication failed"
assert exc_info.value.status_code == 401
assert exc_info.value.original_exception is error


Expand All @@ -56,7 +57,8 @@ def test_generate_text_rate_limit_error(mock_ollama, provider):
with pytest.raises(RateLimitError) as exc_info:
provider.generate_text(prompt="Test prompt", model="test-model")

assert str(exc_info.value) == "Rate limit exceeded"
assert str(exc_info.value) == "[429] Rate limit exceeded"
assert exc_info.value.status_code == 429
assert exc_info.value.original_exception is error


Expand All @@ -67,7 +69,8 @@ def test_generate_text_model_error(mock_ollama, provider):
with pytest.raises(ModelError) as exc_info:
provider.generate_text(prompt="Test prompt", model="test-model")

assert str(exc_info.value) == "Model not found"
assert str(exc_info.value) == "[404] Model not found"
assert exc_info.value.status_code == 404
assert exc_info.value.original_exception is error


Expand All @@ -78,7 +81,8 @@ def test_generate_text_invalid_request_error(mock_ollama, provider):
with pytest.raises(InvalidRequestError) as exc_info:
provider.generate_text(prompt="Test prompt", model="test-model")

assert str(exc_info.value) == "Invalid request"
assert str(exc_info.value) == "[400] Invalid request"
assert exc_info.value.status_code == 400
assert exc_info.value.original_exception is error


Expand All @@ -89,7 +93,8 @@ def test_generate_text_timeout_error(mock_ollama, provider):
with pytest.raises(TimeoutError) as exc_info:
provider.generate_text(prompt="Test prompt", model="test-model")

assert str(exc_info.value) == "Request timed out"
assert str(exc_info.value) == "[408] Request timed out"
assert exc_info.value.status_code == 408
assert exc_info.value.original_exception is error


Expand All @@ -100,7 +105,8 @@ def test_generate_text_api_error(mock_ollama, provider):
with pytest.raises(APIError) as exc_info:
provider.generate_text(prompt="Test prompt", model="test-model")

assert str(exc_info.value) == "API response error"
assert str(exc_info.value) == "[500] API response error"
assert exc_info.value.status_code == 500
assert exc_info.value.original_exception is error


Expand All @@ -111,7 +117,8 @@ def test_chat_request_error(mock_ollama, provider, valid_chat_request):
with pytest.raises(InvalidRequestError) as exc_info:
provider.chat(**valid_chat_request)

assert str(exc_info.value) == "Invalid chat request"
assert str(exc_info.value) == "[400] Invalid chat request"
assert exc_info.value.status_code == 400
assert exc_info.value.original_exception is error


Expand All @@ -122,5 +129,6 @@ def test_chat_response_error(mock_ollama, provider, valid_chat_request):
with pytest.raises(APIError) as exc_info:
provider.chat(**valid_chat_request)

assert str(exc_info.value) == "Chat API response error"
assert str(exc_info.value) == "[500] Chat API response error"
assert exc_info.value.status_code == 500
assert exc_info.value.original_exception is error

0 comments on commit 2f6c3cd

Please sign in to comment.