Skip to content

Commit

Permalink
Support AssistantDefinition contract in executor (#1657)
Browse files Browse the repository at this point in the history
# Description
This pull request supported `AssistantDefinition` contract in executor.
The most important changes include introducing a new
`AssistantDefinition` class and modifying methods to handle conversions
and placeholders.

Main interface changes:

*
[`src/promptflow/promptflow/contracts/types.py`](diffhunk://#diff-3cac7adc96dc2969cfdb50d8b19e5fdedbb4d38bad5a7821405df91bd02e6078R32-R60):
Added a new `AssistantDefinition` class.
*
[`src/promptflow/promptflow/executor/_tool_resolver.py`](diffhunk://#diff-714d8202b40acb4053e3f9b366ee4972b32f98afc8a2efe8a1750842f1facc65R98-R112):
Added a new method `_convert_to_assistant_definition` and made changes
to handle the conversion of values to assistant definitions.
*
[`src/promptflow/promptflow/executor/_assistant_tool_invoker.py`](diffhunk://#diff-a9ca3b4ac0f9cc1667221699844a2528d92d4e226464b7e862cc0707e4baf101L27-R33):
Added a new `setup` method to the `AssistantToolInvoker` class.
*
[`src/promptflow/promptflow/contracts/tool.py`](diffhunk://#diff-96d10cc01c636338eedcf0287bd0a5ab57c9e7e8d8fac4619aaa51e5d36de9abL14-R14):
Added support for the `AssistantDefinition` class as an input type in
the `tool` decorator.

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [x] Title of the pull request is clear and informative.
- [x] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [x] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Lina Tang <[email protected]>
  • Loading branch information
lumoslnt and Lina Tang authored Jan 8, 2024
1 parent fb625ab commit 2938296
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 38 deletions.
7 changes: 6 additions & 1 deletion src/promptflow/promptflow/contracts/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from promptflow._constants import CONNECTION_NAME_PROPERTY

from .multimedia import Image
from .types import FilePath, PromptTemplate, Secret
from .types import AssistantDefinition, FilePath, PromptTemplate, Secret

logger = logging.getLogger(__name__)
T = TypeVar("T", bound="Enum")
Expand Down Expand Up @@ -39,6 +39,7 @@ class ValueType(str, Enum):
OBJECT = "object"
FILE_PATH = "file_path"
IMAGE = "image"
ASSISTANT_DEFINITION = "assistant_definition"

@staticmethod
def from_value(t: Any) -> "ValueType":
Expand Down Expand Up @@ -67,6 +68,8 @@ def from_value(t: Any) -> "ValueType":
return ValueType.STRING
if isinstance(t, list):
return ValueType.LIST
if isinstance(t, AssistantDefinition):
return ValueType.ASSISTANT_DEFINITION
return ValueType.OBJECT

@staticmethod
Expand Down Expand Up @@ -97,6 +100,8 @@ def from_type(t: type) -> "ValueType":
return ValueType.FILE_PATH
if t == Image:
return ValueType.IMAGE
if t == AssistantDefinition:
return ValueType.ASSISTANT_DEFINITION
return ValueType.OBJECT

def parse(self, v: Any) -> Any: # noqa: C901
Expand Down
31 changes: 31 additions & 0 deletions src/promptflow/promptflow/contracts/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from dataclasses import dataclass


class Secret(str):
"""This class is used to hint a parameter is a secret to load."""
Expand All @@ -25,3 +27,32 @@ class FilePath(str):
"""This class is used to hint a parameter is a file path."""

pass


@dataclass
class AssistantDefinition:
"""This class is used to define an assistant definition."""

model: str
instructions: str
tools: list

@staticmethod
def deserialize(data: dict) -> "AssistantDefinition":
return AssistantDefinition(
model=data.get("model", ""),
instructions=data.get("instructions", ""),
tools=data.get("tools", [])
)

def serialize(self):
return {
"model": self.model,
"instructions": self.instructions,
"tools": self.tools,
}

def init_tool_invoker(self):
from promptflow.executor._assistant_tool_invoker import AssistantToolInvoker

return AssistantToolInvoker.init(self.tools)
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ def __init__(self, working_dir: Optional[Path] = None):
self._working_dir = working_dir or Path(os.getcwd())
self._assistant_tools: Dict[str, AssistantTool] = {}

def load_tools(self, tools: list):
@classmethod
def init(cls, tools: list, working_dir: Optional[Path] = None):
invoker = cls(working_dir=working_dir)
invoker._load_tools(tools)
return invoker

def _load_tools(self, tools: list):
for tool in tools:
if tool["type"] in ("code_interpreter", "retrieval"):
self._assistant_tools[tool["type"]] = AssistantTool(
Expand Down
29 changes: 28 additions & 1 deletion src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import inspect
import types
import yaml
from dataclasses import dataclass
from functools import partial
from pathlib import Path
Expand All @@ -17,7 +18,7 @@
from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType
from promptflow.contracts.types import PromptTemplate
from promptflow.contracts.types import AssistantDefinition, PromptTemplate
from promptflow.exceptions import ErrorTarget, PromptflowException, UserErrorException
from promptflow.executor._errors import (
ConnectionNotFound,
Expand Down Expand Up @@ -94,6 +95,21 @@ def _convert_to_custom_strong_type_connection_value(
module=module, to_class=custom_defined_connection_class_name
)

def _convert_to_assistant_definition(self, assistant_definition_path: str, input_name: str, node_name: str):
if assistant_definition_path is None or not (self._working_dir / assistant_definition_path).is_file():
raise InvalidSource(
target=ErrorTarget.EXECUTOR,
message_format="Input '{input_name}' for node '{node_name}' of value '{source_path}' "
"is not a valid path.",
input_name=input_name,
source_path=assistant_definition_path,
node_name=node_name,
)
file = self._working_dir / assistant_definition_path
with open(file, "r", encoding="utf-8") as file:
assistant_definition = yaml.safe_load(file)
return AssistantDefinition.deserialize(assistant_definition)

def _convert_node_literal_input_types(self, node: Node, tool: Tool, module: types.ModuleType = None):
updated_inputs = {
k: v
Expand Down Expand Up @@ -125,6 +141,17 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool, module: type
key=k, error_type_and_message=error_type_and_message,
target=ErrorTarget.EXECUTOR
) from e
elif value_type == ValueType.ASSISTANT_DEFINITION:
try:
updated_inputs[k].value = self._convert_to_assistant_definition(v.value, k, node.name)
except Exception as e:
error_type_and_message = f"({e.__class__.__name__}) {e}"
raise NodeInputValidationError(
message_format="Failed to load assistant definition from input '{key}': "
"{error_type_and_message}",
key=k, error_type_and_message=error_type_and_message,
target=ErrorTarget.EXECUTOR
) from e
elif isinstance(value_type, ValueType):
try:
updated_inputs[k].value = value_type.parse(v.value)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import multiprocessing
from types import GeneratorType

Expand All @@ -8,7 +9,7 @@
from promptflow.executor import FlowExecutor
from promptflow.executor._errors import ConnectionNotFound, InputTypeError, ResolveToolError

from ..utils import FLOW_ROOT, get_flow_sample_inputs, get_yaml_file
from ..utils import FLOW_ROOT, get_flow_folder, get_flow_sample_inputs, get_yaml_file

SAMPLE_FLOW = "web_classification_no_variants"

Expand Down Expand Up @@ -58,10 +59,12 @@ def skip_serp(self, flow_folder, dev_connections):
"connection_as_input",
"async_tools",
"async_tools_with_sync_tools",
"tool_with_assistant_definition",
],
)
def test_executor_exec_line(self, flow_folder, dev_connections):
self.skip_serp(flow_folder, dev_connections)
os.chdir(get_flow_folder(flow_folder))
executor = FlowExecutor.create(get_yaml_file(flow_folder), dev_connections)
flow_result = executor.exec_line(self.get_line_inputs())
assert not executor._run_tracker._flow_runs, "Flow runs in run tracker should be empty."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from promptflow.contracts.multimedia import Image
from promptflow.contracts.run_info import Status
from promptflow.contracts.tool import (
AssistantDefinition,
ConnectionType,
InputDefinition,
OutputDefinition,
Expand Down Expand Up @@ -66,6 +67,7 @@ class TestValueType:
(Secret("secret"), ValueType.SECRET),
(PromptTemplate("prompt"), ValueType.PROMPT_TEMPLATE),
(FilePath("file_path"), ValueType.FILE_PATH),
(AssistantDefinition("model", "instructions", []), ValueType.ASSISTANT_DEFINITION),
],
)
def test_from_value(self, value, expected):
Expand All @@ -84,6 +86,7 @@ def test_from_value(self, value, expected):
(PromptTemplate, ValueType.PROMPT_TEMPLATE),
(FilePath, ValueType.FILE_PATH),
(Image, ValueType.IMAGE),
(AssistantDefinition, ValueType.ASSISTANT_DEFINITION),
],
)
def test_from_type(self, value, expected):
Expand Down
15 changes: 14 additions & 1 deletion src/promptflow/tests/executor/unittests/contracts/test_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from promptflow.contracts.types import Secret, PromptTemplate, FilePath
from promptflow.contracts.types import AssistantDefinition, Secret, PromptTemplate, FilePath
from promptflow.executor._assistant_tool_invoker import AssistantToolInvoker


@pytest.mark.unittest
Expand All @@ -20,3 +21,15 @@ def test_prompt_template():
def test_file_path():
file_path = FilePath('my_file_path')
assert isinstance(file_path, str)


@pytest.mark.unittest
def test_assistant_definition():
data = {"model": "model", "instructions": "instructions", "tools": []}
assistant_definition = AssistantDefinition.deserialize(data)
assert isinstance(assistant_definition, AssistantDefinition)
assert assistant_definition.model == "model"
assert assistant_definition.instructions == "instructions"
assert assistant_definition.tools == []
assert assistant_definition.serialize() == data
assert isinstance(assistant_definition.init_tool_invoker(), AssistantToolInvoker)
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@

@pytest.mark.unittest
class TestAssistantToolInvoker:
@pytest.fixture
def invoker(self):
return AssistantToolInvoker(working_dir=Path(__file__).parent)

@pytest.fixture
def tool_definitions(self):
return [
Expand All @@ -28,7 +24,7 @@ def tool_definitions(self):
@pytest.mark.parametrize(
"predefined_inputs", [({}), ({"input_int": 1})]
)
def test_load_tools(self, invoker, predefined_inputs):
def test_load_tools(self, predefined_inputs):
input_int = 1
input_str = "test"
tool_definitions = [
Expand All @@ -43,7 +39,7 @@ def test_load_tools(self, invoker, predefined_inputs):
]

# Test load tools
invoker.load_tools(tool_definitions)
invoker = AssistantToolInvoker.init(tool_definitions, working_dir=Path(__file__).parent)
for tool_name, assistant_tool in invoker._assistant_tools.items():
assert tool_name in ("code_interpreter", "retrieval", "sample_tool")
assert assistant_tool.name == tool_name
Expand Down Expand Up @@ -86,10 +82,10 @@ def test_load_tools(self, invoker, predefined_inputs):
result = invoker.invoke_tool(func_name="sample_tool", kwargs=kwargs)
assert result == (input_int, input_str)

def test_load_tools_with_invalid_case(self, invoker):
def test_load_tools_with_invalid_case(self):
tool_definitions = [{"type": "invalid_type"}]
with pytest.raises(UnsupportedAssistantToolType) as exc_info:
invoker.load_tools(tool_definitions)
AssistantToolInvoker.init(tool_definitions)
assert "Unsupported assistant tool type" in exc_info.value.message

def _remove_predefined_inputs(self, value: any, predefined_inputs: list):
Expand Down
Loading

0 comments on commit 2938296

Please sign in to comment.