diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py index 5594c0e63..a10fde5f1 100644 --- a/src/google/adk/a2a/converters/event_converter.py +++ b/src/google/adk/a2a/converters/event_converter.py @@ -14,7 +14,8 @@ from __future__ import annotations -import datetime +from datetime import datetime +from datetime import timezone import logging from typing import Any from typing import Dict @@ -35,6 +36,7 @@ from ...agents.invocation_context import InvocationContext from ...events.event import Event +from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME from ...utils.feature_decorator import working_in_progress from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY @@ -224,7 +226,7 @@ def _process_long_running_tool(a2a_part, event: Event) -> None: _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) ) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - and a2a_part.root.metadata.get("id") in event.long_running_tool_ids + and a2a_part.root.data.get("id") in event.long_running_tool_ids ): a2a_part.root.metadata[_get_adk_metadata_key("is_long_running")] = True @@ -287,24 +289,34 @@ def _create_error_status_event( """ error_message = getattr(event, "error_message", None) or DEFAULT_ERROR_MESSAGE + # Get context metadata and add error code + event_metadata = _get_context_metadata(event, invocation_context) + if event.error_code: + event_metadata[_get_adk_metadata_key("error_code")] = str(event.error_code) + return TaskStatusUpdateEvent( taskId=str(uuid.uuid4()), contextId=invocation_context.session.id, final=False, - metadata=_get_context_metadata(event, invocation_context), + metadata=event_metadata, status=TaskStatus( state=TaskState.failed, message=Message( messageId=str(uuid.uuid4()), role=Role.agent, parts=[TextPart(text=error_message)], + metadata={ + _get_adk_metadata_key("error_code"): str(event.error_code) + } + if event.error_code + else {}, ), - timestamp=datetime.datetime.now().isoformat(), + timestamp=datetime.now(timezone.utc).isoformat(), ), ) -def _create_running_status_event( +def _create_status_update_event( message: Message, invocation_context: InvocationContext, event: Event ) -> TaskStatusUpdateEvent: """Creates a TaskStatusUpdateEvent for running scenarios. @@ -317,15 +329,39 @@ def _create_running_status_event( Returns: A TaskStatusUpdateEvent with RUNNING state. """ + status = TaskStatus( + state=TaskState.working, + message=message, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + if any( + part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and part.root.metadata.get(_get_adk_metadata_key("is_long_running")) + is True + and part.root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME + for part in message.parts + ): + status.state = TaskState.auth_required + elif any( + part.root.metadata.get( + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and part.root.metadata.get(_get_adk_metadata_key("is_long_running")) + is True + for part in message.parts + ): + status.state = TaskState.input_required + return TaskStatusUpdateEvent( taskId=str(uuid.uuid4()), contextId=invocation_context.session.id, final=False, - status=TaskStatus( - state=TaskState.working, - message=message, - timestamp=datetime.datetime.now().isoformat(), - ), + status=status, metadata=_get_context_metadata(event, invocation_context), ) @@ -370,7 +406,7 @@ def convert_event_to_a2a_events( # Handle regular message content message = convert_event_to_a2a_status_message(event, invocation_context) if message: - running_event = _create_running_status_event( + running_event = _create_status_update_event( message, invocation_context, event ) a2a_events.append(running_event) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index c47ac7276..d6acbbd32 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -18,6 +18,7 @@ from __future__ import annotations +import base64 import json import logging import sys @@ -45,6 +46,8 @@ A2A_DATA_PART_METADATA_TYPE_KEY = 'type' A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = 'function_call' A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = 'function_response' +A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT = 'code_execution_result' +A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = 'executable_code' @working_in_progress @@ -67,7 +70,8 @@ def convert_a2a_part_to_genai_part( elif isinstance(part.file, a2a_types.FileWithBytes): return genai_types.Part( inline_data=genai_types.Blob( - data=part.file.bytes.encode('utf-8'), mime_type=part.file.mimeType + data=base64.b64decode(part.file.bytes), + mime_type=part.file.mimeType, ) ) else: @@ -118,8 +122,12 @@ def convert_genai_part_to_a2a_part( part: genai_types.Part, ) -> Optional[a2a_types.Part]: """Convert a Google GenAI Part to an A2A Part.""" + if part.text: - return a2a_types.TextPart(text=part.text) + a2a_part = a2a_types.TextPart(text=part.text) + if part.thought is not None: + a2a_part.metadata = {_get_adk_metadata_key('thought'): part.thought} + return a2a_part if part.file_data: return a2a_types.FilePart( @@ -130,14 +138,22 @@ def convert_genai_part_to_a2a_part( ) if part.inline_data: - return a2a_types.Part( + a2a_part = a2a_types.Part( root=a2a_types.FilePart( file=a2a_types.FileWithBytes( - bytes=part.inline_data.data, + bytes=base64.b64encode(part.inline_data.data).decode('utf-8'), mimeType=part.inline_data.mime_type, ) ) ) + if part.video_metadata: + a2a_part.metadata = { + _get_adk_metadata_key( + 'video_metadata' + ): part.video_metadata.model_dump(by_alias=True, exclude_none=True) + } + + return a2a_part # Conver the funcall and function reponse to A2A DataPart. # This is mainly for converting human in the loop and auth request and @@ -172,6 +188,34 @@ def convert_genai_part_to_a2a_part( ) ) + if part.code_execution_result: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.code_execution_result.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + ) + }, + ) + ) + + if part.executable_code: + return a2a_types.Part( + root=a2a_types.DataPart( + data=part.executable_code.model_dump( + by_alias=True, exclude_none=True + ), + metadata={ + A2A_DATA_PART_METADATA_TYPE_KEY: ( + A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE + ) + }, + ) + ) + logger.warning( 'Cannot convert unsupported part for Google GenAI part: %s', part, diff --git a/src/google/adk/a2a/converters/utils.py b/src/google/adk/a2a/converters/utils.py index ecbff1e10..3c38e9962 100644 --- a/src/google/adk/a2a/converters/utils.py +++ b/src/google/adk/a2a/converters/utils.py @@ -45,8 +45,15 @@ def _to_a2a_context_id(app_name: str, user_id: str, session_id: str) -> str: Returns: The A2A context id. + + Raises: + ValueError: If any of the input parameters are empty or None. """ - return [ADK_CONTEXT_ID_PREFIX, app_name, user_id, session_id].join("$") + if not all([app_name, user_id, session_id]): + raise ValueError( + "All parameters (app_name, user_id, session_id) must be non-empty" + ) + return "$".join([ADK_CONTEXT_ID_PREFIX, app_name, user_id, session_id]) def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]: @@ -64,8 +71,16 @@ def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]: if not context_id: return None, None, None - prefix, app_name, user_id, session_id = context_id.split("$") - if prefix == "ADK" and app_name and user_id and session_id: - return app_name, user_id, session_id + try: + parts = context_id.split("$") + if len(parts) != 4: + return None, None, None + + prefix, app_name, user_id, session_id = parts + if prefix == ADK_CONTEXT_ID_PREFIX and app_name and user_id and session_id: + return app_name, user_id, session_id + except ValueError: + # Handle any split errors gracefully + pass return None, None, None diff --git a/src/google/adk/a2a/executor/__init__.py b/src/google/adk/a2a/executor/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/src/google/adk/a2a/executor/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py new file mode 100644 index 000000000..c2a52ea2d --- /dev/null +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -0,0 +1,241 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone +import inspect +import logging +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import Optional +import uuid + +from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Message +from a2a.types import Role +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.runners import Runner +from pydantic import BaseModel +from typing_extensions import override + +from ...utils.feature_decorator import working_in_progress +from ..converters.event_converter import convert_event_to_a2a_events +from ..converters.request_converter import convert_a2a_request_to_adk_run_args +from ..converters.utils import _to_a2a_context_id +from .task_result_aggregator import TaskResultAggregator + +logger = logging.getLogger('google_adk.' + __name__) + + +@working_in_progress +class A2aAgentExecutorConfig(BaseModel): + """Configuration for the A2aAgentExecutor.""" + + pass + + +@working_in_progress +class A2aAgentExecutor(AgentExecutor): + """An AgentExecutor that runs an ADK Agent against an A2A request and + publishes updates to an event queue. + """ + + def __init__( + self, + *, + runner: Runner | Callable[..., Runner | Awaitable[Runner]], + config: Optional[A2aAgentExecutorConfig] = None, + ): + super().__init__() + self._runner = runner + self._config = config + + async def _resolve_runner(self) -> Runner: + """Resolve the runner, handling cases where it's a callable that returns a Runner.""" + # If already resolved and cached, return it + if isinstance(self._runner, Runner): + return self._runner + if callable(self._runner): + # Call the function to get the runner + result = self._runner() + + # Handle async callables + if inspect.iscoroutine(result): + resolved_runner = await result + else: + resolved_runner = result + + # Cache the resolved runner for future calls + self._runner = resolved_runner + return resolved_runner + + raise TypeError( + 'Runner must be a Runner instance or a callable that returns a' + f' Runner, got {type(self._runner)}' + ) + + @override + async def cancel(self, context: RequestContext, event_queue: EventQueue): + """Cancel the execution.""" + # TODO: Implement proper cancellation logic if needed + raise NotImplementedError('Cancellation is not supported') + + @override + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ): + """Executes an A2A request and publishes updates to the event queue + specified. It runs as following: + * Takes the input from the A2A request + * Convert the input to ADK input content, and runs the ADK agent + * Collects output events of the underlying ADK Agent + * Converts the ADK output events into A2A task updates + * Publishes the updates back to A2A server via event queue + """ + if not context.message: + raise ValueError('A2A request must have a message') + + # for new task, create a task submitted event + if not context.current_task: + if not context.task_id: + context.task_id = str(uuid.uuid4()) + event_queue.enqueue_event( + TaskStatusUpdateEvent( + taskId=context.task_id, + status=TaskStatus( + state=TaskState.submitted, + message=context.message, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + contextId=context.context_id, + final=False, + ) + ) + + # Handle the request and publish updates to the event queue + try: + await self._handle_request(context, event_queue) + except Exception as e: + logger.error('Error handling A2A request: %s', e, exc_info=True) + # Publish failure event + event_queue.enqueue_event( + TaskStatusUpdateEvent( + taskId=context.task_id, + status=TaskStatus( + state=TaskState.failed, + timestamp=datetime.now(timezone.utc).isoformat(), + message=Message( + messageId=str(uuid.uuid4()), + role=Role.agent, + parts=[TextPart(text=str(e))], + ), + ), + contextId=context.context_id, + final=True, + ) + ) + + async def _handle_request( + self, + context: RequestContext, + event_queue: EventQueue, + ): + # Resolve the runner instance + runner = await self._resolve_runner() + + # Convert the a2a request to ADK run args + run_args = convert_a2a_request_to_adk_run_args(context) + # ensure the session exists + session = await self._prepare_session(context, run_args, runner) + + # create invocation context + invocation_context = runner._new_invocation_context( + session=session, + new_message=run_args['new_message'], + run_config=run_args['run_config'], + ) + + # publish the task working event + event_queue.enqueue_event( + TaskStatusUpdateEvent( + taskId=context.task_id, + status=TaskStatus( + state=TaskState.working, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + contextId=context.context_id, + final=False, + ) + ) + task_result_aggregator = TaskResultAggregator() + async for adk_event in runner.run_async(**run_args): + task_result_aggregator.process_event(adk_event) + for a2a_event in convert_event_to_a2a_events( + adk_event, invocation_context + ): + event_queue.enqueue_event(a2a_event) + + # publish the task result event - this is final + event_queue.enqueue_event( + TaskStatusUpdateEvent( + taskId=context.task_id, + status=TaskStatus( + state=task_result_aggregator.task_state, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + contextId=context.context_id, + final=True, + ) + ) + + async def _prepare_session( + self, context: RequestContext, run_args: dict[str, Any], runner: Runner + ): + + session_id = run_args['session_id'] + if not session_id: + session_id = str(uuid.uuid4()) + run_args['session_id'] = session_id + # override the non-existing context_id + context.context_id = _to_a2a_context_id( + app_name=runner.app_name, + user_id=run_args['user_id'], + session_id=session_id, + ) + # create a new session if not exists + user_id = run_args['user_id'] + session = await runner.session_service.get_session( + app_name=runner.app_name, + user_id=user_id, + session_id=session_id, + ) + if session is None: + session = await runner.session_service.create_session( + app_name=runner.app_name, + user_id=user_id, + state={}, + session_id=session_id, + ) + + return session diff --git a/src/google/adk/a2a/executor/task_result_aggregator.py b/src/google/adk/a2a/executor/task_result_aggregator.py new file mode 100644 index 000000000..5661dde6b --- /dev/null +++ b/src/google/adk/a2a/executor/task_result_aggregator.py @@ -0,0 +1,60 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from a2a.server.events import Event +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent + +from ...utils.feature_decorator import working_in_progress + + +@working_in_progress +class TaskResultAggregator: + """Aggregates the task status updates and provides the final task state.""" + + def __init__(self): + self._task_state = TaskState.working + + def process_event(self, event: Event): + """Process an event from the agent run and detect signals about the task status. + Priority of task state: + - failed + - auth_required + - input_required + - working + """ + if isinstance(event, TaskStatusUpdateEvent): + if event.status.state == TaskState.failed: + self._task_state = TaskState.failed + elif ( + event.status.state == TaskState.auth_required + and self._task_state != TaskState.failed + ): + self._task_state = TaskState.auth_required + elif ( + event.status.state == TaskState.input_required + and self._task_state + not in (TaskState.failed, TaskState.auth_required) + ): + self._task_state = TaskState.input_required + # final state is already recorded and make sure the intermediate state is + # always working because other state may terminate the event aggregation + # in a2a request handler + event.status.state = TaskState.working + + @property + def task_state(self) -> TaskState: + return self._task_state diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py index 311ffc954..8ce9887fc 100644 --- a/tests/unittests/a2a/converters/test_event_converter.py +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -20,7 +20,7 @@ # Skip all tests in this module if Python version is less than 3.10 pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" ) # Import dependencies with version checking @@ -34,7 +34,7 @@ from google.adk.a2a.converters.event_converter import _convert_artifact_to_a2a_events from google.adk.a2a.converters.event_converter import _create_artifact_id from google.adk.a2a.converters.event_converter import _create_error_status_event - from google.adk.a2a.converters.event_converter import _create_running_status_event + from google.adk.a2a.converters.event_converter import _create_status_update_event from google.adk.a2a.converters.event_converter import _get_adk_metadata_key from google.adk.a2a.converters.event_converter import _get_context_metadata from google.adk.a2a.converters.event_converter import _process_long_running_tool @@ -63,7 +63,7 @@ class DummyTypes: _convert_artifact_to_a2a_events = lambda *args: None _create_artifact_id = lambda *args: None _create_error_status_event = lambda *args: None - _create_running_status_event = lambda *args: None + _create_status_update_event = lambda *args: None _get_adk_metadata_key = lambda *args: None _get_context_metadata = lambda *args: None _process_long_running_tool = lambda *args: None @@ -302,6 +302,8 @@ def test_process_long_running_tool_marks_tool(self): mock_a2a_part = Mock() mock_data_part = Mock(spec=DataPart) mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-123"} + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="tool-123") mock_a2a_part.root = mock_data_part self.mock_event.long_running_tool_ids = {"tool-123"} @@ -315,7 +317,11 @@ def test_process_long_running_tool_marks_tool(self): "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", "function_call", ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, ): + mock_get_key.side_effect = lambda key: f"adk_{key}" _process_long_running_tool(mock_a2a_part, self.mock_event) @@ -327,6 +333,8 @@ def test_process_long_running_tool_no_marking(self): mock_a2a_part = Mock() mock_data_part = Mock(spec=DataPart) mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-456"} + mock_data_part.data = Mock() + mock_data_part.data.get = Mock(return_value="tool-456") mock_a2a_part.root = mock_data_part self.mock_event.long_running_tool_ids = {"tool-123"} # Different ID @@ -340,7 +348,11 @@ def test_process_long_running_tool_no_marking(self): "google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL", "function_call", ), + patch( + "google.adk.a2a.converters.event_converter._get_adk_metadata_key" + ) as mock_get_key, ): + mock_get_key.side_effect = lambda key: f"adk_{key}" _process_long_running_tool(mock_a2a_part, self.mock_event) @@ -413,7 +425,7 @@ def test_convert_event_to_message_none_context(self): assert "Invocation context cannot be None" in str(exc_info.value) @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") - @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + @patch("google.adk.a2a.converters.event_converter.datetime") def test_create_error_status_event(self, mock_datetime, mock_uuid): """Test creation of error status event.""" mock_uuid.return_value = "test-uuid" @@ -433,7 +445,7 @@ def test_create_error_status_event(self, mock_datetime, mock_uuid): assert result.status.message.parts[0].root.text == "Test error message" @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") - @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + @patch("google.adk.a2a.converters.event_converter.datetime") def test_create_error_status_event_no_message(self, mock_datetime, mock_uuid): """Test creation of error status event without error message.""" mock_uuid.return_value = "test-uuid" @@ -447,7 +459,7 @@ def test_create_error_status_event_no_message(self, mock_datetime, mock_uuid): assert result.status.message.parts[0].root.text == DEFAULT_ERROR_MESSAGE - @patch("google.adk.a2a.converters.event_converter.datetime.datetime") + @patch("google.adk.a2a.converters.event_converter.datetime") def test_create_running_status_event(self, mock_datetime): """Test creation of running status event.""" mock_datetime.now.return_value.isoformat.return_value = ( @@ -455,8 +467,9 @@ def test_create_running_status_event(self, mock_datetime): ) mock_message = Mock(spec=Message) + mock_message.parts = [] - result = _create_running_status_event( + result = _create_status_update_event( mock_message, self.mock_invocation_context, self.mock_event ) @@ -473,7 +486,7 @@ def test_create_running_status_event(self, mock_datetime): ) @patch("google.adk.a2a.converters.event_converter._create_error_status_event") @patch( - "google.adk.a2a.converters.event_converter._create_running_status_event" + "google.adk.a2a.converters.event_converter._create_status_update_event" ) def test_convert_event_to_a2a_events_full_scenario( self, @@ -560,7 +573,7 @@ def test_convert_event_to_a2a_events_message_only(self, mock_convert_message): mock_convert_message.return_value = mock_message with patch( - "google.adk.a2a.converters.event_converter._create_running_status_event" + "google.adk.a2a.converters.event_converter._create_status_update_event" ) as mock_create_running: mock_running_event = Mock() mock_create_running.return_value = mock_running_event diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 4b9bd47cf..f828a05ef 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -21,7 +21,7 @@ # Skip all tests in this module if Python version is less than 3.10 pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" ) # Import dependencies with version checking @@ -92,11 +92,14 @@ def test_convert_file_part_with_bytes(self): """Test conversion of A2A FilePart with bytes to GenAI Part.""" # Arrange test_bytes = b"test file content" - # Note: A2A FileWithBytes converts bytes to string automatically + # A2A FileWithBytes expects base64-encoded string + import base64 + + base64_encoded = base64.b64encode(test_bytes).decode("utf-8") a2a_part = a2a_types.Part( root=a2a_types.FilePart( file=a2a_types.FileWithBytes( - bytes=test_bytes, mimeType="text/plain" + bytes=base64_encoded, mimeType="text/plain" ) ) ) @@ -108,7 +111,7 @@ def test_convert_file_part_with_bytes(self): assert result is not None assert isinstance(result, genai_types.Part) assert result.inline_data is not None - # Source code now properly converts A2A string back to bytes for GenAI Blob + # The converter decodes base64 back to original bytes assert result.inline_data.data == test_bytes assert result.inline_data.mime_type == "text/plain" @@ -298,8 +301,11 @@ def test_convert_inline_data_part(self): assert isinstance(result, a2a_types.Part) assert isinstance(result.root, a2a_types.FilePart) assert isinstance(result.root.file, a2a_types.FileWithBytes) - # A2A FileWithBytes stores bytes as strings - assert result.root.file.bytes == test_bytes.decode("utf-8") + # A2A FileWithBytes now stores base64-encoded bytes to ensure round-trip compatibility + import base64 + + expected_base64 = base64.b64encode(test_bytes).decode("utf-8") + assert result.root.file.bytes == expected_base64 assert result.root.file.mimeType == "text/plain" def test_convert_function_call_part(self): @@ -406,6 +412,30 @@ def test_file_uri_round_trip(self): assert result_a2a_part.file.uri == original_uri assert result_a2a_part.file.mimeType == original_mime_type + def test_file_bytes_round_trip(self): + """Test round-trip conversion for file parts with bytes.""" + # Arrange + original_bytes = b"test file content for round trip" + original_mime_type = "application/octet-stream" + + # Start with GenAI part (the more common starting point) + genai_part = genai_types.Part( + inline_data=genai_types.Blob( + data=original_bytes, mime_type=original_mime_type + ) + ) + + # Act - Round trip: GenAI -> A2A -> GenAI + a2a_part = convert_genai_part_to_a2a_part(genai_part) + result_genai_part = convert_a2a_part_to_genai_part(a2a_part) + + # Assert + assert result_genai_part is not None + assert isinstance(result_genai_part, genai_types.Part) + assert result_genai_part.inline_data is not None + assert result_genai_part.inline_data.data == original_bytes + assert result_genai_part.inline_data.mime_type == original_mime_type + class TestEdgeCases: """Test cases for edge cases and error conditions.""" diff --git a/tests/unittests/a2a/converters/test_utils.py b/tests/unittests/a2a/converters/test_utils.py new file mode 100644 index 000000000..481d15bfe --- /dev/null +++ b/tests/unittests/a2a/converters/test_utils.py @@ -0,0 +1,213 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) + +from google.adk.a2a.converters.utils import _from_a2a_context_id +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.a2a.converters.utils import _to_a2a_context_id +from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX +from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX +import pytest + + +class TestUtilsFunctions: + """Test suite for utils module functions.""" + + def test_get_adk_metadata_key_success(self): + """Test successful metadata key generation.""" + key = "test_key" + result = _get_adk_metadata_key(key) + assert result == f"{ADK_METADATA_KEY_PREFIX}{key}" + + def test_get_adk_metadata_key_empty_string(self): + """Test metadata key generation with empty string.""" + with pytest.raises( + ValueError, match="Metadata key cannot be empty or None" + ): + _get_adk_metadata_key("") + + def test_get_adk_metadata_key_none(self): + """Test metadata key generation with None.""" + with pytest.raises( + ValueError, match="Metadata key cannot be empty or None" + ): + _get_adk_metadata_key(None) + + def test_get_adk_metadata_key_whitespace(self): + """Test metadata key generation with whitespace string.""" + key = " " + result = _get_adk_metadata_key(key) + assert result == f"{ADK_METADATA_KEY_PREFIX}{key}" + + def test_to_a2a_context_id_success(self): + """Test successful context ID generation.""" + app_name = "test-app" + user_id = "test-user" + session_id = "test-session" + + result = _to_a2a_context_id(app_name, user_id, session_id) + + expected = f"{ADK_CONTEXT_ID_PREFIX}$test-app$test-user$test-session" + assert result == expected + + def test_to_a2a_context_id_empty_app_name(self): + """Test context ID generation with empty app name.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("", "user", "session") + + def test_to_a2a_context_id_empty_user_id(self): + """Test context ID generation with empty user ID.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("app", "", "session") + + def test_to_a2a_context_id_empty_session_id(self): + """Test context ID generation with empty session ID.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id("app", "user", "") + + def test_to_a2a_context_id_none_values(self): + """Test context ID generation with None values.""" + with pytest.raises( + ValueError, + match=( + "All parameters \\(app_name, user_id, session_id\\) must be" + " non-empty" + ), + ): + _to_a2a_context_id(None, "user", "session") + + def test_to_a2a_context_id_special_characters(self): + """Test context ID generation with special characters.""" + app_name = "test-app@2024" + user_id = "user_123" + session_id = "session-456" + + result = _to_a2a_context_id(app_name, user_id, session_id) + + expected = f"{ADK_CONTEXT_ID_PREFIX}$test-app@2024$user_123$session-456" + assert result == expected + + def test_from_a2a_context_id_success(self): + """Test successful context ID parsing.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}$test-app$test-user$test-session" + + app_name, user_id, session_id = _from_a2a_context_id(context_id) + + assert app_name == "test-app" + assert user_id == "test-user" + assert session_id == "test-session" + + def test_from_a2a_context_id_none_input(self): + """Test context ID parsing with None input.""" + result = _from_a2a_context_id(None) + assert result == (None, None, None) + + def test_from_a2a_context_id_empty_string(self): + """Test context ID parsing with empty string.""" + result = _from_a2a_context_id("") + assert result == (None, None, None) + + def test_from_a2a_context_id_invalid_prefix(self): + """Test context ID parsing with invalid prefix.""" + context_id = "INVALID$test-app$test-user$test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_too_few_parts(self): + """Test context ID parsing with too few parts.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}$test-app$test-user" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_too_many_parts(self): + """Test context ID parsing with too many parts.""" + context_id = ( + f"{ADK_CONTEXT_ID_PREFIX}$test-app$test-user$test-session$extra" + ) + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_empty_components(self): + """Test context ID parsing with empty components.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}$$test-user$test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_from_a2a_context_id_no_dollar_separator(self): + """Test context ID parsing without dollar separators.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}-test-app-test-user-test-session" + + result = _from_a2a_context_id(context_id) + + assert result == (None, None, None) + + def test_roundtrip_context_id(self): + """Test roundtrip conversion: to -> from.""" + app_name = "test-app" + user_id = "test-user" + session_id = "test-session" + + # Convert to context ID + context_id = _to_a2a_context_id(app_name, user_id, session_id) + + # Convert back + parsed_app, parsed_user, parsed_session = _from_a2a_context_id(context_id) + + assert parsed_app == app_name + assert parsed_user == user_id + assert parsed_session == session_id + + def test_from_a2a_context_id_special_characters(self): + """Test context ID parsing with special characters.""" + context_id = f"{ADK_CONTEXT_ID_PREFIX}$test-app@2024$user_123$session-456" + + app_name, user_id, session_id = _from_a2a_context_id(context_id) + + assert app_name == "test-app@2024" + assert user_id == "user_123" + assert session_id == "session-456" diff --git a/tests/unittests/a2a/executor/__init__.py b/tests/unittests/a2a/executor/__init__.py new file mode 100644 index 000000000..0a2669d7a --- /dev/null +++ b/tests/unittests/a2a/executor/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py new file mode 100644 index 000000000..c099f27df --- /dev/null +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -0,0 +1,616 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a.server.agent_execution.context import RequestContext + from a2a.server.events.event_queue import EventQueue + from a2a.types import Message + from a2a.types import TaskState + from a2a.types import TextPart + from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor + from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig + from google.adk.events.event import Event + from google.adk.runners import Runner +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + RequestContext = DummyTypes() + EventQueue = DummyTypes() + Message = DummyTypes() + Role = DummyTypes() + TaskState = DummyTypes() + TaskStatus = DummyTypes() + TaskStatusUpdateEvent = DummyTypes() + TextPart = DummyTypes() + A2aAgentExecutor = DummyTypes() + A2aAgentExecutorConfig = DummyTypes() + Event = DummyTypes() + Runner = DummyTypes() + else: + raise e + + +class TestA2aAgentExecutor: + """Test suite for A2aAgentExecutor class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_runner = Mock(spec=Runner) + self.mock_runner.app_name = "test-app" + self.mock_runner.session_service = Mock() + self.mock_runner._new_invocation_context = Mock() + self.mock_runner.run_async = AsyncMock() + + self.mock_config = Mock(spec=A2aAgentExecutorConfig) + self.executor = A2aAgentExecutor( + runner=self.mock_runner, config=self.mock_config + ) + + self.mock_context = Mock(spec=RequestContext) + self.mock_context.message = Mock(spec=Message) + self.mock_context.message.parts = [Mock(spec=TextPart)] + self.mock_context.current_task = None + self.mock_context.task_id = None + self.mock_context.context_id = "test-context-id" + + self.mock_event_queue = Mock(spec=EventQueue) + + async def _create_async_generator(self, items): + """Helper to create async generator from items.""" + for item in items: + yield item + + @pytest.mark.asyncio + async def test_execute_success_new_task(self): + """Test successful execution of a new task.""" + # Setup + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [] + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ + 0 + ][0] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False + + # Verify working event was enqueued + working_event = self.mock_event_queue.enqueue_event.call_args_list[1][ + 0 + ][0] + assert working_event.status.state == TaskState.working + assert working_event.final == False + + @pytest.mark.asyncio + async def test_execute_no_message_error(self): + """Test execution fails when no message is provided.""" + self.mock_context.message = None + + with pytest.raises(ValueError, match="A2A request must have a message"): + await self.executor.execute(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_execute_existing_task(self): + """Test execution with existing task (no submitted event).""" + self.mock_context.current_task = Mock() + self.mock_context.task_id = "existing-task-id" + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [] + + # Execute + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify no submitted event (first call should be working event) + working_event = self.mock_event_queue.enqueue_event.call_args_list[0][ + 0 + ][0] + assert working_event.status.state == TaskState.working + assert working_event.final == False + + @pytest.mark.asyncio + async def test_prepare_session_new_session(self): + """Test session preparation when session doesn't exist.""" + run_args = { + "user_id": "test-user", + "session_id": None, + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + self.mock_runner.session_service.get_session = AsyncMock(return_value=None) + mock_session = Mock() + mock_session.id = "new-session-id" + self.mock_runner.session_service.create_session = AsyncMock( + return_value=mock_session + ) + + # Execute + result = await self.executor._prepare_session( + self.mock_context, run_args, self.mock_runner + ) + + # Verify session was created + assert result == mock_session + assert run_args["session_id"] is not None + self.mock_runner.session_service.create_session.assert_called_once() + + @pytest.mark.asyncio + async def test_prepare_session_existing_session(self): + """Test session preparation when session exists.""" + run_args = { + "user_id": "test-user", + "session_id": "existing-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "existing-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Execute + result = await self.executor._prepare_session( + self.mock_context, run_args, self.mock_runner + ) + + # Verify existing session was returned + assert result == mock_session + self.mock_runner.session_service.create_session.assert_not_called() + + def test_constructor_with_callable_runner(self): + """Test constructor with callable runner.""" + callable_runner = Mock() + executor = A2aAgentExecutor(runner=callable_runner, config=self.mock_config) + + assert executor._runner == callable_runner + assert executor._config == self.mock_config + + @pytest.mark.asyncio + async def test_resolve_runner_direct_instance(self): + """Test _resolve_runner with direct Runner instance.""" + # Setup - already using direct runner instance in setup_method + runner = await self.executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_sync_callable(self): + """Test _resolve_runner with sync callable that returns Runner.""" + + def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_async_callable(self): + """Test _resolve_runner with async callable that returns Runner.""" + + async def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_invalid_type(self): + """Test _resolve_runner with invalid runner type.""" + executor = A2aAgentExecutor(runner="invalid", config=self.mock_config) + + with pytest.raises( + TypeError, match="Runner must be a Runner instance or a callable" + ): + await executor._resolve_runner() + + @pytest.mark.asyncio + async def test_resolve_runner_callable_with_parameters(self): + """Test _resolve_runner with callable that normally takes parameters.""" + + def create_runner(*args, **kwargs): + # In real usage, this might use the args/kwargs to configure the runner + # For testing, we'll just return the mock runner + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + runner = await executor._resolve_runner() + assert runner == self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_caching(self): + """Test that _resolve_runner caches the result and doesn't call the callable multiple times.""" + call_count = 0 + + def create_runner(): + nonlocal call_count + call_count += 1 + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + + # First call should invoke the callable + runner1 = await executor._resolve_runner() + assert runner1 == self.mock_runner + assert call_count == 1 + + # Second call should return cached result, not invoke callable again + runner2 = await executor._resolve_runner() + assert runner2 == self.mock_runner + assert runner1 is runner2 # Same instance + assert call_count == 1 # Callable was not called again + + # Verify that self._runner is now the resolved Runner instance + assert executor._runner is self.mock_runner + + @pytest.mark.asyncio + async def test_resolve_runner_async_caching(self): + """Test that _resolve_runner caches async callable results correctly.""" + call_count = 0 + + async def create_runner(): + nonlocal call_count + call_count += 1 + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + + # First call should invoke the async callable + runner1 = await executor._resolve_runner() + assert runner1 == self.mock_runner + assert call_count == 1 + + # Second call should return cached result, not invoke callable again + runner2 = await executor._resolve_runner() + assert runner2 == self.mock_runner + assert runner1 is runner2 # Same instance + assert call_count == 1 # Async callable was not called again + + # Verify that self._runner is now the resolved Runner instance + assert executor._runner is self.mock_runner + + @pytest.mark.asyncio + async def test_execute_with_sync_callable_runner(self): + """Test execution with sync callable runner.""" + + def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [] + + # Execute + await executor.execute(self.mock_context, self.mock_event_queue) + + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ + 0 + ][0] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False + + @pytest.mark.asyncio + async def test_execute_with_async_callable_runner(self): + """Test execution with async callable runner.""" + + async def create_runner(): + return self.mock_runner + + executor = A2aAgentExecutor(runner=create_runner, config=self.mock_config) + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with proper async generator + mock_event = Mock(spec=Event) + + async def mock_run_async(**kwargs): + async for item in self._create_async_generator([mock_event]): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [] + + # Execute + await executor.execute(self.mock_context, self.mock_event_queue) + + # Verify task submitted event was enqueued + assert self.mock_event_queue.enqueue_event.call_count >= 3 + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ + 0 + ][0] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False + + @pytest.mark.asyncio + async def test_handle_request_integration(self): + """Test the complete request handling flow.""" + # Setup context with task_id + self.mock_context.task_id = "test-task-id" + + # Setup detailed mocks + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.return_value = { + "user_id": "test-user", + "session_id": "test-session", + "new_message": Mock(), + "run_config": Mock(), + } + + # Mock session service + mock_session = Mock() + mock_session.id = "test-session" + self.mock_runner.session_service.get_session = AsyncMock( + return_value=mock_session + ) + + # Mock invocation context + mock_invocation_context = Mock() + self.mock_runner._new_invocation_context.return_value = ( + mock_invocation_context + ) + + # Mock agent run with multiple events using proper async generator + mock_events = [Mock(spec=Event), Mock(spec=Event)] + + # Configure run_async to return the async generator when awaited + async def mock_run_async(**kwargs): + async for item in self._create_async_generator(mock_events): + yield item + + self.mock_runner.run_async = mock_run_async + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_event_to_a2a_events" + ) as mock_convert_events: + mock_convert_events.return_value = [Mock()] + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.TaskResultAggregator" + ) as mock_aggregator_class: + mock_aggregator = Mock() + mock_aggregator.task_state = TaskState.working + mock_aggregator_class.return_value = mock_aggregator + + # Execute + await self.executor._handle_request( + self.mock_context, self.mock_event_queue + ) + + # Verify working event was enqueued + working_events = [ + call[0][0] + for call in self.mock_event_queue.enqueue_event.call_args_list + if hasattr(call[0][0], "status") + and call[0][0].status.state == TaskState.working + ] + assert len(working_events) >= 1 + + # Verify aggregator processed events + assert mock_aggregator.process_event.call_count == len(mock_events) + + @pytest.mark.asyncio + async def test_cancel_with_task_id(self): + """Test cancellation with a task ID.""" + self.mock_context.task_id = "test-task-id" + + # The current implementation raises NotImplementedError + with pytest.raises( + NotImplementedError, match="Cancellation is not supported" + ): + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_cancel_without_task_id(self): + """Test cancellation without a task ID.""" + self.mock_context.task_id = None + + # The current implementation raises NotImplementedError regardless of task_id + with pytest.raises( + NotImplementedError, match="Cancellation is not supported" + ): + await self.executor.cancel(self.mock_context, self.mock_event_queue) + + @pytest.mark.asyncio + async def test_execute_with_exception_handling(self): + """Test execution with exception handling.""" + self.mock_context.task_id = "test-task-id" + self.mock_context.current_task = ( + None # Make sure it goes through submitted event creation + ) + + with patch( + "google.adk.a2a.executor.a2a_agent_executor.convert_a2a_request_to_adk_run_args" + ) as mock_convert: + mock_convert.side_effect = Exception("Test error") + + # Execute (should not raise since we catch the exception) + await self.executor.execute(self.mock_context, self.mock_event_queue) + + # Verify both submitted and failure events were enqueued + # First call should be submitted event, last should be failure event + assert self.mock_event_queue.enqueue_event.call_count >= 2 + + # Check submitted event (first) + submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][ + 0 + ][0] + assert submitted_event.status.state == TaskState.submitted + assert submitted_event.final == False + + # Check failure event (last) + failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][ + 0 + ] + assert failure_event.status.state == TaskState.failed + assert failure_event.final == True diff --git a/tests/unittests/a2a/executor/test_task_result_aggregator.py b/tests/unittests/a2a/executor/test_task_result_aggregator.py new file mode 100644 index 000000000..93c6ba07f --- /dev/null +++ b/tests/unittests/a2a/executor/test_task_result_aggregator.py @@ -0,0 +1,165 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import Mock + +import pytest + +# Skip all tests in this module if Python version is less than 3.10 +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 10), reason="A2A requires Python 3.10+" +) + +# Import dependencies with version checking +try: + from a2a.types import TaskState + from a2a.types import TaskStatus + from a2a.types import TaskStatusUpdateEvent + from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator +except ImportError as e: + if sys.version_info < (3, 10): + # Create dummy classes to prevent NameError during test collection + # Tests will be skipped anyway due to pytestmark + class DummyTypes: + pass + + TaskState = DummyTypes() + TaskStatus = DummyTypes() + TaskStatusUpdateEvent = DummyTypes() + TaskResultAggregator = DummyTypes() + else: + raise e + + +class TestTaskResultAggregator: + """Test suite for TaskResultAggregator class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.aggregator = TaskResultAggregator() + + def test_initial_state(self): + """Test the initial state of the aggregator.""" + assert self.aggregator.task_state == TaskState.working + + def test_process_failed_event(self): + """Test processing a failed task status event.""" + mock_event = Mock(spec=TaskStatusUpdateEvent) + mock_event.status = Mock(spec=TaskStatus) + mock_event.status.state = TaskState.failed + + self.aggregator.process_event(mock_event) + + assert self.aggregator.task_state == TaskState.failed + + def test_process_auth_required_event(self): + """Test processing an auth_required task status event.""" + mock_event = Mock(spec=TaskStatusUpdateEvent) + mock_event.status = Mock(spec=TaskStatus) + mock_event.status.state = TaskState.auth_required + + self.aggregator.process_event(mock_event) + + assert self.aggregator.task_state == TaskState.auth_required + + def test_process_input_required_event(self): + """Test processing an input_required task status event.""" + mock_event = Mock(spec=TaskStatusUpdateEvent) + mock_event.status = Mock(spec=TaskStatus) + mock_event.status.state = TaskState.input_required + + self.aggregator.process_event(mock_event) + + assert self.aggregator.task_state == TaskState.input_required + + def test_failed_state_priority(self): + """Test that failed state takes priority over other states.""" + # First set to failed + failed_event = Mock(spec=TaskStatusUpdateEvent) + failed_event.status = Mock(spec=TaskStatus) + failed_event.status.state = TaskState.failed + + self.aggregator.process_event(failed_event) + + # Then try to set to auth_required - should remain failed + auth_event = Mock(spec=TaskStatusUpdateEvent) + auth_event.status = Mock(spec=TaskStatus) + auth_event.status.state = TaskState.auth_required + + self.aggregator.process_event(auth_event) + + assert self.aggregator.task_state == TaskState.failed + + def test_auth_required_priority_over_input_required(self): + """Test that auth_required state takes priority over input_required.""" + # First set to auth_required + auth_event = Mock(spec=TaskStatusUpdateEvent) + auth_event.status = Mock(spec=TaskStatus) + auth_event.status.state = TaskState.auth_required + + self.aggregator.process_event(auth_event) + + # Then try to set to input_required - should remain auth_required + input_event = Mock(spec=TaskStatusUpdateEvent) + input_event.status = Mock(spec=TaskStatus) + input_event.status.state = TaskState.input_required + + self.aggregator.process_event(input_event) + + assert self.aggregator.task_state == TaskState.auth_required + + def test_non_task_status_event_ignored(self): + """Test that non-TaskStatusUpdateEvent events are ignored.""" + mock_event = Mock() # Not a TaskStatusUpdateEvent + original_state = self.aggregator.task_state + + self.aggregator.process_event(mock_event) + + assert self.aggregator.task_state == original_state + + def test_event_not_modified(self): + """Test that the input event state is set to working during processing.""" + mock_event = Mock(spec=TaskStatusUpdateEvent) + mock_event.status = Mock(spec=TaskStatus) + original_state = TaskState.auth_required + mock_event.status.state = original_state + + self.aggregator.process_event(mock_event) + + # Event state is modified to working as part of the processing + # (this ensures intermediate state is always working for a2a request handler) + assert mock_event.status.state == TaskState.working + + def test_state_transitions_sequence(self): + """Test a sequence of state transitions.""" + events = [ + (TaskState.working, TaskState.working), + (TaskState.input_required, TaskState.input_required), + (TaskState.auth_required, TaskState.auth_required), + (TaskState.failed, TaskState.failed), + (TaskState.working, TaskState.failed), # Should remain failed + ] + + for event_state, expected_aggregator_state in events: + mock_event = Mock(spec=TaskStatusUpdateEvent) + mock_event.status = Mock(spec=TaskStatus) + mock_event.status.state = event_state + + self.aggregator.process_event(mock_event) + + assert self.aggregator.task_state == expected_aggregator_state, ( + f"Expected {expected_aggregator_state}, got" + f" {self.aggregator.task_state}" + )