diff --git a/instructor/__init__.py b/instructor/__init__.py index 39d338346..e30810231 100644 --- a/instructor/__init__.py +++ b/instructor/__init__.py @@ -78,4 +78,4 @@ if importlib.util.find_spec("vertexai") is not None: from .client_vertexai import from_vertexai - __all__ += ["from_vertexai"] \ No newline at end of file + __all__ += ["from_vertexai"] diff --git a/instructor/cli/usage.py b/instructor/cli/usage.py index 1bc915236..3b1a3443c 100644 --- a/instructor/cli/usage.py +++ b/instructor/cli/usage.py @@ -118,11 +118,11 @@ def calculate_cost( def group_and_sum_by_date_and_snapshot(usage_data: list[dict[str, Any]]) -> Table: """Group and sum the usage data by date and snapshot, including costs.""" - summary: defaultdict[ - str, defaultdict[str, dict[str, Union[int, float]]] - ] = defaultdict( - lambda: defaultdict( - lambda: {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0} + summary: defaultdict[str, defaultdict[str, dict[str, Union[int, float]]]] = ( + defaultdict( + lambda: defaultdict( + lambda: {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0} + ) ) ) diff --git a/instructor/client.py b/instructor/client.py index 6e7b8d8f1..968aa0205 100644 --- a/instructor/client.py +++ b/instructor/client.py @@ -63,8 +63,7 @@ def create( validation_context: dict[str, Any] | None = None, strict: bool = True, **kwargs: Any, - ) -> Awaitable[T]: - ... + ) -> Awaitable[T]: ... @overload def create( @@ -75,8 +74,7 @@ def create( validation_context: dict[str, Any] | None = None, strict: bool = True, **kwargs: Any, - ) -> T: - ... + ) -> T: ... # TODO: we should overload a case where response_model is None def create( @@ -108,8 +106,7 @@ def create_partial( validation_context: dict[str, Any] | None = None, strict: bool = True, **kwargs: Any, - ) -> AsyncGenerator[T, None]: - ... + ) -> AsyncGenerator[T, None]: ... @overload def create_partial( @@ -120,8 +117,7 @@ def create_partial( validation_context: dict[str, Any] | None = None, strict: bool = True, **kwargs: Any, - ) -> Generator[T, None, None]: - ... + ) -> Generator[T, None, None]: ... def create_partial( self, @@ -155,8 +151,7 @@ def create_iterable( validation_context: dict[str, Any] | None = None, strict: bool = True, **kwargs: Any, - ) -> AsyncGenerator[T, None]: - ... + ) -> AsyncGenerator[T, None]: ... @overload def create_iterable( @@ -167,8 +162,7 @@ def create_iterable( validation_context: dict[str, Any] | None = None, strict: bool = True, **kwargs: Any, - ) -> Generator[T, None, None]: - ... + ) -> Generator[T, None, None]: ... def create_iterable( self, @@ -203,8 +197,7 @@ def create_with_completion( validation_context: dict[str, Any] | None = None, strict: bool = True, **kwargs: Any, - ) -> Awaitable[tuple[T, Any]]: - ... + ) -> Awaitable[tuple[T, Any]]: ... @overload def create_with_completion( @@ -215,8 +208,7 @@ def create_with_completion( validation_context: dict[str, Any] | None = None, strict: bool = True, **kwargs: Any, - ) -> tuple[T, Any]: - ... + ) -> tuple[T, Any]: ... def create_with_completion( self, @@ -432,8 +424,7 @@ def from_litellm( completion: Callable[..., Any], mode: instructor.Mode = instructor.Mode.TOOLS, **kwargs: Any, -) -> Instructor: - ... +) -> Instructor: ... @overload diff --git a/instructor/client_anthropic.py b/instructor/client_anthropic.py index 4a8753882..4b0803dbe 100644 --- a/instructor/client_anthropic.py +++ b/instructor/client_anthropic.py @@ -13,8 +13,7 @@ def from_anthropic( ), mode: instructor.Mode = instructor.Mode.ANTHROPIC_TOOLS, **kwargs: Any, -) -> instructor.Instructor: - ... +) -> instructor.Instructor: ... @overload @@ -26,8 +25,7 @@ def from_anthropic( ), mode: instructor.Mode = instructor.Mode.ANTHROPIC_TOOLS, **kwargs: Any, -) -> instructor.AsyncInstructor: - ... +) -> instructor.AsyncInstructor: ... def from_anthropic( diff --git a/instructor/client_cohere.py b/instructor/client_cohere.py index d7f99baf8..d823b870c 100644 --- a/instructor/client_cohere.py +++ b/instructor/client_cohere.py @@ -23,8 +23,7 @@ def from_cohere( client: cohere.Client, mode: instructor.Mode = instructor.Mode.COHERE_TOOLS, **kwargs: Any, -) -> instructor.Instructor: - ... +) -> instructor.Instructor: ... @overload @@ -32,8 +31,7 @@ def from_cohere( client: cohere.AsyncClient, mode: instructor.Mode = instructor.Mode.COHERE_TOOLS, **kwargs: Any, -) -> instructor.AsyncInstructor: - ... +) -> instructor.AsyncInstructor: ... def from_cohere( diff --git a/instructor/client_gemini.py b/instructor/client_gemini.py index 69218d744..13e2538f9 100644 --- a/instructor/client_gemini.py +++ b/instructor/client_gemini.py @@ -32,7 +32,9 @@ def from_gemini( use_async: bool = False, **kwargs: Any, ) -> instructor.Instructor | instructor.AsyncInstructor: - assert mode == instructor.Mode.GEMINI_JSON, "Mode must be instructor.Mode.GEMINI_JSON" + assert ( + mode == instructor.Mode.GEMINI_JSON + ), "Mode must be instructor.Mode.GEMINI_JSON" assert isinstance( client, diff --git a/instructor/client_groq.py b/instructor/client_groq.py index 4b4a9357a..ff72a9cc8 100644 --- a/instructor/client_groq.py +++ b/instructor/client_groq.py @@ -11,8 +11,7 @@ def from_groq( client: groq.Groq, mode: instructor.Mode = instructor.Mode.TOOLS, **kwargs: Any, -) -> instructor.Instructor: - ... +) -> instructor.Instructor: ... @overload @@ -20,8 +19,7 @@ def from_groq( client: groq.AsyncGroq, mode: instructor.Mode = instructor.Mode.TOOLS, **kwargs: Any, -) -> instructor.AsyncInstructor: - ... +) -> instructor.AsyncInstructor: ... def from_groq( diff --git a/instructor/client_mistral.py b/instructor/client_mistral.py index 2dd9b5e48..a586dccbb 100644 --- a/instructor/client_mistral.py +++ b/instructor/client_mistral.py @@ -12,8 +12,7 @@ def from_mistral( client: mistralai.client.MistralClient, mode: instructor.Mode = instructor.Mode.MISTRAL_TOOLS, **kwargs: Any, -) -> instructor.Instructor: - ... +) -> instructor.Instructor: ... @overload @@ -21,8 +20,7 @@ def from_mistral( client: mistralaiasynccli.MistralAsyncClient, mode: instructor.Mode = instructor.Mode.MISTRAL_TOOLS, **kwargs: Any, -) -> instructor.AsyncInstructor: - ... +) -> instructor.AsyncInstructor: ... def from_mistral( diff --git a/instructor/client_vertexai.py b/instructor/client_vertexai.py index 06c8a184a..07896a4ac 100644 --- a/instructor/client_vertexai.py +++ b/instructor/client_vertexai.py @@ -2,59 +2,60 @@ from typing import Any -from vertexai.preview.generative_models import ToolConfig #type: ignore[reportMissingTypeStubs] -import vertexai.generative_models as gm #type: ignore[reportMissingTypeStubs] +from vertexai.preview.generative_models import ToolConfig # type: ignore[reportMissingTypeStubs] +import vertexai.generative_models as gm # type: ignore[reportMissingTypeStubs] from pydantic import BaseModel import instructor -import jsonref #type: ignore[reportMissingTypeStubs] +import jsonref # type: ignore[reportMissingTypeStubs] def _create_vertexai_tool(model: BaseModel) -> gm.Tool: - schema: dict[Any, Any] = jsonref.replace_refs(model.model_json_schema()) #type: ignore[reportMissingTypeStubs] - + schema: dict[Any, Any] = jsonref.replace_refs(model.model_json_schema()) # type: ignore[reportMissingTypeStubs] + parameters: dict[Any, Any] = { "type": schema["type"], "properties": schema["properties"], - "required": schema["required"] + "required": schema["required"], } declaration = gm.FunctionDeclaration( - name=model.__name__, - description=model.__doc__, - parameters=parameters + name=model.__name__, description=model.__doc__, parameters=parameters ) tool = gm.Tool(function_declarations=[declaration]) return tool + def _vertexai_message_parser(message: dict[str, str]) -> gm.Content: return gm.Content( - role=message["role"], - parts=[ - gm.Part.from_text(message["content"]) - ] - ) + role=message["role"], parts=[gm.Part.from_text(message["content"])] + ) + -def vertexai_function_response_parser(response: gm.GenerationResponse, exception: Exception) -> gm.Content: +def vertexai_function_response_parser( + response: gm.GenerationResponse, exception: Exception +) -> gm.Content: return gm.Content( - parts=[ - gm.Part.from_function_response( - name=response.candidates[0].content.parts[0].function_call.name, - response={ - "content": f"Validation Error found:\n{exception}\nRecall the function correctly, fix the errors" - } - ) - ] + parts=[ + gm.Part.from_function_response( + name=response.candidates[0].content.parts[0].function_call.name, + response={ + "content": f"Validation Error found:\n{exception}\nRecall the function correctly, fix the errors" + }, ) + ] + ) + def vertexai_process_response(_kwargs: dict[str, Any], model: BaseModel): messages = _kwargs.pop("messages") contents = [ - _vertexai_message_parser(message) #type: ignore[reportUnkownArgumentType] - if isinstance(message, dict) else message + _vertexai_message_parser(message) # type: ignore[reportUnkownArgumentType] + if isinstance(message, dict) + else message for message in messages - ] + ] tool = _create_vertexai_tool(model=model) tool_config = ToolConfig( function_calling_config=ToolConfig.FunctionCallingConfig( @@ -70,7 +71,9 @@ def from_vertexai( _async: bool = False, **kwargs: Any, ) -> instructor.Instructor: - assert mode == instructor.Mode.VERTEXAI_TOOLS, "Mode must be instructor.Mode.VERTEXAI_TOOLS" + assert ( + mode == instructor.Mode.VERTEXAI_TOOLS + ), "Mode must be instructor.Mode.VERTEXAI_TOOLS" assert isinstance( client, gm.GenerativeModel diff --git a/instructor/dsl/parallel.py b/instructor/dsl/parallel.py index 52c34ae44..23fbfa698 100644 --- a/instructor/dsl/parallel.py +++ b/instructor/dsl/parallel.py @@ -47,9 +47,11 @@ def from_response( if sys.version_info >= (3, 10): from types import UnionType + def is_union_type(typehint: type[Iterable[T]]) -> bool: return get_origin(get_args(typehint)[0]) in (Union, UnionType) else: + def is_union_type(typehint: type[Iterable[T]]) -> bool: return get_origin(get_args(typehint)[0]) is Union diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 2749ba4c8..e340b802e 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -202,10 +202,10 @@ def parse_vertexai_tools( validation_context: Optional[dict[str, Any]] = None, strict: Optional[bool] = None, ) -> BaseModel: - strict=False - tool_call= completion.candidates[0].content.parts[0].function_call.args # type: ignore + strict = False + tool_call = completion.candidates[0].content.parts[0].function_call.args # type: ignore model = {} - for field in tool_call: # type: ignore + for field in tool_call: # type: ignore model[field] = tool_call[field] return cls.model_validate(model, context=validation_context, strict=strict) diff --git a/instructor/patch.py b/instructor/patch.py index cc06436a0..9a1b2a0ae 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -35,8 +35,7 @@ def __call__( max_retries: int = 1, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs, - ) -> T_Model: - ... + ) -> T_Model: ... class AsyncInstructorChatCompletionCreate(Protocol): @@ -47,40 +46,35 @@ async def __call__( max_retries: int = 1, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs, - ) -> T_Model: - ... + ) -> T_Model: ... @overload def patch( client: OpenAI, mode: Mode = Mode.TOOLS, -) -> OpenAI: - ... +) -> OpenAI: ... @overload def patch( client: AsyncOpenAI, mode: Mode = Mode.TOOLS, -) -> AsyncOpenAI: - ... +) -> AsyncOpenAI: ... @overload def patch( create: Callable[T_ParamSpec, T_Retval], mode: Mode = Mode.TOOLS, -) -> InstructorChatCompletionCreate: - ... +) -> InstructorChatCompletionCreate: ... @overload def patch( create: Awaitable[T_Retval], mode: Mode = Mode.TOOLS, -) -> InstructorChatCompletionCreate: - ... +) -> InstructorChatCompletionCreate: ... def patch( diff --git a/instructor/process_response.py b/instructor/process_response.py index a17ece5d2..a28263a39 100644 --- a/instructor/process_response.py +++ b/instructor/process_response.py @@ -419,10 +419,13 @@ def handle_response_model( } elif mode == Mode.VERTEXAI_TOOLS: from instructor.client_vertexai import vertexai_process_response - contents, tools, tool_config = vertexai_process_response(new_kwargs, response_model) - new_kwargs['contents'] = contents - new_kwargs['tools'] = tools - new_kwargs['tool_config'] = tool_config + + contents, tools, tool_config = vertexai_process_response( + new_kwargs, response_model + ) + new_kwargs["contents"] = contents + new_kwargs["tools"] = tools + new_kwargs["tool_config"] = tool_config else: raise ValueError(f"Invalid patch mode: {mode}") diff --git a/instructor/retry.py b/instructor/retry.py index 49f079d5f..5c077c00e 100644 --- a/instructor/retry.py +++ b/instructor/retry.py @@ -93,6 +93,7 @@ def reask_messages(response: ChatCompletion, mode: Mode, exception: Exception): return if mode == Mode.VERTEXAI_TOOLS: from .client_vertexai import vertexai_function_response_parser + yield response.candidates[0].content yield vertexai_function_response_parser(response, exception) return diff --git a/tests/llm/test_vertexai/test_modes.py b/tests/llm/test_vertexai/test_modes.py index 9c50807aa..ec2260819 100644 --- a/tests/llm/test_vertexai/test_modes.py +++ b/tests/llm/test_vertexai/test_modes.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -import vertexai.generative_models as gm #type: ignore[reportMissingTypeStubs] +import vertexai.generative_models as gm # type: ignore[reportMissingTypeStubs] import instructor from .util import model, mode diff --git a/tests/llm/test_vertexai/test_retries.py b/tests/llm/test_vertexai/test_retries.py index 3a9f6a241..866683aa1 100644 --- a/tests/llm/test_vertexai/test_retries.py +++ b/tests/llm/test_vertexai/test_retries.py @@ -1,7 +1,7 @@ from typing import Annotated, cast from pydantic import AfterValidator, BaseModel, Field import instructor -import vertexai.generative_models as gm #type: ignore[reportMissingTypeStubs] +import vertexai.generative_models as gm # type: ignore[reportMissingTypeStubs] from .util import model, mode diff --git a/tests/llm/test_vertexai/test_simple_types.py b/tests/llm/test_vertexai/test_simple_types.py index 97aa0a320..41caac01c 100644 --- a/tests/llm/test_vertexai/test_simple_types.py +++ b/tests/llm/test_vertexai/test_simple_types.py @@ -1,6 +1,6 @@ import instructor import enum -import vertexai.generative_models as gm #type: ignore[reportMissingTypeStubs] +import vertexai.generative_models as gm # type: ignore[reportMissingTypeStubs] from typing import Literal, Union from .util import model, mode @@ -21,7 +21,7 @@ def test_literal(): def test_union(): - client = instructor.from_vertexai(gm.GenerativeModel(model) , mode) + client = instructor.from_vertexai(gm.GenerativeModel(model), mode) response = client.create( response_model=Union[int, str], diff --git a/tests/llm/test_vertexai/test_stream.py b/tests/llm/test_vertexai/test_stream.py index d66a4e3c6..df9d0511f 100644 --- a/tests/llm/test_vertexai/test_stream.py +++ b/tests/llm/test_vertexai/test_stream.py @@ -1,7 +1,7 @@ from collections.abc import Iterable from pydantic import BaseModel import instructor -import vertexai.generative_models as gm #type: ignore[reportMissingTypeStubs] +import vertexai.generative_models as gm # type: ignore[reportMissingTypeStubs] from instructor.dsl.partial import Partial from .util import model, mode