From 80650d667b2e6b5020a6da4251260989b118668a Mon Sep 17 00:00:00 2001 From: Igor Benav Date: Fri, 25 Oct 2024 03:19:58 -0300 Subject: [PATCH 1/3] unified error handling --- clientai/_typing.py | 6 +- clientai/exceptions.py | 104 ++++++++++++++++++ clientai/ollama/_typing.py | 6 +- clientai/ollama/provider.py | 110 +++++++++++++------ clientai/openai/_typing.py | 6 +- clientai/openai/provider.py | 152 +++++++++++++++++++------- clientai/replicate/_typing.py | 9 +- clientai/replicate/provider.py | 142 +++++++++++++++++------- docs/usage/error_handling.md | 170 +++++++++++++++++++++++++++++ tests/ollama/test_exceptions.py | 126 +++++++++++++++++++++ tests/openai/test_exceptions.py | 148 +++++++++++++++++++++++++ tests/openai/test_provider.py | 11 -- tests/replicate/test_exceptions.py | 132 ++++++++++++++++++++++ tests/replicate/test_provider.py | 21 ++-- 14 files changed, 993 insertions(+), 150 deletions(-) create mode 100644 clientai/exceptions.py create mode 100644 docs/usage/error_handling.md create mode 100644 tests/ollama/test_exceptions.py create mode 100644 tests/openai/test_exceptions.py create mode 100644 tests/replicate/test_exceptions.py diff --git a/clientai/_typing.py b/clientai/_typing.py index c18ea64..ec94651 100644 --- a/clientai/_typing.py +++ b/clientai/_typing.py @@ -26,8 +26,7 @@ def generate_text( return_full_response: bool = False, stream: bool = False, **kwargs: Any, - ) -> GenericResponse[R, T, S]: - ... + ) -> GenericResponse[R, T, S]: ... def chat( self, @@ -36,8 +35,7 @@ def chat( return_full_response: bool = False, stream: bool = False, **kwargs: Any, - ) -> GenericResponse[R, T, S]: - ... + ) -> GenericResponse[R, T, S]: ... P = TypeVar("P", bound=AIProviderProtocol) diff --git a/clientai/exceptions.py b/clientai/exceptions.py new file mode 100644 index 0000000..a7680d2 --- /dev/null +++ b/clientai/exceptions.py @@ -0,0 +1,104 @@ +from typing import Optional, Type + + +class ClientAIError(Exception): + """Base exception class for ClientAI errors.""" + + def __init__( + self, + message: str, + status_code: Optional[int] = None, + original_error: Optional[Exception] = None, + ): + super().__init__(message) + self.status_code = status_code + self.original_error = original_error + + def __str__(self): + error_msg = super().__str__() + if self.status_code: + error_msg = f"[{self.status_code}] {error_msg}" + return error_msg + + @property + def original_exception(self) -> Optional[Exception]: + """Returns the original exception object if available.""" + return self.original_error + + +class AuthenticationError(ClientAIError): + """Raised when there's an authentication problem with the AI provider.""" + + +class APIError(ClientAIError): + """Raised when there's an API-related error from the AI provider.""" + + +class RateLimitError(ClientAIError): + """Raised when the AI provider's rate limit is exceeded.""" + + +class InvalidRequestError(ClientAIError): + """Raised when the request to the AI provider is invalid.""" + + +class ModelError(ClientAIError): + """Raised when there's an issue with the specified model.""" + + +class ProviderNotInstalledError(ClientAIError): + """Raised when the required provider package is not installed.""" + + +class TimeoutError(ClientAIError): + """Raised when a request to the AI provider times out.""" + + +def map_status_code_to_exception( + status_code: int, message: str, original_error: Optional[Exception] = None +) -> Type[ClientAIError]: + """ + Maps an HTTP status code to the appropriate ClientAI exception class. + + Args: + status_code (int): The HTTP status code. + message (str): The error message. + original_error (Exception, optional): The original exception caught. + + Returns: + Type[ClientAIError]: The appropriate ClientAI exception class. + """ + if status_code == 401: + return AuthenticationError + elif status_code == 429: + return RateLimitError + elif status_code == 400: + return InvalidRequestError + elif status_code == 404: + return ModelError + elif status_code == 408: + return TimeoutError + elif status_code >= 500: + return APIError + else: + return APIError + + +def raise_clientai_error( + status_code: int, message: str, original_error: Optional[Exception] = None +) -> None: + """ + Raises the appropriate ClientAI exception based on the status code. + + Args: + status_code (int): The HTTP status code. + message (str): The error message. + original_error (Exception, optional): The original exception caught. + + Raises: + ClientAIError: The appropriate ClientAI exception. + """ + exception_class = map_status_code_to_exception( + status_code, message, original_error + ) + raise exception_class(message, status_code, original_error) diff --git a/clientai/ollama/_typing.py b/clientai/ollama/_typing.py index 0fc12ab..e9fbf0a 100644 --- a/clientai/ollama/_typing.py +++ b/clientai/ollama/_typing.py @@ -53,8 +53,7 @@ class OllamaChatResponse(TypedDict): class OllamaClientProtocol(Protocol): def generate( self, model: str, prompt: str, stream: bool = False, **kwargs: Any - ) -> Union[OllamaResponse, Iterator[OllamaStreamResponse]]: - ... + ) -> Union[OllamaResponse, Iterator[OllamaStreamResponse]]: ... def chat( self, @@ -62,8 +61,7 @@ def chat( messages: List[Message], stream: bool = False, **kwargs: Any, - ) -> Union[OllamaChatResponse, Iterator[OllamaStreamResponse]]: - ... + ) -> Union[OllamaChatResponse, Iterator[OllamaStreamResponse]]: ... Client = "ollama.Client" diff --git a/clientai/ollama/provider.py b/clientai/ollama/provider.py index b40b4ad..00cfb09 100644 --- a/clientai/ollama/provider.py +++ b/clientai/ollama/provider.py @@ -2,6 +2,15 @@ from typing import Any, List, Optional, Union, cast from ..ai_provider import AIProvider +from ..exceptions import ( + APIError, + AuthenticationError, + ClientAIError, + InvalidRequestError, + ModelError, + RateLimitError, + TimeoutError, +) from . import OLLAMA_INSTALLED from ._typing import ( Message, @@ -98,6 +107,35 @@ def _stream_chat_response( else: yield chunk["message"]["content"] + def _map_exception_to_clientai_error(self, e: Exception) -> ClientAIError: + """ + Maps an Ollama exception to the appropriate ClientAI exception. + + Args: + e (Exception): The exception caught during the API call. + + Returns: + ClientAIError: An instance of the appropriate ClientAI exception. + """ + message = str(e) + + if isinstance(e, ollama.RequestError): + if "authentication" in message.lower(): + return AuthenticationError(message, original_error=e) + elif "rate limit" in message.lower(): + return RateLimitError(message, original_error=e) + elif "not found" in message.lower(): + return ModelError(message, original_error=e) + else: + return InvalidRequestError(message, original_error=e) + elif isinstance(e, ollama.ResponseError): + if "timeout" in message.lower() or "timed out" in message.lower(): + return TimeoutError(message, original_error=e) + else: + return APIError(message, original_error=e) + else: + return ClientAIError(message, original_error=e) + def generate_text( self, prompt: str, @@ -152,24 +190,28 @@ def generate_text( print(chunk, end="", flush=True) ``` """ - response = self.client.generate( - model=model, prompt=prompt, stream=stream, **kwargs - ) - - if stream: - return cast( - OllamaGenericResponse, - self._stream_generate_response( - cast(Iterator[OllamaStreamResponse], response), - return_full_response, - ), + try: + response = self.client.generate( + model=model, prompt=prompt, stream=stream, **kwargs ) - else: - response = cast(OllamaResponse, response) - if return_full_response: - return response + + if stream: + return cast( + OllamaGenericResponse, + self._stream_generate_response( + cast(Iterator[OllamaStreamResponse], response), + return_full_response, + ), + ) else: - return response["response"] + response = cast(OllamaResponse, response) + if return_full_response: + return response + else: + return response["response"] + + except Exception as e: + raise self._map_exception_to_clientai_error(e) def chat( self, @@ -231,21 +273,25 @@ def chat( print(chunk, end="", flush=True) ``` """ - response = self.client.chat( - model=model, messages=messages, stream=stream, **kwargs - ) - - if stream: - return cast( - OllamaGenericResponse, - self._stream_chat_response( - cast(Iterator[OllamaChatResponse], response), - return_full_response, - ), + try: + response = self.client.chat( + model=model, messages=messages, stream=stream, **kwargs ) - else: - response = cast(OllamaChatResponse, response) - if return_full_response: - return response + + if stream: + return cast( + OllamaGenericResponse, + self._stream_chat_response( + cast(Iterator[OllamaChatResponse], response), + return_full_response, + ), + ) else: - return response["message"]["content"] + response = cast(OllamaChatResponse, response) + if return_full_response: + return response + else: + return response["message"]["content"] + + except Exception as e: + raise self._map_exception_to_clientai_error(e) diff --git a/clientai/openai/_typing.py b/clientai/openai/_typing.py index 4703b81..9610908 100644 --- a/clientai/openai/_typing.py +++ b/clientai/openai/_typing.py @@ -71,8 +71,7 @@ class OpenAIStreamResponse: class OpenAIChatCompletionProtocol(Protocol): def create( self, **kwargs: Any - ) -> Union[OpenAIResponse, Iterator[OpenAIStreamResponse]]: - ... + ) -> Union[OpenAIResponse, Iterator[OpenAIStreamResponse]]: ... class OpenAIChatProtocol(Protocol): @@ -90,8 +89,7 @@ def create( messages: List[Message], stream: bool = False, **kwargs: Any, - ) -> Union[OpenAIResponse, OpenAIStreamResponse]: - ... + ) -> Union[OpenAIResponse, OpenAIStreamResponse]: ... OpenAIProvider = Any diff --git a/clientai/openai/provider.py b/clientai/openai/provider.py index afaec06..6820607 100644 --- a/clientai/openai/provider.py +++ b/clientai/openai/provider.py @@ -3,6 +3,15 @@ from .._common_types import Message from ..ai_provider import AIProvider +from ..exceptions import ( + APIError, + AuthenticationError, + ClientAIError, + InvalidRequestError, + ModelError, + RateLimitError, + TimeoutError, +) from . import OPENAI_INSTALLED from ._typing import ( OpenAIClientProtocol, @@ -13,10 +22,12 @@ if OPENAI_INSTALLED: import openai # type: ignore + from openai import AuthenticationError as OpenAIAuthenticationError Client = openai.OpenAI else: Client = None # type: ignore + OpenAIAuthenticationError = Exception # type: ignore class Provider(AIProvider): @@ -76,6 +87,59 @@ def _stream_response( if content: yield content + def _map_exception_to_clientai_error(self, e: Exception) -> ClientAIError: + """ + Maps an OpenAI exception to the appropriate ClientAI exception. + + Args: + e (Exception): The exception caught during the API call. + + Raises: + ClientAIError: An instance of the appropriate ClientAI exception. + """ + error_message = str(e) + status_code = None + + if hasattr(e, "status_code"): + status_code = e.status_code + else: + try: + status_code = int( + error_message.split("Error code: ")[1].split(" -")[0] + ) + except (IndexError, ValueError): + pass + + if ( + isinstance(e, OpenAIAuthenticationError) + or "incorrect api key" in error_message.lower() + ): + return AuthenticationError( + error_message, status_code, original_error=e + ) + elif ( + isinstance(e, openai.OpenAIError) + or "error code:" in error_message.lower() + ): + if status_code == 429 or "rate limit" in error_message.lower(): + return RateLimitError( + error_message, status_code, original_error=e + ) + elif status_code == 404 or "not found" in error_message.lower(): + return ModelError(error_message, status_code, original_error=e) + elif status_code == 400 or "invalid" in error_message.lower(): + return InvalidRequestError( + error_message, status_code, original_error=e + ) + elif status_code == 408 or "timeout" in error_message.lower(): + return TimeoutError( + error_message, status_code, original_error=e + ) + elif status_code and status_code >= 500: + return APIError(error_message, status_code, original_error=e) + + return ClientAIError(error_message, status_code, original_error=e) + def generate_text( self, prompt: str, @@ -100,6 +164,9 @@ def generate_text( OpenAIGenericResponse: The generated text, full response object, or an iterator for streaming responses. + Raises: + ClientAIError: If an error occurs during the API call. + Examples: Generate text (text only): ```python @@ -117,7 +184,7 @@ def generate_text( model="gpt-3.5-turbo", return_full_response=True ) - print(response["choices"][0]["message"]["content"]) + print(response.choices[0].message.content) ``` Generate text (streaming): @@ -130,27 +197,31 @@ def generate_text( print(chunk, end="", flush=True) ``` """ - response = self.client.chat.completions.create( - model=model, - messages=[{"role": "user", "content": prompt}], - stream=stream, - **kwargs, - ) - - if stream: - return cast( - OpenAIGenericResponse, - self._stream_response( - cast(Iterator[OpenAIStreamResponse], response), - return_full_response, - ), + try: + response = self.client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + stream=stream, + **kwargs, ) - else: - response = cast(OpenAIResponse, response) - if return_full_response: - return response + + if stream: + return cast( + OpenAIGenericResponse, + self._stream_response( + cast(Iterator[OpenAIStreamResponse], response), + return_full_response, + ), + ) else: - return response.choices[0].message.content + response = cast(OpenAIResponse, response) + if return_full_response: + return response + else: + return response.choices[0].message.content + + except Exception as e: + raise self._map_exception_to_clientai_error(e) def chat( self, @@ -177,6 +248,9 @@ def chat( OpenAIGenericResponse: The chat response, full response object, or an iterator for streaming responses. + Raises: + ClientAIError: If an error occurs during the API call. + Examples: Chat (message content only): ```python @@ -199,7 +273,7 @@ def chat( model="gpt-3.5-turbo", return_full_response=True ) - print(response["choices"][0]["message"]["content"]) + print(response.choices[0].message.content) ``` Chat (streaming): @@ -212,21 +286,25 @@ def chat( print(chunk, end="", flush=True) ``` """ - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) - - if stream: - return cast( - OpenAIGenericResponse, - self._stream_response( - cast(Iterator[OpenAIStreamResponse], response), - return_full_response, - ), + try: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs ) - else: - response = cast(OpenAIResponse, response) - if return_full_response: - return response + + if stream: + return cast( + OpenAIGenericResponse, + self._stream_response( + cast(Iterator[OpenAIStreamResponse], response), + return_full_response, + ), + ) else: - return response.choices[0].message.content + response = cast(OpenAIResponse, response) + if return_full_response: + return response + else: + return response.choices[0].message.content + + except Exception as e: + raise self._map_exception_to_clientai_error(e) diff --git a/clientai/replicate/_typing.py b/clientai/replicate/_typing.py index 58afc58..1df9920 100644 --- a/clientai/replicate/_typing.py +++ b/clientai/replicate/_typing.py @@ -12,8 +12,7 @@ class ReplicatePredictionProtocol(Protocol): error: Optional[str] output: Any - def stream(self) -> Iterator[Any]: - ... + def stream(self) -> Iterator[Any]: ... ReplicatePrediction = ReplicatePredictionProtocol @@ -60,12 +59,10 @@ class ReplicateResponse(TypedDict): class ReplicatePredictionsProtocol(Protocol): @staticmethod - def create(**kwargs: Any) -> ReplicatePredictionProtocol: - ... + def create(**kwargs: Any) -> ReplicatePredictionProtocol: ... @staticmethod - def get(id: str) -> ReplicatePredictionProtocol: - ... + def get(id: str) -> ReplicatePredictionProtocol: ... class ReplicateClientProtocol(Protocol): diff --git a/clientai/replicate/provider.py b/clientai/replicate/provider.py index 12a91d9..8f0eed6 100644 --- a/clientai/replicate/provider.py +++ b/clientai/replicate/provider.py @@ -1,9 +1,18 @@ import time from collections.abc import Iterator -from typing import Any, List, Union, cast +from typing import Any, List, Optional, Union, cast from .._common_types import Message from ..ai_provider import AIProvider +from ..exceptions import ( + APIError, + AuthenticationError, + ClientAIError, + InvalidRequestError, + ModelError, + RateLimitError, + TimeoutError, +) from . import REPLICATE_INSTALLED from ._typing import ( ReplicateClientProtocol, @@ -15,10 +24,12 @@ if REPLICATE_INSTALLED: import replicate # type: ignore + from replicate.exceptions import ReplicateError Client = replicate.Client else: Client = None # type: ignore + ReplicateError = Exception # type: ignore class Provider(AIProvider): @@ -85,7 +96,7 @@ def _wait_for_prediction( Raises: TimeoutError: If the prediction doesn't complete within the max_wait_time. - Exception: If the prediction fails. + APIError: If the prediction fails. """ start_time = time.time() while time.time() - start_time < max_wait_time: @@ -93,9 +104,14 @@ def _wait_for_prediction( if prediction.status == "succeeded": return prediction elif prediction.status == "failed": - raise Exception(f"Prediction failed: {prediction.error}") + raise self._map_exception_to_clientai_error( + Exception(f"Prediction failed: {prediction.error}") + ) time.sleep(1) - raise TimeoutError("Prediction timed out") + + raise self._map_exception_to_clientai_error( + Exception("Prediction timed out"), status_code=408 + ) def _stream_response( self, @@ -121,6 +137,46 @@ def _stream_response( else: yield self._process_output(event) + def _map_exception_to_clientai_error( + self, e: Exception, status_code: Optional[int] = None + ) -> ClientAIError: + """ + Maps a Replicate exception to the appropriate ClientAI exception. + + Args: + e (Exception): The exception caught during the API call. + status_code (int, optional): The HTTP status code, if available. + + Returns: + ClientAIError: An instance of the appropriate ClientAI exception. + """ + error_message = str(e) + status_code = status_code or getattr(e, "status_code", None) + + if ( + "authentication" in error_message.lower() + or "unauthorized" in error_message.lower() + ): + return AuthenticationError( + error_message, status_code, original_error=e + ) + elif "rate limit" in error_message.lower(): + return RateLimitError(error_message, status_code, original_error=e) + elif "not found" in error_message.lower(): + return ModelError(error_message, status_code, original_error=e) + elif "invalid" in error_message.lower(): + return InvalidRequestError( + error_message, status_code, original_error=e + ) + elif "timeout" in error_message.lower() or status_code == 408: + return TimeoutError(error_message, status_code, original_error=e) + elif status_code == 400: + return InvalidRequestError( + error_message, status_code, original_error=e + ) + else: + return APIError(error_message, status_code, original_error=e) + def generate_text( self, prompt: str, @@ -177,24 +233,28 @@ def generate_text( print(chunk, end="", flush=True) ``` """ - prediction = self.client.predictions.create( - model=model, input={"prompt": prompt}, stream=stream, **kwargs - ) + try: + prediction = self.client.predictions.create( + model=model, input={"prompt": prompt}, stream=stream, **kwargs + ) - if stream: - return self._stream_response(prediction, return_full_response) - else: - completed_prediction = self._wait_for_prediction(prediction.id) - if return_full_response: - response = cast( - ReplicateResponse, completed_prediction.__dict__.copy() - ) - response["output"] = self._process_output( - completed_prediction.output - ) - return response + if stream: + return self._stream_response(prediction, return_full_response) else: - return self._process_output(completed_prediction.output) + completed_prediction = self._wait_for_prediction(prediction.id) + if return_full_response: + response = cast( + ReplicateResponse, completed_prediction.__dict__.copy() + ) + response["output"] = self._process_output( + completed_prediction.output + ) + return response + else: + return self._process_output(completed_prediction.output) + + except Exception as e: + raise self._map_exception_to_clientai_error(e) def chat( self, @@ -257,24 +317,30 @@ def chat( print(chunk, end="", flush=True) ``` """ - prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) - prompt += "\nassistant: " + try: + prompt = "\n".join( + [f"{m['role']}: {m['content']}" for m in messages] + ) + prompt += "\nassistant: " - prediction = self.client.predictions.create( - model=model, input={"prompt": prompt}, stream=stream, **kwargs - ) + prediction = self.client.predictions.create( + model=model, input={"prompt": prompt}, stream=stream, **kwargs + ) - if stream: - return self._stream_response(prediction, return_full_response) - else: - completed_prediction = self._wait_for_prediction(prediction.id) - if return_full_response: - response = cast( - ReplicateResponse, completed_prediction.__dict__.copy() - ) - response["output"] = self._process_output( - completed_prediction.output - ) - return response + if stream: + return self._stream_response(prediction, return_full_response) else: - return self._process_output(completed_prediction.output) + completed_prediction = self._wait_for_prediction(prediction.id) + if return_full_response: + response = cast( + ReplicateResponse, completed_prediction.__dict__.copy() + ) + response["output"] = self._process_output( + completed_prediction.output + ) + return response + else: + return self._process_output(completed_prediction.output) + + except Exception as e: + raise self._map_exception_to_clientai_error(e) diff --git a/docs/usage/error_handling.md b/docs/usage/error_handling.md new file mode 100644 index 0000000..d1ecb1e --- /dev/null +++ b/docs/usage/error_handling.md @@ -0,0 +1,170 @@ +# Error Handling in ClientAI + +ClientAI provides a robust error handling system that unifies exceptions across different AI providers. This guide covers how to handle potential errors when using ClientAI. + +## Table of Contents + +1. [Exception Hierarchy](#exception-hierarchy) +2. [Handling Errors](#handling-errors) +3. [Provider-Specific Error Mapping](#provider-specific-error-mapping) +4. [Best Practices](#best-practices) + +## Exception Hierarchy + +ClientAI uses a custom exception hierarchy to provide consistent error handling across different AI providers: + +```python +from clientai.exceptions import ( + ClientAIError, + AuthenticationError, + RateLimitError, + InvalidRequestError, + ModelError, + TimeoutError, + APIError +) +``` + +- `ClientAIError`: Base exception class for all ClientAI errors. +- `AuthenticationError`: Raised when there's an authentication problem with the AI provider. +- `RateLimitError`: Raised when the AI provider's rate limit is exceeded. +- `InvalidRequestError`: Raised when the request to the AI provider is invalid. +- `ModelError`: Raised when there's an issue with the specified model. +- `TimeoutError`: Raised when a request to the AI provider times out. +- `APIError`: Raised when there's an API-related error from the AI provider. + +## Handling Errors + +Here's how to handle potential errors when using ClientAI: + +```python +from clientai import ClientAI +from clientai.exceptions import ( + ClientAIError, + AuthenticationError, + RateLimitError, + InvalidRequestError, + ModelError, + TimeoutError, + APIError +) + +client = ClientAI('openai', api_key="your-openai-api-key") + +try: + response = client.generate_text("Tell me a joke", model="gpt-3.5-turbo") + print(f"Generated text: {response}") +except AuthenticationError as e: + print(f"Authentication error: {e}") +except RateLimitError as e: + print(f"Rate limit exceeded: {e}") +except InvalidRequestError as e: + print(f"Invalid request: {e}") +except ModelError as e: + print(f"Model error: {e}") +except TimeoutError as e: + print(f"Request timed out: {e}") +except APIError as e: + print(f"API error: {e}") +except ClientAIError as e: + print(f"An unexpected ClientAI error occurred: {e}") +``` + +## Provider-Specific Error Mapping + +ClientAI maps provider-specific errors to its custom exception hierarchy. For example: + +### OpenAI + +```python +def _map_exception_to_clientai_error(self, e: Exception) -> None: + error_message = str(e) + status_code = getattr(e, 'status_code', None) + + if isinstance(e, OpenAIAuthenticationError) or "incorrect api key" in error_message.lower(): + raise AuthenticationError(error_message, status_code, original_error=e) + elif status_code == 429 or "rate limit" in error_message.lower(): + raise RateLimitError(error_message, status_code, original_error=e) + elif status_code == 404 or "not found" in error_message.lower(): + raise ModelError(error_message, status_code, original_error=e) + elif status_code == 400 or "invalid" in error_message.lower(): + raise InvalidRequestError(error_message, status_code, original_error=e) + elif status_code == 408 or "timeout" in error_message.lower(): + raise TimeoutError(error_message, status_code, original_error=e) + elif status_code and status_code >= 500: + raise APIError(error_message, status_code, original_error=e) + + raise ClientAIError(error_message, status_code, original_error=e) +``` + +### Replicate + +```python +def _map_exception_to_clientai_error(self, e: Exception, status_code: int = None) -> ClientAIError: + error_message = str(e) + status_code = status_code or getattr(e, 'status_code', None) + + if "authentication" in error_message.lower() or "unauthorized" in error_message.lower(): + return AuthenticationError(error_message, status_code, original_error=e) + elif "rate limit" in error_message.lower(): + return RateLimitError(error_message, status_code, original_error=e) + elif "not found" in error_message.lower(): + return ModelError(error_message, status_code, original_error=e) + elif "invalid" in error_message.lower(): + return InvalidRequestError(error_message, status_code, original_error=e) + elif "timeout" in error_message.lower() or status_code == 408: + return TimeoutError(error_message, status_code, original_error=e) + elif status_code == 400: + return InvalidRequestError(error_message, status_code, original_error=e) + else: + return APIError(error_message, status_code, original_error=e) +``` + +## Best Practices + +1. **Specific Exception Handling**: Catch specific exceptions when you need to handle them differently. + +2. **Logging**: Log errors for debugging and monitoring purposes. + + ```python + import logging + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + try: + response = client.generate_text("Tell me a joke", model="gpt-3.5-turbo") + except ClientAIError as e: + logger.error(f"An error occurred: {e}", exc_info=True) + ``` + +3. **Retry Logic**: Implement retry logic for transient errors like rate limiting. + + ```python + import time + from clientai.exceptions import RateLimitError + + def retry_generate(prompt, model, max_retries=3, delay=1): + for attempt in range(max_retries): + try: + return client.generate_text(prompt, model=model) + except RateLimitError as e: + if attempt == max_retries - 1: + raise + wait_time = e.retry_after if hasattr(e, 'retry_after') else delay * (2 ** attempt) + logger.warning(f"Rate limit reached. Waiting for {wait_time} seconds...") + time.sleep(wait_time) + ``` + +4. **Graceful Degradation**: Implement fallback options when errors occur. + + ```python + def generate_with_fallback(prompt, primary_client, fallback_client): + try: + return primary_client.generate_text(prompt, model="gpt-3.5-turbo") + except ClientAIError as e: + logger.warning(f"Primary client failed: {e}. Falling back to secondary client.") + return fallback_client.generate_text(prompt, model="llama-2-70b-chat") + ``` + +By following these practices and utilizing ClientAI's unified error handling system, you can create more robust and maintainable applications that gracefully handle errors across different AI providers. \ No newline at end of file diff --git a/tests/ollama/test_exceptions.py b/tests/ollama/test_exceptions.py new file mode 100644 index 0000000..b636386 --- /dev/null +++ b/tests/ollama/test_exceptions.py @@ -0,0 +1,126 @@ +from unittest.mock import patch + +import pytest + +from clientai.exceptions import ( + APIError, + AuthenticationError, + InvalidRequestError, + ModelError, + RateLimitError, + TimeoutError, +) +from clientai.ollama.provider import Provider + + +@pytest.fixture +def provider(): + return Provider() + + +@pytest.fixture(autouse=True) +def mock_ollama(): + with patch("clientai.ollama.provider.ollama") as mock: + mock.RequestError = type("RequestError", (Exception,), {}) + mock.ResponseError = type("ResponseError", (Exception,), {}) + yield mock + + +@pytest.fixture +def valid_chat_request(): + return { + "model": "test-model", + "messages": [{"role": "user", "content": "Test message"}], + "stream": False, + "format": "", + "options": None, + "keep_alive": None, + } + + +def test_generate_text_authentication_error(mock_ollama, provider): + error = mock_ollama.RequestError("Authentication failed") + mock_ollama.generate.side_effect = error + + with pytest.raises(AuthenticationError) as exc_info: + provider.generate_text(prompt="Test prompt", model="test-model") + + assert str(exc_info.value) == "Authentication failed" + assert exc_info.value.original_exception is error + + +def test_generate_text_rate_limit_error(mock_ollama, provider): + error = mock_ollama.RequestError("Rate limit exceeded") + mock_ollama.generate.side_effect = error + + with pytest.raises(RateLimitError) as exc_info: + provider.generate_text(prompt="Test prompt", model="test-model") + + assert str(exc_info.value) == "Rate limit exceeded" + assert exc_info.value.original_exception is error + + +def test_generate_text_model_error(mock_ollama, provider): + error = mock_ollama.RequestError("Model not found") + mock_ollama.generate.side_effect = error + + with pytest.raises(ModelError) as exc_info: + provider.generate_text(prompt="Test prompt", model="test-model") + + assert str(exc_info.value) == "Model not found" + assert exc_info.value.original_exception is error + + +def test_generate_text_invalid_request_error(mock_ollama, provider): + error = mock_ollama.RequestError("Invalid request") + mock_ollama.generate.side_effect = error + + with pytest.raises(InvalidRequestError) as exc_info: + provider.generate_text(prompt="Test prompt", model="test-model") + + assert str(exc_info.value) == "Invalid request" + assert exc_info.value.original_exception is error + + +def test_generate_text_timeout_error(mock_ollama, provider): + error = mock_ollama.ResponseError("Request timed out") + mock_ollama.generate.side_effect = error + + with pytest.raises(TimeoutError) as exc_info: + provider.generate_text(prompt="Test prompt", model="test-model") + + assert str(exc_info.value) == "Request timed out" + assert exc_info.value.original_exception is error + + +def test_generate_text_api_error(mock_ollama, provider): + error = mock_ollama.ResponseError("API response error") + mock_ollama.generate.side_effect = error + + with pytest.raises(APIError) as exc_info: + provider.generate_text(prompt="Test prompt", model="test-model") + + assert str(exc_info.value) == "API response error" + assert exc_info.value.original_exception is error + + +def test_chat_request_error(mock_ollama, provider, valid_chat_request): + error = mock_ollama.RequestError("Invalid chat request") + mock_ollama.chat.side_effect = error + + with pytest.raises(InvalidRequestError) as exc_info: + provider.chat(**valid_chat_request) + + assert str(exc_info.value) == "Invalid chat request" + assert exc_info.value.original_exception is error + + +def test_chat_response_error(mock_ollama, provider, valid_chat_request): + error = mock_ollama.ResponseError("Chat API response error") + mock_ollama.chat.side_effect = error + + with pytest.raises(APIError) as exc_info: + provider.chat(**valid_chat_request) + + assert str(exc_info.value) == "Chat API response error" + assert exc_info.value.original_exception is error diff --git a/tests/openai/test_exceptions.py b/tests/openai/test_exceptions.py new file mode 100644 index 0000000..014be62 --- /dev/null +++ b/tests/openai/test_exceptions.py @@ -0,0 +1,148 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from clientai.exceptions import ( + APIError, + AuthenticationError, + ClientAIError, + InvalidRequestError, + ModelError, + RateLimitError, + TimeoutError, +) +from clientai.openai.provider import Provider + + +class MockOpenAIError(Exception): + def __init__(self, message, error_type, status_code): + self.message = message + self.type = error_type + self.status_code = status_code + + def __str__(self): + return f"Error code: {self.status_code} - {self.message}" + + +class MockOpenAIAuthenticationError(MockOpenAIError): + def __init__(self, message, status_code=401): + super().__init__(message, "invalid_request_error", status_code) + + +@pytest.fixture +def provider(): + return Provider(api_key="test_key") + + +@pytest.fixture +def mock_openai_client(): + with patch("clientai.openai.provider.Client") as mock_client: + mock_instance = MagicMock() + mock_client.return_value = mock_instance + + mock_instance.with_api_key.return_value = mock_instance + mock_instance.chat.completions.create.return_value = MagicMock() + + yield mock_instance + + +def test_generate_text_authentication_error(mock_openai_client, provider): + error = MockOpenAIAuthenticationError( + "Incorrect API key provided: test_key. You can find your API key at https://platform.openai.com/account/api-keys." + ) + mock_openai_client.chat.completions.create.side_effect = error + + with pytest.raises(AuthenticationError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Incorrect API key provided" in str(exc_info.value) + assert exc_info.value.status_code == 401 + assert isinstance( + exc_info.value.original_error, MockOpenAIAuthenticationError + ) + + +def test_generate_text_rate_limit_error(mock_openai_client, provider): + error = MockOpenAIError("Rate limit exceeded", "rate_limit_error", 429) + mock_openai_client.chat.completions.create.side_effect = error + + with pytest.raises(RateLimitError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Rate limit exceeded" in str(exc_info.value) + assert exc_info.value.status_code == 429 + assert isinstance(exc_info.value.original_error, MockOpenAIError) + + +def test_generate_text_model_error(mock_openai_client, provider): + error = MockOpenAIError("Model not found", "model_not_found", 404) + mock_openai_client.chat.completions.create.side_effect = error + + with pytest.raises(ModelError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Model not found" in str(exc_info.value) + assert exc_info.value.status_code == 404 + assert isinstance(exc_info.value.original_error, MockOpenAIError) + + +def test_generate_text_invalid_request_error(mock_openai_client, provider): + error = MockOpenAIError("Invalid request", "invalid_request_error", 400) + mock_openai_client.chat.completions.create.side_effect = error + + with pytest.raises(InvalidRequestError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Invalid request" in str(exc_info.value) + assert exc_info.value.status_code == 400 + assert isinstance(exc_info.value.original_error, MockOpenAIError) + + +def test_generate_text_timeout_error(mock_openai_client, provider): + error = MockOpenAIError("Request timed out", "timeout_error", 408) + mock_openai_client.chat.completions.create.side_effect = error + + with pytest.raises(TimeoutError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Request timed out" in str(exc_info.value) + assert exc_info.value.status_code == 408 + assert isinstance(exc_info.value.original_error, MockOpenAIError) + + +def test_generate_text_api_error(mock_openai_client, provider): + error = MockOpenAIError("API error", "api_error", 500) + mock_openai_client.chat.completions.create.side_effect = error + + with pytest.raises(APIError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "API error" in str(exc_info.value) + assert exc_info.value.status_code == 500 + assert isinstance(exc_info.value.original_error, MockOpenAIError) + + +def test_chat_error(mock_openai_client, provider): + error = MockOpenAIError("Chat error", "invalid_request_error", 400) + mock_openai_client.chat.completions.create.side_effect = error + + with pytest.raises(InvalidRequestError) as exc_info: + provider.chat( + [{"role": "user", "content": "Test message"}], "test-model" + ) + + assert "Chat error" in str(exc_info.value) + assert exc_info.value.status_code == 400 + assert isinstance(exc_info.value.original_error, MockOpenAIError) + + +def test_generic_error(mock_openai_client, provider): + mock_openai_client.chat.completions.create.side_effect = Exception( + "Unexpected error" + ) + + with pytest.raises(ClientAIError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Unexpected error" in str(exc_info.value) + assert isinstance(exc_info.value.original_error, Exception) diff --git a/tests/openai/test_provider.py b/tests/openai/test_provider.py index 50a9940..c3f721e 100644 --- a/tests/openai/test_provider.py +++ b/tests/openai/test_provider.py @@ -1,6 +1,5 @@ from unittest.mock import Mock, patch -import openai import pytest from clientai.openai._typing import ( @@ -233,16 +232,6 @@ def test_chat_stream(mock_client, provider): ) -def test_openai_error(mock_client, provider): - mock_client.chat.completions.create.side_effect = openai.OpenAIError( - "API Error" - ) - - with pytest.raises(openai.OpenAIError) as exc_info: - provider.generate_text("Test prompt", "gpt-3.5-turbo") - assert "API Error" in str(exc_info.value) - - def test_import_error(): with patch("clientai.openai.provider.OPENAI_INSTALLED", False): with pytest.raises(ImportError) as exc_info: diff --git a/tests/replicate/test_exceptions.py b/tests/replicate/test_exceptions.py new file mode 100644 index 0000000..8545b9a --- /dev/null +++ b/tests/replicate/test_exceptions.py @@ -0,0 +1,132 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from clientai.exceptions import ( + APIError, + AuthenticationError, + InvalidRequestError, + ModelError, + RateLimitError, + TimeoutError, +) +from clientai.replicate.provider import Provider + + +class MockReplicateError(Exception): + def __init__(self, message, status_code=None): + self.message = message + self.status_code = status_code + + def __str__(self): + return self.message + + +@pytest.fixture +def provider(): + return Provider(api_key="test_key") + + +@pytest.fixture +def mock_replicate_client(): + with patch("clientai.replicate.provider.Client") as mock_client: + mock_instance = MagicMock() + mock_client.return_value = mock_instance + yield mock_instance + + +def test_generate_text_authentication_error(mock_replicate_client, provider): + error = MockReplicateError("Authentication failed", status_code=401) + mock_replicate_client.predictions.create.side_effect = error + + with pytest.raises(AuthenticationError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Authentication failed" in str(exc_info.value) + assert exc_info.value.status_code == 401 + assert isinstance(exc_info.value.original_error, MockReplicateError) + + +def test_generate_text_rate_limit_error(mock_replicate_client, provider): + error = MockReplicateError("Rate limit exceeded", status_code=429) + mock_replicate_client.predictions.create.side_effect = error + + with pytest.raises(RateLimitError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Rate limit exceeded" in str(exc_info.value) + assert exc_info.value.status_code == 429 + assert isinstance(exc_info.value.original_error, MockReplicateError) + + +def test_generate_text_model_error(mock_replicate_client, provider): + error = MockReplicateError("Model not found", status_code=404) + mock_replicate_client.predictions.create.side_effect = error + + with pytest.raises(ModelError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Model not found" in str(exc_info.value) + assert exc_info.value.status_code == 404 + assert isinstance(exc_info.value.original_error, MockReplicateError) + + +def test_generate_text_invalid_request_error(mock_replicate_client, provider): + error = MockReplicateError("Invalid request", status_code=400) + mock_replicate_client.predictions.create.side_effect = error + + with pytest.raises(InvalidRequestError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Invalid request" in str(exc_info.value) + assert exc_info.value.status_code == 400 + assert isinstance(exc_info.value.original_error, MockReplicateError) + + +def test_generate_text_api_error(mock_replicate_client, provider): + error = MockReplicateError("API error", status_code=500) + mock_replicate_client.predictions.create.side_effect = error + + with pytest.raises(APIError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "API error" in str(exc_info.value) + assert exc_info.value.status_code == 500 + assert isinstance(exc_info.value.original_error, MockReplicateError) + + +def test_generate_text_timeout_error(mock_replicate_client, provider): + error = MockReplicateError("Request timed out", status_code=408) + mock_replicate_client.predictions.create.side_effect = error + + with pytest.raises(TimeoutError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Request timed out" in str(exc_info.value) + assert exc_info.value.status_code == 408 + assert isinstance(exc_info.value.original_error, MockReplicateError) + + +def test_chat_error(mock_replicate_client, provider): + error = MockReplicateError("Chat error", status_code=400) + mock_replicate_client.predictions.create.side_effect = error + + with pytest.raises(InvalidRequestError) as exc_info: + provider.chat( + [{"role": "user", "content": "Test message"}], "test-model" + ) + + assert "Chat error" in str(exc_info.value) + assert exc_info.value.status_code == 400 + assert isinstance(exc_info.value.original_error, MockReplicateError) + + +def test_generic_error(mock_replicate_client, provider): + error = Exception("Unexpected error") + mock_replicate_client.predictions.create.side_effect = error + + with pytest.raises(APIError) as exc_info: + provider.generate_text("Test prompt", "test-model") + + assert "Unexpected error" in str(exc_info.value) + assert isinstance(exc_info.value.original_error, Exception) diff --git a/tests/replicate/test_provider.py b/tests/replicate/test_provider.py index bfedd84..815d613 100644 --- a/tests/replicate/test_provider.py +++ b/tests/replicate/test_provider.py @@ -1,8 +1,8 @@ from unittest.mock import Mock, patch import pytest -import replicate.exceptions +from clientai.exceptions import APIError, TimeoutError from clientai.replicate.provider import Provider VALID_MODEL = "owner/name" @@ -144,9 +144,12 @@ def test_wait_for_prediction_timeout(mock_client, provider): ) mock_client.predictions.get.return_value = mock_prediction - with pytest.raises(TimeoutError): + with pytest.raises(TimeoutError) as exc_info: provider._wait_for_prediction("test_id", max_wait_time=1) + assert "Prediction timed out" in str(exc_info.value) + assert exc_info.value.status_code == 408 + def test_wait_for_prediction_failure(mock_client, provider): mock_prediction = MockPrediction( @@ -154,16 +157,6 @@ def test_wait_for_prediction_failure(mock_client, provider): ) mock_client.predictions.get.return_value = mock_prediction - with pytest.raises(Exception) as exc_info: + with pytest.raises(APIError) as exc_info: provider._wait_for_prediction("test_id") - assert str(exc_info.value) == "Prediction failed: Test error" - - -def test_replicate_error(mock_client, provider): - mock_client.predictions.create.side_effect = ( - replicate.exceptions.ReplicateError("API Error") - ) - - with pytest.raises(replicate.exceptions.ReplicateError) as exc_info: - provider.generate_text("Test prompt", VALID_MODEL) - assert "API Error" in str(exc_info.value) + assert "Prediction failed: Test error" in str(exc_info.value) From 1c7a02b776ab1a51941b3e9493650cacb38f0d64 Mon Sep 17 00:00:00 2001 From: Igor Benav Date: Fri, 25 Oct 2024 03:22:03 -0300 Subject: [PATCH 2/3] update for version 0.2.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 43b93ac..43a6eae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "clientai" -version = "0.1.2" +version = "0.2.0" description = "Simple unified API for multiple AI services." authors = ["Igor Benav "] readme = "README.md" From c3b3aa9a9c43c90b6f065616348bf349c82297fc Mon Sep 17 00:00:00 2001 From: Igor Benav Date: Fri, 25 Oct 2024 03:22:51 -0300 Subject: [PATCH 3/3] add error handling page --- mkdocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mkdocs.yml b/mkdocs.yml index 2cf5671..8120523 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -59,7 +59,7 @@ nav: - usage/text_generation.md - usage/chat_functionality.md - usage/multiple_providers.md - # - usage/error_handling.md + - usage/error_handling.md - Examples: - Overview: examples/overview.md - Examples: