diff --git a/examples/agents/react.py b/examples/agents/react.py index faeaf0f20..59a4315af 100644 --- a/examples/agents/react.py +++ b/examples/agents/react.py @@ -12,21 +12,29 @@ # See the License for the specific language goveself.rning permissions and # limitations under the License. -from rai.agents import AgentRunner, ReActAgent -from rai.communication.ros2 import ROS2Connector, ROS2Context, ROS2HRIConnector + +from rai.agents.langchain.react_agent import ReActAgent +from rai.communication.hri_connector import HRIMessage +from rai.communication.ros2 import ROS2Connector, ROS2Context +from rai.communication.ros2.connectors.hri_connector import ROS2HRIConnector from rai.tools.ros2 import ROS2Toolkit @ROS2Context() def main(): - connector = ROS2HRIConnector(sources=["/from_human"], targets=["/to_human"]) ros2_connector = ROS2Connector() + hri_connector = ROS2HRIConnector() + agent = ReActAgent( - connectors={"hri": connector}, + target_connectors={ + "/to_human": hri_connector, + }, # agnet's output is sent to /to_human ros2 topic tools=ROS2Toolkit(connector=ros2_connector).get_tools(), - ) # type: ignore - runner = AgentRunner([agent]) - runner.run_and_wait_for_shutdown() + ) + agent.run() + agent(HRIMessage(text="What do you see?")) + agent.wait() # wait for agent to finish + agent.stop() if __name__ == "__main__": diff --git a/examples/agents/react_ros2.py b/examples/agents/react_ros2.py new file mode 100644 index 000000000..8eac7319d --- /dev/null +++ b/examples/agents/react_ros2.py @@ -0,0 +1,41 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language goveself.rning permissions and +# limitations under the License. + + +from rai.agents import AgentRunner +from rai.agents.langchain.react_agent import ReActAgent +from rai.communication.ros2 import ROS2Connector, ROS2Context +from rai.communication.ros2.connectors.hri_connector import ROS2HRIConnector +from rai.tools.ros2 import ROS2Toolkit + + +@ROS2Context() +def main(): + ros2_connector = ROS2Connector() + hri_connector = ROS2HRIConnector() + + agent = ReActAgent( + target_connectors={ + "/to_human": hri_connector, + }, + tools=ROS2Toolkit(connector=ros2_connector).get_tools(), + ) + # Agent will wait for messages published to /from_human ros2 topic + agent.subscribe_source("/from_human", hri_connector) + runner = AgentRunner([agent]) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/src/rai_core/rai/agents/__init__.py b/src/rai_core/rai/agents/__init__.py index 8be23db12..b28c98661 100644 --- a/src/rai_core/rai/agents/__init__.py +++ b/src/rai_core/rai/agents/__init__.py @@ -14,7 +14,7 @@ from rai.agents.base import BaseAgent from rai.agents.conversational_agent import create_conversational_agent -from rai.agents.react_agent import ReActAgent +from rai.agents.langchain.react_agent import ReActAgent from rai.agents.runner import AgentRunner, wait_for_shutdown from rai.agents.state_based import create_state_based_agent from rai.agents.tool_runner import ToolRunner diff --git a/src/rai_core/rai/agents/langchain/agent.py b/src/rai_core/rai/agents/langchain/agent.py new file mode 100644 index 000000000..1e4d2632f --- /dev/null +++ b/src/rai_core/rai/agents/langchain/agent.py @@ -0,0 +1,257 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, Dict, List, Literal, Optional, TypedDict + +from langchain_core.messages import BaseMessage +from langchain_core.runnables import Runnable + +from rai.agents.base import BaseAgent +from rai.agents.langchain import HRICallbackHandler +from rai.agents.langchain.runnables import ReActAgentState +from rai.communication.hri_connector import HRIConnector, HRIMessage +from rai.initialization import get_tracing_callbacks + + +class BaseState(TypedDict): + messages: List[BaseMessage] + + +newMessageBehaviorType = Literal[ + "take_all", + "keep_last", + "queue", + "interrupt_take_all", + "interrupt_keep_last", +] + + +class LangChainAgent(BaseAgent): + """ + Agent pareametrized by LangGraph runnable that communicates with environment using + `HRIConnector`. + + Parameters + ---------- + target_connectors : Dict[str, HRIConnector[HRIMessage]] + Dict of target_name: connector. Agent will send it's output to these targets using connectors. + runnable : Runnable + LangChain runnable that will be used to generate output. + state : BaseState | None, optional + State to seed the LangChain runnable. If None - empty state is used. + new_message_behavior : newMessageBehaviorType, optional + Describes how to handle new messages and interact with LangChain runnable. There are 2 main options: + 1. Agent waits for LangChain runnable to finish processing: + - "take_all": all messages from the queue are concatenated and processed. + - "keep_last": only the last received message is processed, others are dropped. + - "queue": only the first message from the queue is processed, others are kept in the queue. + 2. Agent interrupts LangChain runnable: + - "interrupt_take_all": same as "take_all" + - "interrupt_keep_last": same as "keep_last" + max_size : int, optional + Maximum number of messages to keep in the agent's queue. If exceeded, oldest messages are dropped. + + + Agent can be started using `run` method. Then it is triggered by `HRIMessage`s submited + by `__call__` method. They can be submitted in 2 ways: + - manually using `__call__` method. + - by subscribing to specific source using HRIConnector with `subscribe_source` method. + + Agent can be stopped using `stop` method. + + Due to asynchronous processing of the Agent, it is adviced to handle it's lifetime + with :py:class:`rai.agents.AgentRunner` when source is subscribed. + + Examples: + ```python + # ROS2 Example - agent triggered manually + from rai.agents import AgentRunner + hri_connector = ROS2HRIConnector() + runnable = create_langgraph() + agent = LangChainAgent( + target_connectors={"/to_human": hri_connector}, + runnable=runnable, + ) + agent.run() + agent(HRIMessage(text="Hello!")) + agent.wait() + agent.stop() + + # ROS2 Example - triggered by messages on ros2 topic + ... + runner = AgentRunner([agent]) + runner.run() + agent.source_callback("/from_human", hri_connector) + runner.wait_for_shutdown() + + # Agent will act messages published to rai_interfaces.msg.HRIMessage sent to /from_human topic + """ + + def __init__( + self, + target_connectors: Dict[str, HRIConnector[HRIMessage]], + runnable: Runnable, + state: BaseState | None = None, + new_message_behavior: newMessageBehaviorType = "interrupt_keep_last", + max_size: int = 100, + ): + super().__init__() + self.logger = logging.getLogger(__name__) + self.agent = runnable + self.new_message_behavior: newMessageBehaviorType = new_message_behavior + self.tracing_callbacks = get_tracing_callbacks() + self.state = state or ReActAgentState(messages=[]) + self._langchain_callback = HRICallbackHandler( + connectors=target_connectors, + aggregate_chunks=True, + logger=self.logger, + ) + + self._received_messages: Deque[HRIMessage] = deque() + self._buffer_lock = threading.Lock() + self.max_size = max_size + + self._thread: Optional[threading.Thread] = None + self._stop_event = threading.Event() + self._executor = ThreadPoolExecutor(max_workers=1) + self._interrupt_event = threading.Event() + self._agent_ready_event = threading.Event() + + def subscribe_source(self, source: str, connector: HRIConnector[HRIMessage]): + connector.register_callback( + source, + self.__call__, + ) + + def __call__(self, msg: HRIMessage): + with self._buffer_lock: + if ( + self.max_size is not None + and len(self._received_messages) >= self.max_size + ): + self.logger.warning("Buffer overflow. Dropping olders message") + self._received_messages.popleft() + if "interrupt" in self.new_message_behavior: + self._executor.submit(self._interrupt_agent_and_run) + self.logger.info(f"Received message: {msg}, {type(msg)}") + self._received_messages.append(msg) + + def run(self): + if self._thread is not None: + raise RuntimeError("Agent is already running") + self._thread = threading.Thread(target=self._run_loop) + self._thread.start() + self._agent_ready_event.set() + self.logger.info("Agent started") + + def ready(self): + return self._agent_ready_event.is_set() and len(self._received_messages) == 0 + + def wait(self): + while len(self._received_messages) > 0: + time.sleep(0.1) + + return self._agent_ready_event.wait() + + def _interrupt_agent_and_run(self): + if self.ready(): + self.logger.info("Agent is ready. No need to interrupt it.") + return + self.logger.info("Interrupting agent...") + self._interrupt_event.set() + self._agent_ready_event.wait() + self._interrupt_event.clear() + self.logger.info("Interrupting agent: DONE") + + def _run_agent(self): + if len(self._received_messages) == 0: + self._agent_ready_event.set() + self.logger.info("Waiting for messages...") + time.sleep(0.5) + return + self._agent_ready_event.clear() + try: + self.logger.info("Running agent...") + reduced_message = self._reduce_messages() + langchain_message = reduced_message.to_langchain() + self.state["messages"].append(langchain_message) + for _ in self.agent.stream( + self.state, + config={ + "callbacks": [self._langchain_callback, *self.tracing_callbacks] + }, + ): + if self._interrupt_event.is_set(): + break + finally: + self._agent_ready_event.set() + + def _run_loop(self): + while not self._stop_event.is_set(): + if self._agent_ready_event.wait(0.01): + self._run_agent() + + def stop(self): + self._stop_event.set() + self._interrupt_event.set() + self._agent_ready_event.wait() + if self._thread is not None: + self.logger.info("Stopping the agent. Please wait...") + self._thread.join() + self._thread = None + self.logger.info("Agent stopped") + + @staticmethod + def _apply_reduction_behavior( + method: newMessageBehaviorType, buffer: Deque[HRIMessage] + ) -> List[HRIMessage]: + output = list() + if "take_all" in method: + # Take all starting from the oldest + while len(buffer) > 0: + output.append(buffer.popleft()) + elif "keep_last" in method: + # Take the recently added message + output.append(buffer.pop()) + buffer.clear() + elif method == "queue": + # Take the first message from the queue. Let other messages wait. + output.append(buffer.popleft()) + else: + raise ValueError(f"Invalid new_message_behavior: {method}") + return output + + def _reduce_messages(self) -> HRIMessage: + text = "" + images = [] + audios = [] + with self._buffer_lock: + source_messages = self._apply_reduction_behavior( + self.new_message_behavior, self._received_messages + ) + for source_message in source_messages: + text += f"{source_message.text}\n" + images.extend(source_message.images) + audios.extend(source_message.audios) + return HRIMessage( + text=text, + images=images, + audios=audios, + message_author="human", + ) diff --git a/src/rai_core/rai/agents/langchain/callback.py b/src/rai_core/rai/agents/langchain/callback.py index 7317dda16..5b4be684e 100644 --- a/src/rai_core/rai/agents/langchain/callback.py +++ b/src/rai_core/rai/agents/langchain/callback.py @@ -14,7 +14,7 @@ import logging import threading -from typing import List, Optional +from typing import Dict, List, Optional from uuid import UUID from langchain_core.callbacks import BaseCallbackHandler @@ -27,7 +27,7 @@ class HRICallbackHandler(BaseCallbackHandler): def __init__( self, - connectors: dict[str, HRIConnector[HRIMessage]], + connectors: Dict[str, HRIConnector[HRIMessage]], aggregate_chunks: bool = False, splitting_chars: Optional[List[str]] = None, max_buffer_size: int = 200, @@ -47,21 +47,20 @@ def _should_split(self, token: str) -> bool: return token in self.splitting_chars def _send_all_targets(self, tokens: str, done: bool = False): - self.logger.info( - f"Sending {len(tokens)} tokens to {len(self.connectors)} connectors" - ) - for connector_name, connector in self.connectors.items(): + for target, connector in self.connectors.items(): + self.logger.info(f"Sending {len(tokens)} tokens to target: {target}") try: - connector.send_all_targets( + to_send: HRIMessage = connector.build_message( AIMessage(content=tokens), self.current_conversation_id, self.current_chunk_id, done, ) - self.logger.debug(f"Sent {len(tokens)} tokens to {connector_name}") + connector.send_message(to_send, target) + self.logger.debug(f"Sent {len(tokens)} tokens to hri_connector.") except Exception as e: self.logger.error( - f"Failed to send {len(tokens)} tokens to {connector_name}: {e}" + f"Failed to send {len(tokens)} tokens to hri_connector: {e}" ) def on_llm_new_token(self, token: str, *, run_id: UUID, **kwargs): diff --git a/src/rai_core/rai/agents/langchain/react_agent.py b/src/rai_core/rai/agents/langchain/react_agent.py new file mode 100644 index 000000000..e96baea38 --- /dev/null +++ b/src/rai_core/rai/agents/langchain/react_agent.py @@ -0,0 +1,42 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional + +from langchain_core.language_models import BaseChatModel +from langchain_core.tools import BaseTool + +from rai.agents.langchain import create_react_runnable +from rai.agents.langchain.agent import LangChainAgent +from rai.agents.langchain.runnables import ReActAgentState +from rai.communication.hri_connector import HRIConnector, HRIMessage + + +class ReActAgent(LangChainAgent): + def __init__( + self, + target_connectors: Dict[str, HRIConnector[HRIMessage]], + llm: Optional[BaseChatModel] = None, + tools: Optional[List[BaseTool]] = None, + state: Optional[ReActAgentState] = None, + system_prompt: Optional[str] = None, + ): + runnable = create_react_runnable( + llm=llm, tools=tools, system_prompt=system_prompt + ) + super().__init__( + target_connectors=target_connectors, + runnable=runnable, + state=state, + ) diff --git a/src/rai_core/rai/agents/react_agent.py b/src/rai_core/rai/agents/react_agent.py deleted file mode 100644 index 8ab4b1645..000000000 --- a/src/rai_core/rai/agents/react_agent.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (C) 2025 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import threading -import time -from typing import Any, Dict, List, Optional, cast - -from langchain_core.language_models import BaseChatModel -from langchain_core.tools import BaseTool - -from rai.agents.base import BaseAgent -from rai.agents.langchain import HRICallbackHandler, create_react_runnable -from rai.agents.langchain.runnables import ReActAgentState -from rai.communication.hri_connector import HRIConnector, HRIMessage -from rai.initialization import get_tracing_callbacks - - -class ReActAgent(BaseAgent): - def __init__( - self, - connectors: dict[str, HRIConnector[HRIMessage]], - llm: Optional[BaseChatModel] = None, - tools: Optional[List[BaseTool]] = None, - state: Optional[ReActAgentState] = None, - system_prompt: Optional[str] = None, - ): - super().__init__() - self.logger = logging.getLogger(__name__) - self.agent = create_react_runnable( - llm=llm, tools=tools, system_prompt=system_prompt - ) - self.callback = HRICallbackHandler( - connectors=connectors, aggregate_chunks=True, logger=self.logger - ) - self.tracing_callbacks = get_tracing_callbacks() - self.state = state or ReActAgentState(messages=[]) - self.thread: Optional[threading.Thread] = None - self._stop_event = threading.Event() - self.connectors = connectors - - def run(self): - if self.thread is not None: - raise RuntimeError("Agent is already running") - self.thread = threading.Thread(target=self._run_loop) - self.thread.start() - - def _run_loop(self): - while not self._stop_event.is_set(): - received_messages = {} - try: - received_messages = self.receive_all_connectors() - except ValueError: - self.logger.info("Waiting for messages...") - if received_messages: - self.logger.info("Received messages") - reduced_message = self._reduce_messages(received_messages) - langchain_message = reduced_message.to_langchain() - self.state["messages"].append(langchain_message) - # callback is used to send messages to the connectors - self.agent.invoke( - self.state, - config={"callbacks": [self.callback, *self.tracing_callbacks]}, - ) - time.sleep(0.3) - - def stop(self): - self._stop_event.set() - if self.thread is not None: - self.logger.info("Stopping the agent. Please wait...") - self.thread.join() - self.thread = None - self.logger.info("Agent stopped") - - def receive_all_connectors(self) -> Dict[str, Dict[str, HRIMessage]]: - received_messages: Dict[str, Any] = {} - for connector_name, connector in self.connectors.items(): - received_message = cast( - HRIConnector[HRIMessage], connector - ).receive_all_sources() - received_message = {k: v for k, v in received_message.items() if v} - if received_message: - received_messages[connector_name] = received_message - return received_messages - - def _reduce_messages( - self, received_messages: Dict[str, Dict[str, HRIMessage]] - ) -> HRIMessage: - text = "" - images = [] - audios = [] - for connector_name, connector_sources in received_messages.items(): - text += f"{connector_name}\n" - for source_name, source_message in connector_sources.items(): - text += f"{source_name}: {source_message.text}\n" - images.extend(source_message.images) - audios.extend(source_message.audios) - return HRIMessage( - text=text, - images=images, - audios=audios, - message_author="human", - ) diff --git a/src/rai_core/rai/communication/hri_connector.py b/src/rai_core/rai/communication/hri_connector.py index 9093c0b43..cf1f15dad 100644 --- a/src/rai_core/rai/communication/hri_connector.py +++ b/src/rai_core/rai/communication/hri_connector.py @@ -143,7 +143,7 @@ class HRIConnector(Generic[T], BaseConnector[T]): Used for sending and receiving messages between human and robot from various sources. """ - def _build_message( + def build_message( self, message: LangchainBaseMessage | RAIMultimodalMessage, communication_id: Optional[str] = None, diff --git a/src/rai_core/rai/communication/ros2/connectors/base.py b/src/rai_core/rai/communication/ros2/connectors/base.py index 963f78786..e48e1ff65 100644 --- a/src/rai_core/rai/communication/ros2/connectors/base.py +++ b/src/rai_core/rai/communication/ros2/connectors/base.py @@ -152,7 +152,7 @@ def send_message( qos_profile=qos_profile, ) - def general_callback_preprocessor(self, message: Any): + def general_callback_preprocessor(self, message: Any) -> T: return self.T_class(payload=message, metadata={"msg_type": str(type(message))}) def register_callback( diff --git a/src/rai_core/rai/communication/ros2/connectors/hri_connector.py b/src/rai_core/rai/communication/ros2/connectors/hri_connector.py index ac7149ccf..afe56d45a 100644 --- a/src/rai_core/rai/communication/ros2/connectors/hri_connector.py +++ b/src/rai_core/rai/communication/ros2/connectors/hri_connector.py @@ -12,130 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util import logging import uuid -from typing import Any, Callable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Optional -from rai.communication.ros2.api import ( - ConfigurableROS2TopicAPI, - TopicConfig, -) +from rclpy.qos import QoSProfile + +from rai.communication import HRIConnector from rai.communication.ros2.connectors.base import ROS2BaseConnector from rai.communication.ros2.messages import ROS2HRIMessage -try: - import rai_interfaces.msg -except ImportError: - logging.warning("rai_interfaces is not installed, ROS 2 HRIMessage will not work.") +if importlib.util.find_spec("rai_interfaces.msg") is None: + logging.warning( + "This feature is based on rai_interfaces.msg. Make sure rai_interfaces is installed." + ) -class ROS2HRIConnector(ROS2BaseConnector[ROS2HRIMessage]): +class ROS2HRIConnector(ROS2BaseConnector[ROS2HRIMessage], HRIConnector[ROS2HRIMessage]): def __init__( self, node_name: str = f"rai_ros2_hri_connector_{str(uuid.uuid4())[-12:]}", - targets: List[Union[str, Tuple[str, TopicConfig]]] = [], - sources: List[Union[str, Tuple[str, TopicConfig]]] = [], ): - configured_targets = [ - target[0] if isinstance(target, tuple) else target for target in targets - ] - configured_sources = [ - source[0] if isinstance(source, tuple) else source for source in sources - ] - self.configured_targets = configured_targets - self.configured_sources = configured_sources - - _targets = [ - ( - target - if isinstance(target, tuple) - else (target, TopicConfig(is_subscriber=False)) - ) - for target in targets - ] - _sources = [ - ( - source - if isinstance(source, tuple) - else (source, TopicConfig(is_subscriber=True)) - ) - for source in sources - ] super().__init__(node_name=node_name) - self._topic_api = ConfigurableROS2TopicAPI(self._node) - self._configure_publishers(_targets) - self._configure_subscribers(_sources) - - def _configure_publishers(self, targets: List[Tuple[str, TopicConfig]]): - for target in targets: - self._topic_api.configure_publisher(target[0], target[1]) - def _configure_subscribers(self, sources: List[Tuple[str, TopicConfig]]): - for source in sources: - self._topic_api.configure_subscriber(source[0], source[1]) - - def send_message(self, message: ROS2HRIMessage, target: str, **kwargs): - self._topic_api.publish_configured( + def send_message( + self, + message: ROS2HRIMessage, + target: str, + *, + qos_profile: Optional[QoSProfile] = None, + auto_qos_matching: bool = True, + **kwargs, + ): + self._topic_api.publish( topic=target, msg_content=message.to_ros2_dict(), + msg_type="rai_interfaces/msg/HRIMessage", + auto_qos_matching=auto_qos_matching, + qos_profile=qos_profile, ) - def receive_message( + def register_callback( self, source: str, - timeout_sec: float = 1.0, + callback: Callable[[ROS2HRIMessage], None], + raw: bool = False, *, - message_author: Literal["human", "ai"] = "human", msg_type: Optional[str] = None, - auto_topic_type: bool = True, - **kwargs: Any, - ) -> ROS2HRIMessage: - msg = self._topic_api.receive( - topic=source, - timeout_sec=timeout_sec, - auto_topic_type=auto_topic_type, - ) - if not isinstance(msg, rai_interfaces.msg.HRIMessage): - raise ValueError( - f"Received message is not of type rai_interfaces.msg.HRIMessage, got {type(msg)}" - ) - return ROS2HRIMessage.from_ros2(msg, message_author) - - def create_service( - self, - service_name: str, - on_request: Callable, - on_done: Optional[Callable] = None, - *, - service_type: str, - **kwargs: Any, - ) -> str: - return self._service_api.create_service( - service_name=service_name, - callback=on_request, - on_done=on_done, - service_type=service_type, - **kwargs, - ) - - def create_action( - self, - action_name: str, - generate_feedback_callback: Callable, - *, - action_type: str, + qos_profile: Optional[QoSProfile] = None, + auto_qos_matching: bool = True, **kwargs: Any, ) -> str: - return self._actions_api.create_action_server( - action_name=action_name, - action_type=action_type, - execute_callback=generate_feedback_callback, + if msg_type is None: + msg_type = "rai_interfaces/msg/HRIMessage" + return super().register_callback( + source, + callback, + raw, + msg_type=msg_type, + qos_profile=qos_profile, + auto_qos_matching=auto_qos_matching, **kwargs, ) - def shutdown(self): - self._executor.shutdown() - self._thread.join() - self._actions_api.shutdown() - self._topic_api.shutdown() - self._node.destroy_node() + def general_callback_preprocessor(self, message: Any) -> ROS2HRIMessage: + return ROS2HRIMessage.from_ros2(message, message_author="human") diff --git a/src/rai_core/rai/communication/ros2/messages.py b/src/rai_core/rai/communication/ros2/messages.py index 0f05abf11..458bbef3e 100644 --- a/src/rai_core/rai/communication/ros2/messages.py +++ b/src/rai_core/rai/communication/ros2/messages.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util import logging from collections import OrderedDict from typing import Any, List, Literal, cast @@ -27,14 +28,14 @@ from rai.communication.base_connector import BaseMessage from rai.communication.hri_connector import HRIMessage -try: +if importlib.util.find_spec("rai_interfaces.msg") is None: + logging.warning("rai_interfaces is not installed, ROS 2 HRIMessage will not work.") +else: import rai_interfaces.msg from rai_interfaces.msg import HRIMessage as ROS2HRIMessage_ from rai_interfaces.msg._audio_message import ( AudioMessage as ROS2HRIMessage__Audio, ) -except ImportError: - logging.warning("rai_interfaces is not installed, ROS 2 HRIMessage will not work.") class ROS2Message(BaseMessage): @@ -64,7 +65,7 @@ def from_ros2( for audio_msg in cast(List[ROS2HRIMessage__Audio], msg.audios) ] communication_id = msg.communication_id if msg.communication_id != "" else None - return ROS2HRIMessage( + return cls( text=msg.text, images=pil_images, audios=audio_segments, diff --git a/tests/agents/test_langchain_agent.py b/tests/agents/test_langchain_agent.py new file mode 100644 index 000000000..c220f7bbe --- /dev/null +++ b/tests/agents/test_langchain_agent.py @@ -0,0 +1,41 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque +from typing import List + +import pytest +from rai.agents.langchain.agent import LangChainAgent, newMessageBehaviorType + + +@pytest.mark.parametrize( + "new_message_behavior,in_buffer,out_buffer,output", + [ + ("take_all", [1, 2, 3], [], [1, 2, 3]), + ("keep_last", [1, 2, 3], [], [3]), + ("queue", [1, 2, 3], [2, 3], [1]), + ("interupt_take_all", [1, 2, 3], [], [1, 2, 3]), + ("interupt_keep_last", [1, 2, 3], [], [3]), + ], +) +def test_reduce_messages( + new_message_behavior: newMessageBehaviorType, + in_buffer: List, + out_buffer: List, + output: List, +): + buffer = deque(in_buffer) + output_ = LangChainAgent._apply_reduction_behavior(new_message_behavior, buffer) + assert output == output_ + assert buffer == deque(out_buffer) diff --git a/tests/communication/ros2/test_connectors.py b/tests/communication/ros2/test_connectors.py index ba42148af..7363b9c46 100644 --- a/tests/communication/ros2/test_connectors.py +++ b/tests/communication/ros2/test_connectors.py @@ -170,7 +170,7 @@ def test_ros2hri_default_message_publish( ros_setup: None, request: pytest.FixtureRequest ): topic_name = f"{request.node.originalname}_topic" # type: ignore - connector = ROS2HRIConnector(targets=[topic_name]) + connector = ROS2HRIConnector() hri_message_receiver = HRIMessageSubscriber(topic_name) executors, threads = multi_threaded_spinner([hri_message_receiver])