From 2a90a5201c85eef30c782f92ebea27efff37316f Mon Sep 17 00:00:00 2001 From: Zain AHMAD Date: Thu, 26 Dec 2024 18:39:21 +0800 Subject: [PATCH] added working code for bedrock json --- .cursorignore | 1 + instructor/__init__.py | 7 +- instructor/client_bedrock.py | 56 ++++++++++++++++ instructor/function_calls.py | 105 ++++++++++++++++++++++++----- instructor/mode.py | 2 + instructor/patch.py | 9 ++- instructor/process_response.py | 117 +++++++++++++++++++++++++-------- instructor/retry.py | 54 +++++++++++---- instructor/utils.py | 58 ++++++++++++---- 9 files changed, 337 insertions(+), 72 deletions(-) create mode 100644 .cursorignore create mode 100644 instructor/client_bedrock.py diff --git a/.cursorignore b/.cursorignore new file mode 100644 index 000000000..6f9f00ff4 --- /dev/null +++ b/.cursorignore @@ -0,0 +1 @@ +# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv) diff --git a/instructor/__init__.py b/instructor/__init__.py index efd503c22..148c233c7 100644 --- a/instructor/__init__.py +++ b/instructor/__init__.py @@ -92,7 +92,12 @@ __all__ += ["from_vertexai"] +if importlib.util.find_spec("boto3") is not None: + from .client_bedrock import from_bedrock + + __all__ += ["from_bedrock"] + if importlib.util.find_spec("writerai") is not None: from .client_writer import from_writer - __all__ += ["from_writer"] \ No newline at end of file + __all__ += ["from_writer"] diff --git a/instructor/client_bedrock.py b/instructor/client_bedrock.py new file mode 100644 index 000000000..21e3133cd --- /dev/null +++ b/instructor/client_bedrock.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import Any, overload +import boto3 +from botocore.client import BaseClient +import instructor +from instructor.client import AsyncInstructor, Instructor + + +@overload +def from_bedrock( + client: boto3.client, + mode: instructor.Mode = instructor.Mode.BEDROCK_TOOLS, + **kwargs: Any, +) -> Instructor: ... + + +@overload +def from_bedrock( + client: boto3.client, + mode: instructor.Mode = instructor.Mode.BEDROCK_TOOLS, + **kwargs: Any, +) -> AsyncInstructor: ... + + +def handle_bedrock_json( + response_model: Any, + new_kwargs: Any, +) -> tuple[Any, Any]: + print(f"handle_bedrock_json: response_model {response_model}") + print(f"handle_bedrock_json: new_kwargs {new_kwargs}") + return response_model, new_kwargs + + +def from_bedrock( + client: BaseClient, + mode: instructor.Mode = instructor.Mode.BEDROCK_JSON, + **kwargs: Any, +) -> Instructor | AsyncInstructor: + assert mode in { + instructor.Mode.BEDROCK_TOOLS, + instructor.Mode.BEDROCK_JSON, + }, "Mode must be one of {instructor.Mode.BEDROCK_TOOLS, instructor.Mode.BEDROCK_JSON}" + assert isinstance( + client, + BaseClient, + ), "Client must be an instance of boto3.client" + create = client.converse # Example method, replace with actual method + + return Instructor( + client=client, + create=instructor.patch(create=create, mode=mode), + provider=instructor.Provider.BEDROCK, + mode=mode, + **kwargs, + ) diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 8507c2cd6..4f56bbb79 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -1,6 +1,7 @@ # type: ignore import json import logging +import re from functools import wraps from typing import Annotated, Any, Optional, TypeVar, cast from docstring_parser import parse @@ -45,7 +46,9 @@ def openai_schema(cls) -> dict[str, Any]: schema = cls.model_json_schema() docstring = parse(cls.__doc__ or "") parameters = { - k: v for k, v in schema.items() if k not in ("title", "description") + k: v + for k, v in schema.items() + if k not in ("title", "description") } for param in docstring.params: if (name := param.arg_name) in parameters["properties"] and ( @@ -55,7 +58,9 @@ def openai_schema(cls) -> dict[str, Any]: parameters["properties"][name]["description"] = description parameters["required"] = sorted( - k for k, v in parameters["properties"].items() if "default" not in v + k + for k, v in parameters["properties"].items() + if "default" not in v ) if "description" not in schema: @@ -88,7 +93,9 @@ def gemini_schema(cls) -> Any: function = genai_types.FunctionDeclaration( name=cls.openai_schema["name"], description=cls.openai_schema["description"], - parameters=map_to_gemini_function_schema(cls.openai_schema["parameters"]), + parameters=map_to_gemini_function_schema( + cls.openai_schema["parameters"] + ), ) return function @@ -113,31 +120,52 @@ def from_response( cls (OpenAISchema): An instance of the class """ if mode == Mode.ANTHROPIC_TOOLS: - return cls.parse_anthropic_tools(completion, validation_context, strict) + return cls.parse_anthropic_tools( + completion, validation_context, strict + ) if mode == Mode.ANTHROPIC_JSON: - return cls.parse_anthropic_json(completion, validation_context, strict) + return cls.parse_anthropic_json( + completion, validation_context, strict + ) + + if mode == Mode.BEDROCK_JSON: + return cls.parse_bedrock_json( + completion, validation_context, strict + ) if mode in {Mode.VERTEXAI_TOOLS, Mode.GEMINI_TOOLS}: return cls.parse_vertexai_tools(completion, validation_context) if mode == Mode.VERTEXAI_JSON: - return cls.parse_vertexai_json(completion, validation_context, strict) + return cls.parse_vertexai_json( + completion, validation_context, strict + ) if mode == Mode.COHERE_TOOLS: - return cls.parse_cohere_tools(completion, validation_context, strict) + return cls.parse_cohere_tools( + completion, validation_context, strict + ) if mode == Mode.GEMINI_JSON: - return cls.parse_gemini_json(completion, validation_context, strict) + return cls.parse_gemini_json( + completion, validation_context, strict + ) if mode == Mode.GEMINI_TOOLS: - return cls.parse_gemini_tools(completion, validation_context, strict) + return cls.parse_gemini_tools( + completion, validation_context, strict + ) if mode == Mode.COHERE_JSON_SCHEMA: - return cls.parse_cohere_json_schema(completion, validation_context, strict) + return cls.parse_cohere_json_schema( + completion, validation_context, strict + ) if mode == Mode.WRITER_TOOLS: - return cls.parse_writer_tools(completion, validation_context, strict) + return cls.parse_writer_tools( + completion, validation_context, strict + ) if completion.choices[0].finish_reason == "length": raise IncompleteOutputException(last_completion=completion) @@ -190,12 +218,17 @@ def parse_anthropic_tools( ) -> BaseModel: from anthropic.types import Message - if isinstance(completion, Message) and completion.stop_reason == "max_tokens": + if ( + isinstance(completion, Message) + and completion.stop_reason == "max_tokens" + ): raise IncompleteOutputException(last_completion=completion) # Anthropic returns arguments as a dict, dump to json for model validation below tool_calls = [ - json.dumps(c.input) for c in completion.content if c.type == "tool_use" + json.dumps(c.input) + for c in completion.content + if c.type == "tool_use" ] # TODO update with anthropic specific types tool_calls_validator = TypeAdapter( @@ -237,7 +270,39 @@ def parse_anthropic_json( # Allow control characters. parsed = json.loads(extra_text, strict=False) # Pydantic non-strict: https://docs.pydantic.dev/latest/concepts/strict_mode/ - return cls.model_validate(parsed, context=validation_context, strict=False) + return cls.model_validate( + parsed, context=validation_context, strict=False + ) + + @classmethod + def parse_bedrock_json( + cls: type[BaseModel], + completion: Any, + validation_context: Optional[dict[str, Any]] = None, + strict: Optional[bool] = None, + ) -> BaseModel: + if isinstance(completion, dict): + text = ( + completion.get("output") + .get("message") + .get("content")[0] + .get("text") + ) + + match = re.search(r"```?json(.*?)```?", text, re.DOTALL) + if match: + text = match.group(1).strip() + + text = re.sub(r"```?json|\\n", "", text).strip() + # TODO: remove this + print( + f"instructor.function_calls: parse_bedrock_json: test {text}" + ) + else: + text = completion.text + return cls.model_validate_json( + text, context=validation_context, strict=strict + ) @classmethod def parse_gemini_json( @@ -256,7 +321,9 @@ def parse_gemini_json( try: extra_text = extract_json_from_codeblock(text) # type: ignore except UnboundLocalError: - raise ValueError("Unable to extract JSON from completion text") from None + raise ValueError( + "Unable to extract JSON from completion text" + ) from None if strict: return cls.model_validate_json( @@ -266,7 +333,9 @@ def parse_gemini_json( # Allow control characters. parsed = json.loads(extra_text, strict=False) # Pydantic non-strict: https://docs.pydantic.dev/latest/concepts/strict_mode/ - return cls.model_validate(parsed, context=validation_context, strict=False) + return cls.model_validate( + parsed, context=validation_context, strict=False + ) @classmethod def parse_vertexai_tools( @@ -279,7 +348,9 @@ def parse_vertexai_tools( for field in tool_call: # type: ignore model[field] = tool_call[field] # We enable strict=False because the conversion from protobuf -> dict often results in types like ints being cast to floats, as a result in order for model.validate to work we need to disable strict mode. - return cls.model_validate(model, context=validation_context, strict=False) + return cls.model_validate( + model, context=validation_context, strict=False + ) @classmethod def parse_vertexai_json( diff --git a/instructor/mode.py b/instructor/mode.py index ebd330b40..cf8544fd4 100644 --- a/instructor/mode.py +++ b/instructor/mode.py @@ -28,6 +28,8 @@ class Mode(enum.Enum): FIREWORKS_TOOLS = "fireworks_tools" FIREWORKS_JSON = "fireworks_json" WRITER_TOOLS = "writer_tools" + BEDROCK_TOOLS = "bedrock_tools" + BEDROCK_JSON = "bedrock_json" @classmethod def warn_mode_functions_deprecation(cls): diff --git a/instructor/patch.py b/instructor/patch.py index cc5802672..99d3a50d5 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -131,6 +131,9 @@ def patch( # type: ignore logger.debug(f"Patching `client.chat.completions.create` with {mode=}") + # TODO: remove this + print(f"instructor.patch: patching {create.__name__}") + if create is not None: func = create elif client is not None: @@ -183,7 +186,7 @@ def new_create_sync( **kwargs: T_ParamSpec.kwargs, ) -> T_Model: context = handle_context(context, validation_context) - + print(f"instructor.patch: patched_function {func.__name__}") response_model, new_kwargs = handle_response_model( response_model=response_model, mode=mode, **kwargs ) # type: ignore @@ -228,6 +231,8 @@ def apatch(client: AsyncOpenAI, mode: Mode = Mode.TOOLS) -> AsyncOpenAI: import warnings warnings.warn( - "apatch is deprecated, use patch instead", DeprecationWarning, stacklevel=2 + "apatch is deprecated, use patch instead", + DeprecationWarning, + stacklevel=2, ) return patch(client, mode=mode) diff --git a/instructor/process_response.py b/instructor/process_response.py index f44031ab0..420e7251d 100644 --- a/instructor/process_response.py +++ b/instructor/process_response.py @@ -14,18 +14,23 @@ from openai.types.chat import ChatCompletion from pydantic import BaseModel, create_model +# from instructor.client_bedrock import handle_bedrock_json from instructor.mode import Mode from instructor.dsl.iterable import IterableBase, IterableModel from instructor.dsl.parallel import ( - ParallelBase, - ParallelModel, - handle_parallel_model, + ParallelBase, + ParallelModel, + handle_parallel_model, get_types_array, VertexAIParallelBase, - VertexAIParallelModel + VertexAIParallelModel, ) from instructor.dsl.partial import PartialBase -from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type +from instructor.dsl.simple_type import ( + AdapterBase, + ModelAdapter, + is_simple_type, +) from instructor.function_calls import OpenAISchema, openai_schema from instructor.utils import ( merge_consecutive_messages, @@ -146,6 +151,12 @@ def process_response( f"Instructor Raw Response: {response}", ) + # TODO: remove this + print(f"instructor.process_response.py: response_model {response_model}") + + # TODO: remove this + print(f"instructor.process_response.py: response {response}") + if response_model is None: logger.debug("No response model, returning response as is") return response @@ -183,6 +194,10 @@ def process_response( return model.content model._raw_response = response + + # TODO: remove this + print(f"instructor.process_response.py: model {model}") + return model @@ -210,7 +225,9 @@ def handle_functions( ) -> tuple[type[T], dict[str, Any]]: Mode.warn_mode_functions_deprecation() new_kwargs["functions"] = [response_model.openai_schema] - new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} + new_kwargs["function_call"] = { + "name": response_model.openai_schema["name"] + } return response_model, new_kwargs @@ -311,7 +328,9 @@ def handle_json_modes( "content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA", }, ) - new_kwargs["messages"] = merge_consecutive_messages(new_kwargs["messages"]) + new_kwargs["messages"] = merge_consecutive_messages( + new_kwargs["messages"] + ) if new_kwargs["messages"][0]["role"] != "system": new_kwargs["messages"].insert( @@ -383,13 +402,16 @@ def handle_anthropic_json( ) new_kwargs["system"] = combine_system_messages( - new_kwargs.get("system"), [{"type": "text", "text": json_schema_message}] + new_kwargs.get("system"), + [{"type": "text", "text": json_schema_message}], ) return response_model, new_kwargs -def handle_cohere_modes(new_kwargs: dict[str, Any]) -> tuple[None, dict[str, Any]]: +def handle_cohere_modes( + new_kwargs: dict[str, Any] +) -> tuple[None, dict[str, Any]]: messages = new_kwargs.pop("messages", []) chat_history = [] for message in messages[:-1]: @@ -459,13 +481,15 @@ def handle_gemini_json( ) if new_kwargs["messages"][0]["role"] != "system": - new_kwargs["messages"].insert(0, {"role": "system", "content": message}) + new_kwargs["messages"].insert( + 0, {"role": "system", "content": message} + ) else: new_kwargs["messages"][0]["content"] += f"\n\n{message}" - new_kwargs["generation_config"] = new_kwargs.get("generation_config", {}) | { - "response_mime_type": "application/json" - } + new_kwargs["generation_config"] = new_kwargs.get( + "generation_config", {} + ) | {"response_mime_type": "application/json"} new_kwargs = update_gemini_kwargs(new_kwargs) return response_model, new_kwargs @@ -498,17 +522,19 @@ def handle_vertexai_parallel_tools( assert ( new_kwargs.get("stream", False) is False ), "stream=True is not supported when using PARALLEL_TOOLS mode" - + from instructor.client_vertexai import vertexai_process_response - + # Extract concrete types before passing to vertexai_process_response model_types = list(get_types_array(response_model)) - contents, tools, tool_config = vertexai_process_response(new_kwargs, model_types) - + contents, tools, tool_config = vertexai_process_response( + new_kwargs, model_types + ) + new_kwargs["contents"] = contents new_kwargs["tools"] = tools new_kwargs["tool_config"] = tool_config - + return VertexAIParallelModel(typehint=response_model), new_kwargs @@ -517,7 +543,9 @@ def handle_vertexai_tools( ) -> tuple[type[T], dict[str, Any]]: from instructor.client_vertexai import vertexai_process_response - contents, tools, tool_config = vertexai_process_response(new_kwargs, response_model) + contents, tools, tool_config = vertexai_process_response( + new_kwargs, response_model + ) new_kwargs["contents"] = contents new_kwargs["tools"] = tools @@ -539,6 +567,36 @@ def handle_vertexai_json( return response_model, new_kwargs +def handle_bedrock_json( + response_model: type[T], new_kwargs: dict[str, Any] +) -> tuple[type[T], dict[str, Any]]: + print(f"handle_bedrock_json: response_model {response_model}") + print(f"handle_bedrock_json: new_kwargs {new_kwargs}") + json_message = dedent( + f""" + As a genius expert, your task is to understand the content and provide + the parsed objects in json that match the following json_schema:\n + + {json.dumps(response_model.model_json_schema(), indent=2, ensure_ascii=False)} + + Make sure to return an instance of the JSON, not the schema itself + and don't include any other text in the response apart from the json + """ + ) + system_message = new_kwargs.pop("system", None) + if not system_message: + new_kwargs["system"] = [{"text": json_message}] + else: + if not isinstance(system_message, list): + raise ValueError( + """system must be a list of SystemMessage refer + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html + """ + ) + system_message.append({"text": json_message}) + return response_model, new_kwargs + + def handle_cohere_json_schema( response_model: type[T], new_kwargs: dict[str, Any] ) -> tuple[type[T], dict[str, Any]]: @@ -584,9 +642,9 @@ def handle_cerebras_json( Your response should consist only of a valid JSON object that `{response_model.__name__}.model_validate_json()` can successfully parse. """ - new_kwargs["messages"] = [{"role": "system", "content": instruction}] + new_kwargs[ - "messages" - ] + new_kwargs["messages"] = [ + {"role": "system", "content": instruction} + ] + new_kwargs["messages"] return response_model, new_kwargs @@ -612,7 +670,7 @@ def handle_cohere_tools( def handle_writer_tools( - response_model: type[T], new_kwargs: dict[str, Any] + response_model: type[T], new_kwargs: dict[str, Any] ) -> tuple[type[T], dict[str, Any]]: new_kwargs["tools"] = [ { @@ -691,6 +749,7 @@ def handle_response_model( """ new_kwargs = kwargs.copy() + print(f"instructor.process_response.py: new_kwargs -> {new_kwargs}") autodetect_images = new_kwargs.pop("autodetect_images", False) if response_model is None: @@ -706,7 +765,9 @@ def handle_response_model( ) if mode in {Mode.ANTHROPIC_JSON, Mode.ANTHROPIC_TOOLS}: # Handle OpenAI style or Anthropic style messages - new_kwargs["messages"] = [m for m in messages if m["role"] != "system"] + new_kwargs["messages"] = [ + m for m in messages if m["role"] != "system" + ] if "system" not in new_kwargs: system_message = extract_system_messages(messages) if system_message: @@ -744,10 +805,13 @@ def handle_response_model( Mode.FIREWORKS_JSON: handle_fireworks_json, Mode.FIREWORKS_TOOLS: handle_fireworks_tools, Mode.WRITER_TOOLS: handle_writer_tools, + Mode.BEDROCK_JSON: handle_bedrock_json, } if mode in mode_handlers: - response_model, new_kwargs = mode_handlers[mode](response_model, new_kwargs) + response_model, new_kwargs = mode_handlers[mode]( + response_model, new_kwargs + ) else: raise ValueError(f"Invalid patch mode: {mode}") @@ -763,7 +827,8 @@ def handle_response_model( "mode": mode.value, "response_model": ( response_model.__name__ - if response_model is not None and hasattr(response_model, "__name__") + if response_model is not None + and hasattr(response_model, "__name__") else str(response_model) ), "new_kwargs": new_kwargs, diff --git a/instructor/retry.py b/instructor/retry.py index 853a73a54..ab52316f5 100644 --- a/instructor/retry.py +++ b/instructor/retry.py @@ -4,17 +4,24 @@ import logging from json import JSONDecodeError -from typing import Any, Callable, TypeVar +from typing import Any, Callable, TypeVar, Union from instructor.exceptions import InstructorRetryException from instructor.hooks import Hooks from instructor.mode import Mode from instructor.reask import handle_reask_kwargs -from instructor.process_response import process_response, process_response_async +from instructor.process_response import ( + process_response, + process_response_async, +) from instructor.utils import update_total_usage from instructor.validators import AsyncValidationError from openai.types.chat import ChatCompletion -from openai.types.completion_usage import CompletionUsage, CompletionTokensDetails, PromptTokensDetails +from openai.types.completion_usage import ( + CompletionUsage, + CompletionTokensDetails, + PromptTokensDetails, +) from pydantic import BaseModel, ValidationError from tenacity import ( AsyncRetrying, @@ -33,7 +40,9 @@ T = TypeVar("T") -def initialize_retrying(max_retries: int | Retrying | AsyncRetrying, is_async: bool): +def initialize_retrying( + max_retries: int | Retrying | AsyncRetrying, is_async: bool +): """ Initialize the retrying mechanism based on the type (synchronous or asynchronous). @@ -71,9 +80,16 @@ def initialize_usage(mode: Mode) -> CompletionUsage | Any: Returns: CompletionUsage | Any: Initialized usage object. """ - total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0, - completion_tokens_details = CompletionTokensDetails(audio_tokens=0, reasoning_tokens=0), - prompt_tokens_details = PromptTokensDetails(audio_tokens=0, cached_tokens=0) + total_usage = CompletionUsage( + completion_tokens=0, + prompt_tokens=0, + total_tokens=0, + completion_tokens_details=CompletionTokensDetails( + audio_tokens=0, reasoning_tokens=0 + ), + prompt_tokens_details=PromptTokensDetails( + audio_tokens=0, cached_tokens=0 + ), ) if mode in {Mode.ANTHROPIC_TOOLS, Mode.ANTHROPIC_JSON}: from anthropic.types import Usage as AnthropicUsage @@ -136,7 +152,9 @@ def retry_sync( response = None for attempt in max_retries: with attempt: - logger.debug(f"Retrying, attempt: {attempt.retry_state.attempt_number}") + logger.debug( + f"Retrying, attempt: {attempt.retry_state.attempt_number}" + ) try: hooks.emit_completion_arguments(*args, **kwargs) response = func(*args, **kwargs) @@ -144,6 +162,10 @@ def retry_sync( response = update_total_usage( response=response, total_usage=total_usage ) + + # TODO: remove this + print(f"instructor.retry.py: {response}") + return process_response( # type: ignore response=response, response_model=response_model, @@ -170,7 +192,8 @@ def retry_sync( n_attempts=attempt.retry_state.attempt_number, #! deprecate messages soon messages=kwargs.get( - "messages", kwargs.get("contents", kwargs.get("chat_history", [])) + "messages", + kwargs.get("contents", kwargs.get("chat_history", [])), ), create_kwargs=kwargs, total_usage=total_usage, @@ -215,7 +238,9 @@ async def retry_async( try: response = None async for attempt in max_retries: - logger.debug(f"Retrying, attempt: {attempt.retry_state.attempt_number}") + logger.debug( + f"Retrying, attempt: {attempt.retry_state.attempt_number}" + ) with attempt: try: hooks.emit_completion_arguments(*args, **kwargs) @@ -233,7 +258,11 @@ async def retry_async( mode=mode, stream=kwargs.get("stream", False), ) - except (ValidationError, JSONDecodeError, AsyncValidationError) as e: + except ( + ValidationError, + JSONDecodeError, + AsyncValidationError, + ) as e: logger.debug(f"Parse error: {e}") hooks.emit_parse_error(e) kwargs = handle_reask_kwargs( @@ -251,7 +280,8 @@ async def retry_async( n_attempts=attempt.retry_state.attempt_number, #! deprecate messages soon messages=kwargs.get( - "messages", kwargs.get("contents", kwargs.get("chat_history", [])) + "messages", + kwargs.get("contents", kwargs.get("chat_history", [])), ), create_kwargs=kwargs, total_usage=total_usage, diff --git a/instructor/utils.py b/instructor/utils.py index 55d746760..49a7de4fa 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -55,6 +55,7 @@ class Provider(Enum): FIREWORKS = "fireworks" WRITER = "writer" UNKNOWN = "unknown" + BEDROCK = "bedrock" def get_provider(base_url: str) -> Provider: @@ -93,7 +94,9 @@ def extract_json_from_codeblock(content: str) -> str: return content[first_paren : last_paren + 1] -def extract_json_from_stream(chunks: Iterable[str]) -> Generator[str, None, None]: +def extract_json_from_stream( + chunks: Iterable[str], +) -> Generator[str, None, None]: capturing = False brace_count = 0 for chunk in chunks: @@ -141,23 +144,33 @@ def update_total_usage( return None response_usage = getattr(response, "usage", None) - if isinstance(response_usage, OpenAIUsage) and isinstance(total_usage, OpenAIUsage): + if isinstance(response_usage, OpenAIUsage) and isinstance( + total_usage, OpenAIUsage + ): total_usage.completion_tokens += response_usage.completion_tokens or 0 total_usage.prompt_tokens += response_usage.prompt_tokens or 0 total_usage.total_tokens += response_usage.total_tokens or 0 if (rtd := response_usage.completion_tokens_details) and ( ttd := total_usage.completion_tokens_details ): - ttd.audio_tokens = (ttd.audio_tokens or 0) + (rtd.audio_tokens or 0) + ttd.audio_tokens = (ttd.audio_tokens or 0) + ( + rtd.audio_tokens or 0 + ) ttd.reasoning_tokens = (ttd.reasoning_tokens or 0) + ( rtd.reasoning_tokens or 0 ) if (rpd := response_usage.prompt_tokens_details) and ( tpd := total_usage.prompt_tokens_details ): - tpd.audio_tokens = (tpd.audio_tokens or 0) + (rpd.audio_tokens or 0) - tpd.cached_tokens = (tpd.cached_tokens or 0) + (rpd.cached_tokens or 0) - response.usage = total_usage # Replace each response usage with the total usage + tpd.audio_tokens = (tpd.audio_tokens or 0) + ( + rpd.audio_tokens or 0 + ) + tpd.cached_tokens = (tpd.cached_tokens or 0) + ( + rpd.cached_tokens or 0 + ) + response.usage = ( + total_usage # Replace each response usage with the total usage + ) return response # Anthropic usage. @@ -174,7 +187,9 @@ def update_total_usage( except ImportError: pass - logger.debug("No compatible response.usage found, token usage not updated.") + logger.debug( + "No compatible response.usage found, token usage not updated." + ) return response @@ -215,7 +230,9 @@ def is_async(func: Callable[..., Any]) -> bool: return is_coroutine -def merge_consecutive_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: +def merge_consecutive_messages( + messages: list[dict[str, Any]] +) -> list[dict[str, Any]]: # merge all consecutive user messages into a single message new_messages: list[dict[str, Any]] = [] # Detect whether all messages have a flat content (i.e. all string) @@ -227,7 +244,10 @@ def merge_consecutive_messages(messages: list[dict[str, Any]]) -> list[dict[str, # If content is not flat, transform it into a list of text new_content = [{"type": "text", "text": new_content}] - if len(new_messages) > 0 and message["role"] == new_messages[-1]["role"]: + if ( + len(new_messages) > 0 + and message["role"] == new_messages[-1]["role"] + ): if flat_string: # New content is a string new_messages[-1]["content"] += f"\n\n{new_content}" @@ -299,7 +319,9 @@ def transform_to_gemini_prompt( if messages_gemini: messages_gemini[0]["parts"].insert(0, f"*{system_prompt}*") else: - messages_gemini.append({"role": "user", "parts": [f"*{system_prompt}*"]}) + messages_gemini.append( + {"role": "user", "parts": [f"*{system_prompt}*"]} + ) return messages_gemini @@ -341,7 +363,9 @@ def add_enum_format(obj: dict[str, Any]) -> dict[str, Any]: schema = add_enum_format(schema) - return FunctionSchema(**schema).model_dump(exclude_none=True, exclude_unset=True) + return FunctionSchema(**schema).model_dump( + exclude_none=True, exclude_unset=True + ) def update_gemini_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: @@ -416,8 +440,12 @@ def combine_system_messages( raise ValueError("Unsupported system message type combination") -def extract_system_messages(messages: list[dict[str, Any]]) -> list[SystemMessage]: - def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: # noqa: UP007 +def extract_system_messages( + messages: list[dict[str, Any]] +) -> list[SystemMessage]: + def convert_message( + content: Union[str, dict[str, Any]] + ) -> SystemMessage: # noqa: UP007 if isinstance(content, str): return SystemMessage(type="text", text=content) elif isinstance(content, dict): @@ -429,7 +457,9 @@ def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: # no for m in messages: if m["role"] == "system": # System message must always be a string or list of dictionaries - content = cast(Union[str, list[dict[str, Any]]], m["content"]) # noqa: UP007 + content = cast( + Union[str, list[dict[str, Any]]], m["content"] + ) # noqa: UP007 if isinstance(content, list): result.extend(convert_message(item) for item in content) else: