From 5dfb0e0b36638ca742a69acb3a6bcc9886f2b6d6 Mon Sep 17 00:00:00 2001 From: Philip Kiely Date: Tue, 8 Jul 2025 21:33:12 -0700 Subject: [PATCH 1/2] Add baseten integration --- pyproject.toml | 3 + src/strands/models/baseten.py | 185 ++++++++++++++++ tests/strands/models/test_baseten.py | 304 +++++++++++++++++++++++++++ 3 files changed, 492 insertions(+) create mode 100644 src/strands/models/baseten.py create mode 100644 tests/strands/models/test_baseten.py diff --git a/pyproject.toml b/pyproject.toml index 7d865feff..363446237 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ packages = ["src/strands"] anthropic = [ "anthropic>=0.21.0,<1.0.0", ] +baseten = [ + "openai>=1.68.0,<2.0.0", +] dev = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", diff --git a/src/strands/models/baseten.py b/src/strands/models/baseten.py new file mode 100644 index 000000000..72daa9967 --- /dev/null +++ b/src/strands/models/baseten.py @@ -0,0 +1,185 @@ +"""Baseten model provider. + +- Docs: https://docs.baseten.co/ +""" + +import logging +from typing import Any, Generator, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast + +import openai +from openai.types.chat.parsed_chat_completion import ParsedChatCompletion +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import Messages +from ..types.models import OpenAIModel + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class Client(Protocol): + """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" + + @property + # pragma: no cover + def chat(self) -> Any: + """Chat completions interface.""" + ... + + +class BasetenModel(OpenAIModel): + """Baseten model provider implementation.""" + + client: Client + + class BasetenConfig(TypedDict, total=False): + """Configuration options for Baseten models. + + Attributes: + model_id: Model ID for the Baseten model. + For Model APIs, use model slugs like "deepseek-ai/DeepSeek-R1-0528" or "meta-llama/Llama-4-Maverick-17B-128E-Instruct". + For dedicated deployments, use the deployment ID. + base_url: Base URL for the Baseten API. + For Model APIs: https://inference.baseten.co/v1 + For dedicated deployments: https://model-xxxxxxx.api.baseten.co/environments/production/sync/v1 + params: Model parameters (e.g., max_tokens). + For a complete list of supported parameters, see + https://platform.openai.com/docs/api-reference/chat/create. + """ + + model_id: str + base_url: Optional[str] + params: Optional[dict[str, Any]] + + def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[BasetenConfig]) -> None: + """Initialize provider instance. + + Args: + client_args: Arguments for the Baseten client. + For a complete list of supported arguments, see https://pypi.org/project/openai/. + **model_config: Configuration options for the Baseten model. + """ + self.config = dict(model_config) + + logger.debug("config=<%s> | initializing", self.config) + + client_args = client_args or {} + + # Set default base URL for Model APIs if not provided + if "base_url" not in client_args and "base_url" not in self.config: + client_args["base_url"] = "https://inference.baseten.co/v1" + elif "base_url" in self.config: + client_args["base_url"] = self.config["base_url"] + + self.client = openai.OpenAI(**client_args) + + @override + def update_config(self, **model_config: Unpack[BasetenConfig]) -> None: # type: ignore[override] + """Update the Baseten model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> BasetenConfig: + """Get the Baseten model configuration. + + Returns: + The Baseten model configuration. + """ + return cast(BasetenModel.BasetenConfig, self.config) + + @override + def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + """Send the request to the Baseten model and get the streaming response. + + Args: + request: The formatted request to send to the Baseten model. + + Returns: + An iterable of response events from the Baseten model. + """ + response = self.client.chat.completions.create(**request) + + yield {"chunk_type": "message_start"} + yield {"chunk_type": "content_start", "data_type": "text"} + + tool_calls: dict[int, list[Any]] = {} + + for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + if choice.delta.content: + yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + + if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + yield { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": choice.delta.reasoning_content, + } + + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + + if choice.finish_reason: + break + + yield {"chunk_type": "content_stop", "data_type": "text"} + + for tool_deltas in tool_calls.values(): + yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} + + for tool_delta in tool_deltas: + yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} + + yield {"chunk_type": "content_stop", "data_type": "tool"} + + yield {"chunk_type": "message_stop", "data": choice.finish_reason} + + # Skip remaining events as we don't have use for anything except the final usage payload + for event in response: + _ = event + + yield {"chunk_type": "metadata", "data": event.usage} + + @override + def structured_output( + self, output_model: Type[T], prompt: Messages + ) -> Generator[dict[str, Union[T, Any]], None, None]: + """Get structured output from the model. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + + Yields: + Model events with the last being the structured output. + """ + response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore + model=self.get_config()["model_id"], + messages=super().format_request(prompt)["messages"], + response_format=output_model, + ) + + parsed: T | None = None + # Find the first choice with tool_calls + if len(response.choices) > 1: + raise ValueError("Multiple choices found in the Baseten response.") + + for choice in response.choices: + if isinstance(choice.message.parsed, output_model): + parsed = choice.message.parsed + break + + if parsed: + yield {"output": parsed} + else: + raise ValueError("No valid tool use or tool use input was found in the Baseten response.") \ No newline at end of file diff --git a/tests/strands/models/test_baseten.py b/tests/strands/models/test_baseten.py new file mode 100644 index 000000000..52867ac09 --- /dev/null +++ b/tests/strands/models/test_baseten.py @@ -0,0 +1,304 @@ +import unittest.mock + +import pydantic +import pytest + +import strands +from strands.models.baseten import BasetenModel + + +@pytest.fixture +def openai_client_cls(): + with unittest.mock.patch.object(strands.models.baseten.openai, "OpenAI") as mock_client_cls: + yield mock_client_cls + + +@pytest.fixture +def openai_client(openai_client_cls): + return openai_client_cls.return_value + + +@pytest.fixture +def model_id(): + return "deepseek-ai/DeepSeek-R1-0528" + + +@pytest.fixture +def model(openai_client, model_id): + _ = openai_client + + return BasetenModel(model_id=model_id) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + +def test__init__model_apis(openai_client_cls, model_id): + model = BasetenModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) + + tru_config = model.get_config() + exp_config = {"model_id": "deepseek-ai/DeepSeek-R1-0528", "params": {"max_tokens": 1}} + + assert tru_config == exp_config + + openai_client_cls.assert_called_once_with(api_key="k1", base_url="https://inference.baseten.co/v1") + + +def test__init__dedicated_deployment(openai_client_cls): + deployment_id = "dq4kr413" + environment = "production" + base_url = f"https://model-{deployment_id}.api.baseten.co/environments/{environment}/sync/v1" + + model = BasetenModel( + {"api_key": "k1"}, + model_id=deployment_id, + base_url=base_url, + environment=environment, + params={"max_tokens": 1} + ) + + tru_config = model.get_config() + exp_config = { + "model_id": "dq4kr413", + "base_url": base_url, + "environment": "production", + "params": {"max_tokens": 1} + } + + assert tru_config == exp_config + + openai_client_cls.assert_called_once_with(api_key="k1", base_url=base_url) + + +def test__init__dedicated_deployment_custom_environment(openai_client_cls): + deployment_id = "dq4kr413" + environment = "staging" + base_url = f"https://model-{deployment_id}.api.baseten.co/environments/{environment}/sync/v1" + + model = BasetenModel( + {"api_key": "k1"}, + model_id=deployment_id, + base_url=base_url, + environment=environment, + params={"max_tokens": 1} + ) + + tru_config = model.get_config() + exp_config = { + "model_id": "dq4kr413", + "base_url": base_url, + "environment": "staging", + "params": {"max_tokens": 1} + } + + assert tru_config == exp_config + + openai_client_cls.assert_called_once_with(api_key="k1", base_url=base_url) + + +def test__init__base_url_in_client_args(openai_client_cls, model_id): + custom_base_url = "https://custom.baseten.co/v1" + model = BasetenModel( + {"api_key": "k1", "base_url": custom_base_url}, + model_id=model_id + ) + + openai_client_cls.assert_called_once_with(api_key="k1", base_url=custom_base_url) + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_stream(openai_client, model): + mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) + mock_delta_1 = unittest.mock.Mock( + reasoning_content="", + content=None, + tool_calls=None, + ) + mock_delta_2 = unittest.mock.Mock( + reasoning_content="\nI'm thinking", + content=None, + tool_calls=None, + ) + mock_delta_3 = unittest.mock.Mock( + content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1], reasoning_content=None + ) + + mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) + mock_delta_4 = unittest.mock.Mock( + content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2], reasoning_content=None + ) + + mock_delta_5 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)]) + mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)]) + mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)]) + mock_event_6 = unittest.mock.Mock() + + openai_client.chat.completions.create.return_value = iter( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6] + ) + + request = {"model": "deepseek-ai/DeepSeek-R1-0528", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} + response = model.stream(request) + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "\nI'm thinking"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"chunk_type": "metadata", "data": mock_event_6.usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_stream_empty(openai_client, model): + mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) + mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + mock_event_4 = unittest.mock.Mock(usage=mock_usage) + + openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + + request = {"model": "deepseek-ai/DeepSeek-R1-0528", "messages": [{"role": "user", "content": []}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_stream_with_empty_choices(openai_client, model): + mock_delta = unittest.mock.Mock(content="content", tool_calls=None, reasoning_content=None) + mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + # Event with no choices attribute + mock_event_1 = unittest.mock.Mock(spec=[]) + + # Event with empty choices list + mock_event_2 = unittest.mock.Mock(choices=[]) + + # Valid event with content + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + + # Event with finish reason + mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + + # Final event with usage info + mock_event_5 = unittest.mock.Mock(usage=mock_usage) + + openai_client.chat.completions.create.return_value = iter( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] + ) + + request = {"model": "deepseek-ai/DeepSeek-R1-0528", "messages": [{"role": "user", "content": ["test"]}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_structured_output(openai_client, model, test_output_model_cls): + mock_parsed_response = test_output_model_cls(name="test", age=25) + mock_choice = unittest.mock.Mock() + mock_choice.message.parsed = mock_parsed_response + mock_response = unittest.mock.Mock(choices=[mock_choice]) + + openai_client.beta.chat.completions.parse.return_value = mock_response + + prompt = [{"role": "user", "content": [{"text": "test"}]}] + result = list(model.structured_output(test_output_model_cls, prompt)) + + assert len(result) == 1 + assert result[0]["output"] == mock_parsed_response + + openai_client.beta.chat.completions.parse.assert_called_once() + + +def test_structured_output_multiple_choices(openai_client, model, test_output_model_cls): + mock_choice_1 = unittest.mock.Mock() + mock_choice_2 = unittest.mock.Mock() + mock_response = unittest.mock.Mock(choices=[mock_choice_1, mock_choice_2]) + + openai_client.beta.chat.completions.parse.return_value = mock_response + + prompt = [{"role": "user", "content": [{"text": "test"}]}] + + with pytest.raises(ValueError, match="Multiple choices found in the Baseten response."): + list(model.structured_output(test_output_model_cls, prompt)) + + +def test_structured_output_no_valid_parsed(openai_client, model, test_output_model_cls): + mock_choice = unittest.mock.Mock() + mock_choice.message.parsed = None + mock_response = unittest.mock.Mock(choices=[mock_choice]) + + openai_client.beta.chat.completions.parse.return_value = mock_response + + prompt = [{"role": "user", "content": [{"text": "test"}]}] + + with pytest.raises(ValueError, match="No valid tool use or tool use input was found in the Baseten response."): + list(model.structured_output(test_output_model_cls, prompt)) \ No newline at end of file From 315281938526908db1c4ef7c2a808567fb221f26 Mon Sep 17 00:00:00 2001 From: Philip Kiely Date: Mon, 14 Jul 2025 17:49:14 -0700 Subject: [PATCH 2/2] Update Strands with async and fixed bug with dedicated --- src/strands/models/baseten.py | 353 ++++++++++++++++++-- tests/strands/models/test_baseten.py | 393 +++++++++++++++-------- tests_integ/models/providers.py | 12 + tests_integ/models/test_model_baseten.py | 385 ++++++++++++++++++++++ 4 files changed, 983 insertions(+), 160 deletions(-) create mode 100644 tests_integ/models/test_model_baseten.py diff --git a/src/strands/models/baseten.py b/src/strands/models/baseten.py index 72daa9967..15b6c4652 100644 --- a/src/strands/models/baseten.py +++ b/src/strands/models/baseten.py @@ -3,16 +3,21 @@ - Docs: https://docs.baseten.co/ """ +import base64 +import json import logging -from typing import Any, Generator, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast +import mimetypes +from typing import Any, AsyncGenerator, AsyncIterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import Messages -from ..types.models import OpenAIModel +from ..types.content import ContentBlock, Messages +from ..types.streaming import StreamEvent +from ..types.tools import ToolResult, ToolSpec, ToolUse +from .model import Model logger = logging.getLogger(__name__) @@ -29,7 +34,7 @@ def chat(self) -> Any: ... -class BasetenModel(OpenAIModel): +class BasetenModel(Model): """Baseten model provider implementation.""" client: Client @@ -73,7 +78,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: elif "base_url" in self.config: client_args["base_url"] = self.config["base_url"] - self.client = openai.OpenAI(**client_args) + self.client = openai.AsyncOpenAI(**client_args) @override def update_config(self, **model_config: Unpack[BasetenConfig]) -> None: # type: ignore[override] @@ -93,38 +98,306 @@ def get_config(self) -> BasetenConfig: """ return cast(BasetenModel.BasetenConfig, self.config) - @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: - """Send the request to the Baseten model and get the streaming response. + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + """Format a Baseten compatible content block. + + Args: + content: Message content. + + Returns: + Baseten compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to a Baseten-compatible format. + """ + if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "file": { + "file_data": f"data:{mime_type};base64,{file_data}", + "filename": content["document"]["name"], + }, + "type": "file", + } + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @classmethod + def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + """Format a Baseten compatible tool call. Args: - request: The formatted request to send to the Baseten model. + tool_use: Tool use requested by the model. Returns: - An iterable of response events from the Baseten model. + Baseten compatible tool call. """ - response = self.client.chat.completions.create(**request) + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + @classmethod + def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + """Format a Baseten compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + Baseten compatible tool message. + """ + contents = cast( + list[ContentBlock], + [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ], + ) - yield {"chunk_type": "message_start"} - yield {"chunk_type": "content_start", "data_type": "text"} + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": [cls.format_request_message_content(content) for content in contents], + } + + def format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format a Baseten compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A Baseten compatible messages array. + """ + formatted_messages: list[dict[str, Any]] + formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + + # Check if this is a dedicated deployment (empty model_id) + is_dedicated_deployment = not self.get_config()["model_id"] + + for message in messages: + contents = message["content"] + + formatted_contents = [ + self.format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + self.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content + ] + formatted_tool_messages = [ + self.format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + # For dedicated deployments, convert content to string if it's a single text block + if is_dedicated_deployment and formatted_contents and len(formatted_contents) == 1 and "text" in formatted_contents[0]: + content_value = formatted_contents[0]["text"] + else: + content_value = formatted_contents + + formatted_message = { + "role": message["role"], + "content": content_value, + **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format a Baseten compatible request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A Baseten compatible request dictionary. + """ + request = { + "messages": self.format_request_messages(messages, system_prompt), + "stream": True, + "stream_options": {"include_usage": True}, + **cast(dict[str, Any], self.config.get("params", {})), + } + + # Only include tools if tool_specs is provided + if tool_specs: + request["tools"] = [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs + ] + + # Include model field - use actual model_id for Model APIs, placeholder for dedicated deployments + model_id = self.get_config()["model_id"] + if model_id: + request["model"] = model_id + else: + # For dedicated deployments, use a placeholder model name + request["model"] = "placeholder" + + return request + + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format a Baseten response event into a standardized message chunk. + + Args: + event: A response event from the Baseten model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the Baseten model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + """ + logger.debug("formatting request") + request = self.format_request(messages, tool_specs, system_prompt) + logger.debug("formatted request=<%s>", request) + + logger.debug("invoking model") + response = await self.client.chat.completions.create(**request) + + logger.debug("got response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) tool_calls: dict[int, list[Any]] = {} - for event in response: + async for event in response: # Defensive: skip events with empty or missing choices if not getattr(event, "choices", None): continue choice = event.choices[0] if choice.delta.content: - yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + ) if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: - yield { - "chunk_type": "content_delta", - "data_type": "reasoning_content", - "data": choice.delta.reasoning_content, - } + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": choice.delta.reasoning_content, + } + ) for tool_call in choice.delta.tool_calls or []: tool_calls.setdefault(tool_call.index, []).append(tool_call) @@ -132,41 +405,53 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: if choice.finish_reason: break - yield {"chunk_type": "content_stop", "data_type": "text"} + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) for tool_deltas in tool_calls.values(): - yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) for tool_delta in tool_deltas: - yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) - yield {"chunk_type": "content_stop", "data_type": "tool"} + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) - yield {"chunk_type": "message_stop", "data": choice.finish_reason} + yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) # Skip remaining events as we don't have use for anything except the final usage payload - for event in response: + async for event in response: _ = event - yield {"chunk_type": "metadata", "data": event.usage} + if event.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + + logger.debug("finished streaming response from model") @override - def structured_output( - self, output_model: Type[T], prompt: Messages - ) -> Generator[dict[str, Union[T, Any]], None, None]: + async def structured_output( + self, output_model: Type[T], prompt: Messages, **kwargs: Any + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: """Get structured output from the model. Args: output_model: The output model to use for the agent. prompt: The prompt messages to use for the agent. + **kwargs: Additional keyword arguments for future extensibility. Yields: Model events with the last being the structured output. """ - response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore - model=self.get_config()["model_id"], - messages=super().format_request(prompt)["messages"], - response_format=output_model, + model_id = self.get_config()["model_id"] + parse_kwargs = { + "messages": self.format_request_messages(prompt), + "response_format": output_model, + } + + # Only include model field if model_id is not empty (for dedicated deployments) + if model_id: + parse_kwargs["model"] = model_id + + response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore + **parse_kwargs ) parsed: T | None = None diff --git a/tests/strands/models/test_baseten.py b/tests/strands/models/test_baseten.py index 52867ac09..0ce28b0c1 100644 --- a/tests/strands/models/test_baseten.py +++ b/tests/strands/models/test_baseten.py @@ -9,7 +9,7 @@ @pytest.fixture def openai_client_cls(): - with unittest.mock.patch.object(strands.models.baseten.openai, "OpenAI") as mock_client_cls: + with unittest.mock.patch.object(strands.models.baseten.openai, "AsyncOpenAI") as mock_client_cls: yield mock_client_cls @@ -61,49 +61,19 @@ def test__init__model_apis(openai_client_cls, model_id): def test__init__dedicated_deployment(openai_client_cls): - deployment_id = "dq4kr413" - environment = "production" - base_url = f"https://model-{deployment_id}.api.baseten.co/environments/{environment}/sync/v1" + base_url = "https://model-abcd1234.api.baseten.co/environments/production/sync/v1" model = BasetenModel( {"api_key": "k1"}, - model_id=deployment_id, + model_id="abcd1234", base_url=base_url, - environment=environment, params={"max_tokens": 1} ) tru_config = model.get_config() exp_config = { - "model_id": "dq4kr413", + "model_id": "abcd1234", "base_url": base_url, - "environment": "production", - "params": {"max_tokens": 1} - } - - assert tru_config == exp_config - - openai_client_cls.assert_called_once_with(api_key="k1", base_url=base_url) - - -def test__init__dedicated_deployment_custom_environment(openai_client_cls): - deployment_id = "dq4kr413" - environment = "staging" - base_url = f"https://model-{deployment_id}.api.baseten.co/environments/{environment}/sync/v1" - - model = BasetenModel( - {"api_key": "k1"}, - model_id=deployment_id, - base_url=base_url, - environment=environment, - params={"max_tokens": 1} - ) - - tru_config = model.get_config() - exp_config = { - "model_id": "dq4kr413", - "base_url": base_url, - "environment": "staging", "params": {"max_tokens": 1} } @@ -131,69 +101,228 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id -def test_stream(openai_client, model): - mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) - mock_delta_1 = unittest.mock.Mock( - reasoning_content="", - content=None, - tool_calls=None, - ) - mock_delta_2 = unittest.mock.Mock( - reasoning_content="\nI'm thinking", - content=None, - tool_calls=None, - ) - mock_delta_3 = unittest.mock.Mock( - content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1], reasoning_content=None +def test_format_request_message_content_text(): + """Test formatting text content blocks.""" + content = {"text": "Hello, world!"} + result = BasetenModel.format_request_message_content(content) + assert result == {"text": "Hello, world!", "type": "text"} + + +def test_format_request_message_content_document(): + """Test formatting document content blocks.""" + content = { + "document": { + "name": "test.pdf", + "format": "pdf", + "source": {"bytes": b"test content"} + } + } + result = BasetenModel.format_request_message_content(content) + assert result["type"] == "file" + assert result["file"]["filename"] == "test.pdf" + + +def test_format_request_message_content_image(): + """Test formatting image content blocks.""" + content = { + "image": { + "format": "png", + "source": {"bytes": b"test image"} + } + } + result = BasetenModel.format_request_message_content(content) + assert result["type"] == "image_url" + assert "image_url" in result + + +def test_format_request_message_content_unsupported(): + """Test handling unsupported content types.""" + content = {"unsupported": "data"} + with pytest.raises(TypeError, match="unsupported type"): + BasetenModel.format_request_message_content(content) + + +def test_format_request_messages_simple(): + """Test formatting simple messages.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + result = BasetenModel.format_request_messages(messages) + assert len(result) == 1 + assert result[0]["role"] == "user" + assert result[0]["content"] == [{"text": "Hello", "type": "text"}] + + +def test_format_request_messages_with_system_prompt(): + """Test formatting messages with system prompt.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant." + result = BasetenModel.format_request_messages(messages, system_prompt) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[0]["content"] == system_prompt + + +def test_format_request_messages_with_tool_use(): + """Test formatting messages with tool use.""" + messages = [{ + "role": "assistant", + "content": [ + {"text": "I'll help you"}, + {"toolUse": {"name": "calculator", "input": {"a": 1, "b": 2}, "toolUseId": "call_1"}} + ] + }] + result = BasetenModel.format_request_messages(messages) + assert len(result) == 1 + assert result[0]["role"] == "assistant" + assert "tool_calls" in result[0] + + +def test_format_request_messages_with_tool_result(): + """Test formatting messages with tool result.""" + messages = [{ + "role": "tool", + "content": [ + {"toolResult": {"toolUseId": "call_1", "content": [{"json": {"result": 3}}]}} + ] + }] + result = BasetenModel.format_request_messages(messages) + assert len(result) == 1 + assert result[0]["role"] == "tool" + assert result[0]["tool_call_id"] == "call_1" + + +@pytest.mark.asyncio +async def test_stream_model_apis(openai_client): + """Test streaming with Model APIs.""" + model = BasetenModel( + {"api_key": "k1"}, + model_id="deepseek-ai/DeepSeek-R1-0528" ) + + mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None) + mock_event = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + + async def async_iter(): + yield mock_event + yield unittest.mock.Mock(usage=mock_usage) - mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) - mock_delta_4 = unittest.mock.Mock( - content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2], reasoning_content=None + openai_client.chat.completions.create.return_value = async_iter() + + messages = [{"role": "user", "content": [{"text": "calculate 2+2"}]}] + response = model.stream(messages) + tru_events = [] + async for event in response: + tru_events.append(event) + + # Check that the first few events match expected format + assert len(tru_events) > 0 + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + assert tru_events[1] == {"contentBlockStart": {"start": {}}} + + # Verify the API was called with correct parameters + openai_client.chat.completions.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_stream_dedicated_deployment(openai_client): + """Test streaming with dedicated deployment.""" + base_url = "https://model-abcd1234.api.baseten.co/environments/production/sync/v1" + + model = BasetenModel( + {"api_key": "k1"}, + model_id="abcd1234", + base_url=base_url ) + + mock_delta = unittest.mock.Mock(content="Response", tool_calls=None, reasoning_content=None) + mock_event = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_usage = unittest.mock.Mock(prompt_tokens=5, completion_tokens=3, total_tokens=8) - mock_delta_5 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None) + async def async_iter(): + yield mock_event + yield unittest.mock.Mock(usage=mock_usage) - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) - mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)]) - mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)]) - mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)]) - mock_event_6 = unittest.mock.Mock() + openai_client.chat.completions.create.return_value = async_iter() - openai_client.chat.completions.create.return_value = iter( - [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6] + messages = [{"role": "user", "content": [{"text": "Test"}]}] + response = model.stream(messages) + + tru_events = [] + async for event in response: + tru_events.append(event) + + assert len(tru_events) > 0 + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + openai_client.chat.completions.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_stream_with_tools(openai_client, model): + """Test streaming with tool specifications.""" + mock_tool_call = unittest.mock.Mock(index=0) + mock_delta = unittest.mock.Mock( + content="I'll calculate", + tool_calls=[mock_tool_call], + reasoning_content=None ) + mock_event = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta)]) - request = {"model": "deepseek-ai/DeepSeek-R1-0528", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} - response = model.stream(request) - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "\nI'm thinking"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, - {"chunk_type": "content_stop", "data_type": "tool"}, - {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, - {"chunk_type": "content_stop", "data_type": "tool"}, - {"chunk_type": "message_stop", "data": "tool_calls"}, - {"chunk_type": "metadata", "data": mock_event_6.usage}, - ] - - assert tru_events == exp_events - openai_client.chat.completions.create.assert_called_once_with(**request) - - -def test_stream_empty(openai_client, model): + async def async_iter(): + yield mock_event + yield unittest.mock.Mock() + + openai_client.chat.completions.create.return_value = async_iter() + + messages = [{"role": "user", "content": [{"text": "Calculate 2+2"}]}] + tool_specs = [{ + "name": "calculator", + "description": "A calculator tool", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}} + }] + + response = model.stream(messages, tool_specs) + + tru_events = [] + async for event in response: + tru_events.append(event) + + assert len(tru_events) > 0 + # Verify the request included tools + call_args = openai_client.chat.completions.create.call_args + assert "tools" in call_args[1] + + +@pytest.mark.asyncio +async def test_stream_with_system_prompt(openai_client, model): + """Test streaming with system prompt.""" + mock_delta = unittest.mock.Mock(content="Response", tool_calls=None, reasoning_content=None) + mock_event = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + + async def async_iter(): + yield mock_event + yield unittest.mock.Mock() + + openai_client.chat.completions.create.return_value = async_iter() + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant." + + response = model.stream(messages, system_prompt=system_prompt) + + tru_events = [] + async for event in response: + tru_events.append(event) + + assert len(tru_events) > 0 + # Verify the request included system prompt + call_args = openai_client.chat.completions.create.call_args + messages_arg = call_args[1]["messages"] + assert messages_arg[0]["role"] == "system" + assert messages_arg[0]["content"] == system_prompt + + +@pytest.mark.asyncio +async def test_stream_empty(openai_client, model): mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) @@ -202,25 +331,30 @@ def test_stream_empty(openai_client, model): mock_event_3 = unittest.mock.Mock() mock_event_4 = unittest.mock.Mock(usage=mock_usage) - openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + # Create async iterator for the response + async def async_iter(): + for event in [mock_event_1, mock_event_2, mock_event_3, mock_event_4]: + yield event + + openai_client.chat.completions.create.return_value = async_iter() - request = {"model": "deepseek-ai/DeepSeek-R1-0528", "messages": [{"role": "user", "content": []}]} - response = model.stream(request) + messages = [{"role": "user", "content": []}] + response = model.stream(messages) - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "message_stop", "data": "stop"}, - {"chunk_type": "metadata", "data": mock_usage}, - ] + tru_events = [] + async for event in response: + tru_events.append(event) - assert tru_events == exp_events - openai_client.chat.completions.create.assert_called_once_with(**request) + # Check that we get the expected events + assert len(tru_events) > 0 + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + assert tru_events[1] == {"contentBlockStart": {"start": {}}} + + openai_client.chat.completions.create.assert_called_once() -def test_stream_with_empty_choices(openai_client, model): +@pytest.mark.asyncio +async def test_stream_with_empty_choices(openai_client, model): mock_delta = unittest.mock.Mock(content="content", tool_calls=None, reasoning_content=None) mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -239,29 +373,30 @@ def test_stream_with_empty_choices(openai_client, model): # Final event with usage info mock_event_5 = unittest.mock.Mock(usage=mock_usage) - openai_client.chat.completions.create.return_value = iter( - [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] - ) + # Create async iterator for the response + async def async_iter(): + for event in [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5]: + yield event - request = {"model": "deepseek-ai/DeepSeek-R1-0528", "messages": [{"role": "user", "content": ["test"]}]} - response = model.stream(request) + openai_client.chat.completions.create.return_value = async_iter() - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "message_stop", "data": "stop"}, - {"chunk_type": "metadata", "data": mock_usage}, - ] + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) - assert tru_events == exp_events - openai_client.chat.completions.create.assert_called_once_with(**request) + tru_events = [] + async for event in response: + tru_events.append(event) + + # Check that we get the expected events + assert len(tru_events) > 0 + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + assert tru_events[1] == {"contentBlockStart": {"start": {}}} + + openai_client.chat.completions.create.assert_called_once() -def test_structured_output(openai_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output(openai_client, model, test_output_model_cls): mock_parsed_response = test_output_model_cls(name="test", age=25) mock_choice = unittest.mock.Mock() mock_choice.message.parsed = mock_parsed_response @@ -270,7 +405,9 @@ def test_structured_output(openai_client, model, test_output_model_cls): openai_client.beta.chat.completions.parse.return_value = mock_response prompt = [{"role": "user", "content": [{"text": "test"}]}] - result = list(model.structured_output(test_output_model_cls, prompt)) + result = [] + async for event in model.structured_output(test_output_model_cls, prompt): + result.append(event) assert len(result) == 1 assert result[0]["output"] == mock_parsed_response @@ -278,7 +415,8 @@ def test_structured_output(openai_client, model, test_output_model_cls): openai_client.beta.chat.completions.parse.assert_called_once() -def test_structured_output_multiple_choices(openai_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output_multiple_choices(openai_client, model, test_output_model_cls): mock_choice_1 = unittest.mock.Mock() mock_choice_2 = unittest.mock.Mock() mock_response = unittest.mock.Mock(choices=[mock_choice_1, mock_choice_2]) @@ -288,10 +426,12 @@ def test_structured_output_multiple_choices(openai_client, model, test_output_mo prompt = [{"role": "user", "content": [{"text": "test"}]}] with pytest.raises(ValueError, match="Multiple choices found in the Baseten response."): - list(model.structured_output(test_output_model_cls, prompt)) + async for _ in model.structured_output(test_output_model_cls, prompt): + pass -def test_structured_output_no_valid_parsed(openai_client, model, test_output_model_cls): +@pytest.mark.asyncio +async def test_structured_output_no_valid_parsed(openai_client, model, test_output_model_cls): mock_choice = unittest.mock.Mock() mock_choice.message.parsed = None mock_response = unittest.mock.Mock(choices=[mock_choice]) @@ -301,4 +441,5 @@ def test_structured_output_no_valid_parsed(openai_client, model, test_output_mod prompt = [{"role": "user", "content": [{"text": "test"}]}] with pytest.raises(ValueError, match="No valid tool use or tool use input was found in the Baseten response."): - list(model.structured_output(test_output_model_cls, prompt)) \ No newline at end of file + async for _ in model.structured_output(test_output_model_cls, prompt): + pass \ No newline at end of file diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 543f58480..5c9b43331 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -10,6 +10,7 @@ from strands.models import BedrockModel, Model from strands.models.anthropic import AnthropicModel +from strands.models.baseten import BasetenModel from strands.models.litellm import LiteLLMModel from strands.models.llamaapi import LlamaAPIModel from strands.models.mistral import MistralModel @@ -69,6 +70,16 @@ def __init__(self): max_tokens=512, ), ) +baseten = ProviderInfo( + id="baseten", + environment_variable="BASETEN_API_KEY", + factory=lambda: BasetenModel( + model_id="deepseek-ai/DeepSeek-V3-0324", + client_args={ + "api_key": os.getenv("BASETEN_API_KEY"), + }, + ), +) bedrock = ProviderInfo(id="bedrock", factory=lambda: BedrockModel()) cohere = ProviderInfo( id="cohere", @@ -133,6 +144,7 @@ def __init__(self): all_providers = [ bedrock, anthropic, + baseten, cohere, llama, litellm, diff --git a/tests_integ/models/test_model_baseten.py b/tests_integ/models/test_model_baseten.py new file mode 100644 index 000000000..5cc4b0a22 --- /dev/null +++ b/tests_integ/models/test_model_baseten.py @@ -0,0 +1,385 @@ +"""Integration tests for Baseten model provider.""" + +import os +import pytest +from typing import AsyncGenerator + +from strands.models.baseten import BasetenModel +from strands.types.content import Messages +from strands.types.streaming import StreamEvent + + +@pytest.fixture +def baseten_model_apis(): + """Create a BasetenModel instance for Model APIs testing.""" + return BasetenModel( + model_id="deepseek-ai/DeepSeek-R1-0528", + client_args={ + "api_key": os.getenv("BASETEN_API_KEY"), + }, + ) + + +@pytest.fixture +def baseten_dedicated_deployment(): + """Create a BasetenModel instance for dedicated deployment testing.""" + # This would need a real base URL + base_url = os.getenv("BASETEN_BASE_URL", "https://model-test-deployment.api.baseten.co/environments/production/sync/v1") + + return BasetenModel( + model_id="test-deployment", + base_url=base_url, + client_args={ + "api_key": os.getenv("BASETEN_API_KEY"), + }, + ) + + +@pytest.mark.asyncio +async def test_baseten_model_apis_streaming(baseten_model_apis): + """Test streaming with Baseten Model APIs.""" + if not os.getenv("BASETEN_API_KEY"): + pytest.skip("BASETEN_API_KEY not set") + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello, how are you?"}]} + ] + + events = [] + async for event in baseten_model_apis.stream(messages): + events.append(event) + + # Verify we get the expected event types + assert len(events) > 0 + assert events[0]["messageStart"]["role"] == "assistant" + + # Check for content events + content_events = [e for e in events if "contentBlockDelta" in e] + assert len(content_events) > 0 + + # Check for message stop + stop_events = [e for e in events if "messageStop" in e] + assert len(stop_events) > 0 + + +@pytest.mark.asyncio +async def test_baseten_model_apis_with_system_prompt(baseten_model_apis): + """Test streaming with Baseten Model APIs and system prompt.""" + if not os.getenv("BASETEN_API_KEY"): + pytest.skip("BASETEN_API_KEY not set") + + messages: Messages = [ + {"role": "user", "content": [{"text": "What is 2+2?"}]} + ] + system_prompt = "You are a helpful math assistant. Always provide clear explanations." + + events = [] + async for event in baseten_model_apis.stream(messages, system_prompt=system_prompt): + events.append(event) + + assert len(events) > 0 + assert events[0]["messageStart"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_baseten_model_apis_structured_output(baseten_model_apis): + """Test structured output with Baseten Model APIs.""" + if not os.getenv("BASETEN_API_KEY"): + pytest.skip("BASETEN_API_KEY not set") + + from pydantic import BaseModel + + class MathResult(BaseModel): + answer: int + explanation: str + + messages: Messages = [ + {"role": "user", "content": [{"text": "What is 5 + 3? Provide the answer as a number and explain your reasoning."}]} + ] + + results = [] + async for result in baseten_model_apis.structured_output(MathResult, messages): + results.append(result) + + assert len(results) == 1 + assert "output" in results[0] + assert isinstance(results[0]["output"], MathResult) + assert results[0]["output"].answer == 8 + + +@pytest.mark.asyncio +async def test_baseten_model_apis_with_tools(baseten_model_apis): + """Test streaming with Baseten Model APIs and tools.""" + if not os.getenv("BASETEN_API_KEY"): + pytest.skip("BASETEN_API_KEY not set") + + messages: Messages = [ + {"role": "user", "content": [{"text": "Calculate 10 + 20"}]} + ] + + tool_specs = [ + { + "name": "calculator", + "description": "A simple calculator that can perform basic arithmetic", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "The mathematical expression to evaluate" + } + }, + "required": ["expression"] + } + } + } + ] + + events = [] + async for event in baseten_model_apis.stream(messages, tool_specs=tool_specs): + events.append(event) + + assert len(events) > 0 + assert events[0]["messageStart"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_baseten_model_apis_complex_messages(baseten_model_apis): + """Test streaming with complex message structures.""" + if not os.getenv("BASETEN_API_KEY"): + pytest.skip("BASETEN_API_KEY not set") + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there! How can I help you today?"}]}, + {"role": "user", "content": [{"text": "What's the weather like?"}]} + ] + + events = [] + async for event in baseten_model_apis.stream(messages): + events.append(event) + + assert len(events) > 0 + assert events[0]["messageStart"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_baseten_dedicated_deployment_streaming(baseten_dedicated_deployment): + """Test streaming with Baseten dedicated deployment.""" + if not os.getenv("BASETEN_API_KEY") or not os.getenv("BASETEN_BASE_URL"): + pytest.skip("BASETEN_API_KEY or BASETEN_BASE_URL not set") + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello from dedicated deployment!"}]} + ] + + events = [] + async for event in baseten_dedicated_deployment.stream(messages): + events.append(event) + + assert len(events) > 0 + assert events[0]["messageStart"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_baseten_dedicated_deployment_structured_output(baseten_dedicated_deployment): + """Test structured output with Baseten dedicated deployment.""" + if not os.getenv("BASETEN_API_KEY") or not os.getenv("BASETEN_BASE_URL"): + pytest.skip("BASETEN_API_KEY or BASETEN_BASE_URL not set") + + from pydantic import BaseModel + + class SimpleResponse(BaseModel): + message: str + + messages: Messages = [ + {"role": "user", "content": [{"text": "Say hello"}]} + ] + + results = [] + async for result in baseten_dedicated_deployment.structured_output(SimpleResponse, messages): + results.append(result) + + assert len(results) == 1 + assert "output" in results[0] + assert isinstance(results[0]["output"], SimpleResponse) + + +def test_baseten_config_management(): + """Test configuration management for BasetenModel.""" + model = BasetenModel( + model_id="test-model", + params={"max_tokens": 100, "temperature": 0.7} + ) + + # Test initial config + config = model.get_config() + assert config["model_id"] == "test-model" + assert config["params"]["max_tokens"] == 100 + assert config["params"]["temperature"] == 0.7 + + # Test config update + model.update_config(params={"max_tokens": 200, "temperature": 0.5}) + updated_config = model.get_config() + assert updated_config["params"]["max_tokens"] == 200 + assert updated_config["params"]["temperature"] == 0.5 + + +def test_baseten_model_apis_configuration(): + """Test Model APIs configuration.""" + model = BasetenModel( + model_id="deepseek-ai/DeepSeek-R1-0528", + client_args={"api_key": "test-key"} + ) + + config = model.get_config() + assert config["model_id"] == "deepseek-ai/DeepSeek-R1-0528" + # Should use default base URL for Model APIs + assert "base_url" not in config + + +def test_baseten_dedicated_deployment_configuration(): + """Test dedicated deployment configuration.""" + base_url = "https://model-test-deployment.api.baseten.co/environments/production/sync/v1" + + model = BasetenModel( + model_id="test-deployment", + base_url=base_url, + client_args={"api_key": "test-key"} + ) + + config = model.get_config() + assert config["model_id"] == "test-deployment" + assert config["base_url"] == base_url + + +def test_baseten_message_formatting(): + """Test message formatting methods.""" + # Test text content formatting + text_content = {"text": "Hello, world!"} + formatted = BasetenModel.format_request_message_content(text_content) + assert formatted == {"text": "Hello, world!", "type": "text"} + + # Test document content formatting + doc_content = { + "document": { + "name": "test.pdf", + "format": "pdf", + "source": {"bytes": b"test content"} + } + } + formatted = BasetenModel.format_request_message_content(doc_content) + assert formatted["type"] == "file" + assert formatted["file"]["filename"] == "test.pdf" + + # Test image content formatting + img_content = { + "image": { + "format": "png", + "source": {"bytes": b"test image"} + } + } + formatted = BasetenModel.format_request_message_content(img_content) + assert formatted["type"] == "image_url" + + +def test_baseten_message_formatting_with_tools(): + """Test message formatting with tool use and tool results.""" + # Test tool use formatting + tool_use = { + "name": "calculator", + "input": {"expression": "2+2"}, + "toolUseId": "call_1" + } + formatted = BasetenModel.format_request_message_tool_call(tool_use) + assert formatted["function"]["name"] == "calculator" + assert formatted["id"] == "call_1" + + # Test tool result formatting + tool_result = { + "toolUseId": "call_1", + "content": [{"json": {"result": 4}}] + } + formatted = BasetenModel.format_request_tool_message(tool_result) + assert formatted["role"] == "tool" + assert formatted["tool_call_id"] == "call_1" + + +def test_baseten_messages_formatting(): + """Test complete message formatting.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]} + ] + + formatted = BasetenModel.format_request_messages(messages) + assert len(formatted) == 2 + assert formatted[0]["role"] == "user" + assert formatted[1]["role"] == "assistant" + + +def test_baseten_messages_formatting_with_system_prompt(): + """Test message formatting with system prompt.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant." + + formatted = BasetenModel.format_request_messages(messages, system_prompt) + assert len(formatted) == 2 + assert formatted[0]["role"] == "system" + assert formatted[0]["content"] == system_prompt + assert formatted[1]["role"] == "user" + + +def test_baseten_request_formatting(): + """Test complete request formatting.""" + model = BasetenModel(model_id="test-model") + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object"}} + } + ] + + request = model.format_request(messages, tool_specs, "You are helpful") + + assert request["model"] == "test-model" + assert request["stream"] is True + assert "tools" in request + assert len(request["tools"]) == 1 + assert request["tools"][0]["function"]["name"] == "test_tool" + + +def test_baseten_chunk_formatting(): + """Test response chunk formatting.""" + model = BasetenModel(model_id="test-model") + + # Test message start + chunk = model.format_chunk({"chunk_type": "message_start"}) + assert chunk == {"messageStart": {"role": "assistant"}} + + # Test content start + chunk = model.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + assert chunk == {"contentBlockStart": {"start": {}}} + + # Test content delta + chunk = model.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": "Hello"}) + assert chunk == {"contentBlockDelta": {"delta": {"text": "Hello"}}} + + # Test content stop + chunk = model.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + assert chunk == {"contentBlockStop": {}} + + # Test message stop + chunk = model.format_chunk({"chunk_type": "message_stop", "data": "stop"}) + assert chunk == {"messageStop": {"stopReason": "end_turn"}} + + +def test_baseten_unsupported_content_type(): + """Test handling of unsupported content types.""" + unsupported_content = {"unsupported": "data"} + + with pytest.raises(TypeError, match="unsupported type"): + BasetenModel.format_request_message_content(unsupported_content) \ No newline at end of file