Skip to content

Commit

Permalink
moved up logic for required and default inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
zilto authored and elijahbenizzy committed Oct 23, 2024
1 parent f3262d7 commit 6ec2f3c
Showing 1 changed file with 96 additions and 37 deletions.
133 changes: 96 additions & 37 deletions burr/integrations/haystack.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,35 +33,72 @@ def __init__(
"""Create a Burr ``Action`` from a Haystack ``Component``.
:param component: Haystack ``Component`` to wrap
:param reads: State fields read and passed to ``Component.run()``
:param writes: State fields where results of ``Component.run()`` are written
:param name: Name of the action. Can be set later via ``.with_name()`` or in the
``ApplicationBuilder``.
:param reads: State fields read and passed to ``Component.run()``.
Use a mapping {socket: state_field} to rename Haystack input sockets (see example).
:param writes: State fields where results of ``Component.run()`` are written.
Use a mapping {state_field: socket} to rename Haystack output sockets (see example).
:param name: Name of the action. Can be set later via ``.with_name()``
or in ``ApplicationBuilder.with_actions()``.
:param bound_params: Parameters to bind to the `Component.run()` method.
Basic example:
.. code-block:: python
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from burr.core import ApplicationBuilder
from burr.integrations.haystack import HaystackAction
retrieve_documents = HaystackAction(
component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()),
name="retrieve_documents",
reads=["query_embedding"],
writes=["documents"],
)
app = (
ApplicationBuilder()
.with_actions(retrieve_documents)
.with_transitions("retrieve_documents", "retrieve_documents")
.with_entrypoint("retrieve_documents")
.build()
)
Pass the mapping ``{"foo": "state_field"}`` to read the value of ``state_field`` on the Burr state
and pass it to ``Component.run()`` as ``foo``.
.. code-block:: python
@component
class HaystackComponent:
@component.output_types()
def run(self, foo: int) -> dict:
return {}
HaystackAction(
component=HaystackComponent(),
reads={"foo": "state_field"},
writes=[]
)
Pass the mapping ``{"state_field": "bar"}`` to get the ``bar`` value from the results
of ``.run()`` and set the field ``state_field`` on the Burr state
.. code-block:: python
@component
class HaystackComponent:
@component.output_types(bar=int)
def run(self) -> dict:
return {"bar": 1}
HaystackAction(
component=HaystackComponent(),
reads=[],
writes={"state_field": "bar"}
)
Basic usage:
.. code-block:: python
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from burr.core import ApplicationBuilder
from burr.integrations.haystack import HaystackAction
retrieve_documents = HaystackAction(
component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()),
name="retrieve_documents",
reads=["query_embedding"],
writes=["documents"],
)
app = (
ApplicationBuilder()
.with_actions(retrieve_documents)
.with_transitions("retrieve_documents", "retrieve_documents")
.with_entrypoint("retrieve_documents")
.build()
)
"""
self._component = component
self._name = name
Expand Down Expand Up @@ -90,7 +127,12 @@ def __init__(

self._validate_output_sockets()

self._required_inputs, self._optional_inputs = self._get_required_and_optional_inputs()

def _validate_input_sockets(self) -> None:
"""Check that input socket names passed by the user match the Component's input sockets"""
# NOTE those are internal attributes, but we expect them be stable.
# reference: https://github.com/deepset-ai/haystack/blob/906177329bcc54f6946af361fcd3d0e334e6ce5f/haystack/core/component/component.py#L371
component_inputs = self._component.__haystack_input__._sockets_dict.keys()
for socket_name in self._input_socket_mapping.keys():
if socket_name not in component_inputs:
Expand All @@ -99,6 +141,9 @@ def _validate_input_sockets(self) -> None:
)

def _validate_output_sockets(self) -> None:
"""Check that output socket names passed by the user match the Component's output sockets"""
# NOTE those are internal attributes, but we expect them be stable.
# reference: https://github.com/deepset-ai/haystack/blob/906177329bcc54f6946af361fcd3d0e334e6ce5f/haystack/core/component/component.py#L449
component_outputs = self._component.__haystack_output__._sockets_dict.keys()
for socket_name in self._output_socket_mapping.values():
if socket_name not in component_outputs:
Expand All @@ -121,26 +166,40 @@ def writes(self) -> list[str]:
"""State fields where results of `Component.run()` are written."""
return self._writes

@property
def inputs(self) -> tuple[dict[str, str], dict[str, str]]:
"""Return dictionaries of required and optional inputs for `Component.run()`"""
required_inputs, optional_inputs = {}, {}
def _get_required_and_optional_inputs(self) -> tuple[list[str], list[str]]:
"""Iterate over Haystack Component input sockets and inspect default values.
If we expect the value to come from state or it's a bound parameter, skip this socket.
Otherwise, if it has a default value, it's optional.
"""
required_inputs, optional_inputs = [], []
# NOTE those are internal attributes, but we expect them be stable.
# reference: https://github.com/deepset-ai/haystack/blob/906177329bcc54f6946af361fcd3d0e334e6ce5f/haystack/core/component/component.py#L371
for socket_name, input_socket in self._component.__haystack_input__._sockets_dict.items():
state_field_name = self._input_socket_mapping.get(socket_name, socket_name)

# if we expect the value to come from state (previous actions) or it's a
# bound parameter, then this socket isn't a user-provided input
if state_field_name in self.reads or state_field_name in self._bound_params:
continue

# determine if input is required or optional based on the socket's default value
if input_socket.default_value == haystack_empty:
required_inputs[state_field_name] = input_socket.type
required_inputs.append(state_field_name)
else:
optional_inputs[state_field_name] = input_socket.type
optional_inputs.append(state_field_name)

return required_inputs, optional_inputs

@property
def inputs(self) -> list[str]:
"""Return a list of required inputs for ``Component.run()``
This corresponds to the Component's required input socket names.
"""
return self._required_inputs

@property
def optional_and_required_inputs(self) -> tuple[list[str], list[str]]:
"""Return a tuple of required and optional inputs for ``Component.run()``
This corresponds to the Component's required and optional input socket names.
"""
return self._required_inputs, self._optional_inputs

def run(self, state: State, **run_kwargs) -> dict[str, Any]:
"""Call the Haystack `Component.run()` method.
Expand Down

0 comments on commit 6ec2f3c

Please sign in to comment.