Skip to content

draft: handler router, tool router, with axum style. #201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions crates/rmcp-macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,22 +542,22 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
// generate wrapped tool function
let tool_call_fn = {
// wrapper function have the same sig:
// async fn #tool_tool_call(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>)
// async fn #tool_tool_call(context: rmcp::handler::server::tool::ToolCallContext<Self>)
// -> std::result::Result<rmcp::model::CallToolResult, rmcp::Error>
//
// and the block part should be like:
// {
// use rmcp::handler::server::tool::*;
// let (t0, context) = <T0>::from_tool_call_context_part(context)?;
// let (t1, context) = <T1>::from_tool_call_context_part(context)?;
// let t0 = <T0>::from_tool_call_context_part(&mut context)?;
// let t1 = <T1>::from_tool_call_context_part(&mut context)?;
// ...
// let (tn, context) = <Tn>::from_tool_call_context_part(context)?;
// let tn = <Tn>::from_tool_call_context_part(&mut context)?;
// // for params
// ... expand helper types here
// let (__rmcp_tool_req, context) = rmcp::model::JsonObject::from_tool_call_context_part(context)?;
// let __rmcp_tool_req = rmcp::model::JsonObject::from_tool_call_context_part(&mut context)?;
// let __#TOOL_ToolCallParam { param_0, param_1, param_2, .. } = parse_json_object(__rmcp_tool_req)?;
// // for aggr
// let (Parameters(aggr), context) = <Parameters<AggrType>>::from_tool_call_context_part(context)?;
// let Parameters(aggr) = <Parameters<AggrType>>::from_tool_call_context_part(&mut context)?;
// Self::#tool_ident(to, param_0, t1, param_1, ..., param_2, tn, aggr).await.into_call_tool_result()
//
// }
Expand All @@ -584,14 +584,14 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
let pat = &pat_type.pat;
let ty = &pat_type.ty;
quote! {
let (#pat, context) = <#ty>::from_tool_call_context_part(context)?;
let #pat = <#ty>::from_tool_call_context_part(&mut context)?;
}
}
FnArg::Receiver(r) => {
let ty = r.ty.clone();
let pat = receiver_ident();
quote! {
let (#pat, context) = <#ty>::from_tool_call_context_part(context)?;
let #pat = <#ty>::from_tool_call_context_part(&mut context)?;
}
}
};
Expand All @@ -605,7 +605,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
ToolParams::Aggregated { rust_type } => {
let PatType { pat, ty, .. } = rust_type;
quote! {
let (Parameters(#pat), context) = <Parameters<#ty>>::from_tool_call_context_part(context)?;
let Parameters(#pat) = <Parameters<#ty>>::from_tool_call_context_part(&mut context)?;
}
}
ToolParams::Params { attrs } => {
Expand All @@ -615,7 +615,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
let params_ident = attrs.iter().map(|attr| &attr.ident).collect::<Vec<_>>();
quote! {
#param_type
let (__rmcp_tool_req, context) = rmcp::model::JsonObject::from_tool_call_context_part(context)?;
let __rmcp_tool_req = rmcp::model::JsonObject::from_tool_call_context_part(&mut context)?;
let #temp_param_type_name {
#(#params_ident,)*
} = parse_json_object(__rmcp_tool_req)?;
Expand Down Expand Up @@ -669,7 +669,7 @@ pub(crate) fn tool_fn_item(attr: TokenStream, mut input_fn: ItemFn) -> syn::Resu
.collect::<Vec<_>>();
quote! {
#(#raw_fn_attr)*
#raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext<'_, Self>)
#raw_fn_vis async fn #tool_call_fn_ident(context: rmcp::handler::server::tool::ToolCallContext<Self>)
-> std::result::Result<rmcp::model::CallToolResult, rmcp::Error> {
use rmcp::handler::server::tool::*;
#trivial_arg_extraction_part
Expand Down
1 change: 1 addition & 0 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
mod resource;
pub mod tool;
pub mod wrapper;
pub mod router;
impl<H: ServerHandler> Service<RoleServer> for H {
async fn handle_request(
&self,
Expand Down
94 changes: 94 additions & 0 deletions crates/rmcp/src/handler/server/router.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use std::sync::Arc;

use tool::{IntoToolRoute, ToolRoute};

use crate::{
RoleServer, Service,
model::{ClientRequest, ListToolsResult, ServerResult},
};

use super::ServerHandler;

pub mod tool;

pub struct Router<S> {
pub tool_router: tool::ToolRouter<S>,
pub service: Arc<S>,
}

impl<S> Router<S>
where
S: ServerHandler,
{
pub fn new(service: S) -> Self {
Self {
tool_router: tool::ToolRouter::new(),
service: Arc::new(service),
}
}

pub fn with_tool<R, A>(mut self, route: R) -> Self
where
R: IntoToolRoute<S, A>,
{
self.tool_router.add(route.into_tool_route());
self
}

pub fn with_tools(mut self, routes: impl IntoIterator<Item = ToolRoute<S>>) -> Self
{
for route in routes {
self.tool_router.add(route);
}
self
}
}

impl<S> Service<RoleServer> for Router<S>
where
S: ServerHandler,
{
async fn handle_notification(
&self,
notification: <RoleServer as crate::service::ServiceRole>::PeerNot,
) -> Result<(), crate::Error> {
self.service.handle_notification(notification).await
}
async fn handle_request(
&self,
request: <RoleServer as crate::service::ServiceRole>::PeerReq,
context: crate::service::RequestContext<RoleServer>,
) -> Result<<RoleServer as crate::service::ServiceRole>::Resp, crate::Error> {
match request {
ClientRequest::CallToolRequest(request) => {
if self.tool_router.has(request.params.name.as_ref())
|| !self.tool_router.transparent_when_not_found
{
let tool_call_context = crate::handler::server::tool::ToolCallContext::new(
self.service.clone(),
request.params,
context,
);
let result = self.tool_router.call(tool_call_context).await?;
Ok(ServerResult::CallToolResult(result))
} else {
self.service
.handle_request(ClientRequest::CallToolRequest(request), context)
.await
}
}
ClientRequest::ListToolsRequest(_) => {
let tools = self.tool_router.list_all();
Ok(ServerResult::ListToolsResult(ListToolsResult {
tools,
next_cursor: None,
}))
}
rest => self.service.handle_request(rest, context).await,
}
}

fn get_info(&self) -> <RoleServer as crate::service::ServiceRole>::Info {
self.service.get_info()
}
}
212 changes: 212 additions & 0 deletions crates/rmcp/src/handler/server/router/tool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
use std::borrow::Cow;

use futures::future::BoxFuture;
use schemars::JsonSchema;

use crate::model::{CallToolResult, Tool, ToolAnnotations};

use crate::handler::server::tool::{
CallToolHandler, DynCallToolHandler, ToolCallContext, schema_for_type,
};

pub struct ToolRoute<S> {
#[allow(clippy::type_complexity)]
pub call: Box<DynCallToolHandler<S>>,
pub attr: crate::model::Tool,
}

impl<S: Send + Sync + 'static> ToolRoute<S> {
pub fn new<C, A>(attr: impl Into<Tool>, call: C) -> Self
where
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
<C as CallToolHandler<S, A>>::Fut: 'static,
{
Self {
call: Box::new(move |context: ToolCallContext<S>| {
let call = call.clone();
Box::pin(async move { context.invoke(call).await })
}),
attr: attr.into(),
}
}
pub fn new_dyn<C>(attr: impl Into<Tool>, call: C) -> Self
where
C: Fn(ToolCallContext<S>) -> BoxFuture<'static, Result<CallToolResult, crate::Error>>
+ Send
+ Sync
+ 'static,
{
Self {
call: Box::new(call),
attr: attr.into(),
}
}
pub fn name(&self) -> &str {
&self.attr.name
}
}

pub trait IntoToolRoute<S, A> {
fn into_tool_route(self) -> ToolRoute<S>;
}

impl<S, C, A, T> IntoToolRoute<S, A> for (T, C)
where
S: Send + Sync + 'static,
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
T: Into<Tool>,
<C as CallToolHandler<S, A>>::Fut: 'static,
{
fn into_tool_route(self) -> ToolRoute<S> {
ToolRoute::new(self.0.into(), self.1)
}
}

impl<S> IntoToolRoute<S, ()> for ToolRoute<S>
where
S: Send + Sync + 'static,
{
fn into_tool_route(self) -> ToolRoute<S> {
self
}
}

pub struct ToolAttrGenerateFunctionAdapter;
impl<S, F> IntoToolRoute<S, ToolAttrGenerateFunctionAdapter> for F
where
S: Send + Sync + 'static,
F: Fn() -> ToolRoute<S>,
{
fn into_tool_route(self) -> ToolRoute<S> {
(self)()
}
}

pub trait CallToolHandlerExt<S, A>: Sized
where
Self: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
<Self as CallToolHandler<S, A>>::Fut: 'static,
{
fn name(self, name: impl Into<Cow<'static, str>>) -> WithToolAttr<Self, S, A>;
}

impl<C, S, A> CallToolHandlerExt<S, A> for C
where
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
<C as CallToolHandler<S, A>>::Fut: 'static,
{
fn name(self, name: impl Into<Cow<'static, str>>) -> WithToolAttr<Self, S, A> {
WithToolAttr {
attr: Tool::new(
name.into(),
"",
schema_for_type::<crate::model::JsonObject>(),
),
call: self,
_marker: std::marker::PhantomData,
}
}
}

pub struct WithToolAttr<C, S, A>
where
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
<C as CallToolHandler<S, A>>::Fut: 'static,
{
pub attr: crate::model::Tool,
pub call: C,
pub _marker: std::marker::PhantomData<fn(S, A)>,
}

impl<C, S, A> IntoToolRoute<S, A> for WithToolAttr<C, S, A>
where
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
<C as CallToolHandler<S, A>>::Fut: 'static,
S: Send + Sync + 'static,
{
fn into_tool_route(self) -> ToolRoute<S> {
ToolRoute::new(self.attr, self.call)
}
}

impl<C, S, A> WithToolAttr<C, S, A>
where
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
<C as CallToolHandler<S, A>>::Fut: 'static,
{
pub fn description(mut self, description: impl Into<Cow<'static, str>>) -> Self {
self.attr.description = Some(description.into());
self
}
pub fn parameters<T: JsonSchema>(mut self) -> Self {
self.attr.input_schema = schema_for_type::<T>().into();
self
}
pub fn parameters_value(mut self, schema: serde_json::Value) -> Self {
self.attr.input_schema = crate::model::object(schema).into();
self
}
pub fn annotation(mut self, annotation: impl Into<ToolAnnotations>) -> Self {
self.attr.annotations = Some(annotation.into());
self
}
}

#[derive(Default)]
pub struct ToolRouter<S> {
#[allow(clippy::type_complexity)]
pub map: std::collections::HashMap<Cow<'static, str>, ToolRoute<S>>,

pub transparent_when_not_found: bool,
}

impl<S> IntoIterator for ToolRouter<S> {
type Item = ToolRoute<S>;
type IntoIter = std::collections::hash_map::IntoValues<Cow<'static, str>, ToolRoute<S>>;

fn into_iter(self) -> Self::IntoIter {
self.map.into_values()
}
}

impl<S> ToolRouter<S>
where
S: Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
map: std::collections::HashMap::new(),
transparent_when_not_found: false,
}
}
pub fn with<C, A>(mut self, attr: crate::model::Tool, call: C) -> Self
where
C: CallToolHandler<S, A> + Send + Sync + Clone + 'static,
<C as CallToolHandler<S, A>>::Fut: 'static,
{
self.add(ToolRoute::new(attr, call));
self
}

pub fn add(&mut self, item: ToolRoute<S>) {
self.map.insert(item.attr.name.clone(), item);
}

pub fn remove<H, A>(&mut self, name: &str) {
self.map.remove(name);
}
pub fn has(&self, name: &str) -> bool {
self.map.contains_key(name)
}
pub async fn call(&self, context: ToolCallContext<S>) -> Result<CallToolResult, crate::Error> {
let item = self
.map
.get(context.name())
.ok_or_else(|| crate::Error::invalid_params("tool not found", None))?;
(item.call)(context).await
}

pub fn list_all(&self) -> Vec<crate::model::Tool> {
self.map.values().map(|item| item.attr.clone()).collect()
}
}
Loading
Loading