diff --git a/docs/demos/rosbot_xl.md b/docs/demos/rosbot_xl.md
index bca41f11d..6f0173bab 100644
--- a/docs/demos/rosbot_xl.md
+++ b/docs/demos/rosbot_xl.md
@@ -6,15 +6,19 @@ This demo utilizes Open 3D Engine simulation and allows you to work with RAI on
## Quick start
+> [!TIP]
+> The demo uses the `complex_model` LLM configured in [../../config.toml](../../config.toml). The model should be a multimodal, tool-calling model.
+
1. Download the newest binary release:
- Ubuntu 22.04 & ros2 humble: [link](https://robotec-ml-roscon2024-demos.s3.eu-central-1.amazonaws.com/ROSCON_Release/RAIROSBotDemo_1.0.0_jammyhumble.zip)
- Ubuntu 24.04 & ros2 jazzy: [link](https://robotec-ml-roscon2024-demos.s3.eu-central-1.amazonaws.com/ROSCON_Release/RAIROSBotDemo_1.0.0_noblejazzy.zip)
-2. Install required packages
+2. Install and download required packages
```bash
- sudo apt install ros-${ROS_DISTRO}-ackermann-msgs ros-${ROS_DISTRO}-gazebo-msgs ros-${ROS_DISTRO}-control-toolbox ros-${ROS_DISTRO}-nav2-bringup
+ vcs import < demos.repos
+ rosdep install --from-paths src --ignore-src -r -y
poetry install --with openset
```
@@ -32,56 +36,28 @@ This demo utilizes Open 3D Engine simulation and allows you to work with RAI on
If you would like more freedom to adapt the simulation to your needs, you can make changes using
[O3DE Editor](https://www.docs.o3de.org/docs/welcome-guide/) and build the project
yourself.
-Please refer to [rai husarion rosbot xl demo][rai rosbot demo] for more details.
+Please refer to [rai husarion rosbot xl demo](https://github.com/RobotecAI/rai-rosbot-xl-demo) for more details.
# Running RAI
-1. Robot identity
-
- Process of setting up the robot identity is described in [create_robots_whoami](../create_robots_whoami.md).
- We provide ready whoami for RosBotXL in the package.
-
- ```bash
- cd rai
- vcs import < demos.repos
- colcon build --symlink-install --packages-select rosbot_xl_whoami
- ```
-
-2. Running rai nodes and agents, navigation stack and O3DE simulation.
+1. Running rai nodes and agents, navigation stack and O3DE simulation.
```bash
ros2 launch ./examples/rosbot-xl.launch.py game_launcher:=path/to/RAIROSBotXLDemo.GameLauncher
```
-3. Play with the demo, adding tasks to the RAI agent. Here are some examples:
+2. Run streamlit gui:
```bash
- # Ask robot where it is. RAI will use camera to describe the environment
- ros2 action send_goal -f /perform_task rai_interfaces/action/Task "{priority: 10, description: '', task: 'Where are you?'}"
-
- # See integration with the navigation stack
- ros2 action send_goal -f /perform_task rai_interfaces/action/Task "{priority: 10, description: '', task: 'Drive 1 meter forward'}"
- ros2 action send_goal -f /perform_task rai_interfaces/action/Task "{priority: 10, description: '', task: 'Spin 90 degrees'}"
-
- # Try out more complicated tasks
- ros2 action send_goal -f /perform_task rai_interfaces/action/Task "{priority: 10, description: '', task: ' Drive forward if the path is clear, otherwise backward'}"
+ streamlit run examples/rosbot-xl-demo.py
```
-> **NOTE**: For now agent is capable of performing only 1 task at once.
-> Human-Robot Interaction module is not yet included in the demo (coming soon!).
-
-### What is happening?
-
-By looking at the example code in [rai/examples/rosbot-xl-demo.py](../../examples/rosbot-xl-demo.py) `examples` you can see that:
-
-- This node has no information about the robot besides what it can get from `rai_whoami_node`.
-- Topics can be whitelisted to only receive information about the robot.
-- Before every LLM decision, `rai_node` sends its state to the LLM Agent. By default, it contains ros interfaces (topics, services, actions) and logs summary, but the state can be extended.
-- In the example we are also adding description of the camera image to the state.
-
-If you wish, you can learn more about [configuring RAI for a specific robot](../create_robots_whoami.md).
+3. Play with the demo, prompting the agent to perform tasks. Here are some examples:
-[rai rosbot demo]: https://github.com/RobotecAI/rai-rosbot-xl-demo
+ - Where are you now?
+ - What do you see?
+ - What is the position of bed?
+ - Navigate to the kitchen.
> [!TIP]
> If you are having trouble running the binary, you can build it from source [here](https://github.com/RobotecAI/rai-rosbot-xl-demo).
diff --git a/docs/developer_guide/tools.md b/docs/developer_guide/tools.md
new file mode 100644
index 000000000..026ac96a7
--- /dev/null
+++ b/docs/developer_guide/tools.md
@@ -0,0 +1,247 @@
+# Tools
+
+Tools are a fundamental concept in LangChain that allow AI models to interact with external systems and perform specific operations. Think of tools as callable functions that bridge the gap between natural language understanding and system execution.
+
+RAI offers a comprehensive set of pre-built tools, including both general-purpose and ROS 2-specific tools [here](../../src/rai_core/rai/tools/ros2). However, in some cases, you may need to develop custom tools tailored to specific robots or applications. This guide demonstrates how to create custom tools in RAI using the [LangChain framework](https://python.langchain.com/docs/).
+
+RAI supports two primary approaches for implementing tools, each with distinct advantages:
+
+### `BaseTool` Class
+
+- Offers full control over tool behavior and lifecycle
+- Allows configuration parameters
+- Supports stateful operations (e.g., maintaining ROS 2 connector instances)
+
+### `@tool` Decorator
+
+- Provides a lightweight, functional approach
+- Ideal for stateless operations
+- Minimizes boilerplate code
+- Suited for simple, single-purpose tools
+
+Use the `BaseTool` class when state management, or extensive configuration is required. Choose the `@tool` decorator for simple, stateless functionality where conciseness is preferred.
+
+---
+
+## Creating a Custom Tool
+
+LangChain tools typically return either a string or a tuple containing a string and an artifact.
+
+RAI extends LangChain’s tool capabilities by supporting **multimodal tools**—tools that return not only text but also other content types, such as images, audio, or structured data. This is achieved using a special object called `MultimodalArtifact` along with a custom `ToolRunner` class.
+
+---
+
+### Single-Modal Tool (Text Output)
+
+Here’s an example of a single-modal tool implemented using class inheritance:
+
+```python
+from langchain_core.tools import BaseTool
+from pydantic import BaseModel, Field
+from typing import Type
+
+
+class GrabObjectToolInput(BaseModel):
+ """Input schema for the GrabObjectTool."""
+ object_name: str = Field(description="The name of the object to grab")
+
+
+class GrabObjectTool(BaseTool):
+ """Tool for grabbing objects using a robot."""
+ name: str = "grab_object"
+ description: str = "Grabs a specified object using the robot's manipulator"
+ args_schema: Type[GrabObjectToolInput] = GrabObjectToolInput
+
+ def _run(self, object_name: str) -> str:
+ """Execute the object grabbing operation."""
+ try:
+ status = robot.grab_object(object_name)
+ return f"Successfully grabbed object: {object_name}, status: {status}"
+ except Exception as e:
+ return f"Failed to grab object: {object_name}, error: {str(e)}"
+```
+
+Alternatively, using the `@tool` decorator:
+
+```python
+from langchain_core.tools import tool
+
+@tool
+def grab_object(object_name: str) -> str:
+ """Grabs a specified object using the robot's manipulator."""
+ try:
+ status = robot.grab_object(object_name)
+ return f"Successfully grabbed object: {object_name}, status: {status}"
+ except Exception as e:
+ return f"Failed to grab object: {object_name}, error: {str(e)}"
+```
+
+---
+
+### Multimodal Tool (Text + Image Output)
+
+RAI supports multimodal tools through the `rai.agents.tool_runner.ToolRunner` class. These tools must use this runner either directly or via agents such as [`create_react_runnable`](../../src/rai_core/rai/agents/langchain/runnables.py) to handle multimedia output correctly.
+
+```python
+from langchain_core.tools import BaseTool
+from pydantic import BaseModel, Field
+from typing import Type, Tuple
+from rai.messages import MultimodalArtifact
+
+
+class Get360ImageToolInput(BaseModel):
+ """Input schema for the Get360ImageTool."""
+ topic: str = Field(description="The topic name for the 360 image")
+
+
+class Get360ImageTool(BaseTool):
+ """Tool for retrieving 360-degree images."""
+ name: str = "get_360_image"
+ description: str = "Retrieves a 360-degree image from the specified topic"
+ args_schema: Type[Get360ImageToolInput] = Get360ImageToolInput
+ response_format: str = "content_and_artifact"
+
+ def _run(self, topic: str) -> Tuple[str, MultimodalArtifact]:
+ try:
+ image = robot.get_360_image(topic)
+ return "Successfully retrieved 360 image", MultimodalArtifact(images=[image])
+ except Exception as e:
+ return f"Failed to retrieve image: {str(e)}", MultimodalArtifact(images=[])
+```
+
+---
+
+### ROS 2 Tools
+
+RAI includes a base class for ROS 2 tools, supporting configuration of readable, writable, and forbidden topics/actions/services, as well as ROS 2 connector. TODO(docs): link docs to the ARIConnector.
+
+```python
+from rai.tools.ros2.base import BaseROS2Tool
+from pydantic import BaseModel, Field
+from typing import Type, cast
+from sensor_msgs.msg import PointCloud2
+
+
+class GetROS2LidarDataToolInput(BaseModel):
+ """Input schema for the GetROS2LidarDataTool."""
+ topic: str = Field(description="The topic name for the LiDAR data")
+
+
+class GetROS2LidarDataTool(BaseROS2Tool):
+ """Tool for retrieving and processing LiDAR data."""
+ name: str = "get_ros2_lidar_data"
+ description: str = "Retrieves and processes LiDAR data from the specified topic"
+ args_schema: Type[GetROS2LidarDataToolInput] = GetROS2LidarDataToolInput
+
+ def _run(self, topic: str) -> str:
+ try:
+ lidar_data = self.connector.receive_message(topic)
+ msg = cast(PointCloud2, lidar_data.payload)
+ # Process the LiDAR data
+ return f"Successfully processed LiDAR data. Detected objects: ..."
+ except Exception as e:
+ return f"Failed to process LiDAR data: {str(e)}"
+```
+
+Refer to the [BaseROS2Tool source code](../../src/rai_core/rai/tools/ros2/base.py) for more information.
+
+---
+
+## Tool Initialization
+
+Tools can be initialized with parameters such as a connector, enabling custom configurations for ROS 2 environments.
+
+```python
+from rai.communication.ros2 import ROS2ARIConnector
+from rai.tools.ros2 import (
+ GetROS2ImageTool,
+ GetROS2TopicsNamesAndTypesTool,
+ PublishROS2MessageTool,
+)
+
+def initialize_tools(connector: ROS2ARIConnector):
+ """Initialize and configure ROS 2 tools.
+
+ Returns:
+ list: A list of configured tools.
+ """
+ readable_names = ["/color_image5", "/depth_image5", "/color_camera_info5"]
+ forbidden_names = ["cmd_vel"]
+ writable_names = ["/to_human"]
+
+ return [
+ GetROS2ImageTool(
+ connector=connector, readable=readable_names, forbidden=forbidden_names
+ ),
+ GetROS2TopicsNamesAndTypesTool(
+ connector=connector,
+ readable=readable_names,
+ forbidden=forbidden_names,
+ writable=writable_names,
+ ),
+ PublishROS2MessageTool(
+ connector=connector, writable=writable_names, forbidden=forbidden_names
+ ),
+ ]
+```
+
+---
+
+### Using Tools in a RAI Agent (Distributed Setup)
+
+TODO(docs): add link to the BaseAgent docs (regarding distributed setup)
+
+```python
+from rai.agents import ReActAgent
+from rai.communication import ROS2ARIConnector, ROS2HRIConnector
+from rai.tools.ros2 import ROS2Toolkit
+from rai.utils import ROS2Context, wait_for_shutdown
+
+@ROS2Context()
+def main() -> None:
+ """Initialize and run the RAI agent with configured tools."""
+ connector = ROS2HRIConnector(sources=["/from_human"], targets=["/to_human"])
+ ari_connector = ROS2ARIConnector()
+ agent = ReActAgent(
+ connectors={"hri": connector},
+ tools=initialize_tools(connector=ari_connector),
+ )
+ agent.run()
+ wait_for_shutdown([agent])
+
+# Example:
+# ros2 topic pub /from_human rai_interfaces/msg/HRIMessage "{\"text\": \"What do you see?\"}"
+# ros2 topic echo /to_human rai_interfaces/msg/HRIMessage
+```
+
+---
+
+### Using Tools in LangChain/LangGraph Agent (Local Setup)
+
+```python
+from rai.agents.langchain import create_react_runnable
+from langchain.schema import HumanMessage
+from rai.utils import ROS2Context, wait_for_shutdown
+
+@ROS2Context()
+def main():
+ ari_connector = ROS2ARIConnector()
+ agent = create_react_runnable(
+ tools=initialize_tools(connector=ari_connector),
+ system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
+ )
+ state = {'messages': []}
+ while True:
+ input_text = input("Enter a prompt: ")
+ state['messages'].append(HumanMessage(content=input_text))
+ response = agent.invoke(state)
+ print(response)
+```
+
+---
+
+## Related Topics
+
+- [Connectors](../communication/connectors.md)
+- [ROS2ARIConnector](../communication/ros2.md)
+- [ROS2HRIConnector](../communication/ros2.md)
diff --git a/examples/manipulation-demo.py b/examples/manipulation-demo.py
index 92c820031..197532fe7 100644
--- a/examples/manipulation-demo.py
+++ b/examples/manipulation-demo.py
@@ -19,7 +19,7 @@
from rai.agents.conversational_agent import create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool
-from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool
+from rai.tools.ros2 import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool
from rai.utils.model_initialization import get_llm_model
from rai_open_set_vision.tools import GetGrabbingPointTool
diff --git a/examples/rosbot-xl-demo.py b/examples/rosbot-xl-demo.py
index d71cd934e..694d52ce9 100644
--- a/examples/rosbot-xl-demo.py
+++ b/examples/rosbot-xl-demo.py
@@ -12,148 +12,98 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List
+
import rclpy
-import rclpy.executors
-import rclpy.logging
import streamlit as st
-from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
-from rai.agents.conversational_agent import create_conversational_agent
-from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke
+from langchain_core.tools import BaseTool
+from rai.agents import ReActAgent
from rai.communication.ros2 import ROS2ARIConnector
-from rai.messages import HumanMultimodalMessage
+from rai.frontend.streamlit import run_streamlit_app
from rai.tools.ros.manipulation import GetGrabbingPointTool, GetObjectPositionsTool
-from rai.tools.ros2 import ROS2Toolkit
+from rai.tools.ros2 import (
+ GetROS2ImageConfiguredTool,
+ GetROS2TransformConfiguredTool,
+ Nav2Toolkit,
+)
from rai.tools.time import WaitForSecondsTool
from rai.utils.model_initialization import get_llm_model
-from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool
+
+# Set page configuration first
+st.set_page_config(
+ page_title="RAI ROSBotXL Demo",
+ page_icon=":robot:",
+)
@st.cache_resource
def initialize_agent():
rclpy.init()
- SYSTEM_PROMPT = """You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests.
- Do not make assumptions about the environment you are currently in.
- You can use ros2 topics, services and actions to operate.
-
- As a first step check transforms by getting 1 message from /tf topic
- use /cmd_vel topic very carefully. Obstacle detection works only with nav2 stack, so be careful when it is not used. >
- be patient with running ros2 actions. usually the take some time to run.
- Always check your transform before and after you perform ros2 actions, so that you can verify if it worked.
-
- Navigation tips:
- - it's good to start finding objects by rotating, then navigating to some diverse location with occasional rotations. Remember to frequency detect objects.
- - for driving forward/backward or to some coordinates, ros2 actions are better.
- - for driving for some specific time or in specific manner (like shaper or turns) it good to use /cmd_vel topic
- - you are currently unable to read map or point-cloud, so please avoid subscribing to such topics.
- - if you are asked to drive towards some object, it's good to:
- 1. check the camera image and verify if objects can be seen
- 2. if only driving forward is required, do it
- 3. if obstacle avoidance might be required, use ros2 actions navigate_*, but first check your current position, then very accurately estimate the goal pose.
- - it is good to verify using given information if the robot is not stuck
- - navigation actions sometimes fail. Their output can be read from rosout. You can also tell if they partially worked by checking the robot position and rotation.
- - before using any ros2 interfaces, always make sure to check you are using the right interface
- - processing camera image takes 5-10s. Take it into account that if the robot is moving, the information can be outdated. Handle it by good planning of your movements.
- - you are encouraged to use wait tool in between checking the status of actions
- - to find some object navigate around and check the surrounding area
- - when the goal is accomplished please make sure to cancel running actions
- - when you reach the navigation goal - double check if you reached it by checking the current position
- - if you detect collision, please stop operation
-
- - you will be given your camera image description. Based on this information you can reason about positions of objects.
- - be careful and aboid obstacles
-
- Here are the corners of your environment:
- (-2.76,9.04, 0.0),
- (4.62, 9.07, 0.0),
- (-2.79, -3.83, 0.0),
- (4.59, -3.81, 0.0)
-
- This is location of places:
- Kitchen:
- (2.06, -0.23, 0.0),
- (2.07, -1.43, 0.0),
- (-2.44, -0.38, 0.0),
- (-2.56, -1.47, 0.0)
-
- # Living room:
- (-2.49, 1.87, 0.0),
- (-2.50, 5.49, 0.0),
- (0.79, 5.73, 0.0),
- (0.92, 1.01, 0.0)
-
- Before starting anything, make sure to load available topics, services and actions.
+ SYSTEM_PROMPT = """
+ You are an intelligent autonomous agent embodied in ROSBotXL—this robot is your body, your interface with the physical world.
+ You operate within a known indoor environment. Key locations include:
+ Kitchen (center): (-0.2175, -0.8775, 0.0)
+ Living Room (center): (-0.82, 3.525, 0.0)
+ ROSBotXL is equipped with a camera, enabling you to visually perceive your surroundings.
+ You can obtain real-time images from the ROS 2 topic using the get_ros2_camera_image tool.
+ When executing tasks that require time to complete—such as navigating between locations,
+ waiting for an event, or monitoring a process—you must use the WaitForSecondsTool to pause appropriately during or between steps.
+ This ensures smooth and realistic operation.
+ Your mission is to understand and faithfully execute the user's commands using your tools, sensors, and spatial knowledge.
+ Always plan ahead: analyze the task, evaluate the context, and reason through your actions to ensure they are effective, safe, and aligned with the goal.
+ Act with intelligence and autonomy. Be proactive, deliberate, and aware of your environment.
+ Your job is to transform user intent into meaningful, goal-driven behavior within the physical world.
"""
connector = ROS2ARIConnector()
-
- agent = create_conversational_agent(
- llm=get_llm_model("complex_model", streaming=True),
- system_prompt=SYSTEM_PROMPT,
- tools=[
- *ROS2Toolkit(
- connector=connector, forbidden=["/tf", "/cmd_vel"]
- ).get_tools(),
- WaitForSecondsTool(),
- GetDetectionTool(connector=connector, node=connector.node),
- GetDistanceToObjectsTool(connector=connector, node=connector.node),
- GetObjectPositionsTool(
+ tools: List[BaseTool] = [
+ GetROS2TransformConfiguredTool(
+ connector=connector,
+ source_frame="map",
+ target_frame="base_link",
+ timeout_sec=5.0,
+ ),
+ GetROS2ImageConfiguredTool(
+ connector=connector,
+ topic="/camera/camera/color/image_raw",
+ response_format="content_and_artifact",
+ ),
+ WaitForSecondsTool(),
+ GetObjectPositionsTool(
+ connector=connector,
+ target_frame="map",
+ source_frame="sensor_frame",
+ camera_topic="/camera/camera/color/image_raw",
+ depth_topic="/camera/camera/depth/image_rect_raw",
+ camera_info_topic="/camera/camera/color/camera_info",
+ get_grabbing_point_tool=GetGrabbingPointTool(
connector=connector,
- target_frame="map",
- source_frame="sensor_frame",
- camera_topic="/camera/camera/color/image_raw",
- depth_topic="/camera/camera/depth/image_rect_raw",
- camera_info_topic="/camera/camera/color/camera_info",
- get_grabbing_point_tool=GetGrabbingPointTool(
- connector=connector,
- ),
),
- ],
- )
+ ),
+ *Nav2Toolkit(connector=connector).get_tools(),
+ ]
+ # Initialize an empty connectors dictionary since we're using the agent in direct mode
+ # In a distributed setup, connectors would be used to handle communication between
+ # components, routing agent inputs/outputs through the distributed system
+ connectors = {}
+
+ agent = ReActAgent(
+ connectors=connectors,
+ llm=get_llm_model("complex_model", streaming=True),
+ system_prompt=SYSTEM_PROMPT,
+ tools=tools,
+ ).agent
connector.node.declare_parameter("conversion_ratio", 1.0)
return agent
def main():
- st.set_page_config(
- page_title="RAI ROSBotXL Demo",
- page_icon=":robot:",
+ run_streamlit_app(
+ initialize_agent(),
+ "RAI ROSBotXL Demo",
+ "Hi! I am a ROSBotXL robot. What can I do for you?",
)
- st.title("RAI ROSBotXL Demo")
- st.markdown("---")
-
- st.sidebar.header("Tool Calls History")
-
- if "graph" not in st.session_state:
- graph = initialize_agent()
- st.session_state["graph"] = graph
-
- if "messages" not in st.session_state:
- st.session_state["messages"] = [
- AIMessage(content="Hi! I am ROSBotXL. What can I do for you?")
- ]
-
- prompt = st.chat_input()
- for msg in st.session_state.messages:
- if isinstance(msg, AIMessage):
- if msg.content:
- st.chat_message("assistant").write(msg.content)
- elif isinstance(msg, HumanMultimodalMessage):
- continue
- elif isinstance(msg, HumanMessage):
- st.chat_message("user").write(msg.content)
- elif isinstance(msg, ToolMessage):
- with st.sidebar.expander(f"Tool: {msg.name}", expanded=False):
- st.code(msg.content, language="json")
-
- if prompt:
- st.session_state.messages.append(HumanMessage(content=prompt))
- st.chat_message("user").write(prompt)
- with st.chat_message("assistant"):
- st_callback = get_streamlit_cb(st.container())
- streamlit_invoke(
- st.session_state["graph"], st.session_state.messages, [st_callback]
- )
if __name__ == "__main__":
diff --git a/examples/rosbot-xl_allowlist.txt b/examples/rosbot-xl_allowlist.txt
deleted file mode 100644
index bf9a3d6bd..000000000
--- a/examples/rosbot-xl_allowlist.txt
+++ /dev/null
@@ -1,22 +0,0 @@
-/rosout
-/camera/camera/color/image_raw
-/camera/camera/depth/image_rect_raw
-/map
-/scan
-/diagnostics
-/cmd_vel
-/led_strip
-/backup
-/compute_path_through_poses
-/compute_path_to_pose
-/dock_robot
-/drive_on_heading
-/follow_gps_waypoints
-/follow_path
-/follow_waypoints
-/navigate_through_poses
-/navigate_to_pose
-/smooth_path
-/spin
-/undock_robot
-/wait
diff --git a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py
index 4a6dabeb0..325ffda93 100644
--- a/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py
+++ b/src/rai_bench/rai_bench/examples/o3de_test_benchmark.py
@@ -27,7 +27,7 @@
GetObjectPositionsTool,
MoveToPointTool,
)
-from rai.tools.ros2.topics import (
+from rai.tools.ros2 import (
GetROS2ImageTool,
GetROS2TopicsNamesAndTypesTool,
)
diff --git a/src/rai_core/rai/__init__.py b/src/rai_core/rai/__init__.py
index 1813d1263..97ceef6f0 100644
--- a/src/rai_core/rai/__init__.py
+++ b/src/rai_core/rai/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2024 Robotec.AI
+# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from rai.apps.high_level_api import ROS2Agent
-
-__all__ = ["ROS2Agent"]
diff --git a/src/rai_core/rai/agents/integrations/streamlit.py b/src/rai_core/rai/agents/integrations/streamlit.py
index 73360893a..42f45b62d 100644
--- a/src/rai_core/rai/agents/integrations/streamlit.py
+++ b/src/rai_core/rai/agents/integrations/streamlit.py
@@ -14,12 +14,14 @@
import base64
import inspect
-from typing import Any, Callable, Dict, TypeVar
+from typing import Any, Callable, Dict, List, TypeVar
import cv2
import numpy as np
import streamlit as st
from langchain_core.callbacks.base import BaseCallbackHandler
+from langchain_core.messages import BaseMessage
+from langchain_core.runnables import Runnable, RunnableConfig
from streamlit.delta_generator import DeltaGenerator
from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
@@ -170,7 +172,12 @@ def wrapper(*args, **kwargs) -> fn_return_type:
return st_cb
-def streamlit_invoke(graph, messages, callables):
+def streamlit_invoke(
+ graph: Runnable[Any, Any], messages: List[BaseMessage], callables: List[Callable]
+):
if not isinstance(callables, list):
raise TypeError("callables must be a list")
- return graph.invoke({"messages": messages}, config={"callbacks": callables})
+ return graph.invoke(
+ {"messages": messages},
+ config=RunnableConfig({"callbacks": callables, "recursion_limit": 100}),
+ )
diff --git a/src/rai_core/rai/agents/langchain/runnables.py b/src/rai_core/rai/agents/langchain/runnables.py
index 928581047..f733a540e 100644
--- a/src/rai_core/rai/agents/langchain/runnables.py
+++ b/src/rai_core/rai/agents/langchain/runnables.py
@@ -16,7 +16,7 @@
from typing import List, Optional, TypedDict, cast
from langchain_core.language_models import BaseChatModel
-from langchain_core.messages import BaseMessage
+from langchain_core.messages import BaseMessage, SystemMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langgraph.graph import START, StateGraph
@@ -38,7 +38,7 @@ class ReActAgentState(TypedDict):
messages: List[BaseMessage]
-def llm_node(llm: BaseChatModel, state: ReActAgentState):
+def llm_node(llm: BaseChatModel, system_prompt: Optional[str], state: ReActAgentState):
"""Process messages using the LLM.
Parameters
@@ -58,13 +58,18 @@ def llm_node(llm: BaseChatModel, state: ReActAgentState):
ValueError
If state is invalid or LLM processing fails
"""
-
+ if system_prompt:
+ # at this point, state['messages'] length should at least be 1
+ if not isinstance(state["messages"][0], SystemMessage):
+ state["messages"].insert(0, SystemMessage(content=system_prompt))
ai_msg = llm.invoke(state["messages"])
state["messages"].append(ai_msg)
def create_react_runnable(
- llm: Optional[BaseChatModel] = None, tools: Optional[List[BaseTool]] = None
+ llm: Optional[BaseChatModel] = None,
+ tools: Optional[List[BaseTool]] = None,
+ system_prompt: Optional[str] = None,
) -> Runnable[ReActAgentState, ReActAgentState]:
"""Create a react agent that can process messages and optionally use tools.
@@ -101,9 +106,9 @@ def create_react_runnable(
graph.add_edge("tools", "llm")
# Bind tools to LLM
bound_llm = cast(BaseChatModel, llm.bind_tools(tools))
- graph.add_node("llm", partial(llm_node, bound_llm))
+ graph.add_node("llm", partial(llm_node, bound_llm, system_prompt))
else:
- graph.add_node("llm", partial(llm_node, llm))
+ graph.add_node("llm", partial(llm_node, llm, system_prompt))
# Compile the graph
return graph.compile()
diff --git a/src/rai_core/rai/agents/react_agent.py b/src/rai_core/rai/agents/react_agent.py
index 98059ec27..f6097fbcd 100644
--- a/src/rai_core/rai/agents/react_agent.py
+++ b/src/rai_core/rai/agents/react_agent.py
@@ -34,10 +34,13 @@ def __init__(
llm: Optional[BaseChatModel] = None,
tools: Optional[List[BaseTool]] = None,
state: Optional[ReActAgentState] = None,
+ system_prompt: Optional[str] = None,
):
super().__init__(connectors=connectors)
self.logger = logging.getLogger(__name__)
- self.agent = create_react_runnable(llm=llm, tools=tools)
+ self.agent = create_react_runnable(
+ llm=llm, tools=tools, system_prompt=system_prompt
+ )
self.callback = HRICallbackHandler(
connectors=connectors, aggregate_chunks=True, logger=self.logger
)
diff --git a/src/rai_core/rai/apps/high_level_api.py b/src/rai_core/rai/apps/high_level_api.py
index dfe1087b5..224602007 100644
--- a/src/rai_core/rai/apps/high_level_api.py
+++ b/src/rai_core/rai/apps/high_level_api.py
@@ -18,7 +18,7 @@
from langchain_core.messages import BaseMessage, HumanMessage
from rai.agents.conversational_agent import create_conversational_agent
-from rai.tools.ros.cli import ros2_action, ros2_interface, ros2_service, ros2_topic
+from rai.tools.ros2.cli import ros2_action, ros2_interface, ros2_service, ros2_topic
from rai.utils.model_initialization import get_llm_model
diff --git a/src/rai_core/rai/communication/ros2/connectors/action_mixin.py b/src/rai_core/rai/communication/ros2/connectors/action_mixin.py
index 34b068d2a..22b456f46 100644
--- a/src/rai_core/rai/communication/ros2/connectors/action_mixin.py
+++ b/src/rai_core/rai/communication/ros2/connectors/action_mixin.py
@@ -24,7 +24,7 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
raise AttributeError(
f"{self.__class__.__name__} instance must have an attribute 'actions_api' of type ROS2ActionAPI"
)
- self._actions_api = self._actions_api # to make the type checker happy
+ self._actions_api: ROS2ActionAPI # to make the type checker happy
if not isinstance(self._actions_api, ROS2ActionAPI):
raise AttributeError(
f"{self.__class__.__name__} instance must have an attribute 'actions_api' of type ROS2ActionAPI"
@@ -60,4 +60,4 @@ def start_action(
return handle
def terminate_action(self, action_handle: str, **kwargs: Any):
- self._actions_api.terminate_goal(action_handle)
+ return self._actions_api.terminate_goal(action_handle)
diff --git a/src/rai_core/rai/frontend/__init__.py b/src/rai_core/rai/frontend/__init__.py
new file mode 100644
index 000000000..b8b774cc1
--- /dev/null
+++ b/src/rai_core/rai/frontend/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..frontend.streamlit import run_streamlit_app
+
+__all__ = ["run_streamlit_app"]
diff --git a/src/rai_core/rai/utils/streamlit.py b/src/rai_core/rai/frontend/streamlit.py
similarity index 82%
rename from src/rai_core/rai/utils/streamlit.py
rename to src/rai_core/rai/frontend/streamlit.py
index e4b385e6f..a2423ee3c 100644
--- a/src/rai_core/rai/utils/streamlit.py
+++ b/src/rai_core/rai/frontend/streamlit.py
@@ -12,21 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any
+
import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import Runnable
from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke
-from rai.agents.langchain.runnables import ReActAgentState
from rai.messages import HumanMultimodalMessage
-def run_streamlit_app(agent: Runnable[ReActAgentState, ReActAgentState]):
- st.set_page_config(
- page_title="RAI Manipulation Demo",
- page_icon=":robot:",
- )
- st.title("RAI Manipulation Demo")
+def run_streamlit_app(agent: Runnable[Any, Any], page_title: str, initial_message: str):
+ st.title(page_title)
st.markdown("---")
st.sidebar.header("Tool Calls History")
@@ -35,9 +32,7 @@ def run_streamlit_app(agent: Runnable[ReActAgentState, ReActAgentState]):
st.session_state["graph"] = agent
if "messages" not in st.session_state:
- st.session_state["messages"] = [
- AIMessage(content="Hi! I am a robotic arm. What can I do for you?")
- ]
+ st.session_state["messages"] = [AIMessage(content=initial_message)]
prompt = st.chat_input()
for msg in st.session_state.messages:
diff --git a/src/rai_core/rai/tools/debugging_assistant.py b/src/rai_core/rai/tools/debugging_assistant.py
index 54edd4c6e..789c0c956 100644
--- a/src/rai_core/rai/tools/debugging_assistant.py
+++ b/src/rai_core/rai/tools/debugging_assistant.py
@@ -17,7 +17,7 @@
from rai.agents.conversational_agent import create_conversational_agent
from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke
-from rai.tools.ros.cli import (
+from rai.tools.ros2.cli import (
ros2_action,
ros2_interface,
ros2_node,
diff --git a/src/rai_core/rai/tools/ros/__init__.py b/src/rai_core/rai/tools/ros/__init__.py
index c3be30b81..250c4b10a 100644
--- a/src/rai_core/rai/tools/ros/__init__.py
+++ b/src/rai_core/rai/tools/ros/__init__.py
@@ -13,31 +13,10 @@
# limitations under the License.
-from .cli import (
- ros2_action,
- ros2_interface,
- ros2_node,
- ros2_param,
- ros2_service,
- ros2_topic,
-)
-from .native import Ros2BaseInput, Ros2BaseTool
from .tools import (
AddDescribedWaypointToDatabaseTool,
- GetCurrentPositionTool,
- GetOccupancyGridTool,
)
__all__ = [
"AddDescribedWaypointToDatabaseTool",
- "GetCurrentPositionTool",
- "GetOccupancyGridTool",
- "Ros2BaseInput",
- "Ros2BaseTool",
- "ros2_action",
- "ros2_interface",
- "ros2_node",
- "ros2_param",
- "ros2_service",
- "ros2_topic",
]
diff --git a/src/rai_core/rai/tools/ros/deprecated.py b/src/rai_core/rai/tools/ros/deprecated.py
deleted file mode 100644
index ec8e3ba8e..000000000
--- a/src/rai_core/rai/tools/ros/deprecated.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import base64
-import logging
-from typing import Any, Callable, cast
-
-import cv2
-import rclpy
-import rclpy.qos
-from cv_bridge import CvBridge
-from rclpy.duration import Duration
-from rclpy.qos import (
- QoSDurabilityPolicy,
- QoSHistoryPolicy,
- QoSLivelinessPolicy,
- QoSProfile,
- QoSReliabilityPolicy,
-)
-from sensor_msgs.msg import Image
-
-from rai.tools.ros.utils import wait_for_message
-
-
-class SingleMessageGrabber:
- def __init__(
- self,
- topic: str,
- message_type: type,
- timeout_sec: int,
- logging_level: int = logging.INFO,
- postprocess: Callable[[Any], Any] = lambda x: x,
- ):
- self.topic = topic
- self.message_type = message_type
- self.timeout_sec = timeout_sec
- self.logger = logging.getLogger(self.__class__.__name__)
- self.logger.setLevel(logging_level)
- self.postprocess = getattr(self, "postprocess", postprocess)
-
- def grab_message(self) -> Any:
- node = rclpy.create_node(self.__class__.__name__ + "_node") # type: ignore
- qos_profile = rclpy.qos.qos_profile_sensor_data
- if (
- self.topic == "/map"
- ): # overfitting to husarion TODO(maciejmajek): find a better way
- qos_profile = QoSProfile(
- reliability=QoSReliabilityPolicy.RELIABLE,
- history=QoSHistoryPolicy.KEEP_ALL,
- durability=QoSDurabilityPolicy.TRANSIENT_LOCAL,
- lifespan=Duration(seconds=0),
- deadline=Duration(seconds=0),
- liveliness=QoSLivelinessPolicy.AUTOMATIC,
- liveliness_lease_duration=Duration(seconds=0),
- )
- success, msg = wait_for_message(
- self.message_type,
- node,
- self.topic,
- qos_profile=qos_profile,
- time_to_wait=self.timeout_sec,
- )
-
- if success:
- self.logger.info(
- f"Received message of type {self.message_type.__class__.__name__} from topic {self.topic}" # type: ignore
- )
- else:
- self.logger.error(
- f"Failed to receive message of type {self.message_type.__class__.__name__} from topic {self.topic}" # type: ignore
- )
-
- node.destroy_node()
- return msg
-
- def get_data(self) -> Any:
- msg = self.grab_message()
- return self.postprocess(msg)
-
-
-class SingleImageGrabber(SingleMessageGrabber):
- def __init__(
- self, topic: str, timeout_sec: int = 10, logging_level: int = logging.INFO
- ):
- self.topic = topic
- super().__init__(
- topic=topic,
- message_type=Image,
- timeout_sec=timeout_sec,
- logging_level=logging_level,
- )
-
- def postprocess(self, msg: Image) -> str:
- bridge = CvBridge()
- cv_image = cast(
- cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")
- ) # type: ignore
- if cv_image.shape[-1] == 4:
- cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB)
- base64_image = base64.b64encode(
- bytes(cv2.imencode(".png", cv_image)[1])
- ).decode("utf-8")
- return base64_image
- else:
- cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
-
- image_data = cv2.imencode(".png", cv_image)[1].tostring() # type: ignore
- base64_image = base64.b64encode(image_data).decode("utf-8") # type: ignore
- return base64_image
diff --git a/src/rai_core/rai/tools/ros/native.py b/src/rai_core/rai/tools/ros/native.py
deleted file mode 100644
index 855410d9f..000000000
--- a/src/rai_core/rai/tools/ros/native.py
+++ /dev/null
@@ -1,290 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import importlib
-import json
-import time
-from typing import Any, Dict, OrderedDict, Tuple, Type
-
-import rclpy
-import rclpy.callback_groups
-import rclpy.executors
-import rclpy.node
-import rclpy.qos
-import rclpy.subscription
-import rclpy.task
-import rosidl_runtime_py.set_message
-import rosidl_runtime_py.utilities
-import sensor_msgs.msg
-from langchain.tools import BaseTool
-from pydantic import BaseModel, Field
-from rclpy.impl.rcutils_logger import RcutilsLogger
-from rosidl_runtime_py.utilities import get_namespaced_type
-
-from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str
-
-
-# --------------------- Inputs ---------------------
-class Ros2BaseInput(BaseModel):
- """Empty input for ros2 tool"""
-
-
-class Ros2MsgInterfaceInput(BaseModel):
- """Input for the show_ros2_msg_interface tool."""
-
- msg_name: str = Field(..., description="Ros2 message name in typical ros2 format.")
-
-
-class Ros2GetOneMsgFromTopicInput(BaseModel):
- """Input for the get_current_position tool."""
-
- topic: str = Field(..., description="Ros2 topic")
- msg_type: str = Field(
- ..., description="Type of ros2 message in typical ros2 format."
- )
- timeout_sec: int = Field(
- 10, description="The time in seconds to wait for a message to be received."
- )
-
-
-class PubRos2MessageToolInput(BaseModel):
- topic_name: str = Field(..., description="Ros2 topic to publish the message")
- msg_type: str = Field(
- ..., description="Type of ros2 message in typical ros2 format."
- )
- msg_args: Dict[str, Any] = Field(
- ..., description="The arguments of the service call."
- )
- rate: int = Field(10, description="The rate at which to publish the message.")
- timeout_seconds: int = Field(1, description="The timeout in seconds.")
-
-
-# --------------------- Tools ---------------------
-class Ros2BaseTool(BaseTool):
- # TODO: Make the decision between rclpy.node.Node and RaiNode
- node: rclpy.node.Node = Field(..., exclude=True, required=True)
-
- args_schema: Type[Ros2BaseInput] = Ros2BaseInput
-
- @property
- def logger(self) -> RcutilsLogger:
- return self.node.get_logger()
-
-
-class Ros2GetTopicsNamesAndTypesTool(Ros2BaseTool):
- name: str = "Ros2GetTopicsNamesAndTypes"
- description: str = "A tool for getting all ros2 topics names and types"
-
- def _run(self):
- return self.node.get_topic_names_and_types()
-
-
-class Ros2GetRobotInterfaces(Ros2BaseTool):
- name: str = "ros2_robot_interfaces"
- description: str = (
- "A tool for getting all ros2 robot interfaces: topics, services and actions"
- )
-
- def _run(self):
- return self.node.ros_discovery_info.dict()
-
-
-class Ros2ShowMsgInterfaceTool(BaseTool):
- name: str = "Ros2ShowMsgInterface"
- description: str = """A tool for showing ros2 message interface in json format.
- usage:
- ```python
- ShowRos2MsgInterface.run({"msg_name": "geometry_msgs/msg/PoseStamped"})
- ```
- """
-
- args_schema: Type[Ros2MsgInterfaceInput] = Ros2MsgInterfaceInput
-
- def _run(self, msg_name: str):
- """Show ros2 message interface in json format."""
- msg_cls: Type = rosidl_runtime_py.utilities.get_interface(msg_name)
- try:
- msg_dict: OrderedDict = rosidl_runtime_py.convert.message_to_ordereddict(
- msg_cls()
- )
- return json.dumps(msg_dict)
- except NotImplementedError:
- # For action classes that can't be instantiated
- goal_dict: OrderedDict = rosidl_runtime_py.convert.message_to_ordereddict(
- msg_cls.Goal()
- )
-
- result_dict: OrderedDict = rosidl_runtime_py.convert.message_to_ordereddict(
- msg_cls.Result()
- )
-
- feedback_dict: OrderedDict = (
- rosidl_runtime_py.convert.message_to_ordereddict(msg_cls.Feedback())
- )
- return json.dumps(
- {"goal": goal_dict, "result": result_dict, "feedback": feedback_dict}
- )
-
-
-class Ros2PubMessageTool(Ros2BaseTool):
- name: str = "PubRos2MessageTool"
- description: str = """A tool for publishing a message to a ros2 topic
-
- By default 10 messages are published for 1 second. If you want to publish multiple messages, you can specify 'rate' and 'timeout_sec'.
- Example usage:
-
- ```python
- tool = Ros2PubMessageTool()
- tool.run(
- {
- "topic_name": "/some_topic",
- "msg_type": "geometry_msgs/Point",
- "msg_args": {"x": 0.0, "y": 0.0, "z": 0.0},
- "rate" : 10,
- "timeout_sec" : 1
- }
- )
-
- ```
- """
-
- args_schema: Type[PubRos2MessageToolInput] = PubRos2MessageToolInput
-
- def _build_msg(
- self, msg_type: str, msg_args: Dict[str, Any]
- ) -> Tuple[object, Type]:
- msg_cls: Type = import_message_from_str(msg_type)
- msg = msg_cls()
- rosidl_runtime_py.set_message.set_message_fields(msg, msg_args)
- return msg, msg_cls
-
- def _run(
- self,
- topic_name: str,
- msg_type: str,
- msg_args: Dict[str, Any],
- rate: int = 10,
- timeout_seconds: int = 1,
- ):
- """Publishes a message to the specified topic."""
- if "/msg/" not in msg_type:
- raise ValueError("msg_name must contain 'msg' in the name.")
- msg, msg_cls = self._build_msg(msg_type, msg_args)
-
- publisher = self.node.create_publisher(
- msg_cls, topic_name, 10, callback_group=self.node.callback_group
- ) # TODO(boczekbartek): infer qos profile from topic info
-
- def callback():
- publisher.publish(msg)
- self.logger.info(f"Published message '{msg}' to topic '{topic_name}'")
-
- ts = time.perf_counter()
- timer = self.node.create_timer(
- 1.0 / rate, callback, callback_group=self.node.callback_group
- )
-
- while time.perf_counter() - ts < timeout_seconds:
- time.sleep(0.1)
-
- timer.cancel()
- timer.destroy()
-
- self.logger.info(
- f"Published messages for {timeout_seconds}s to topic '{topic_name}' with rate {rate}"
- )
- return
-
-
-class TopicInput(Ros2BaseInput):
- topic_name: str = Field(..., description="Ros2 topic name")
-
-
-class GetMsgFromTopic(Ros2BaseTool):
- name: str = "get_msg_from_topic"
- description: str = "Get message from topic"
- args_schema: Type[TopicInput] = TopicInput
- response_format: str = "content_and_artifact"
-
- def _run(self, topic_name: str):
- msg = self.node.get_raw_message_from_topic(topic_name)
- if type(msg) is sensor_msgs.msg.Image:
- img = convert_ros_img_to_base64(msg)
- return "Got image", {"images": [img]}
- else:
- return str(msg), {}
-
-
-class Ros2GenericServiceCallerInput(BaseModel):
- service_name: str = Field(..., description="Name of the ROS2 service to call")
- service_type: str = Field(
- ..., description="Type of the ROS2 service in typical ros2 format"
- )
- request_args: Dict[str, Any] = Field(
- ..., description="Arguments for the service request"
- )
-
-
-class Ros2GenericServiceCaller(Ros2BaseTool):
- name: str = "Ros2GenericServiceCaller"
- description: str = "A tool for calling any ROS2 service dynamically."
-
- args_schema: Type[Ros2GenericServiceCallerInput] = Ros2GenericServiceCallerInput
-
- def _build_request(self, service_type: str, request_args: Dict[str, Any]) -> Any:
- srv_module, _, srv_name = service_type.split("/")
- srv_class = getattr(importlib.import_module(f"{srv_module}.srv"), srv_name)
- request = srv_class.Request()
- rosidl_runtime_py.set_message.set_message_fields(request, request_args)
- return request
-
- def _run(self, service_name: str, service_type: str, request_args: Dict[str, Any]):
- if not service_name.startswith("/"):
- service_name = f"/{service_name}"
-
- try:
- request = self._build_request(service_type, request_args)
- except Exception as e:
- return f"Failed to build service request: {e}"
- namespaced_type = get_namespaced_type(service_type)
- client = self.node.create_client(
- rosidl_runtime_py.import_message.import_message_from_namespaced_type(
- namespaced_type
- ),
- service_name,
- )
-
- if not client.wait_for_service(timeout_sec=1.0):
- return f"Service '{service_name}' is not available"
-
- future = client.call_async(request)
- rclpy.spin_until_future_complete(self.node, future)
-
- if future.result() is not None:
- return str(future.result())
- else:
- return f"Service call to '{service_name}' failed"
-
-
-class GetCameraImage(Ros2BaseTool):
- name: str = "get_camera_image"
- description: str = "get image from robots camera"
- response_format: str = "content_and_artifact"
- args_schema: Type[TopicInput] = TopicInput
-
- def _run(self, topic_name: str):
- msg = self.node.get_raw_message_from_topic(topic_name)
- img = convert_ros_img_to_base64(msg)
- return "Got image", {"images": [img]}
diff --git a/src/rai_core/rai/tools/ros/native_actions.py b/src/rai_core/rai/tools/ros/native_actions.py
deleted file mode 100644
index c3833a3d7..000000000
--- a/src/rai_core/rai/tools/ros/native_actions.py
+++ /dev/null
@@ -1,199 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from typing import Any, Dict, Optional, Tuple, Type
-
-import rosidl_runtime_py.set_message
-import rosidl_runtime_py.utilities
-from action_msgs.msg import GoalStatus
-from pydantic import BaseModel, Field
-from rclpy.action.client import ActionClient
-from rosidl_runtime_py import message_to_ordereddict
-
-from rai.tools.ros.native import Ros2BaseInput, Ros2BaseTool
-from rai.tools.ros.utils import get_transform
-
-
-# --------------------- Inputs ---------------------
-class Ros2ActionRunnerInput(BaseModel):
- action_name: str = Field(..., description="Name of the action")
- action_type: str = Field(..., description="Type of the action")
- action_goal_args: Dict[str, Any] = Field(
- ..., description="Dictionary with arguments for the action goal message"
- )
-
-
-class ActionUidInput(BaseModel):
- uid: str = Field(..., description="Action uid.")
-
-
-class OptionalActionUidInput(BaseModel):
- uid: Optional[str] = Field(
- None,
- description="Optional action uid. If None - results from all submitted actions will be returned.",
- )
-
-
-# --------------------- Tools ---------------------
-class Ros2GetActionNamesAndTypesTool(Ros2BaseTool):
- name: str = "Ros2GetActionNamesAndTypes"
- description: str = "A tool for getting all ros2 actions names and types"
-
- def _run(self):
- return self.node.ros_discovery_info.actions_and_types
-
-
-class Ros2BaseActionTool(Ros2BaseTool):
- pass
-
-
-class Ros2RunActionSync(Ros2BaseTool):
- name: str = "Ros2RunAction"
- description: str = "A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result"
-
- args_schema: Type[Ros2ActionRunnerInput] = Ros2ActionRunnerInput
-
- def _build_msg(
- self, msg_type: str, msg_args: Dict[str, Any]
- ) -> Tuple[object, Type]:
- """
- Import message and create it. Return both ready message and message class.
-
- msgs args can have two formats:
- { "goal" : {arg 1 : xyz, ... } or {arg 1 : xyz, ... }
- """
-
- msg_cls: Type = rosidl_runtime_py.utilities.get_interface(msg_type)
- msg = msg_cls.Goal()
-
- if "goal" in msg_args:
- msg_args = msg_args["goal"]
- rosidl_runtime_py.set_message.set_message_fields(msg, msg_args)
- return msg, msg_cls
-
- def _run(
- self, action_name: str, action_type: str, action_goal_args: Dict[str, Any]
- ):
- if action_name[0] != "/":
- action_name = "/" + action_name
- self.node.get_logger().info(f"Action name corrected to: {action_name}")
-
- try:
- goal_msg, msg_cls = self._build_msg(action_type, action_goal_args)
- except Exception as e:
- return f"Failed to build message: {e}"
-
- client = ActionClient(self.node, msg_cls, action_name)
-
- retries = 0
- while not client.wait_for_server(timeout_sec=1.0):
- retries += 1
- if retries > 5:
- raise Exception(
- f"Action server '{action_name}' is not available. Make sure `action_name` is correct..."
- )
- self.node.get_logger().info(
- f"'{action_name}' action server not available, waiting..."
- )
-
- self.node.get_logger().info(f"Sending action message: {goal_msg}")
- result = client.send_goal(goal_msg)
- self.node.get_logger().info("Action finished and result received!")
-
- if result is not None:
- status = result.status
- else:
- status = GoalStatus.STATUS_UNKNOWN
-
- if status == GoalStatus.STATUS_SUCCEEDED:
- res = f"Action succeeded, {result.result}"
- elif status == GoalStatus.STATUS_ABORTED:
- res = f"Action aborted, {result.result}"
- elif status == GoalStatus.STATUS_CANCELED:
- res = f"Action canceled: {result.result}"
- else:
- res = "Action failed"
-
- self.node.get_logger().info(res)
- return res
-
-
-class Ros2RunActionAsync(Ros2BaseActionTool):
- name: str = "Ros2RunAction"
- description: str = """A tool for running a ros2 action.
- Always check action interface before setting action_goal_args."""
-
- args_schema: Type[Ros2ActionRunnerInput] = Ros2ActionRunnerInput
-
- def _run(
- self, action_name: str, action_type: str, action_goal_args: Dict[str, Any]
- ):
- return self.node.run_action(action_name, action_type, action_goal_args)
-
-
-class Ros2IsActionComplete(Ros2BaseActionTool):
- name: str = "Ros2IsActionComplete"
- description: str = "A tool for checking if submitted ros2 actions is complete"
-
- args_schema: Type[Ros2BaseInput] = Ros2BaseInput
-
- def _run(self) -> bool:
- return self.node.is_task_complete()
-
-
-class Ros2GetActionResult(Ros2BaseActionTool):
- name: str = "Ros2GetActionResult"
- description: str = "A tool for checking the result of submitted ros2 action"
-
- args_schema: Type[Ros2BaseInput] = Ros2BaseInput
-
- def _run(self) -> bool:
- return self.node.get_task_result()
-
-
-class Ros2CancelAction(Ros2BaseActionTool):
- name: str = "Ros2CancelAction"
- description: str = "Cancel submitted action"
-
- args_schema: Type[Ros2BaseInput] = Ros2BaseInput
-
- def _run(self) -> bool:
- return self.node.cancel_task()
-
-
-class Ros2GetLastActionFeedback(Ros2BaseActionTool):
- name: str = "Ros2GetLastActionFeedback"
- description: str = "Action feedback is an optional intermediate information from ros2 action. With this tool you can get the last feedback of running action."
-
- args_schema: Type[Ros2BaseInput] = Ros2BaseInput
-
- def _run(self) -> str:
- return str(self.node.action_feedback)
-
-
-class GetTransformInput(BaseModel):
- target_frame: str = Field(default="map", description="Target frame")
- source_frame: str = Field(default="body_link", description="Source frame")
-
-
-class GetTransformTool(Ros2BaseActionTool):
- name: str = "GetTransform"
- description: str = "Get transform between two frames"
-
- args_schema: Type[GetTransformInput] = GetTransformInput
-
- def _run(self, target_frame: str = "map", source_frame: str = "body_link") -> dict:
- return message_to_ordereddict(
- get_transform(self.node, target_frame, source_frame)
- )
diff --git a/src/rai_core/rai/tools/ros/tools.py b/src/rai_core/rai/tools/ros/tools.py
index 508c174a3..8d41a9bfa 100644
--- a/src/rai_core/rai/tools/ros/tools.py
+++ b/src/rai_core/rai/tools/ros/tools.py
@@ -13,23 +13,13 @@
# limitations under the License.
-import base64
import json
import logging
import time
-from typing import Any, Dict, Type, cast
+from typing import Any, Dict, Type
-import cv2
-import numpy as np
-from geometry_msgs.msg import Point, Quaternion, TransformStamped
from langchain_core.tools import BaseTool
-from nav_msgs.msg import OccupancyGrid
from pydantic import BaseModel, Field
-from tf_transformations import euler_from_quaternion
-
-from rai.tools.ros.deprecated import SingleMessageGrabber
-from rai.tools.ros.native import TopicInput
-from rai.tools.utils import TF2TransformFetcher
logger = logging.getLogger(__name__)
@@ -92,179 +82,3 @@ def update_map_database(
with open(self.map_database, "w") as file:
json.dump(map_database, file, indent=2)
-
-
-class GetOccupancyGridTool(BaseTool):
- """Get the current map as an image with the robot's position marked on it (red dot)."""
-
- name: str = "GetOccupancyGridTool"
- description: str = "A tool for getting the current map as an image with the robot's position marked on it."
-
- args_schema: Type[TopicInput] = TopicInput
-
- image_width: int = 1500
- debug: bool = False
-
- def _postprocess_msg(self, map_msg: OccupancyGrid, transform: TransformStamped):
- width = cast(int, map_msg.info.width)
- height = cast(int, map_msg.info.height)
- resolution = cast(float, map_msg.info.resolution)
- origin_position = cast(Point, map_msg.info.origin.position)
- origin_orientation = cast(Quaternion, map_msg.info.origin.orientation)
-
- data = np.array(map_msg.data).reshape((height, width))
-
- # Convert the OccupancyGrid values to grayscale image (0-255)
- # the final image shape is (self.image_width, self.image_width), scale to fit
- scale = self.image_width / max(width, height)
- width = int(width * scale)
- height = int(height * scale)
- data = cv2.resize(data, (width, height), interpolation=cv2.INTER_NEAREST)
- resolution = resolution / scale
- image = np.zeros_like(data, dtype=np.uint8)
- image[data == -1] = 127 # Unknown space
- image[data == 0] = 255 # Free space
- image[data > 0] = 0 # Occupied space
-
- # Calculate robot's position in the image
- robot_x = cast(
- float, (transform.transform.translation.x - origin_position.x) / resolution
- )
-
- robot_y = cast(
- float, (transform.transform.translation.y - origin_position.y) / resolution
- )
-
- _, _, yaw = euler_from_quaternion(
- [
- origin_orientation.x, # type: ignore
- origin_orientation.y, # type: ignore
- origin_orientation.z, # type: ignore
- origin_orientation.w, # type: ignore
- ]
- )
- # Rotate the robot's position based on the yaw angle
- rotated_x = robot_x * np.cos(yaw) - robot_y * np.sin(yaw)
- rotated_y = robot_x * np.sin(yaw) + robot_y * np.cos(yaw)
- robot_x = int(rotated_x)
- robot_y = int(rotated_y)
-
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
-
- # Draw the robot's position as an arrow
- if 0 <= robot_x < width and 0 <= robot_y < height:
- _, _, yaw = euler_from_quaternion(
- [
- transform.transform.rotation.x, # type: ignore
- transform.transform.rotation.y, # type: ignore
- transform.transform.rotation.z, # type: ignore
- transform.transform.rotation.w, # type: ignore
- ]
- )
- arrow_length = 100
- arrow_end_x = int(robot_x + arrow_length * np.cos(yaw))
- arrow_end_y = int(robot_y + arrow_length * np.sin(yaw))
- cv2.arrowedLine(
- image, (robot_x, robot_y), (arrow_end_x, arrow_end_y), (0, 0, 255), 5
- )
-
- image = cv2.flip(image, 1)
-
- step_size_m: float = 2.0 # Step size for grid lines in meters, adjust as needed
- step_size_pixels = int(step_size_m / resolution)
- # print(step_size_pixels, scale)
- for x in range(0, width, step_size_pixels):
- cv2.line(
- img=image,
- pt1=(x, 0),
- pt2=(x, height),
- color=(200, 200, 200),
- thickness=1,
- )
- cv2.putText(
- img=image,
- text=f"{x * resolution + origin_position.x:.1f}",
- org=(x, 30),
- fontFace=cv2.FONT_HERSHEY_SIMPLEX,
- fontScale=1.0,
- color=(0, 0, 0),
- thickness=1,
- lineType=cv2.LINE_AA,
- )
- for y in range(0, height, step_size_pixels):
- cv2.line(
- img=image,
- pt1=(0, y),
- pt2=(width, y),
- color=(200, 200, 200),
- thickness=1,
- )
- cv2.putText(
- img=image,
- text=f"{y * resolution + origin_position.y:.1f}",
- org=(15, y + 35),
- fontFace=cv2.FONT_HERSHEY_SIMPLEX,
- fontScale=1.0,
- color=(0, 0, 0),
- thickness=1,
- lineType=cv2.LINE_AA,
- )
- # Encode into PNG base64
- _, buffer = cv2.imencode(".png", image)
- return image
-
- if self.debug:
- cv2.imwrite("map.png", image)
- return base64.b64encode(buffer.tobytes()).decode("utf-8")
-
- def _run(self, topic_name: str):
- """Gets the current map from the specified topic."""
- map_grabber = SingleMessageGrabber(topic_name, OccupancyGrid, timeout_sec=10)
- tf_grabber = TF2TransformFetcher(target_frame="map", source_frame="base_link")
-
- map_msg = map_grabber.get_data()
- transform = tf_grabber.get_data()
-
- if map_msg is None or transform is None:
- return {"content": "Failed to get the map, wrong topic?"}
-
- base64_image = self._postprocess_msg(map_msg, transform)
- return {"content": "Map grabbed successfully", "images": [base64_image]}
-
-
-class GetCurrentPositionToolInput(BaseModel):
- """Input for the get_current_position tool."""
-
-
-class GetCurrentPositionTool(BaseTool):
- """Get the current position and rotation of the robot."""
-
- name: str = "GetCurrentPositionTool"
- description: str = "A tool for getting the current position of the robot."
-
- args_schema: Type[GetCurrentPositionToolInput] = GetCurrentPositionToolInput
-
- def _run(self):
- """Gets the current position from the specified topic."""
- tf_grabber = TF2TransformFetcher(target_frame="map", source_frame="base_link")
- transform_stamped = tf_grabber.get_data()
- position = transform_stamped.transform.translation # type: ignore
- orientation = transform_stamped.transform.rotation # type: ignore
- _, _, yaw = euler_from_quaternion(
- [
- orientation.x, # type: ignore
- orientation.y, # type: ignore
- orientation.z, # type: ignore
- orientation.w, # type: ignore
- ]
- )
- return {
- "content": str(
- {
- "x": position.x, # type: ignore
- "y": position.y, # type: ignore
- "z": position.z, # type: ignore
- "yaw": yaw,
- }
- ),
- }
diff --git a/src/rai_core/rai/tools/ros2/__init__.py b/src/rai_core/rai/tools/ros2/__init__.py
index 23b3ffc8f..3752f6a7b 100644
--- a/src/rai_core/rai/tools/ros2/__init__.py
+++ b/src/rai_core/rai/tools/ros2/__init__.py
@@ -12,37 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .actions import (
+from .generic import (
+ CallROS2ServiceTool,
CancelROS2ActionTool,
GetROS2ActionsNamesAndTypesTool,
- ROS2ActionToolkit,
- StartROS2ActionTool,
-)
-from .services import (
- CallROS2ServiceTool,
- GetROS2ServicesNamesAndTypesTool,
- ROS2ServicesToolkit,
-)
-from .toolkit import ROS2Toolkit
-from .topics import (
GetROS2ImageTool,
GetROS2MessageInterfaceTool,
+ GetROS2ServicesNamesAndTypesTool,
GetROS2TopicsNamesAndTypesTool,
GetROS2TransformTool,
PublishROS2MessageTool,
ReceiveROS2MessageTool,
+ ROS2ActionToolkit,
+ ROS2ServicesToolkit,
+ ROS2Toolkit,
ROS2TopicsToolkit,
+ StartROS2ActionTool,
+)
+from .nav2 import (
+ CancelNavigateToPoseTool,
+ GetNavigateToPoseFeedbackTool,
+ GetNavigateToPoseResultTool,
+ Nav2Toolkit,
+ NavigateToPoseTool,
+)
+from .simple import (
+ GetROS2ImageConfiguredTool,
+ GetROS2TransformConfiguredTool,
)
__all__ = [
"CallROS2ServiceTool",
+ "CancelNavigateToPoseTool",
"CancelROS2ActionTool",
+ "GetNavigateToPoseFeedbackTool",
+ "GetNavigateToPoseResultTool",
"GetROS2ActionsNamesAndTypesTool",
+ "GetROS2ImageConfiguredTool",
"GetROS2ImageTool",
"GetROS2MessageInterfaceTool",
"GetROS2ServicesNamesAndTypesTool",
"GetROS2TopicsNamesAndTypesTool",
+ "GetROS2TransformConfiguredTool",
"GetROS2TransformTool",
+ "Nav2Toolkit",
+ "NavigateToPoseTool",
"PublishROS2MessageTool",
"ROS2ActionToolkit",
"ROS2ServicesToolkit",
diff --git a/src/rai_core/rai/tools/ros/cli.py b/src/rai_core/rai/tools/ros2/cli.py
similarity index 100%
rename from src/rai_core/rai/tools/ros/cli.py
rename to src/rai_core/rai/tools/ros2/cli.py
diff --git a/src/rai_core/rai/tools/ros2/generic/__init__.py b/src/rai_core/rai/tools/ros2/generic/__init__.py
new file mode 100644
index 000000000..c3924b22a
--- /dev/null
+++ b/src/rai_core/rai/tools/ros2/generic/__init__.py
@@ -0,0 +1,53 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .actions import (
+ CancelROS2ActionTool,
+ GetROS2ActionsNamesAndTypesTool,
+ ROS2ActionToolkit,
+ StartROS2ActionTool,
+)
+from .services import (
+ CallROS2ServiceTool,
+ GetROS2ServicesNamesAndTypesTool,
+ ROS2ServicesToolkit,
+)
+from .toolkit import ROS2Toolkit
+from .topics import (
+ GetROS2ImageTool,
+ GetROS2MessageInterfaceTool,
+ GetROS2TopicsNamesAndTypesTool,
+ GetROS2TransformTool,
+ PublishROS2MessageTool,
+ ReceiveROS2MessageTool,
+ ROS2TopicsToolkit,
+)
+
+__all__ = [
+ "CallROS2ServiceTool",
+ "CancelROS2ActionTool",
+ "GetROS2ActionsNamesAndTypesTool",
+ "GetROS2ImageTool",
+ "GetROS2MessageInterfaceTool",
+ "GetROS2ServicesNamesAndTypesTool",
+ "GetROS2TopicsNamesAndTypesTool",
+ "GetROS2TransformTool",
+ "PublishROS2MessageTool",
+ "ROS2ActionToolkit",
+ "ROS2ServicesToolkit",
+ "ROS2Toolkit",
+ "ROS2TopicsToolkit",
+ "ReceiveROS2MessageTool",
+ "StartROS2ActionTool",
+]
diff --git a/src/rai_core/rai/tools/ros2/actions.py b/src/rai_core/rai/tools/ros2/generic/actions.py
similarity index 96%
rename from src/rai_core/rai/tools/ros2/actions.py
rename to src/rai_core/rai/tools/ros2/generic/actions.py
index b6a19be22..c7bde3a8d 100644
--- a/src/rai_core/rai/tools/ros2/actions.py
+++ b/src/rai_core/rai/tools/ros2/generic/actions.py
@@ -28,6 +28,7 @@
from langchain_core.tools import BaseTool # type: ignore
from langchain_core.utils import stringify_dict
from pydantic import BaseModel, Field
+from rclpy.action import CancelResponse
from rai.communication.ros2 import ROS2ARIConnector, ROS2ARIMessage
from rai.tools.ros2.base import BaseROS2Tool, BaseROS2Toolkit
@@ -253,14 +254,12 @@ class CancelROS2ActionTool(BaseROS2Tool):
description: str = "Cancel a ROS2 action"
args_schema: Type[CancelROS2ActionToolInput] = CancelROS2ActionToolInput
- internal_action_id_mapping: Dict[str, str] = Field(
- default_factory=get_internal_action_id_mapping
- )
-
def _run(self, action_id: str) -> str:
- external_action_id = self.internal_action_id_mapping[action_id]
- self.connector.terminate_action(external_action_id)
- return f"Action {action_id} cancelled"
+ response = self.connector.terminate_action(action_id)
+ if response == CancelResponse.ACCEPT:
+ return f"Action {action_id} cancelled."
+ else:
+ return f"Action {action_id} could not be cancelled."
class GetROS2ActionIDsToolInput(BaseModel):
diff --git a/src/rai_core/rai/tools/ros2/services.py b/src/rai_core/rai/tools/ros2/generic/services.py
similarity index 100%
rename from src/rai_core/rai/tools/ros2/services.py
rename to src/rai_core/rai/tools/ros2/generic/services.py
diff --git a/src/rai_core/rai/tools/ros2/toolkit.py b/src/rai_core/rai/tools/ros2/generic/toolkit.py
similarity index 88%
rename from src/rai_core/rai/tools/ros2/toolkit.py
rename to src/rai_core/rai/tools/ros2/generic/toolkit.py
index 406ecd944..e368b41dd 100644
--- a/src/rai_core/rai/tools/ros2/toolkit.py
+++ b/src/rai_core/rai/tools/ros2/generic/toolkit.py
@@ -16,14 +16,18 @@
from langchain_core.tools import BaseTool
-from rai.tools.ros2.actions import ROS2ActionToolkit
from rai.tools.ros2.base import BaseROS2Toolkit
-from rai.tools.ros2.services import ROS2ServicesToolkit
-from rai.tools.ros2.topics import ROS2TopicsToolkit
class ROS2Toolkit(BaseROS2Toolkit):
def get_tools(self) -> List[BaseTool]:
+ # lazy import to avoid circular import
+ from rai.tools.ros2.generic import (
+ ROS2ActionToolkit,
+ ROS2ServicesToolkit,
+ ROS2TopicsToolkit,
+ )
+
return [
*ROS2TopicsToolkit(
connector=self.connector,
diff --git a/src/rai_core/rai/tools/ros2/topics.py b/src/rai_core/rai/tools/ros2/generic/topics.py
similarity index 100%
rename from src/rai_core/rai/tools/ros2/topics.py
rename to src/rai_core/rai/tools/ros2/generic/topics.py
diff --git a/src/rai_core/rai/tools/ros2/nav2.py b/src/rai_core/rai/tools/ros2/nav2.py
new file mode 100644
index 000000000..a606e84a1
--- /dev/null
+++ b/src/rai_core/rai/tools/ros2/nav2.py
@@ -0,0 +1,304 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import time
+from typing import List, Optional, Type, cast
+
+import cv2
+import numpy as np
+from geometry_msgs.msg import Point, PoseStamped, Quaternion, TransformStamped
+from langchain_core.tools import BaseTool
+from nav2_msgs.action import NavigateToPose
+from nav_msgs.msg import OccupancyGrid
+from pydantic import BaseModel, Field
+from rclpy.action import ActionClient
+from tf_transformations import euler_from_quaternion, quaternion_from_euler
+
+from rai.communication.ros2 import ROS2ARIMessage
+from rai.communication.ros2.connectors import ROS2ARIConnector
+from rai.messages.multimodal import MultimodalArtifact
+from rai.tools.ros2.base import BaseROS2Tool, BaseROS2Toolkit
+
+action_client: Optional[ActionClient] = None
+current_action_id: Optional[str] = None
+current_feedback: Optional[NavigateToPose.Feedback] = None
+current_result: Optional[NavigateToPose.Result] = None
+
+
+class Nav2Toolkit(BaseROS2Toolkit):
+ connector: ROS2ARIConnector
+ frame_id: str = Field(
+ default="map", description="The frame id of the Nav2 stack (map, odom, etc.)"
+ )
+ action_name: str = Field(
+ default="navigate_to_pose", description="The name of the NavigateToPose action"
+ )
+
+ def get_tools(self) -> List[BaseTool]:
+ return [
+ NavigateToPoseTool(
+ connector=self.connector,
+ frame_id=self.frame_id,
+ action_name=self.action_name,
+ ),
+ CancelNavigateToPoseTool(connector=self.connector),
+ GetNavigateToPoseFeedbackTool(connector=self.connector),
+ GetNavigateToPoseResultTool(connector=self.connector),
+ ]
+
+
+class NavigateToPoseToolInput(BaseModel):
+ x: float = Field(..., description="The x coordinate of the pose")
+ y: float = Field(..., description="The y coordinate of the pose")
+ z: float = Field(..., description="The z coordinate of the pose")
+ yaw: float = Field(..., description="The yaw angle of the pose")
+
+
+class NavigateToPoseTool(BaseROS2Tool):
+ name: str = "navigate_to_pose"
+ description: str = "Navigate to a specific pose"
+
+ args_schema: Type[NavigateToPoseToolInput] = NavigateToPoseToolInput
+
+ frame_id: str = Field(
+ default="map", description="The frame id of the Nav2 stack (map, odom, etc.)"
+ )
+ action_name: str = Field(
+ default="navigate_to_pose", description="The name of the NavigateToPose action"
+ )
+
+ def on_feedback(self, feedback: NavigateToPose.Feedback) -> None:
+ global current_feedback
+ current_feedback = feedback
+
+ def on_done(self, result: NavigateToPose.Result) -> None:
+ global current_result
+ current_result = result
+
+ def _run(self, x: float, y: float, z: float, yaw: float) -> str:
+ global action_client
+ if action_client is None:
+ action_client = ActionClient(
+ self.connector.node, NavigateToPose, self.action_name
+ )
+
+ pose = PoseStamped()
+ pose.header.frame_id = self.frame_id
+ pose.header.stamp = self.connector.node.get_clock().now().to_msg()
+ pose.pose.position.x = x
+ pose.pose.position.y = y
+ pose.pose.position.z = z
+ quat = quaternion_from_euler(0, 0, yaw)
+ pose.pose.orientation = Quaternion(x=quat[0], y=quat[1], z=quat[2], w=quat[3])
+
+ goal = {
+ "pose": {
+ "header": {
+ "frame_id": self.frame_id,
+ "stamp": self.connector.node.get_clock().now().to_msg(),
+ },
+ "pose": {
+ "position": {"x": x, "y": y, "z": z},
+ "orientation": {
+ "x": quat[0],
+ "y": quat[1],
+ "z": quat[2],
+ "w": quat[3],
+ },
+ },
+ }
+ }
+
+ msg = ROS2ARIMessage(payload=goal)
+ action_id = self.connector.start_action(
+ action_data=msg,
+ target=self.action_name,
+ msg_type="nav2_msgs/action/NavigateToPose",
+ on_feedback=self.on_feedback,
+ on_done=self.on_done,
+ )
+ global current_action_id
+ current_action_id = action_id
+
+ return "Navigating to pose"
+
+
+class GetNavigateToPoseFeedbackTool(BaseROS2Tool):
+ name: str = "get_navigate_to_pose_feedback"
+ description: str = "Get the feedback of the navigate to pose action"
+
+ def _run(self) -> str:
+ global current_feedback
+ return str(current_feedback)
+
+
+class GetNavigateToPoseResultTool(BaseROS2Tool):
+ name: str = "get_navigate_to_pose_result"
+ description: str = "Get the result of the navigate to pose action"
+
+ def _run(self) -> str:
+ global current_result
+ if current_result is None:
+ return "Action is not done yet"
+ return str(current_result.result().result)
+
+
+class CancelNavigateToPoseTool(BaseROS2Tool):
+ name: str = "cancel_navigate_to_pose"
+ description: str = "Cancel the navigate to pose action"
+
+ def _run(self) -> str:
+ global current_action_id
+ if current_action_id is None:
+ return "No action to cancel"
+ self.connector.terminate_action(current_action_id)
+ return "Action cancelled"
+
+
+class GetOccupancyGridTool(BaseROS2Tool):
+ """Get the current map as an image with the robot's position marked on it (red dot)."""
+
+ name: str = "GetOccupancyGridTool"
+ description: str = "A tool for getting the current map as an image with the robot's position marked on it."
+
+ response_format: str = "content_and_artifact"
+ image_width: int = 1500
+ debug: bool = False
+
+ def _postprocess_msg(self, map_msg: OccupancyGrid, transform: TransformStamped):
+ width = cast(int, map_msg.info.width)
+ height = cast(int, map_msg.info.height)
+ resolution = cast(float, map_msg.info.resolution)
+ origin_position = cast(Point, map_msg.info.origin.position)
+ origin_orientation = cast(Quaternion, map_msg.info.origin.orientation)
+
+ data = np.array(map_msg.data).reshape((height, width))
+
+ # Convert the OccupancyGrid values to grayscale image (0-255)
+ # the final image shape is (self.image_width, self.image_width), scale to fit
+ scale = self.image_width / max(width, height)
+ width = int(width * scale)
+ height = int(height * scale)
+ data = cv2.resize(data, (width, height), interpolation=cv2.INTER_NEAREST)
+ resolution = resolution / scale
+ image = np.zeros_like(data, dtype=np.uint8)
+ image[data == -1] = 127 # Unknown space
+ image[data == 0] = 255 # Free space
+ image[data > 0] = 0 # Occupied space
+
+ # Calculate robot's position in the image
+ robot_x = cast(
+ float, (transform.transform.translation.x - origin_position.x) / resolution
+ )
+
+ robot_y = cast(
+ float, (transform.transform.translation.y - origin_position.y) / resolution
+ )
+
+ _, _, yaw = euler_from_quaternion(
+ [
+ origin_orientation.x, # type: ignore
+ origin_orientation.y, # type: ignore
+ origin_orientation.z, # type: ignore
+ origin_orientation.w, # type: ignore
+ ]
+ )
+ # Rotate the robot's position based on the yaw angle
+ rotated_x = robot_x * np.cos(yaw) - robot_y * np.sin(yaw)
+ rotated_y = robot_x * np.sin(yaw) + robot_y * np.cos(yaw)
+ robot_x = int(rotated_x)
+ robot_y = int(rotated_y)
+
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
+
+ # Draw the robot's position as an arrow
+ if 0 <= robot_x < width and 0 <= robot_y < height:
+ _, _, yaw = euler_from_quaternion(
+ [
+ transform.transform.rotation.x, # type: ignore
+ transform.transform.rotation.y, # type: ignore
+ transform.transform.rotation.z, # type: ignore
+ transform.transform.rotation.w, # type: ignore
+ ]
+ )
+ arrow_length = 100
+ arrow_end_x = int(robot_x + arrow_length * np.cos(yaw))
+ arrow_end_y = int(robot_y + arrow_length * np.sin(yaw))
+ cv2.arrowedLine(
+ image, (robot_x, robot_y), (arrow_end_x, arrow_end_y), (0, 0, 255), 5
+ )
+
+ image = cv2.flip(image, 1)
+
+ step_size_m: float = 2.0 # Step size for grid lines in meters, adjust as needed
+ step_size_pixels = int(step_size_m / resolution)
+ # print(step_size_pixels, scale)
+ for x in range(0, width, step_size_pixels):
+ cv2.line(
+ img=image,
+ pt1=(x, 40),
+ pt2=(x, height),
+ color=(200, 200, 200),
+ thickness=1,
+ )
+ cv2.putText(
+ img=image,
+ text=f"{x * resolution + origin_position.x:.1f}",
+ org=(x, 30),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
+ fontScale=1.5,
+ color=(0, 0, 255),
+ thickness=2,
+ lineType=cv2.LINE_AA,
+ )
+ for y in range(0, height, step_size_pixels):
+ cv2.line(
+ img=image,
+ pt1=(0, y),
+ pt2=(width, y),
+ color=(200, 200, 200),
+ thickness=1,
+ )
+ cv2.putText(
+ img=image,
+ text=f"{y * resolution + origin_position.y:.1f}",
+ org=(15, y + 35),
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
+ fontScale=1.5,
+ color=(0, 0, 255),
+ thickness=2,
+ lineType=cv2.LINE_AA,
+ )
+ # Encode into PNG base64
+ _, buffer = cv2.imencode(".png", image)
+
+ if self.debug:
+ cv2.imwrite(f"map{time.time()}.png", image)
+ return base64.b64encode(buffer.tobytes()).decode("utf-8")
+
+ def _run(self):
+ """Gets the current map from the specified topic."""
+ map_msg = self.connector.receive_message("/map", timeout_sec=10).payload
+ transform = self.connector.get_transform(
+ target_frame="map", source_frame="base_link", timeout_sec=10
+ )
+
+ if map_msg is None or transform is None:
+ return {"content": "Failed to get the map, wrong topic?"}
+
+ base64_image = self._postprocess_msg(map_msg, transform)
+ return "Map grabbed successfully", MultimodalArtifact(
+ images=[base64_image], audios=[]
+ )
diff --git a/src/rai_core/rai/tools/ros2/simple.py b/src/rai_core/rai/tools/ros2/simple.py
new file mode 100644
index 000000000..1dfac60b1
--- /dev/null
+++ b/src/rai_core/rai/tools/ros2/simple.py
@@ -0,0 +1,67 @@
+# Copyright (C) 2025 Robotec.AI
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This module provides streamlined ROS2 tools designed for enhanced usability.
+Unlike ROS2Toolkit, these tools are purpose-built for specific use cases
+rather than offering generic functionality across topics, services, and actions.
+For generic ROS2 functionality, use ROS2Toolkit.
+"""
+
+from typing import Any, Literal
+
+from pydantic import Field
+
+from rai.tools.ros2.base import BaseROS2Tool
+from rai.tools.ros2.generic.topics import (
+ GetROS2ImageTool,
+ GetROS2TransformTool,
+)
+
+
+class GetROS2ImageConfiguredTool(BaseROS2Tool):
+ name: str = "get_ros2_camera_image"
+ description: str = "Get the current image from the camera"
+ response_format: Literal["content", "content_and_artifact"] = "content_and_artifact"
+
+ topic: str = Field(..., description="The topic to get the image from")
+
+ def model_post_init(self, __context: Any) -> None:
+ if not self.is_readable(topic=self.topic):
+ raise ValueError(f"Bad configuration: topic {self.topic} is not readable")
+
+ def _run(self) -> Any:
+ tool = GetROS2ImageTool(
+ connector=self.connector,
+ )
+ return tool._run(topic=self.topic)
+
+
+class GetROS2TransformConfiguredTool(BaseROS2Tool):
+ name: str = "get_ros2_robot_position"
+ description: str = "Get the robot's position"
+
+ source_frame: str = Field(..., description="The source frame")
+ target_frame: str = Field(..., description="The target frame")
+ timeout_sec: float = Field(default=5.0, description="The timeout in seconds")
+
+ def _run(self) -> Any:
+ tool = GetROS2TransformTool(
+ connector=self.connector,
+ )
+ return tool._run(
+ source_frame=self.source_frame,
+ target_frame=self.target_frame,
+ timeout_sec=self.timeout_sec,
+ )
diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py
index 62656c5a8..b65690a34 100644
--- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py
+++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py
@@ -16,9 +16,9 @@
import numpy as np
import sensor_msgs.msg
+from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from rai.communication.ros2.connectors import ROS2ARIConnector
-from rai.tools.ros import Ros2BaseInput, Ros2BaseTool
from rai.tools.ros.utils import convert_ros_img_to_ndarray
from rai.utils.ros_async import get_future_result
from rclpy.exceptions import (
@@ -32,7 +32,7 @@
# --------------------- Inputs ---------------------
-class Ros2GetDetectionInput(Ros2BaseInput):
+class Ros2GetDetectionInput(BaseModel):
camera_topic: str = Field(
...,
description="Ros2 topic for the camera image containing image to run detection on.",
@@ -42,7 +42,7 @@ class Ros2GetDetectionInput(Ros2BaseInput):
)
-class GetDistanceToObjectsInput(Ros2BaseInput):
+class GetDistanceToObjectsInput(BaseModel):
camera_topic: str = Field(
...,
description="Ros2 topic for the camera image containing image to run detection on.",
@@ -78,7 +78,7 @@ class DistanceMeasurement(NamedTuple):
# --------------------- Tools ---------------------
-class GroundingDinoBaseTool(Ros2BaseTool):
+class GroundingDinoBaseTool(BaseTool):
connector: ROS2ARIConnector = Field(..., exclude=True)
box_threshold: float = Field(default=0.35, description="Box threshold for GDINO")
@@ -211,10 +211,12 @@ def _run(
camera_img_msg = self._get_image_message(camera_topic)
depth_img_msg = self._get_image_message(depth_topic)
future = self._call_gdino_node(camera_img_msg, object_names)
- logger = self.node.get_logger()
+ logger = self.connector.node.get_logger()
try:
- threshold = self.node.get_parameter("outlier_sigma_threshold").value
+ threshold = self.connector.node.get_parameter(
+ "outlier_sigma_threshold"
+ ).value
if not isinstance(threshold, float):
logger.error(
f"Parameter outlier_sigma_threshold was set badly: {type(threshold)}: {threshold} expected float. Using default value 1.0"
@@ -227,7 +229,9 @@ def _run(
threshold = 1.0
try:
- conversion_ratio = self.node.get_parameter("conversion_ratio").value
+ conversion_ratio = self.connector.node.get_parameter(
+ "conversion_ratio"
+ ).value
if not isinstance(conversion_ratio, float):
logger.error(
f"Parameter conversion_ratio was set badly: {type(conversion_ratio)}: {conversion_ratio} expected float. Using default value 0.001"
diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py
index 043802d8f..a6390e1e1 100644
--- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py
+++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py
@@ -19,9 +19,8 @@
import rclpy
import sensor_msgs.msg
from langchain_core.tools import BaseTool
-from pydantic import Field
+from pydantic import BaseModel, Field
from rai.communication.ros2.connectors import ROS2ARIConnector
-from rai.tools.ros import Ros2BaseInput
from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray
from rai.utils.ros_async import get_future_result
from rclpy import Future
@@ -36,7 +35,7 @@
# --------------------- Inputs ---------------------
-class GetSegmentationInput(Ros2BaseInput):
+class GetSegmentationInput(BaseModel):
camera_topic: str = Field(
...,
description="Ros2 topic for the camera image containing image to run detection on.",
@@ -46,7 +45,7 @@ class GetSegmentationInput(Ros2BaseInput):
)
-class GetGrabbingPointInput(Ros2BaseInput):
+class GetGrabbingPointInput(BaseModel):
camera_topic: str = Field(
...,
description="Ros2 topic for the camera image containing image to run detection on.",
@@ -96,7 +95,7 @@ def _call_gdino_node(
) -> Future:
cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME)
while not cli.wait_for_service(timeout_sec=1.0):
- self.node.get_logger().info(
+ self.connector.node.get_logger().info(
f"service {GDINO_SERVICE_NAME} not available, waiting again..."
)
req = RAIGroundingDino.Request()
@@ -113,7 +112,7 @@ def _call_gsam_node(
):
cli = self.connector.node.create_client(RAIGroundedSam, "grounded_sam_segment")
while not cli.wait_for_service(timeout_sec=1.0):
- self.node.get_logger().info(
+ self.connector.node.get_logger().info(
"service grounded_sam_segment not available, waiting again..."
)
req = RAIGroundedSam.Request()
diff --git a/tests/communication/ros2/helpers.py b/tests/communication/ros2/helpers.py
index 0c86e54cb..8f80fb1b1 100644
--- a/tests/communication/ros2/helpers.py
+++ b/tests/communication/ros2/helpers.py
@@ -109,6 +109,8 @@ def handle_test_message(self, msg: Any) -> None:
class TestActionServer(Node):
+ __test__ = False
+
def __init__(self, action_name: str):
super().__init__("test_action_server")
self.action_server = ActionServer(
@@ -120,6 +122,7 @@ def __init__(self, action_name: str):
cancel_callback=self.cancel_callback,
callback_group=ReentrantCallbackGroup(),
)
+ self.cancelled: bool = False
def handle_test_action(
self, goal_handle: ServerGoalHandle
@@ -147,10 +150,13 @@ def goal_accepted(self, goal_handle: ServerGoalHandle) -> GoalResponse:
def cancel_callback(self, cancel_request) -> CancelResponse:
self.get_logger().info("Got cancel request")
+ self.cancelled = True
return CancelResponse.ACCEPT
class TestActionClient(Node):
+ __test__ = False
+
def __init__(self):
super().__init__("navigate_to_pose_client")
self._action_client = ActionClient(self, NavigateToPose, "navigate_to_pose")
@@ -188,6 +194,8 @@ def feedback_callback(self, feedback_msg):
class TestServiceClient(Node):
+ __test__ = False
+
def __init__(self):
super().__init__("set_bool_client")
self.client = self.create_client(SetBool, "set_bool")
diff --git a/tests/communication/ros2/test_connectors.py b/tests/communication/ros2/test_connectors.py
index 31c559698..8fa81ae5c 100644
--- a/tests/communication/ros2/test_connectors.py
+++ b/tests/communication/ros2/test_connectors.py
@@ -233,7 +233,7 @@ def test_ros2ari_connector_create_service(
service_client = TestServiceClient()
executors, threads = multi_threaded_spinner([service_client])
service_client.send_request()
- time.sleep(0.02)
+ time.sleep(0.2)
assert mock_callback.called
finally:
connector.shutdown()
diff --git a/tests/core/test_tool_runner.py b/tests/core/test_tool_runner.py
index 0c6971d26..0c330e212 100644
--- a/tests/core/test_tool_runner.py
+++ b/tests/core/test_tool_runner.py
@@ -19,7 +19,7 @@
from rai.agents.tool_runner import ToolRunner
from rai.messages import HumanMultimodalMessage, ToolMultimodalMessage
from rai.messages.utils import preprocess_image
-from rai.tools.ros.cli import ros2_topic
+from rai.tools.ros2.cli import ros2_topic
@tool(response_format="content_and_artifact")
diff --git a/tests/messages/test_transport.py b/tests/messages/test_transport.py
index fbee56d67..c7d1de475 100644
--- a/tests/messages/test_transport.py
+++ b/tests/messages/test_transport.py
@@ -57,6 +57,8 @@ def get_qos_profiles() -> List[str]:
class TestPublisher(Node):
+ __test__ = False
+
def __init__(self, qos_profile: QoSProfile):
super().__init__("test_publisher_" + str(uuid.uuid4()).replace("-", ""))
diff --git a/tests/tools/ros2/test_action_tools.py b/tests/tools/ros2/test_action_tools.py
index 76bf6ea45..40dee7c1a 100644
--- a/tests/tools/ros2/test_action_tools.py
+++ b/tests/tools/ros2/test_action_tools.py
@@ -22,8 +22,8 @@
pytest.skip("ROS2 is not installed", allow_module_level=True)
-from rai.communication.ros2.connectors import ROS2ARIConnector
-from rai.tools.ros2 import StartROS2ActionTool
+from rai.communication.ros2 import ROS2ARIConnector
+from rai.tools.ros2 import CancelROS2ActionTool, StartROS2ActionTool
from tests.communication.ros2.helpers import (
TestActionServer,
@@ -85,3 +85,27 @@ def test_action_call_tool_with_writable_action(
finally:
shutdown_executors_and_threads(executors, threads)
+
+
+def test_cancel_action_tool(ros_setup: None, request: pytest.FixtureRequest) -> None:
+ action_name = f"{request.node.originalname}_action" # type: ignore
+ connector = ROS2ARIConnector()
+ server = TestActionServer(action_name=action_name)
+ executors, threads = multi_threaded_spinner([server])
+ start_tool = StartROS2ActionTool(connector=connector)
+ cancel_tool = CancelROS2ActionTool(connector=connector)
+ try:
+ response = start_tool._run( # type: ignore
+ action_name=action_name,
+ action_type="nav2_msgs/action/NavigateToPose",
+ action_args={},
+ )
+ action_id = response.split("Action started with ID:")[1].strip()
+
+ response = cancel_tool._run( # type: ignore
+ action_id=action_id,
+ )
+ assert server.cancelled
+
+ finally:
+ shutdown_executors_and_threads(executors, threads)