Skip to content

feat: Add tool support to STACKITChatGenerator #1964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/stackit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ concurrency:
env:
PYTHONUNBUFFERED: "1"
FORCE_COLOR: "1"
STACKIT: ${{ secrets.STACKIT_API_KEY }}
STACKIT_API_KEY: ${{ secrets.STACKIT_API_KEY }}

jobs:
run:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# SPDX-FileCopyrightText: 2025-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_to_dict
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import StreamingCallbackT
from haystack.tools import Tool, Toolset, serialize_tools_or_toolset
from haystack.utils import serialize_callable
from haystack.utils.auth import Secret

Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(
api_base_url: Optional[str] = "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1",
generation_kwargs: Optional[Dict[str, Any]] = None,
*,
tools: Optional[Union[List[Tool], Toolset]] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
http_client_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -74,6 +76,9 @@ def __init__(
events as they become available, with the stream terminated by a data: [DONE] message.
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
- `random_seed`: The seed to use for random sampling.
:param tools:
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
list of `Tool` objects or a `Toolset` instance.
:param timeout:
Timeout for STACKIT client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment
variable, or 30 seconds.
Expand All @@ -93,6 +98,7 @@ def __init__(
generation_kwargs=generation_kwargs,
timeout=timeout,
max_retries=max_retries,
tools=tools,
http_client_kwargs=http_client_kwargs,
)

Expand All @@ -108,14 +114,14 @@ def to_dict(self) -> Dict[str, Any]:
# if we didn't implement the to_dict method here then the to_dict method of the superclass would be used
# which would serialiaze some fields that we don't want to serialize (e.g. the ones we don't have in
# the __init__)
# it would be hard to maintain the compatibility as superclass changes
return default_to_dict(
self,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs,
api_key=self.api_key.to_dict(),
tools=serialize_tools_or_toolset(self.tools),
timeout=self.timeout,
max_retries=self.max_retries,
http_client_kwargs=self.http_client_kwargs,
Expand Down
78 changes: 77 additions & 1 deletion integrations/stackit/tests/test_stackit_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest
import pytz
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall
from haystack.tools import Tool
from haystack.utils.auth import Secret
from openai import OpenAIError
from openai.types import CompletionUsage
Expand All @@ -23,6 +24,24 @@ def chat_messages():
]


def weather(city: str):
"""Get weather for a given city."""
return f"The weather in {city} is sunny and 32°C"


@pytest.fixture
def tools():
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
tool = Tool(
name="weather",
description="useful to determine the weather in a given location",
parameters=tool_parameters,
function=weather,
)

return [tool]


@pytest.fixture
def mock_chat_completion():
"""
Expand Down Expand Up @@ -254,3 +273,60 @@ def __call__(self, chunk: StreamingChunk) -> None:

assert callback.counter > 1
assert "Paris" in callback.responses

@pytest.mark.skipif(
not os.environ.get("STACKIT_API_KEY", None),
reason="Export an env var called STACKIT_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_with_tools_and_response(self, tools):
"""
Integration test that the MistralChatGenerator component can run with tools and get a response.
"""
initial_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")]
component = STACKITChatGenerator(
# Only model that supports tool calls at the moment
# This one does indeed run, but for some reason the tool call is put into
# chat_completion.choices[0].message.content instead chat_completion.choices[0].message.tool_calls
model="cortecs/Llama-3.3-70B-Instruct-FP8-Dynamic",
tools=tools
)
results = component.run(
messages=initial_messages,
generation_kwargs={"tool_choice": "auto"}
)

assert len(results["replies"]) == 1

# Find the message with tool calls
tool_message = results["replies"][0]

assert isinstance(tool_message, ChatMessage)
tool_calls = tool_message.tool_calls
assert len(tool_calls) == 2
assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT)

for tool_call in tool_calls:
assert tool_call.id is not None
assert isinstance(tool_call, ToolCall)
assert tool_call.tool_name == "weather"

arguments = [tool_call.arguments for tool_call in tool_calls]
assert sorted(arguments, key=lambda x: x["city"]) == [{"city": "Berlin"}, {"city": "Paris"}]
assert tool_message.meta["finish_reason"] == "tool_calls"

new_messages = [
initial_messages[0],
tool_message,
ChatMessage.from_tool(tool_result="22° C and sunny", origin=tool_calls[0]),
ChatMessage.from_tool(tool_result="16° C and windy", origin=tool_calls[1]),
]
# Pass the tool result to the model to get the final response
results = component.run(new_messages)

assert len(results["replies"]) == 1
final_message = results["replies"][0]
assert final_message.is_from(ChatRole.ASSISTANT)
assert len(final_message.text) > 0
assert "paris" in final_message.text.lower()
assert "berlin" in final_message.text.lower()
Loading