Skip to content

Commit dd4d2d6

Browse files
committed
draft: axum style router
1 parent ef99baf commit dd4d2d6

File tree

6 files changed

+573
-115
lines changed

6 files changed

+573
-115
lines changed

crates/rmcp-macros/src/tool.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -542,22 +542,22 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
542542
// generate wrapped tool function
543543
let tool_call_fn = {
544544
// wrapper function have the same sig:
545-
// async fn #tool_tool_call(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>)
545+
// async fn #tool_tool_call(context: rmcp::handler::server::tool::ToolCallContext<Self>)
546546
// -> std::result::Result<rmcp::model::CallToolResult, rmcp::Error>
547547
//
548548
// and the block part should be like:
549549
// {
550550
// use rmcp::handler::server::tool::*;
551-
// let (t0, context) = <T0>::from_tool_call_context_part(context)?;
552-
// let (t1, context) = <T1>::from_tool_call_context_part(context)?;
551+
// let t0 = <T0>::from_tool_call_context_part(&mut context)?;
552+
// let t1 = <T1>::from_tool_call_context_part(&mut context)?;
553553
// ...
554-
// let (tn, context) = <Tn>::from_tool_call_context_part(context)?;
554+
// let tn = <Tn>::from_tool_call_context_part(&mut context)?;
555555
// // for params
556556
// ... expand helper types here
557-
// let (__rmcp_tool_req, context) = rmcp::model::JsonObject::from_tool_call_context_part(context)?;
557+
// let __rmcp_tool_req = rmcp::model::JsonObject::from_tool_call_context_part(&mut context)?;
558558
// let __#TOOL_ToolCallParam { param_0, param_1, param_2, .. } = parse_json_object(__rmcp_tool_req)?;
559559
// // for aggr
560-
// let (Parameters(aggr), context) = <Parameters<AggrType>>::from_tool_call_context_part(context)?;
560+
// let Parameters(aggr) = <Parameters<AggrType>>::from_tool_call_context_part(&mut context)?;
561561
// Self::#tool_ident(to, param_0, t1, param_1, ..., param_2, tn, aggr).await.into_call_tool_result()
562562
//
563563
// }
@@ -584,14 +584,14 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
584584
let pat = &pat_type.pat;
585585
let ty = &pat_type.ty;
586586
quote! {
587-
let (#pat, context) = <#ty>::from_tool_call_context_part(context)?;
587+
let #pat = <#ty>::from_tool_call_context_part(&mut context)?;
588588
}
589589
}
590590
FnArg::Receiver(r) => {
591591
let ty = r.ty.clone();
592592
let pat = receiver_ident();
593593
quote! {
594-
let (#pat, context) = <#ty>::from_tool_call_context_part(context)?;
594+
let #pat = <#ty>::from_tool_call_context_part(&mut context)?;
595595
}
596596
}
597597
};
@@ -605,7 +605,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
605605
ToolParams::Aggregated { rust_type } => {
606606
let PatType { pat, ty, .. } = rust_type;
607607
quote! {
608-
let (Parameters(#pat), context) = <Parameters<#ty>>::from_tool_call_context_part(context)?;
608+
let Parameters(#pat) = <Parameters<#ty>>::from_tool_call_context_part(&mut context)?;
609609
}
610610
}
611611
ToolParams::Params { attrs } => {
@@ -615,7 +615,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
615615
let params_ident = attrs.iter().map(|attr| &attr.ident).collect::<Vec<_>>();
616616
quote! {
617617
#param_type
618-
let (__rmcp_tool_req, context) = rmcp::model::JsonObject::from_tool_call_context_part(context)?;
618+
let __rmcp_tool_req = rmcp::model::JsonObject::from_tool_call_context_part(&mut context)?;
619619
let #temp_param_type_name {
620620
#(#params_ident,)*
621621
} = parse_json_object(__rmcp_tool_req)?;
@@ -669,7 +669,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
669669
.collect::<Vec<_>>();
670670
quote! {
671671
#(#raw_fn_attr)*
672-
#raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>)
672+
#raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext<Self>)
673673
-> std::result::Result<rmcp::model::CallToolResult, rmcp::Error> {
674674
use rmcp::handler::server::tool::*;
675675
#trivial_arg_extraction_part

crates/rmcp/src/handler/server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::{
77
mod resource;
88
pub mod tool;
99
pub mod wrapper;
10+
pub mod router;
1011
impl<H: ServerHandler> Service<RoleServer> for H {
1112
async fn handle_request(
1213
&self,
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
use std::sync::Arc;
2+
3+
use tool::{IntoToolRoute, ToolRoute};
4+
5+
use crate::{
6+
RoleServer, Service,
7+
model::{ClientRequest, ListToolsResult, ServerResult},
8+
};
9+
10+
use super::ServerHandler;
11+
12+
pub mod tool;
13+
14+
pub struct Router<S> {
15+
pub tool_router: tool::ToolRouter<S>,
16+
pub service: Arc<S>,
17+
}
18+
19+
impl<S> Router<S>
20+
where
21+
S: ServerHandler,
22+
{
23+
pub fn new(service: S) -> Self {
24+
Self {
25+
tool_router: tool::ToolRouter::new(),
26+
service: Arc::new(service),
27+
}
28+
}
29+
30+
pub fn with_tool<R, A>(mut self, route: R) -> Self
31+
where
32+
R: IntoToolRoute<S, A>,
33+
{
34+
self.tool_router.add(route.into_tool_route());
35+
self
36+
}
37+
38+
pub fn with_tools(mut self, routes: impl IntoIterator<Item = ToolRoute<S>>) -> Self
39+
{
40+
for route in routes {
41+
self.tool_router.add(route);
42+
}
43+
self
44+
}
45+
}
46+
47+
impl<S> Service<RoleServer> for Router<S>
48+
where
49+
S: ServerHandler,
50+
{
51+
async fn handle_notification(
52+
&self,
53+
notification: <RoleServer as crate::service::ServiceRole>::PeerNot,
54+
) -> Result<(), crate::Error> {
55+
self.service.handle_notification(notification).await
56+
}
57+
async fn handle_request(
58+
&self,
59+
request: <RoleServer as crate::service::ServiceRole>::PeerReq,
60+
context: crate::service::RequestContext<RoleServer>,
61+
) -> Result<<RoleServer as crate::service::ServiceRole>::Resp, crate::Error> {
62+
match request {
63+
ClientRequest::CallToolRequest(request) => {
64+
if self.tool_router.has(request.params.name.as_ref())
65+
|| !self.tool_router.transparent_when_not_found
66+
{
67+
let tool_call_context = crate::handler::server::tool::ToolCallContext::new(
68+
self.service.clone(),
69+
request.params,
70+
context,
71+
);
72+
let result = self.tool_router.call(tool_call_context).await?;
73+
Ok(ServerResult::CallToolResult(result))
74+
} else {
75+
self.service
76+
.handle_request(ClientRequest::CallToolRequest(request), context)
77+
.await
78+
}
79+
}
80+
ClientRequest::ListToolsRequest(_) => {
81+
let tools = self.tool_router.list_all();
82+
Ok(ServerResult::ListToolsResult(ListToolsResult {
83+
tools,
84+
next_cursor: None,
85+
}))
86+
}
87+
rest => self.service.handle_request(rest, context).await,
88+
}
89+
}
90+
91+
fn get_info(&self) -> <RoleServer as crate::service::ServiceRole>::Info {
92+
self.service.get_info()
93+
}
94+
}
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
use std::borrow::Cow;
2+
3+
use futures::future::BoxFuture;
4+
use schemars::JsonSchema;
5+
6+
use crate::model::{CallToolResult, Tool, ToolAnnotations};
7+
8+
use crate::handler::server::tool::{
9+
CallToolHandler, DynCallToolHandler, ToolCallContext, schema_for_type,
10+
};
11+
12+
pub struct ToolRoute<S> {
13+
#[allow(clippy::type_complexity)]
14+
pub call: Box<DynCallToolHandler<S>>,
15+
pub attr: crate::model::Tool,
16+
}
17+
18+
impl<S: Send + Sync + 'static> ToolRoute<S> {
19+
pub fn new<C, A>(attr: impl Into<Tool>, call: C) -> Self
20+
where
21+
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
22+
<C as CallToolHandler<S, A>>::Fut: 'static,
23+
{
24+
Self {
25+
call: Box::new(move |context: ToolCallContext<S>| {
26+
let call = call.clone();
27+
Box::pin(async move { context.invoke(call).await })
28+
}),
29+
attr: attr.into(),
30+
}
31+
}
32+
pub fn new_dyn<C>(attr: impl Into<Tool>, call: C) -> Self
33+
where
34+
C: Fn(ToolCallContext<S>) -> BoxFuture<'static, Result<CallToolResult, crate::Error>>
35+
+ Send
36+
+ Sync
37+
+ 'static,
38+
{
39+
Self {
40+
call: Box::new(call),
41+
attr: attr.into(),
42+
}
43+
}
44+
pub fn name(&self) -> &str {
45+
&self.attr.name
46+
}
47+
}
48+
49+
pub trait IntoToolRoute<S, A> {
50+
fn into_tool_route(self) -> ToolRoute<S>;
51+
}
52+
53+
impl<S, C, A, T> IntoToolRoute<S, A> for (T, C)
54+
where
55+
S: Send + Sync + 'static,
56+
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
57+
T: Into<Tool>,
58+
<C as CallToolHandler<S, A>>::Fut: 'static,
59+
{
60+
fn into_tool_route(self) -> ToolRoute<S> {
61+
ToolRoute::new(self.0.into(), self.1)
62+
}
63+
}
64+
65+
impl<S> IntoToolRoute<S, ()> for ToolRoute<S>
66+
where
67+
S: Send + Sync + 'static,
68+
{
69+
fn into_tool_route(self) -> ToolRoute<S> {
70+
self
71+
}
72+
}
73+
74+
pub struct ToolAttrGenerateFunctionAdapter;
75+
impl<S, F> IntoToolRoute<S, ToolAttrGenerateFunctionAdapter> for F
76+
where
77+
S: Send + Sync + 'static,
78+
F: Fn() -> ToolRoute<S>,
79+
{
80+
fn into_tool_route(self) -> ToolRoute<S> {
81+
(self)()
82+
}
83+
}
84+
85+
pub trait CallToolHandlerExt<S, A>: Sized
86+
where
87+
Self: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
88+
<Self as CallToolHandler<S, A>>::Fut: 'static,
89+
{
90+
fn name(self, name: impl Into<Cow<'static, str>>) -> WithToolAttr<Self, S, A>;
91+
}
92+
93+
impl<C, S, A> CallToolHandlerExt<S, A> for C
94+
where
95+
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
96+
<C as CallToolHandler<S, A>>::Fut: 'static,
97+
{
98+
fn name(self, name: impl Into<Cow<'static, str>>) -> WithToolAttr<Self, S, A> {
99+
WithToolAttr {
100+
attr: Tool::new(
101+
name.into(),
102+
"",
103+
schema_for_type::<crate::model::JsonObject>(),
104+
),
105+
call: self,
106+
_marker: std::marker::PhantomData,
107+
}
108+
}
109+
}
110+
111+
pub struct WithToolAttr<C, S, A>
112+
where
113+
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
114+
<C as CallToolHandler<S, A>>::Fut: 'static,
115+
{
116+
pub attr: crate::model::Tool,
117+
pub call: C,
118+
pub _marker: std::marker::PhantomData<fn(S, A)>,
119+
}
120+
121+
impl<C, S, A> IntoToolRoute<S, A> for WithToolAttr<C, S, A>
122+
where
123+
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
124+
<C as CallToolHandler<S, A>>::Fut: 'static,
125+
S: Send + Sync + 'static,
126+
{
127+
fn into_tool_route(self) -> ToolRoute<S> {
128+
ToolRoute::new(self.attr, self.call)
129+
}
130+
}
131+
132+
impl<C, S, A> WithToolAttr<C, S, A>
133+
where
134+
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
135+
<C as CallToolHandler<S, A>>::Fut: 'static,
136+
{
137+
pub fn description(mut self, description: impl Into<Cow<'static, str>>) -> Self {
138+
self.attr.description = Some(description.into());
139+
self
140+
}
141+
pub fn parameters<T: JsonSchema>(mut self) -> Self {
142+
self.attr.input_schema = schema_for_type::<T>().into();
143+
self
144+
}
145+
pub fn parameters_value(mut self, schema: serde_json::Value) -> Self {
146+
self.attr.input_schema = crate::model::object(schema).into();
147+
self
148+
}
149+
pub fn annotation(mut self, annotation: impl Into<ToolAnnotations>) -> Self {
150+
self.attr.annotations = Some(annotation.into());
151+
self
152+
}
153+
}
154+
155+
#[derive(Default)]
156+
pub struct ToolRouter<S> {
157+
#[allow(clippy::type_complexity)]
158+
pub map: std::collections::HashMap<Cow<'static, str>, ToolRoute<S>>,
159+
160+
pub transparent_when_not_found: bool,
161+
}
162+
163+
impl<S> IntoIterator for ToolRouter<S> {
164+
type Item = ToolRoute<S>;
165+
type IntoIter = std::collections::hash_map::IntoValues<Cow<'static, str>, ToolRoute<S>>;
166+
167+
fn into_iter(self) -> Self::IntoIter {
168+
self.map.into_values()
169+
}
170+
}
171+
172+
impl<S> ToolRouter<S>
173+
where
174+
S: Send + Sync + 'static,
175+
{
176+
pub fn new() -> Self {
177+
Self {
178+
map: std::collections::HashMap::new(),
179+
transparent_when_not_found: false,
180+
}
181+
}
182+
pub fn with<C, A>(mut self, attr: crate::model::Tool, call: C) -> Self
183+
where
184+
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
185+
<C as CallToolHandler<S, A>>::Fut: 'static,
186+
{
187+
self.add(ToolRoute::new(attr, call));
188+
self
189+
}
190+
191+
pub fn add(&mut self, item: ToolRoute<S>) {
192+
self.map.insert(item.attr.name.clone(), item);
193+
}
194+
195+
pub fn remove<H, A>(&mut self, name: &str) {
196+
self.map.remove(name);
197+
}
198+
pub fn has(&self, name: &str) -> bool {
199+
self.map.contains_key(name)
200+
}
201+
pub async fn call(&self, context: ToolCallContext<S>) -> Result<CallToolResult, crate::Error> {
202+
let item = self
203+
.map
204+
.get(context.name())
205+
.ok_or_else(|| crate::Error::invalid_params("tool not found", None))?;
206+
(item.call)(context).await
207+
}
208+
209+
pub fn list_all(&self) -> Vec<crate::model::Tool> {
210+
self.map.values().map(|item| item.attr.clone()).collect()
211+
}
212+
}

0 commit comments

Comments
 (0)