From a1e6fd776da6ef69e6c84eb1b9526445a7b238ae Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Wed, 23 Apr 2025 16:54:22 +0200 Subject: [PATCH 01/15] use rmcp crate instead of mcp client rust, first version --- refact-agent/engine/Cargo.toml | 7 +- .../engine/src/integrations/integr_mcp.rs | 166 +++++++++--------- 2 files changed, 82 insertions(+), 91 deletions(-) diff --git a/refact-agent/engine/Cargo.toml b/refact-agent/engine/Cargo.toml index 15d1ba2a6..7cc44d7f2 100644 --- a/refact-agent/engine/Cargo.toml +++ b/refact-agent/engine/Cargo.toml @@ -54,6 +54,7 @@ regex = "1.9.5" reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls-webpki-roots", "charset", "http2"] } reqwest-eventsource = "0.6.0" resvg = "0.44.0" +rmcp = { "version" = "0.1.5", features = ["client", "transport-child-process", "transport-sse"] } ropey = "1.6" rusqlite = { version = "0.31.0", features = ["bundled"] } rust-embed = "8.5.0" @@ -97,9 +98,3 @@ uuid = { version = "1", features = ["v4", "serde"] } walkdir = "2.3" which = "7.0.1" zerocopy = "0.8.14" - -# There you can use a local copy: -mcp_client_rs = { git = "https://github.com/smallcloudai/mcp_client_rust.git" } -#mcp_client_rs = { path = "../../../mcp_client_rust" } - - diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index 4f4e5a60b..c922778ce 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -7,9 +7,9 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex as AMutex; use tokio::sync::RwLock as ARwLock; -use mcp_client_rs::client::Client as MCPClient; +use rmcp::{RoleClient, ServiceExt, service::RunningService}; use tokio::task::{AbortHandle, JoinHandle}; -use mcp_client_rs::client::ClientBuilder; +use rmcp::model::{CallToolRequestParam, Tool as McpTool}; use crate::global_context::GlobalContext; use crate::at_commands::at_commands::AtCommandsContext; @@ -30,8 +30,8 @@ pub struct SettingsMCP { pub struct ToolMCP { pub common: IntegrationCommon, pub config_path: String, - pub mcp_client: Arc>, - pub mcp_tool: mcp_client_rs::Tool, + pub mcp_client: Arc>>>, + pub mcp_tool: McpTool, } #[derive(Default)] @@ -46,8 +46,8 @@ pub struct SessionMCP { pub debug_name: String, pub config_path: String, // to check if expired or not pub launched_cfg: SettingsMCP, // a copy to compare against IntegrationMCP::cfg, to see if anything has changed - pub mcp_client: Option>>, - pub mcp_tools: Vec, + pub mcp_client: Option>>>>, + pub mcp_tools: Vec, pub startup_task_handles: Option<(Arc>>>, AbortHandle)>, pub logs: Arc>>, // Store log messages } @@ -103,25 +103,30 @@ async fn _add_log_entry(session_logs: Arc>>, entry: String) { async fn _session_kill_process( debug_name: &str, - mcp_client: Arc>, + mcp_client: Arc>>>, session_logs: Arc>>, ) { tracing::info!("Stopping MCP Server for {}", debug_name); _add_log_entry(session_logs.clone(), "Stopping MCP Server".to_string()).await; - let client_result = { + let client_to_cancel = { let mut mcp_client_locked = mcp_client.lock().await; - mcp_client_locked.shutdown().await + mcp_client_locked.take() }; - if let Err(e) = client_result { - let error_msg = format!("Failed to stop MCP: {:?}", e); - tracing::error!("{} for {}", error_msg, debug_name); - _add_log_entry(session_logs, error_msg).await; - } else { - let success_msg = "MCP server stopped".to_string(); - tracing::info!("{} for {}", success_msg, debug_name); - _add_log_entry(session_logs, success_msg).await; + if let Some(client) = client_to_cancel { + match client.cancel().await { + Ok(reason) => { + let success_msg = format!("MCP server stopped: {:?}", reason); + tracing::info!("{} for {}", success_msg, debug_name); + _add_log_entry(session_logs, success_msg).await; + }, + Err(e) => { + let error_msg = format!("Failed to stop MCP: {:?}", e); + tracing::error!("{} for {}", error_msg, debug_name); + _add_log_entry(session_logs, error_msg).await; + } + } } } @@ -205,16 +210,14 @@ async fn _session_apply_settings( } }; - let mut client_builder = ClientBuilder::new(&parsed_args[0]); - for arg in parsed_args.iter().skip(1) { - client_builder = client_builder.arg(arg); - } + let mut command = tokio::process::Command::new(&parsed_args[0]); + command.args(&parsed_args[1..]); for (key, value) in &new_cfg_clone.mcp_env { - client_builder = client_builder.env(key, value); + command.env(key, value); } - let (mut client, imp, caps) = match client_builder.spawn().await { - Ok(r) => r, + let transport = match rmcp::transport::TokioChildProcess::new(&mut command) { + Ok(t) => t, Err(e) => { let err_msg = format!("Failed to init process: {}", e); tracing::error!("{err_msg} for {debug_name}"); @@ -222,33 +225,21 @@ async fn _session_apply_settings( return; } }; - if let Err(e) = client.initialize(imp, caps).await { - let err_msg = format!("Failed to init server: {}", e); - tracing::error!("{err_msg} for {debug_name}"); - _add_log_entry(logs.clone(), err_msg).await; - if let Ok(error_log) = client.get_stderr(None).await { - _add_log_entry(logs.clone(), error_log).await; + + let client = match ().serve(transport).await { + Ok(c) => c, + Err(e) => { + let err_msg = format!("Failed to init server: {}", e); + tracing::error!("{err_msg} for {debug_name}"); + _add_log_entry(logs.clone(), err_msg).await; + return; } - return; }; - // let set_result = client.request( - // "logging/setLevel", - // Some(serde_json::json!({ "level": "debug" })), - // ).await; - // match set_result { - // Ok(_) => { - // tracing::info!("MCP START SESSION (2) set log level success"); - // } - // Err(e) => { - // tracing::info!("MCP START SESSION (2) failed to set log level: {:?}", e); - // } - // } - tracing::info!("MCP START SESSION (2) {:?}", debug_name); _add_log_entry(logs.clone(), "Listing tools".to_string()).await; - let tools_result = match client.list_tools().await { + let tools_result = match client.list_tools(None).await { Ok(result) => { let success_msg = format!("Successfully listed {} tools", result.tools.len()); tracing::info!("{} for {}", success_msg, debug_name); @@ -258,14 +249,11 @@ async fn _session_apply_settings( let err_msg = format!("Failed to list tools: {:?}", tools_error); tracing::error!("{} for {}", err_msg, debug_name); _add_log_entry(logs.clone(), err_msg).await; - if let Ok(error_log) = client.get_stderr(None).await { - _add_log_entry(logs.clone(), error_log).await; - } return; } }; - - let new_mcp_client = Arc::new(AMutex::new(client)); + + let new_mcp_client = Arc::new(AMutex::new(Some(client))); let tools_len = { tracing::info!("MCP START SESSION (3) {:?}", debug_name); @@ -416,38 +404,48 @@ impl Tool for ToolMCP { _add_log_entry(session_logs.clone(), format!("Executing tool '{}' with arguments: {:?}", self.mcp_tool.name, json_args)).await; let result_probably = { - let mut mcp_client_locked = self.mcp_client.lock().await; - mcp_client_locked.call_tool(self.mcp_tool.name.as_str(), json_args).await + let mcp_client_locked = self.mcp_client.lock().await; + if let Some(client) = &*mcp_client_locked { + client.call_tool(CallToolRequestParam { + name: self.mcp_tool.name.clone(), + arguments: match json_args { + serde_json::Value::Object(map) => Some(map), + _ => None, + }, + }).await + } else { + return Err("MCP client is not available".to_string()); + } }; let tool_output = match result_probably { Ok(result) => { - if result.is_error { + if result.is_error.unwrap_or(false) { let error_msg = format!("Tool execution error: {:?}", result.content); _add_log_entry(session_logs.clone(), error_msg.clone()).await; return Err(error_msg); } - if let Some(mcp_client_rs::MessageContent::Text { text }) = result.content.get(0) { - let success_msg = format!("Tool '{}' executed successfully", self.mcp_tool.name); - _add_log_entry(session_logs.clone(), success_msg).await; - text.clone() + if let Some(content) = result.content.get(0) { + if let rmcp::model::RawContent::Text(text_content) = &content.raw { + let text = text_content.text.clone(); + let success_msg = format!("Tool '{}' executed successfully", self.mcp_tool.name); + _add_log_entry(session_logs.clone(), success_msg).await; + text + } else { + let error_msg = format!("Unexpected tool output format: {:?}", result.content); + tracing::error!("{}", error_msg); + _add_log_entry(session_logs.clone(), error_msg.clone()).await; + return Err("Unexpected tool output format".to_string()); + } } else { - let error_msg = format!("Unexpected tool output format: {:?}", result.content); - tracing::error!("{}", error_msg); - _add_log_entry(session_logs.clone(), error_msg.clone()).await; - return Err("Unexpected tool output format".to_string()); + String::new() } } Err(e) => { let error_msg = format!("Failed to call tool: {:?}", e); tracing::error!("{}", error_msg); _add_log_entry(session_logs.clone(), error_msg).await; - - let error_log = self.mcp_client.lock().await.get_stderr(None).await; - if let Ok(error_log) = error_log { - _add_log_entry(session_logs.clone(), error_log).await; - } return Err(e.to_string()); } }; @@ -489,25 +487,23 @@ impl Tool for ToolMCP { let mut parameters = vec![]; let mut parameters_required = vec![]; - if let serde_json::Value::Object(schema) = &self.mcp_tool.input_schema { - if let Some(serde_json::Value::Object(properties)) = schema.get("properties") { - for (name, prop) in properties { - if let serde_json::Value::Object(prop_obj) = prop { - let param_type = prop_obj.get("type").and_then(|v| v.as_str()).unwrap_or("string").to_string(); - let description = prop_obj.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string(); - parameters.push(ToolParam { - name: name.clone(), - param_type, - description, - }); - } + if let Some(serde_json::Value::Object(properties)) = self.mcp_tool.input_schema.get("properties") { + for (name, prop) in properties { + if let serde_json::Value::Object(prop_obj) = prop { + let param_type = prop_obj.get("type").and_then(|v| v.as_str()).unwrap_or("string").to_string(); + let description = prop_obj.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string(); + parameters.push(ToolParam { + name: name.clone(), + param_type, + description, + }); } } - if let Some(serde_json::Value::Array(required)) = schema.get("required") { - for req in required { - if let Some(req_str) = req.as_str() { - parameters_required.push(req_str.to_string()); - } + } + if let Some(serde_json::Value::Array(required)) = self.mcp_tool.input_schema.get("required") { + for req in required { + if let Some(req_str) = req.as_str() { + parameters_required.push(req_str.to_string()); } } } @@ -516,7 +512,7 @@ impl Tool for ToolMCP { name: self.tool_name(), agentic: true, experimental: false, - description: self.mcp_tool.description.clone(), + description: self.mcp_tool.description.to_string(), parameters, parameters_required, } @@ -541,7 +537,7 @@ impl Tool for ToolMCP { ) -> Result { let command = self.mcp_tool.name.clone(); tracing::info!("MCP command_to_match_against_confirm_deny() returns {:?}", command); - Ok(command) + Ok(command.to_string()) } fn confirm_deny_rules(&self) -> Option { From fd9edbec447765a5f7f71e6c3f37c9525dad1ca5 Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Wed, 23 Apr 2025 22:32:49 +0200 Subject: [PATCH 02/15] add logs from stderr of the process, by using a file to redirect stderr there --- .../src/http/routers/v1/v1_integrations.rs | 20 ++++-- .../engine/src/integrations/integr_mcp.rs | 64 +++++++++++++++++-- .../src/integrations/process_io_utils.rs | 26 +++++++- 3 files changed, 97 insertions(+), 13 deletions(-) diff --git a/refact-agent/engine/src/http/routers/v1/v1_integrations.rs b/refact-agent/engine/src/http/routers/v1/v1_integrations.rs index 2fac3f4b0..608d57276 100644 --- a/refact-agent/engine/src/http/routers/v1/v1_integrations.rs +++ b/refact-agent/engine/src/http/routers/v1/v1_integrations.rs @@ -206,20 +206,32 @@ pub async fn handle_v1_integrations_mcp_logs( let session = gcx.read().await.integration_sessions.get(&session_key).cloned() .ok_or(ScratchError::new(StatusCode::NOT_FOUND, format!("session {} not found", session_key)))?; - let logs_arc = { + let (logs_arc, stderr_file_path, stderr_cursor) = { let mut session_locked = session.lock().await; let session_downcasted = session_locked.as_any_mut().downcast_mut::() .ok_or(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, "Session is not a MCP session".to_string()))?; - session_downcasted.logs.clone() + ( + session_downcasted.logs.clone(), + session_downcasted.stderr_file_path.clone(), + session_downcasted.stderr_cursor.clone(), + ) }; - let logs = logs_arc.lock().await.clone(); + if let Some(stderr_path) = &stderr_file_path { + if let Err(e) = crate::integrations::integr_mcp::update_logs_from_stderr( + stderr_path, + stderr_cursor, + logs_arc.clone() + ).await { + tracing::warn!("Failed to read stderr file: {}", e); + } + } return Ok(Response::builder() .status(StatusCode::OK) .header("Content-Type", "application/json") .body(Body::from(serde_json::json!({ - "logs": logs, + "logs": logs_arc.lock().await.clone(), }).to_string())) .unwrap()) } diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index c922778ce..fa4d27e16 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -1,8 +1,10 @@ use std::any::Any; use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; use std::sync::Weak; use std::future::Future; +use std::process::Stdio; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex as AMutex; @@ -10,13 +12,16 @@ use tokio::sync::RwLock as ARwLock; use rmcp::{RoleClient, ServiceExt, service::RunningService}; use tokio::task::{AbortHandle, JoinHandle}; use rmcp::model::{CallToolRequestParam, Tool as McpTool}; +use tempfile::NamedTempFile; +use crate::custom_error::MapErrToString; use crate::global_context::GlobalContext; use crate::at_commands::at_commands::AtCommandsContext; use crate::tools::tools_description::{Tool, ToolDesc, ToolParam}; use crate::call_validation::{ChatMessage, ChatContent, ContextEnum}; use crate::integrations::integr_abstract::{IntegrationTrait, IntegrationCommon, IntegrationConfirmation}; use crate::integrations::sessions::IntegrationSession; +use crate::integrations::process_io_utils::read_file_with_cursor; #[derive(Deserialize, Serialize, Clone, Default, PartialEq, Debug)] @@ -50,6 +55,8 @@ pub struct SessionMCP { pub mcp_tools: Vec, pub startup_task_handles: Option<(Arc>>>, AbortHandle)>, pub logs: Arc>>, // Store log messages + pub stderr_file_path: Option, // Path to the temporary file for stderr + pub stderr_cursor: Arc>, // Position in the file where we last read from } impl IntegrationSession for SessionMCP { @@ -63,7 +70,7 @@ impl IntegrationSession for SessionMCP { fn try_stop(&mut self, self_arc: Arc>>) -> Box + Send> { Box::new(async move { - let (debug_name, client, logs, startup_task_handles) = { + let (debug_name, client, logs, startup_task_handles, stderr_file) = { let mut session_locked = self_arc.lock().await; let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); ( @@ -71,6 +78,7 @@ impl IntegrationSession for SessionMCP { session_downcasted.mcp_client.clone(), session_downcasted.logs.clone(), session_downcasted.startup_task_handles.clone(), + session_downcasted.stderr_file_path.clone(), ) }; @@ -82,6 +90,11 @@ impl IntegrationSession for SessionMCP { if let Some(client) = client { _session_kill_process(&debug_name, client, logs).await; } + if let Some(stderr_file) = &stderr_file { + if let Err(e) = tokio::fs::remove_file(stderr_file).await { + tracing::error!("Failed to remove {}: {}", stderr_file.to_string_lossy(), e); + } + } "".to_string() }) @@ -101,6 +114,19 @@ async fn _add_log_entry(session_logs: Arc>>, entry: String) { } } +pub async fn update_logs_from_stderr( + stderr_file_path: &PathBuf, + stderr_cursor: Arc>, + session_logs: Arc>> +) -> Result<(), String> { + let (buffer, bytes_read) = read_file_with_cursor(stderr_file_path, stderr_cursor.clone()).await + .map_err_with_prefix("Failed to read file:")?; + if bytes_read > 0 && !buffer.trim().is_empty() { + _add_log_entry(session_logs, buffer.trim().to_string()).await; + } + Ok(()) +} + async fn _session_kill_process( debug_name: &str, mcp_client: Arc>>>, @@ -149,6 +175,8 @@ async fn _session_apply_settings( mcp_tools: Vec::new(), startup_task_handles: None, logs: Arc::new(AMutex::new(Vec::new())), + stderr_file_path: None, + stderr_cursor: Arc::new(AMutex::new(0)), }))); tracing::info!("MCP START SESSION {:?}", session_key); gcx_write.integration_sessions.insert(session_key.clone(), new_session.clone()); @@ -175,14 +203,16 @@ async fn _session_apply_settings( } let startup_task_join_handle = tokio::spawn(async move { - let (mcp_client, logs, debug_name) = { + let (mcp_client, logs, debug_name, stderr_file) = { let mut session_locked = session_arc_clone.lock().await; - let mcp_sesion = session_locked.as_any_mut().downcast_mut::().unwrap(); - mcp_sesion.launched_cfg = new_cfg_clone.clone(); + let mcp_session = session_locked.as_any_mut().downcast_mut::().unwrap(); + mcp_session.stderr_cursor = Arc::new(AMutex::new(0)); + mcp_session.launched_cfg = new_cfg_clone.clone(); ( - std::mem::take(&mut mcp_sesion.mcp_client), - mcp_sesion.logs.clone(), - mcp_sesion.debug_name.clone(), + std::mem::take(&mut mcp_session.mcp_client), + mcp_session.logs.clone(), + mcp_session.debug_name.clone(), + std::mem::take(&mut mcp_session.stderr_file_path), ) }; @@ -191,6 +221,11 @@ async fn _session_apply_settings( if let Some(mcp_client) = mcp_client { _session_kill_process(&debug_name, mcp_client, logs.clone()).await; } + if let Some(stderr_file) = &stderr_file { + if let Err(e) = tokio::fs::remove_file(stderr_file).await { + tracing::error!("Failed to remove {}: {}", stderr_file.to_string_lossy(), e); + } + } let parsed_args = match shell_words::split(&new_cfg_clone.mcp_command) { Ok(args) => { @@ -216,6 +251,21 @@ async fn _session_apply_settings( command.env(key, value); } + match NamedTempFile::new().map(|f| f.keep()) { + Ok(Ok((file, path))) => { + { + let mut session_locked = session_arc_clone.lock().await; + let mcp_session = session_locked.as_any_mut().downcast_mut::().unwrap(); + + mcp_session.stderr_file_path = Some(path.clone()); + mcp_session.stderr_cursor = Arc::new(AMutex::new(0)); + } + command.stderr(Stdio::from(file)); + }, + Ok(Err(e)) => tracing::error!("Failed to persist stderr file for {debug_name}: {e}"), + Err(e) => tracing::error!("Failed to create stderr file for {debug_name}: {e}"), + } + let transport = match rmcp::transport::TokioChildProcess::new(&mut command) { Ok(t) => t, Err(e) => { diff --git a/refact-agent/engine/src/integrations/process_io_utils.rs b/refact-agent/engine/src/integrations/process_io_utils.rs index 42860d49c..b38f2d912 100644 --- a/refact-agent/engine/src/integrations/process_io_utils.rs +++ b/refact-agent/engine/src/integrations/process_io_utils.rs @@ -1,10 +1,13 @@ use futures::future::try_join3; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::process::{Child, ChildStdin, Command}; +use tokio::sync::Mutex as AMutex; use tokio::time::Duration; +use std::path::Path; use std::pin::Pin; use std::process::Output; +use std::sync::Arc; use std::time::Instant; use std::process::Stdio; use tracing::error; @@ -77,6 +80,25 @@ pub async fn blocking_read_until_token_or_timeout< Ok((String::from_utf8_lossy(&output).to_string(), String::from_utf8_lossy(&error).to_string(), have_the_token)) } +pub async fn read_file_with_cursor( + file_path: &Path, + cursor: Arc>, +) -> Result<(String, usize), String> { + let file = tokio::fs::OpenOptions::new().read(true).open(file_path).await + .map_err(|e| format!("Failed to read file: {}", e))?; + let mut cursor_locked = cursor.lock().await; + let mut file = tokio::io::BufReader::new(file); + file.seek(tokio::io::SeekFrom::Start(*cursor_locked)).await + .map_err(|e| format!("Failed to seek: {}", e))?; + let mut buffer = String::new(); + let bytes_read = file.read_to_string(&mut buffer).await + .map_err(|e| format!("Failed to read to buffer: {}", e))?; + if bytes_read > 0 { + *cursor_locked += bytes_read as u64; + } + Ok((buffer, bytes_read)) +} + pub async fn is_someone_listening_on_that_tcp_port(port: u16, timeout: tokio::time::Duration) -> bool { match tokio::time::timeout(timeout, TcpStream::connect(&format!("127.0.0.1:{}", port))).await { Ok(Ok(_)) => true, // Connection successful @@ -177,4 +199,4 @@ pub async fn execute_command(mut cmd: Command, timeout_secs: u64, cmd_str: &str) ).await .map_err(|_| format!("command '{cmd_str}' timed out after {timeout_secs} seconds"))? .map_err(|e| format!("command '{cmd_str}' failed to execute: {e}")) -} \ No newline at end of file +} From b0b28b01bb7ae74292023b09e6a11d0b44c41f0b Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Thu, 24 Apr 2025 22:48:58 +0200 Subject: [PATCH 03/15] fix: add timeout to requests, init and stop --- .../engine/src/integrations/integr_mcp.rs | 63 ++++++++++++++----- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index fa4d27e16..677b2622d 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -9,8 +9,10 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex as AMutex; use tokio::sync::RwLock as ARwLock; -use rmcp::{RoleClient, ServiceExt, service::RunningService}; +use tokio::time::timeout; +use tokio::time::Duration; use tokio::task::{AbortHandle, JoinHandle}; +use rmcp::{RoleClient, ServiceExt, service::RunningService}; use rmcp::model::{CallToolRequestParam, Tool as McpTool}; use tempfile::NamedTempFile; @@ -23,6 +25,9 @@ use crate::integrations::integr_abstract::{IntegrationTrait, IntegrationCommon, use crate::integrations::sessions::IntegrationSession; use crate::integrations::process_io_utils::read_file_with_cursor; +const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(30); +const MCP_SERVER_INIT_TIMEOUT: Duration = Duration::from_secs(60); +const MCP_SERVER_STOP_TIMEOUT: Duration = Duration::from_secs(3); #[derive(Deserialize, Serialize, Clone, Default, PartialEq, Debug)] pub struct SettingsMCP { @@ -141,16 +146,21 @@ async fn _session_kill_process( }; if let Some(client) = client_to_cancel { - match client.cancel().await { - Ok(reason) => { + match timeout(MCP_SERVER_STOP_TIMEOUT, client.cancel()).await { + Ok(Ok(reason)) => { let success_msg = format!("MCP server stopped: {:?}", reason); tracing::info!("{} for {}", success_msg, debug_name); _add_log_entry(session_logs, success_msg).await; }, - Err(e) => { + Ok(Err(e)) => { let error_msg = format!("Failed to stop MCP: {:?}", e); tracing::error!("{} for {}", error_msg, debug_name); _add_log_entry(session_logs, error_msg).await; + }, + Err(_) => { + let error_msg = format!("MCP server stop operation timed out after {} seconds", MCP_SERVER_STOP_TIMEOUT.as_secs()); + tracing::error!("{} for {}", error_msg, debug_name); + _add_log_entry(session_logs, error_msg).await; } } } @@ -276,30 +286,42 @@ async fn _session_apply_settings( } }; - let client = match ().serve(transport).await { - Ok(c) => c, - Err(e) => { + let client = match timeout(MCP_SERVER_INIT_TIMEOUT, ().serve(transport)).await { + Ok(Ok(c)) => c, + Ok(Err(e)) => { let err_msg = format!("Failed to init server: {}", e); tracing::error!("{err_msg} for {debug_name}"); _add_log_entry(logs.clone(), err_msg).await; return; + }, + Err(_) => { + let err_msg = format!("Server initialization timed out after {} seconds", MCP_SERVER_INIT_TIMEOUT.as_secs()); + tracing::error!("{err_msg} for {debug_name}"); + _add_log_entry(logs.clone(), err_msg).await; + return; } }; tracing::info!("MCP START SESSION (2) {:?}", debug_name); _add_log_entry(logs.clone(), "Listing tools".to_string()).await; - let tools_result = match client.list_tools(None).await { - Ok(result) => { + let tools_result = match timeout(MCP_REQUEST_TIMEOUT, client.list_tools(None)).await { + Ok(Ok(result)) => { let success_msg = format!("Successfully listed {} tools", result.tools.len()); tracing::info!("{} for {}", success_msg, debug_name); result }, - Err(tools_error) => { + Ok(Err(tools_error)) => { let err_msg = format!("Failed to list tools: {:?}", tools_error); tracing::error!("{} for {}", err_msg, debug_name); _add_log_entry(logs.clone(), err_msg).await; return; + }, + Err(_) => { + let err_msg = format!("Request timed out after {} seconds", MCP_REQUEST_TIMEOUT.as_secs()); + tracing::error!("{} for {}", err_msg, debug_name); + _add_log_entry(logs.clone(), err_msg).await; + return; } }; @@ -456,13 +478,20 @@ impl Tool for ToolMCP { let result_probably = { let mcp_client_locked = self.mcp_client.lock().await; if let Some(client) = &*mcp_client_locked { - client.call_tool(CallToolRequestParam { - name: self.mcp_tool.name.clone(), - arguments: match json_args { - serde_json::Value::Object(map) => Some(map), - _ => None, - }, - }).await + match timeout(MCP_REQUEST_TIMEOUT, + client.call_tool(CallToolRequestParam { + name: self.mcp_tool.name.clone(), + arguments: match json_args { + serde_json::Value::Object(map) => Some(map), + _ => None, + }, + }) + ).await { + Ok(result) => result, + Err(_) => Err(rmcp::service::ServiceError::Timeout { + timeout: MCP_REQUEST_TIMEOUT + }), + } } else { return Err("MCP client is not available".to_string()); } From bc2145c79c4d8e5d9aa4c3ae8c058d62fad767c4 Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Wed, 30 Apr 2025 23:41:32 +0200 Subject: [PATCH 04/15] add remote sse mcp servers, configure url and headers, default headers --- .../engine/src/integrations/integr_mcp.rs | 274 +++++++++++------- .../engine/src/integrations/yaml_schema.rs | 2 +- 2 files changed, 171 insertions(+), 105 deletions(-) diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index 677b2622d..b787f3f1d 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -15,6 +15,7 @@ use tokio::task::{AbortHandle, JoinHandle}; use rmcp::{RoleClient, ServiceExt, service::RunningService}; use rmcp::model::{CallToolRequestParam, Tool as McpTool}; use tempfile::NamedTempFile; +use tracing::Level; use crate::custom_error::MapErrToString; use crate::global_context::GlobalContext; @@ -29,12 +30,22 @@ const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(30); const MCP_SERVER_INIT_TIMEOUT: Duration = Duration::from_secs(60); const MCP_SERVER_STOP_TIMEOUT: Duration = Duration::from_secs(3); -#[derive(Deserialize, Serialize, Clone, Default, PartialEq, Debug)] +#[derive(Deserialize, Serialize, Clone, PartialEq, Default, Debug)] pub struct SettingsMCP { - #[serde(rename = "command")] + #[serde(default = "default_server_transport")] + pub server_transport: String, + #[serde(rename = "command", default)] pub mcp_command: String, #[serde(default, rename = "env")] pub mcp_env: HashMap, + #[serde(default)] + pub url: String, + #[serde(default)] + pub headers: HashMap, +} + +fn default_server_transport() -> String { + "stdio".to_string() } pub struct ToolMCP { @@ -79,8 +90,8 @@ impl IntegrationSession for SessionMCP { let mut session_locked = self_arc.lock().await; let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); ( - session_downcasted.debug_name.clone(), - session_downcasted.mcp_client.clone(), + session_downcasted.debug_name.clone(), + session_downcasted.mcp_client.clone(), session_downcasted.logs.clone(), session_downcasted.startup_task_handles.clone(), session_downcasted.stderr_file_path.clone(), @@ -92,7 +103,7 @@ impl IntegrationSession for SessionMCP { abort_handle.abort(); } - if let Some(client) = client { + if let Some(client) = client { _session_kill_process(&debug_name, client, logs).await; } if let Some(stderr_file) = &stderr_file { @@ -109,7 +120,7 @@ impl IntegrationSession for SessionMCP { async fn _add_log_entry(session_logs: Arc>>, entry: String) { let timestamp = chrono::Local::now().format("%H:%M:%S%.3f").to_string(); let log_entry = format!("[{}] {}", timestamp, entry); - + let mut session_logs_locked = session_logs.lock().await; session_logs_locked.extend(log_entry.lines().into_iter().map(|s| s.to_string())); @@ -120,8 +131,8 @@ async fn _add_log_entry(session_logs: Arc>>, entry: String) { } pub async fn update_logs_from_stderr( - stderr_file_path: &PathBuf, - stderr_cursor: Arc>, + stderr_file_path: &PathBuf, + stderr_cursor: Arc>, session_logs: Arc>> ) -> Result<(), String> { let (buffer, bytes_read) = read_file_with_cursor(stderr_file_path, stderr_cursor.clone()).await @@ -133,18 +144,18 @@ pub async fn update_logs_from_stderr( } async fn _session_kill_process( - debug_name: &str, - mcp_client: Arc>>>, + debug_name: &str, + mcp_client: Arc>>>, session_logs: Arc>>, ) { tracing::info!("Stopping MCP Server for {}", debug_name); _add_log_entry(session_logs.clone(), "Stopping MCP Server".to_string()).await; - + let client_to_cancel = { let mut mcp_client_locked = mcp_client.lock().await; mcp_client_locked.take() }; - + if let Some(client) = client_to_cancel { match timeout(MCP_SERVER_STOP_TIMEOUT, client.cancel()).await { Ok(Ok(reason)) => { @@ -211,7 +222,7 @@ async fn _session_apply_settings( return; } } - + let startup_task_join_handle = tokio::spawn(async move { let (mcp_client, logs, debug_name, stderr_file) = { let mut session_locked = session_arc_clone.lock().await; @@ -225,110 +236,149 @@ async fn _session_apply_settings( std::mem::take(&mut mcp_session.stderr_file_path), ) }; - - _add_log_entry(logs.clone(), "Applying new settings".to_string()).await; + + let log = async |level: Level, msg: String| { + match level { + Level::ERROR => tracing::error!("{msg} for {debug_name}"), + Level::WARN => tracing::warn!("{msg} for {debug_name}"), + _ => tracing::info!("{msg} for {debug_name}"), + } + _add_log_entry(logs.clone(), msg).await; + }; + + log(Level::INFO, "Applying new settings".to_string()).await; if let Some(mcp_client) = mcp_client { _session_kill_process(&debug_name, mcp_client, logs.clone()).await; } if let Some(stderr_file) = &stderr_file { if let Err(e) = tokio::fs::remove_file(stderr_file).await { - tracing::error!("Failed to remove {}: {}", stderr_file.to_string_lossy(), e); + log(Level::ERROR, format!("Failed to remove {}: {}", stderr_file.to_string_lossy(), e)).await; } } - let parsed_args = match shell_words::split(&new_cfg_clone.mcp_command) { - Ok(args) => { - if args.is_empty() { - let error_msg = "Empty command".to_string(); - tracing::info!("{error_msg} for {debug_name}"); - _add_log_entry(logs.clone(), error_msg).await; - return; + let client = match new_cfg_clone.server_transport.to_lowercase().trim() { + "stdio" => { + let parsed_args = match shell_words::split(&new_cfg_clone.mcp_command) { + Ok(args) => { + if args.is_empty() { + log(Level::ERROR, "Empty command".to_string()).await; + return; + } + args + } + Err(e) => { + log(Level::ERROR, format!("Failed to parse command: {}", e)).await; + return; + } + }; + + let mut command = tokio::process::Command::new(&parsed_args[0]); + command.args(&parsed_args[1..]); + for (key, value) in &new_cfg_clone.mcp_env { + command.env(key, value); } - args - } - Err(e) => { - let error_msg = format!("Failed to parse command: {}", e); - tracing::info!("{error_msg} for {debug_name}"); - _add_log_entry(logs.clone(), error_msg).await; - return; - } - }; - let mut command = tokio::process::Command::new(&parsed_args[0]); - command.args(&parsed_args[1..]); - for (key, value) in &new_cfg_clone.mcp_env { - command.env(key, value); - } + match NamedTempFile::new().map(|f| f.keep()) { + Ok(Ok((file, path))) => { + { + let mut session_locked = session_arc_clone.lock().await; + let mcp_session = session_locked.as_any_mut().downcast_mut::().unwrap(); + + mcp_session.stderr_file_path = Some(path.clone()); + mcp_session.stderr_cursor = Arc::new(AMutex::new(0)); + } + command.stderr(Stdio::from(file)); + }, + Ok(Err(e)) => tracing::error!("Failed to persist stderr file for {debug_name}: {e}"), + Err(e) => tracing::error!("Failed to create stderr file for {debug_name}: {e}"), + } - match NamedTempFile::new().map(|f| f.keep()) { - Ok(Ok((file, path))) => { - { - let mut session_locked = session_arc_clone.lock().await; - let mcp_session = session_locked.as_any_mut().downcast_mut::().unwrap(); - - mcp_session.stderr_file_path = Some(path.clone()); - mcp_session.stderr_cursor = Arc::new(AMutex::new(0)); + let transport = match rmcp::transport::TokioChildProcess::new(&mut command) { + Ok(t) => t, + Err(e) => { + log(Level::ERROR, format!("Failed to init Tokio child process: {}", e)).await; + return; + } + }; + match timeout(MCP_SERVER_INIT_TIMEOUT, ().serve(transport)).await { + Ok(Ok(client)) => client, + Ok(Err(e)) => { + log(Level::ERROR, format!("Failed to init SSE server: {}", e)).await; + return; + }, + Err(_) => { + log(Level::ERROR, format!("Request timed out after {} seconds", MCP_SERVER_INIT_TIMEOUT.as_secs())).await; + return; + } } - command.stderr(Stdio::from(file)); }, - Ok(Err(e)) => tracing::error!("Failed to persist stderr file for {debug_name}: {e}"), - Err(e) => tracing::error!("Failed to create stderr file for {debug_name}: {e}"), - } + "sse" => { + if new_cfg_clone.url.is_empty() { + log(Level::ERROR, "URL is required for MCP with SSE transport".to_string()).await; + return; + } - let transport = match rmcp::transport::TokioChildProcess::new(&mut command) { - Ok(t) => t, - Err(e) => { - let err_msg = format!("Failed to init process: {}", e); - tracing::error!("{err_msg} for {debug_name}"); - _add_log_entry(logs.clone(), err_msg).await; - return; + let mut header_map = reqwest::header::HeaderMap::new(); + for (k, v) in &new_cfg_clone.headers { + match (reqwest::header::HeaderName::from_bytes(k.as_bytes()), + reqwest::header::HeaderValue::from_str(v), + ) { + (Ok(name), Ok(value)) => { + header_map.insert(name, value); + } + _ => log(Level::WARN, format!("Invalid header: {}: {}", k, v)).await, + } + } + let client = match reqwest::Client::builder().default_headers(header_map).build() { + Ok(c) => c, + Err(e) => { + log(Level::ERROR, format!("Failed to build reqwest client: {}", e)).await; + return; + } + }; + let transport = match rmcp::transport::SseTransport::start_with_client(&new_cfg_clone.url, client).await { + Ok(t) => t, + Err(e) => { + log(Level::ERROR, format!("Failed to init SSE transport: {}", e)).await; + return; + } + }; + match timeout(MCP_SERVER_INIT_TIMEOUT, ().serve(transport)).await { + Ok(Ok(client)) => client, + Ok(Err(e)) => { + log(Level::ERROR, format!("Failed to init SSE server: {}", e)).await; + return; + }, + Err(_) => { + log(Level::ERROR, format!("Request timed out after {} seconds", MCP_SERVER_INIT_TIMEOUT.as_secs())).await; + return; + } + } } - }; - - let client = match timeout(MCP_SERVER_INIT_TIMEOUT, ().serve(transport)).await { - Ok(Ok(c)) => c, - Ok(Err(e)) => { - let err_msg = format!("Failed to init server: {}", e); - tracing::error!("{err_msg} for {debug_name}"); - _add_log_entry(logs.clone(), err_msg).await; - return; - }, - Err(_) => { - let err_msg = format!("Server initialization timed out after {} seconds", MCP_SERVER_INIT_TIMEOUT.as_secs()); - tracing::error!("{err_msg} for {debug_name}"); - _add_log_entry(logs.clone(), err_msg).await; + _ => { + log(Level::ERROR, format!("Unsupported server transport: {}", new_cfg_clone.server_transport)).await; return; } }; - tracing::info!("MCP START SESSION (2) {:?}", debug_name); - _add_log_entry(logs.clone(), "Listing tools".to_string()).await; - + log(Level::INFO, "Listing tools".to_string()).await; + let tools_result = match timeout(MCP_REQUEST_TIMEOUT, client.list_tools(None)).await { - Ok(Ok(result)) => { - let success_msg = format!("Successfully listed {} tools", result.tools.len()); - tracing::info!("{} for {}", success_msg, debug_name); - result - }, + Ok(Ok(result)) => result, Ok(Err(tools_error)) => { - let err_msg = format!("Failed to list tools: {:?}", tools_error); - tracing::error!("{} for {}", err_msg, debug_name); - _add_log_entry(logs.clone(), err_msg).await; + log(Level::ERROR, format!("Failed to list tools: {:?}", tools_error)).await; return; }, Err(_) => { - let err_msg = format!("Request timed out after {} seconds", MCP_REQUEST_TIMEOUT.as_secs()); - tracing::error!("{} for {}", err_msg, debug_name); - _add_log_entry(logs.clone(), err_msg).await; + log(Level::ERROR, format!("Request timed out after {} seconds", MCP_REQUEST_TIMEOUT.as_secs())).await; return; } }; - + let new_mcp_client = Arc::new(AMutex::new(Some(client))); - + let tools_len = { - tracing::info!("MCP START SESSION (3) {:?}", debug_name); let mut session_locked = session_arc_clone.lock().await; let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); @@ -337,12 +387,10 @@ async fn _session_apply_settings( session_downcasted.mcp_tools.len() }; - - let setup_msg = format!("MCP session setup complete with {tools_len} tools"); - tracing::info!("{} for {}", setup_msg, debug_name); - _add_log_entry(logs.clone(), setup_msg).await; + + log(Level::INFO, format!("MCP session setup complete with {tools_len} tools")).await; }); - + let startup_task_abort_handle = startup_task_join_handle.abort_handle(); session_downcasted.startup_task_handles = Some( (Arc::new(AMutex::new(Some(startup_task_join_handle))), startup_task_abort_handle) @@ -358,7 +406,7 @@ async fn _session_wait_startup_task( let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); session_downcasted.startup_task_handles.clone() }; - + if let Some((join_handler_arc, _)) = startup_task_handles { let mut join_handler_locked = join_handler_arc.lock().await; if let Some(join_handler) = join_handler_locked.take() { @@ -392,7 +440,7 @@ impl IntegrationTrait for IntegrationMCP { async fn integr_tools(&self, _integr_name: &str) -> Vec> { let session_key = format!("{}", self.config_path); - + let gcx = match self.gcx_option.clone() { Some(gcx_weak) => match gcx_weak.upgrade() { Some(gcx) => gcx, @@ -406,7 +454,7 @@ impl IntegrationTrait for IntegrationMCP { return vec![]; } }; - + let session_maybe = gcx.read().await.integration_sessions.get(&session_key).cloned(); let session = match session_maybe { Some(session) => session, @@ -472,7 +520,7 @@ impl Tool for ToolMCP { let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); session_downcasted.logs.clone() }; - + _add_log_entry(session_logs.clone(), format!("Executing tool '{}' with arguments: {:?}", self.mcp_tool.name, json_args)).await; let result_probably = { @@ -488,15 +536,15 @@ impl Tool for ToolMCP { }) ).await { Ok(result) => result, - Err(_) => Err(rmcp::service::ServiceError::Timeout { - timeout: MCP_REQUEST_TIMEOUT + Err(_) => Err(rmcp::service::ServiceError::Timeout { + timeout: MCP_REQUEST_TIMEOUT }), } } else { return Err("MCP client is not available".to_string()); } }; - + let tool_output = match result_probably { Ok(result) => { if result.is_error.unwrap_or(false) { @@ -504,7 +552,7 @@ impl Tool for ToolMCP { _add_log_entry(session_logs.clone(), error_msg.clone()).await; return Err(error_msg); } - + if let Some(content) = result.content.get(0) { if let rmcp::model::RawContent::Text(text_content) = &content.raw { let text = text_content.text.clone(); @@ -630,15 +678,33 @@ impl Tool for ToolMCP { pub const MCP_INTEGRATION_SCHEMA: &str = r#" fields: + server_transport: + f_type: enum + f_enum_values: ["stdio", "sse"] + f_default: "stdio" + f_desc: "The transport protocol to use. 'stdio' for local processes, 'sse' for remote servers using Server-Sent Events." command: f_type: string - f_desc: "The MCP command to execute, like `npx -y `, `/my/path/venv/python -m `, or `docker run -i --rm `. On Windows, use `npx.cmd` or `npm.cmd` instead of `npx` or `npm`." + f_desc: "The MCP command to execute (for stdio transport), like `npx -y `, `/my/path/venv/python -m `, or `docker run -i --rm `. On Windows, use `npx.cmd` or `npm.cmd` instead of `npx` or `npm`." env: f_type: string_to_string_map + f_desc: "Environment variables to pass to the MCP command (for stdio transport)." + url: + f_type: string + f_desc: "The URL of the MCP server (for sse transport), e.g., 'https://api.example.com/mcp/sse'." + headers: + f_type: string_to_string_map + f_desc: "HTTP headers to include in requests to the MCP server (for sse transport)." + f_default: + User-Agent: "Refact.ai (+https://github.com/smallcloudai/refact)" + Accept: text/event-stream + Content-Type: application/json description: | - You can add almost any MCP (Model Context Protocol) server here! This supports local MCP servers, - with remote servers coming up as the specificion gets updated. You can read more - here https://www.anthropic.com/news/model-context-protocol + You can add almost any MCP (Model Context Protocol) server here! This supports both local MCP servers (stdio) + and remote MCP servers (sse). You can read more about MCP here: https://www.anthropic.com/news/model-context-protocol + + For local servers, use server_transport="stdio" and provide the command to execute. + For remote servers, use server_transport="sse" and provide the URL of the server. available: on_your_laptop_possible: true when_isolated_possible: true diff --git a/refact-agent/engine/src/integrations/yaml_schema.rs b/refact-agent/engine/src/integrations/yaml_schema.rs index 2c6a93679..f2d9a6aa1 100644 --- a/refact-agent/engine/src/integrations/yaml_schema.rs +++ b/refact-agent/engine/src/integrations/yaml_schema.rs @@ -18,7 +18,7 @@ pub struct ISchemaField { #[serde(default, skip_serializing_if="is_default")] pub f_desc: String, #[serde(default, skip_serializing_if="is_default")] - pub f_default: String, + pub f_default: serde_json::Value, #[serde(default, skip_serializing_if="is_default")] pub f_placeholder: String, #[serde(default, skip_serializing_if="is_default")] From 6e8ff00872e8c4fc73e8cb5eb9bec3042b84168c Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Thu, 1 May 2025 16:36:25 +0200 Subject: [PATCH 05/15] fix: list tools with pagination if needed --- .../engine/src/integrations/integr_mcp.rs | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index b787f3f1d..ac3a8ace9 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -6,6 +6,7 @@ use std::sync::Weak; use std::future::Future; use std::process::Stdio; use async_trait::async_trait; +use rmcp::model::PaginatedRequestParamInner; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex as AMutex; use tokio::sync::RwLock as ARwLock; @@ -364,26 +365,34 @@ async fn _session_apply_settings( log(Level::INFO, "Listing tools".to_string()).await; - let tools_result = match timeout(MCP_REQUEST_TIMEOUT, client.list_tools(None)).await { - Ok(Ok(result)) => result, - Ok(Err(tools_error)) => { - log(Level::ERROR, format!("Failed to list tools: {:?}", tools_error)).await; - return; - }, - Err(_) => { - log(Level::ERROR, format!("Request timed out after {} seconds", MCP_REQUEST_TIMEOUT.as_secs())).await; - return; - } - }; - - let new_mcp_client = Arc::new(AMutex::new(Some(client))); + // List tools, with pagination if needed + let mut all_tools = Vec::new(); + let mut cursor = None; + loop { + let tools_result = match timeout(MCP_REQUEST_TIMEOUT, + client.list_tools(Some(PaginatedRequestParamInner { cursor: cursor.clone() })) + ).await { + Ok(Ok(result)) => result, + Ok(Err(tools_error)) => { + log(Level::ERROR, format!("Failed to list tools: {:?}", tools_error)).await; + return; + }, + Err(_) => { + log(Level::ERROR, format!("Request timed out after {} seconds", MCP_REQUEST_TIMEOUT.as_secs())).await; + return; + } + }; + all_tools.extend(tools_result.tools); + cursor = tools_result.next_cursor; + if cursor.is_none() { break; } + } let tools_len = { let mut session_locked = session_arc_clone.lock().await; let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); - session_downcasted.mcp_client = Some(new_mcp_client); - session_downcasted.mcp_tools = tools_result.tools; + session_downcasted.mcp_client = Some(Arc::new(AMutex::new(Some(client)))); + session_downcasted.mcp_tools = all_tools; session_downcasted.mcp_tools.len() }; From 739a16ab29ca30ab9ac68b515f527e99aac016f6 Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Thu, 1 May 2025 23:44:57 +0200 Subject: [PATCH 06/15] use list all tools instead of manual pagination in mcp --- .../engine/src/integrations/integr_mcp.rs | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index ac3a8ace9..0c37adbab 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -305,7 +305,7 @@ async fn _session_apply_settings( match timeout(MCP_SERVER_INIT_TIMEOUT, ().serve(transport)).await { Ok(Ok(client)) => client, Ok(Err(e)) => { - log(Level::ERROR, format!("Failed to init SSE server: {}", e)).await; + log(Level::ERROR, format!("Failed to init stdio server: {}", e)).await; return; }, Err(_) => { @@ -365,34 +365,25 @@ async fn _session_apply_settings( log(Level::INFO, "Listing tools".to_string()).await; - // List tools, with pagination if needed - let mut all_tools = Vec::new(); - let mut cursor = None; - loop { - let tools_result = match timeout(MCP_REQUEST_TIMEOUT, - client.list_tools(Some(PaginatedRequestParamInner { cursor: cursor.clone() })) - ).await { - Ok(Ok(result)) => result, - Ok(Err(tools_error)) => { - log(Level::ERROR, format!("Failed to list tools: {:?}", tools_error)).await; - return; - }, - Err(_) => { - log(Level::ERROR, format!("Request timed out after {} seconds", MCP_REQUEST_TIMEOUT.as_secs())).await; - return; - } - }; - all_tools.extend(tools_result.tools); - cursor = tools_result.next_cursor; - if cursor.is_none() { break; } - } + let tools = match timeout(MCP_REQUEST_TIMEOUT, client.list_all_tools()).await { + Ok(Ok(result)) => result, + Ok(Err(tools_error)) => { + log(Level::ERROR, format!("Failed to list tools: {:?}", tools_error)).await; + return; + }, + Err(_) => { + log(Level::ERROR, format!("Request timed out after {} seconds", MCP_REQUEST_TIMEOUT.as_secs())).await; + return; + } + }; + let tools_len = tools.len(); - let tools_len = { + { let mut session_locked = session_arc_clone.lock().await; let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); session_downcasted.mcp_client = Some(Arc::new(AMutex::new(Some(client)))); - session_downcasted.mcp_tools = all_tools; + session_downcasted.mcp_tools = tools; session_downcasted.mcp_tools.len() }; From ed80e030e2dfcfa2246be76fd63de396ce3a27a1 Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Tue, 6 May 2025 09:51:51 +0200 Subject: [PATCH 07/15] use for of rust sdk for rmcp, that handles killing zombie processes --- refact-agent/engine/Cargo.toml | 5 ++++- .../engine/src/integrations/integr_mcp.rs | 20 +++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/refact-agent/engine/Cargo.toml b/refact-agent/engine/Cargo.toml index 7cc44d7f2..cb1dcb563 100644 --- a/refact-agent/engine/Cargo.toml +++ b/refact-agent/engine/Cargo.toml @@ -54,7 +54,6 @@ regex = "1.9.5" reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls-webpki-roots", "charset", "http2"] } reqwest-eventsource = "0.6.0" resvg = "0.44.0" -rmcp = { "version" = "0.1.5", features = ["client", "transport-child-process", "transport-sse"] } ropey = "1.6" rusqlite = { version = "0.31.0", features = ["bundled"] } rust-embed = "8.5.0" @@ -98,3 +97,7 @@ uuid = { version = "1", features = ["v4", "serde"] } walkdir = "2.3" which = "7.0.1" zerocopy = "0.8.14" + +# There you can use a local copy +# rmcp = { path = "../../../rust-sdk/crates/rmcp/", "features" = ["client", "transport-child-process", "transport-sse"] } +rmcp = { git = "https://github.com/smallcloudai/rust-sdk", branch = "cleanup-zombie-processes-for-child-process-client", features = ["client", "transport-child-process", "transport-sse"] } \ No newline at end of file diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index 0c37adbab..b091a7ee7 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -6,7 +6,8 @@ use std::sync::Weak; use std::future::Future; use std::process::Stdio; use async_trait::async_trait; -use rmcp::model::PaginatedRequestParamInner; +use rmcp::transport::sse::ReqwestSseClient; +use rmcp::transport::SseTransport; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex as AMutex; use tokio::sync::RwLock as ARwLock; @@ -295,7 +296,7 @@ async fn _session_apply_settings( Err(e) => tracing::error!("Failed to create stderr file for {debug_name}: {e}"), } - let transport = match rmcp::transport::TokioChildProcess::new(&mut command) { + let transport = match rmcp::transport::TokioChildProcess::new(command) { Ok(t) => t, Err(e) => { log(Level::ERROR, format!("Failed to init Tokio child process: {}", e)).await; @@ -331,14 +332,21 @@ async fn _session_apply_settings( _ => log(Level::WARN, format!("Invalid header: {}: {}", k, v)).await, } } - let client = match reqwest::Client::builder().default_headers(header_map).build() { - Ok(c) => c, + let reqwest_client = match reqwest::Client::builder().default_headers(header_map).build() { + Ok(reqwest_client) => reqwest_client, Err(e) => { log(Level::ERROR, format!("Failed to build reqwest client: {}", e)).await; return; } }; - let transport = match rmcp::transport::SseTransport::start_with_client(&new_cfg_clone.url, client).await { + let sse_client = match ReqwestSseClient::new_with_client(&new_cfg_clone.url, reqwest_client).await { + Ok(sse_client) => sse_client, + Err(e) => { + log(Level::ERROR, format!("Failed to init SSE client: {}", e)).await; + return; + }, + }; + let transport = match SseTransport::start_with_client(sse_client).await { Ok(t) => t, Err(e) => { log(Level::ERROR, format!("Failed to init SSE transport: {}", e)).await; @@ -639,7 +647,7 @@ impl Tool for ToolMCP { name: self.tool_name(), agentic: true, experimental: false, - description: self.mcp_tool.description.to_string(), + description: self.mcp_tool.description.to_owned().unwrap_or_default().to_string(), parameters, parameters_required, } From bcf780693bcbe4d3a607a09e2e429f006fa890c8 Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Tue, 6 May 2025 19:58:20 +0200 Subject: [PATCH 08/15] don't configure server_transport, make it implicit from if url or command is specified, to match config with claude desktop --- .../engine/src/integrations/integr_mcp.rs | 134 ++++++++---------- 1 file changed, 61 insertions(+), 73 deletions(-) diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index b091a7ee7..81819034a 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -34,20 +34,14 @@ const MCP_SERVER_STOP_TIMEOUT: Duration = Duration::from_secs(3); #[derive(Deserialize, Serialize, Clone, PartialEq, Default, Debug)] pub struct SettingsMCP { - #[serde(default = "default_server_transport")] - pub server_transport: String, #[serde(rename = "command", default)] pub mcp_command: String, #[serde(default, rename = "env")] pub mcp_env: HashMap, - #[serde(default)] - pub url: String, - #[serde(default)] - pub headers: HashMap, -} - -fn default_server_transport() -> String { - "stdio".to_string() + #[serde(default, rename = "url")] + pub mcp_url: String, + #[serde(default, rename = "headers")] + pub mcp_headers: HashMap, } pub struct ToolMCP { @@ -259,9 +253,58 @@ async fn _session_apply_settings( } } - let client = match new_cfg_clone.server_transport.to_lowercase().trim() { - "stdio" => { - let parsed_args = match shell_words::split(&new_cfg_clone.mcp_command) { + let client = match (new_cfg_clone.mcp_url.trim(), new_cfg_clone.mcp_command.trim()) { + ("", "") => { + log(Level::ERROR, "Url and command are both empty, set up either url for sse protocol, or command for stdio protocol".to_string()).await; + return; + }, + (url, "") => { + let mut header_map = reqwest::header::HeaderMap::new(); + for (k, v) in &new_cfg_clone.mcp_headers { + match (reqwest::header::HeaderName::from_bytes(k.as_bytes()), + reqwest::header::HeaderValue::from_str(v), + ) { + (Ok(name), Ok(value)) => { + header_map.insert(name, value); + } + _ => log(Level::WARN, format!("Invalid header: {}: {}", k, v)).await, + } + } + let reqwest_client = match reqwest::Client::builder().default_headers(header_map).build() { + Ok(reqwest_client) => reqwest_client, + Err(e) => { + log(Level::ERROR, format!("Failed to build reqwest client: {}", e)).await; + return; + } + }; + let sse_client = match ReqwestSseClient::new_with_client(url, reqwest_client).await { + Ok(sse_client) => sse_client, + Err(e) => { + log(Level::ERROR, format!("Failed to init SSE client: {}", e)).await; + return; + }, + }; + let transport = match SseTransport::start_with_client(sse_client).await { + Ok(t) => t, + Err(e) => { + log(Level::ERROR, format!("Failed to init SSE transport: {}", e)).await; + return; + } + }; + match timeout(MCP_SERVER_INIT_TIMEOUT, ().serve(transport)).await { + Ok(Ok(client)) => client, + Ok(Err(e)) => { + log(Level::ERROR, format!("Failed to init SSE server: {}", e)).await; + return; + }, + Err(_) => { + log(Level::ERROR, format!("Request timed out after {} seconds", MCP_SERVER_INIT_TIMEOUT.as_secs())).await; + return; + } + } + }, + ("", command) => { + let parsed_args = match shell_words::split(&command) { Ok(args) => { if args.is_empty() { log(Level::ERROR, "Empty command".to_string()).await; @@ -315,60 +358,10 @@ async fn _session_apply_settings( } } }, - "sse" => { - if new_cfg_clone.url.is_empty() { - log(Level::ERROR, "URL is required for MCP with SSE transport".to_string()).await; - return; - } - - let mut header_map = reqwest::header::HeaderMap::new(); - for (k, v) in &new_cfg_clone.headers { - match (reqwest::header::HeaderName::from_bytes(k.as_bytes()), - reqwest::header::HeaderValue::from_str(v), - ) { - (Ok(name), Ok(value)) => { - header_map.insert(name, value); - } - _ => log(Level::WARN, format!("Invalid header: {}: {}", k, v)).await, - } - } - let reqwest_client = match reqwest::Client::builder().default_headers(header_map).build() { - Ok(reqwest_client) => reqwest_client, - Err(e) => { - log(Level::ERROR, format!("Failed to build reqwest client: {}", e)).await; - return; - } - }; - let sse_client = match ReqwestSseClient::new_with_client(&new_cfg_clone.url, reqwest_client).await { - Ok(sse_client) => sse_client, - Err(e) => { - log(Level::ERROR, format!("Failed to init SSE client: {}", e)).await; - return; - }, - }; - let transport = match SseTransport::start_with_client(sse_client).await { - Ok(t) => t, - Err(e) => { - log(Level::ERROR, format!("Failed to init SSE transport: {}", e)).await; - return; - } - }; - match timeout(MCP_SERVER_INIT_TIMEOUT, ().serve(transport)).await { - Ok(Ok(client)) => client, - Ok(Err(e)) => { - log(Level::ERROR, format!("Failed to init SSE server: {}", e)).await; - return; - }, - Err(_) => { - log(Level::ERROR, format!("Request timed out after {} seconds", MCP_SERVER_INIT_TIMEOUT.as_secs())).await; - return; - } - } - } - _ => { - log(Level::ERROR, format!("Unsupported server transport: {}", new_cfg_clone.server_transport)).await; + (_url, _command) => { + log(Level::ERROR, "Url and command cannot be specified at the same time, set up either url for sse protocol, or command for stdio protocol".to_string()).await; return; - } + }, }; log(Level::INFO, "Listing tools".to_string()).await; @@ -686,11 +679,6 @@ impl Tool for ToolMCP { pub const MCP_INTEGRATION_SCHEMA: &str = r#" fields: - server_transport: - f_type: enum - f_enum_values: ["stdio", "sse"] - f_default: "stdio" - f_desc: "The transport protocol to use. 'stdio' for local processes, 'sse' for remote servers using Server-Sent Events." command: f_type: string f_desc: "The MCP command to execute (for stdio transport), like `npx -y `, `/my/path/venv/python -m `, or `docker run -i --rm `. On Windows, use `npx.cmd` or `npm.cmd` instead of `npx` or `npm`." @@ -711,8 +699,8 @@ description: | You can add almost any MCP (Model Context Protocol) server here! This supports both local MCP servers (stdio) and remote MCP servers (sse). You can read more about MCP here: https://www.anthropic.com/news/model-context-protocol - For local servers, use server_transport="stdio" and provide the command to execute. - For remote servers, use server_transport="sse" and provide the URL of the server. + For servers using stdio protocol, provide the command to execute, and optionally, set the environment variables. + For remote using sse protocol, provide the URL of the server, and optionally, add more headers. available: on_your_laptop_possible: true when_isolated_possible: true From 6f1287f3a8706e81fa8bf1f47b1c17f6b29c9201 Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Wed, 7 May 2025 12:46:09 +0200 Subject: [PATCH 09/15] feat: add images support to mcp --- .../engine/src/integrations/integr_mcp.rs | 100 +++++++++++++----- 1 file changed, 73 insertions(+), 27 deletions(-) diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index 81819034a..e975946d2 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -6,6 +6,7 @@ use std::sync::Weak; use std::future::Future; use std::process::Stdio; use async_trait::async_trait; +use rmcp::model::RawContent; use rmcp::transport::sse::ReqwestSseClient; use rmcp::transport::SseTransport; use serde::{Deserialize, Serialize}; @@ -19,9 +20,11 @@ use rmcp::model::{CallToolRequestParam, Tool as McpTool}; use tempfile::NamedTempFile; use tracing::Level; +use crate::caps::resolve_chat_model; use crate::custom_error::MapErrToString; use crate::global_context::GlobalContext; use crate::at_commands::at_commands::AtCommandsContext; +use crate::scratchpads::multimodality::MultimodalElement; use crate::tools::tools_description::{Tool, ToolDesc, ToolParam}; use crate::call_validation::{ChatMessage, ChatContent, ContextEnum}; use crate::integrations::integr_abstract::{IntegrationTrait, IntegrationCommon, IntegrationConfirmation}; @@ -504,13 +507,22 @@ impl Tool for ToolMCP { args: &HashMap, ) -> Result<(bool, Vec), String> { let session_key = format!("{}", self.config_path); - let gcx = ccx.lock().await.global_context.clone(); - let session_option = gcx.read().await.integration_sessions.get(&session_key).cloned(); - if session_option.is_none() { + let (gcx, current_model) = { + let ccx_locked = ccx.lock().await; + (ccx_locked.global_context.clone(), ccx_locked.current_model.clone()) + }; + let (session_maybe, caps_maybe) = { + let gcx_locked = gcx.read().await; + (gcx_locked.integration_sessions.get(&session_key).cloned(), gcx_locked.caps.clone()) + }; + if session_maybe.is_none() { tracing::error!("No session for {:?}, strange (2)", session_key); return Err(format!("No session for {:?}", session_key)); } - let session = session_option.unwrap(); + let session = session_maybe.unwrap(); + let model_supports_multimodality = caps_maybe.is_some_and(|caps| { + resolve_chat_model(caps, ¤t_model).is_ok_and(|m| m.supports_multimodality) + }); _session_wait_startup_task(session.clone()).await; let json_args = serde_json::json!(args); @@ -546,7 +558,7 @@ impl Tool for ToolMCP { } }; - let tool_output = match result_probably { + let result_message = match result_probably { Ok(result) => { if result.is_error.unwrap_or(false) { let error_msg = format!("Tool execution error: {:?}", result.content); @@ -554,21 +566,63 @@ impl Tool for ToolMCP { return Err(error_msg); } - if let Some(content) = result.content.get(0) { - if let rmcp::model::RawContent::Text(text_content) = &content.raw { - let text = text_content.text.clone(); - let success_msg = format!("Tool '{}' executed successfully", self.mcp_tool.name); - _add_log_entry(session_logs.clone(), success_msg).await; - text - } else { - let error_msg = format!("Unexpected tool output format: {:?}", result.content); - tracing::error!("{}", error_msg); - _add_log_entry(session_logs.clone(), error_msg.clone()).await; - return Err("Unexpected tool output format".to_string()); + let mut elements = Vec::new(); + for content in result.content { + match content.raw { + RawContent::Text(text_content) => { + elements.push(MultimodalElement { + m_type: "text".to_string(), + m_content: text_content.text, + }) + } + RawContent::Image(image_content) => { + if model_supports_multimodality { + let mime_type = if image_content.mime_type.starts_with("image/") { + image_content.mime_type + } else { + format!("image/{}", image_content.mime_type) + }; + elements.push(MultimodalElement { + m_type: mime_type, + m_content: image_content.data, + }) + } else { + elements.push(MultimodalElement { + m_type: "text".to_string(), + m_content: "Server returned an image, but model does not support multimodality".to_string(), + }) + } + }, + RawContent::Audio(_) => { + elements.push(MultimodalElement { + m_type: "text".to_string(), + m_content: "Server returned audio, which is not supported".to_string(), + }) + }, + RawContent::Resource(_) => { + elements.push(MultimodalElement { + m_type: "text".to_string(), + m_content: "Server returned resource, which is not supported".to_string(), + }) + }, } - } else { - String::new() } + + let content = if elements.iter().all(|el| el.m_type == "text") { + ChatContent::SimpleText( + elements.into_iter().map(|el| el.m_content).collect::>().join("\n\n") + ) + } else { + ChatContent::Multimodal(elements) + }; + + ContextEnum::ChatMessage(ChatMessage { + role: "tool".to_string(), + content, + tool_calls: None, + tool_call_id: tool_call_id.clone(), + ..Default::default() + }) } Err(e) => { let error_msg = format!("Failed to call tool: {:?}", e); @@ -578,15 +632,7 @@ impl Tool for ToolMCP { } }; - let result = vec![ContextEnum::ChatMessage(ChatMessage { - role: "tool".to_string(), - content: ChatContent::SimpleText(tool_output), - tool_calls: None, - tool_call_id: tool_call_id.clone(), - ..Default::default() - })]; - - Ok((false, result)) + Ok((false, vec![result_message])) } fn tool_depends_on(&self) -> Vec { From de84a1312ba21d8108836971dc7b9a2bedf263f6 Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Wed, 7 May 2025 22:34:54 +0200 Subject: [PATCH 10/15] add extra configurable fields for init and request timeout --- .../engine/src/integrations/integr_mcp.rs | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index e975946d2..7ab61c966 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -7,6 +7,7 @@ use std::future::Future; use std::process::Stdio; use async_trait::async_trait; use rmcp::model::RawContent; +use rmcp::serve_client; use rmcp::transport::sse::ReqwestSseClient; use rmcp::transport::SseTransport; use serde::{Deserialize, Serialize}; @@ -15,7 +16,7 @@ use tokio::sync::RwLock as ARwLock; use tokio::time::timeout; use tokio::time::Duration; use tokio::task::{AbortHandle, JoinHandle}; -use rmcp::{RoleClient, ServiceExt, service::RunningService}; +use rmcp::{RoleClient, service::RunningService}; use rmcp::model::{CallToolRequestParam, Tool as McpTool}; use tempfile::NamedTempFile; use tracing::Level; @@ -30,10 +31,7 @@ use crate::call_validation::{ChatMessage, ChatContent, ContextEnum}; use crate::integrations::integr_abstract::{IntegrationTrait, IntegrationCommon, IntegrationConfirmation}; use crate::integrations::sessions::IntegrationSession; use crate::integrations::process_io_utils::read_file_with_cursor; - -const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(30); -const MCP_SERVER_INIT_TIMEOUT: Duration = Duration::from_secs(60); -const MCP_SERVER_STOP_TIMEOUT: Duration = Duration::from_secs(3); +use crate::integrations::utils::{serialize_num_to_str, deserialize_str_to_num}; #[derive(Deserialize, Serialize, Clone, PartialEq, Default, Debug)] pub struct SettingsMCP { @@ -45,13 +43,21 @@ pub struct SettingsMCP { pub mcp_url: String, #[serde(default, rename = "headers")] pub mcp_headers: HashMap, + #[serde(default = "default_init_timeout", serialize_with = "serialize_num_to_str", deserialize_with = "deserialize_str_to_num")] + pub init_timeout: u64, + #[serde(default = "default_request_timeout", serialize_with = "serialize_num_to_str", deserialize_with = "deserialize_str_to_num")] + pub request_timeout: u64, } +fn default_init_timeout() -> u64 { 60 } +fn default_request_timeout() -> u64 { 30 } + pub struct ToolMCP { pub common: IntegrationCommon, pub config_path: String, pub mcp_client: Arc>>>, pub mcp_tool: McpTool, + pub request_timeout: u64, } #[derive(Default)] @@ -156,7 +162,7 @@ async fn _session_kill_process( }; if let Some(client) = client_to_cancel { - match timeout(MCP_SERVER_STOP_TIMEOUT, client.cancel()).await { + match timeout(Duration::from_secs(3), client.cancel()).await { Ok(Ok(reason)) => { let success_msg = format!("MCP server stopped: {:?}", reason); tracing::info!("{} for {}", success_msg, debug_name); @@ -168,7 +174,7 @@ async fn _session_kill_process( _add_log_entry(session_logs, error_msg).await; }, Err(_) => { - let error_msg = format!("MCP server stop operation timed out after {} seconds", MCP_SERVER_STOP_TIMEOUT.as_secs()); + let error_msg = "MCP server stop operation timed out after 3 seconds".to_string(); tracing::error!("{} for {}", error_msg, debug_name); _add_log_entry(session_logs, error_msg).await; } @@ -294,14 +300,14 @@ async fn _session_apply_settings( return; } }; - match timeout(MCP_SERVER_INIT_TIMEOUT, ().serve(transport)).await { + match timeout(Duration::from_secs(new_cfg_clone.init_timeout), serve_client((), transport)).await { Ok(Ok(client)) => client, Ok(Err(e)) => { log(Level::ERROR, format!("Failed to init SSE server: {}", e)).await; return; }, Err(_) => { - log(Level::ERROR, format!("Request timed out after {} seconds", MCP_SERVER_INIT_TIMEOUT.as_secs())).await; + log(Level::ERROR, format!("Request timed out after {} seconds", new_cfg_clone.init_timeout)).await; return; } } @@ -349,14 +355,14 @@ async fn _session_apply_settings( return; } }; - match timeout(MCP_SERVER_INIT_TIMEOUT, ().serve(transport)).await { + match timeout(Duration::from_secs(new_cfg_clone.init_timeout), serve_client((), transport)).await { Ok(Ok(client)) => client, Ok(Err(e)) => { log(Level::ERROR, format!("Failed to init stdio server: {}", e)).await; return; }, Err(_) => { - log(Level::ERROR, format!("Request timed out after {} seconds", MCP_SERVER_INIT_TIMEOUT.as_secs())).await; + log(Level::ERROR, format!("Request timed out after {} seconds", new_cfg_clone.init_timeout)).await; return; } } @@ -369,14 +375,14 @@ async fn _session_apply_settings( log(Level::INFO, "Listing tools".to_string()).await; - let tools = match timeout(MCP_REQUEST_TIMEOUT, client.list_all_tools()).await { + let tools = match timeout(Duration::from_secs(new_cfg_clone.request_timeout), client.list_all_tools()).await { Ok(Ok(result)) => result, Ok(Err(tools_error)) => { log(Level::ERROR, format!("Failed to list tools: {:?}", tools_error)).await; return; }, Err(_) => { - log(Level::ERROR, format!("Request timed out after {} seconds", MCP_REQUEST_TIMEOUT.as_secs())).await; + log(Level::ERROR, format!("Request timed out after {} seconds", new_cfg_clone.request_timeout)).await; return; } }; @@ -482,6 +488,7 @@ impl IntegrationTrait for IntegrationMCP { config_path: self.config_path.clone(), mcp_client: session_downcasted.mcp_client.clone().unwrap(), mcp_tool: tool.clone(), + request_timeout: self.cfg.request_timeout, })); } } @@ -539,7 +546,7 @@ impl Tool for ToolMCP { let result_probably = { let mcp_client_locked = self.mcp_client.lock().await; if let Some(client) = &*mcp_client_locked { - match timeout(MCP_REQUEST_TIMEOUT, + match timeout(Duration::from_secs(self.request_timeout), client.call_tool(CallToolRequestParam { name: self.mcp_tool.name.clone(), arguments: match json_args { @@ -549,9 +556,9 @@ impl Tool for ToolMCP { }) ).await { Ok(result) => result, - Err(_) => Err(rmcp::service::ServiceError::Timeout { - timeout: MCP_REQUEST_TIMEOUT - }), + Err(_) => {Err(rmcp::service::ServiceError::Timeout { + timeout: Duration::from_secs(self.request_timeout), + })}, } } else { return Err("MCP client is not available".to_string()); @@ -741,6 +748,16 @@ fields: User-Agent: "Refact.ai (+https://github.com/smallcloudai/refact)" Accept: text/event-stream Content-Type: application/json + init_timeout: + f_type: string_short + f_desc: "Timeout in seconds for MCP server initialization." + f_default: "60" + f_extra: true + request_timeout: + f_type: string_short + f_desc: "Timeout in seconds for MCP requests." + f_default: "30" + f_extra: true description: | You can add almost any MCP (Model Context Protocol) server here! This supports both local MCP servers (stdio) and remote MCP servers (sse). You can read more about MCP here: https://www.anthropic.com/news/model-context-protocol From 243f0614232abb044b75a2f42b8645015ab819a2 Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Wed, 7 May 2025 22:45:16 +0200 Subject: [PATCH 11/15] fix: add default headers if not set --- refact-agent/engine/src/integrations/integr_mcp.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs index 7ab61c966..36e2c4f35 100644 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ b/refact-agent/engine/src/integrations/integr_mcp.rs @@ -41,7 +41,7 @@ pub struct SettingsMCP { pub mcp_env: HashMap, #[serde(default, rename = "url")] pub mcp_url: String, - #[serde(default, rename = "headers")] + #[serde(default = "default_headers", rename = "headers")] pub mcp_headers: HashMap, #[serde(default = "default_init_timeout", serialize_with = "serialize_num_to_str", deserialize_with = "deserialize_str_to_num")] pub init_timeout: u64, @@ -51,6 +51,13 @@ pub struct SettingsMCP { fn default_init_timeout() -> u64 { 60 } fn default_request_timeout() -> u64 { 30 } +fn default_headers() -> HashMap { + HashMap::from([ + ("User-Agent".to_string(), "Refact.ai (+https://github.com/smallcloudai/refact)".to_string()), + ("Accept".to_string(), "text/event-stream".to_string()), + ("Content-Type".to_string(), "application/json".to_string()), + ]) +} pub struct ToolMCP { pub common: IntegrationCommon, From 64c7d2173ea4733e8dfd6a8f53d4b763ab6794de Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Thu, 8 May 2025 13:19:00 +0200 Subject: [PATCH 12/15] fix: load mcp integrations at startup to have servers ready for chats --- refact-agent/engine/src/files_in_workspace.rs | 6 ++++++ .../engine/src/integrations/setting_up_integrations.rs | 7 +++++++ refact-agent/engine/src/main.rs | 4 ++++ 3 files changed, 17 insertions(+) diff --git a/refact-agent/engine/src/files_in_workspace.rs b/refact-agent/engine/src/files_in_workspace.rs index 32b18f546..5dae623f6 100644 --- a/refact-agent/engine/src/files_in_workspace.rs +++ b/refact-agent/engine/src/files_in_workspace.rs @@ -16,6 +16,7 @@ use tracing::info; use crate::files_correction::{canonical_path, CommandSimplifiedDirExt}; use crate::git::operations::git_ls_files; use crate::global_context::GlobalContext; +use crate::integrations::running_integrations::load_integrations; use crate::telemetry; use crate::file_filter::{is_valid_file, SOURCE_FILE_EXTENSIONS}; use crate::ast::ast_indexer_thread::ast_indexer_enqueue_files; @@ -626,6 +627,8 @@ pub async fn on_workspaces_init(gcx: Arc>) -> i32 { // Called from lsp and lsp_like // Not called from main.rs as part of initialization + let allow_experimental = gcx.read().await.cmdline.experimental; + watcher_init(gcx.clone()).await; let files_enqueued = enqueue_all_files_from_workspace_folders(gcx.clone(), false, false).await; @@ -634,6 +637,9 @@ pub async fn on_workspaces_init(gcx: Arc>) -> i32 crate::git::checkpoints::init_shadow_repos_if_needed(gcx_clone).await; }); + // Start or connect to mcp servers + let _ = load_integrations(gcx.clone(), allow_experimental, &["**/mcp_*".to_string()]).await; + files_enqueued } diff --git a/refact-agent/engine/src/integrations/setting_up_integrations.rs b/refact-agent/engine/src/integrations/setting_up_integrations.rs index 501229b83..10e4d4a45 100644 --- a/refact-agent/engine/src/integrations/setting_up_integrations.rs +++ b/refact-agent/engine/src/integrations/setting_up_integrations.rs @@ -11,6 +11,7 @@ use tokio::io::AsyncWriteExt; use crate::custom_error::YamlError; use crate::global_context::GlobalContext; use crate::files_correction::any_glob_matches_path; +use crate::integrations::running_integrations::load_integrations; // use crate::tools::tools_description::Tool; // use crate::yaml_configs::create_configs::{integrations_enabled_cfg, read_yaml_into_value}; @@ -537,6 +538,7 @@ pub async fn integration_config_save( integr_config_path: &String, integr_values: &serde_json::Value, ) -> Result<(), String> { + let allow_experimental = gcx.read().await.cmdline.experimental; let config_path = crate::files_correction::canonical_path(integr_config_path); let (integr_name, _project_path) = crate::integrations::setting_up_integrations::split_path_into_project_and_integration(&config_path) .map_err(|e| format!("Failed to split path: {}", e))?; @@ -570,6 +572,11 @@ pub async fn integration_config_save( format!("Failed to write to {}: {}", config_path.display(), e) })?; + // If it is an mcp integration, ensure we restart or reconnect to the server + if config_path.file_name().and_then(|f| f.to_str()).is_some_and(|f| f.starts_with("mcp_")) { + let _ = load_integrations(gcx.clone(), allow_experimental, &["**/mcp_*".to_string()]).await; + } + Ok(()) } diff --git a/refact-agent/engine/src/main.rs b/refact-agent/engine/src/main.rs index 7e3da95af..34b00de8a 100644 --- a/refact-agent/engine/src/main.rs +++ b/refact-agent/engine/src/main.rs @@ -3,6 +3,7 @@ use std::env; use std::panic; use files_correction::canonical_path; +use integrations::running_integrations; use tokio::task::JoinHandle; use tracing::{info, Level}; use tracing_appender; @@ -181,6 +182,9 @@ async fn main() { crate::git::checkpoints::init_shadow_repos_if_needed(gcx_clone).await; }); + // Start or connect to mcp servers + let _ = running_integrations::load_integrations(gcx.clone(), cmdline.experimental, &["**/mcp_*".to_string()]).await; + // not really needed, but it's nice to have an error message sooner if there's one let _caps = crate::global_context::try_load_caps_quickly_if_not_present(gcx.clone(), 0).await; From 1b53b15d4e7e9a9c8ace1ad12bdbad1e8c305406 Mon Sep 17 00:00:00 2001 From: Humberto Yusta Date: Fri, 9 May 2025 13:20:45 +0200 Subject: [PATCH 13/15] refactor integr mcp to multiple files --- .../src/http/routers/v1/v1_integrations.rs | 4 +- .../engine/src/integrations/integr_mcp.rs | 788 ------------------ .../engine/src/integrations/mcp/integr_mcp.rs | 357 ++++++++ .../src/integrations/mcp/mcp_schema.yaml | 47 ++ .../engine/src/integrations/mcp/mod.rs | 7 + .../src/integrations/mcp/session_mcp.rs | 144 ++++ .../engine/src/integrations/mcp/tool_mcp.rs | 254 ++++++ refact-agent/engine/src/integrations/mod.rs | 4 +- 8 files changed, 813 insertions(+), 792 deletions(-) delete mode 100644 refact-agent/engine/src/integrations/integr_mcp.rs create mode 100644 refact-agent/engine/src/integrations/mcp/integr_mcp.rs create mode 100644 refact-agent/engine/src/integrations/mcp/mcp_schema.yaml create mode 100644 refact-agent/engine/src/integrations/mcp/mod.rs create mode 100644 refact-agent/engine/src/integrations/mcp/session_mcp.rs create mode 100644 refact-agent/engine/src/integrations/mcp/tool_mcp.rs diff --git a/refact-agent/engine/src/http/routers/v1/v1_integrations.rs b/refact-agent/engine/src/http/routers/v1/v1_integrations.rs index 608d57276..11634d17b 100644 --- a/refact-agent/engine/src/http/routers/v1/v1_integrations.rs +++ b/refact-agent/engine/src/http/routers/v1/v1_integrations.rs @@ -12,7 +12,7 @@ use rust_embed::RustEmbed; use crate::custom_error::ScratchError; use crate::global_context::GlobalContext; use crate::integrations::setting_up_integrations::split_path_into_project_and_integration; -use crate::integrations::integr_mcp::SessionMCP; +use crate::integrations::mcp::session_mcp::SessionMCP; pub async fn handle_v1_integrations( @@ -218,7 +218,7 @@ pub async fn handle_v1_integrations_mcp_logs( }; if let Some(stderr_path) = &stderr_file_path { - if let Err(e) = crate::integrations::integr_mcp::update_logs_from_stderr( + if let Err(e) = crate::integrations::mcp::session_mcp::update_logs_from_stderr( stderr_path, stderr_cursor, logs_arc.clone() diff --git a/refact-agent/engine/src/integrations/integr_mcp.rs b/refact-agent/engine/src/integrations/integr_mcp.rs deleted file mode 100644 index 36e2c4f35..000000000 --- a/refact-agent/engine/src/integrations/integr_mcp.rs +++ /dev/null @@ -1,788 +0,0 @@ -use std::any::Any; -use std::collections::HashMap; -use std::path::PathBuf; -use std::sync::Arc; -use std::sync::Weak; -use std::future::Future; -use std::process::Stdio; -use async_trait::async_trait; -use rmcp::model::RawContent; -use rmcp::serve_client; -use rmcp::transport::sse::ReqwestSseClient; -use rmcp::transport::SseTransport; -use serde::{Deserialize, Serialize}; -use tokio::sync::Mutex as AMutex; -use tokio::sync::RwLock as ARwLock; -use tokio::time::timeout; -use tokio::time::Duration; -use tokio::task::{AbortHandle, JoinHandle}; -use rmcp::{RoleClient, service::RunningService}; -use rmcp::model::{CallToolRequestParam, Tool as McpTool}; -use tempfile::NamedTempFile; -use tracing::Level; - -use crate::caps::resolve_chat_model; -use crate::custom_error::MapErrToString; -use crate::global_context::GlobalContext; -use crate::at_commands::at_commands::AtCommandsContext; -use crate::scratchpads::multimodality::MultimodalElement; -use crate::tools::tools_description::{Tool, ToolDesc, ToolParam}; -use crate::call_validation::{ChatMessage, ChatContent, ContextEnum}; -use crate::integrations::integr_abstract::{IntegrationTrait, IntegrationCommon, IntegrationConfirmation}; -use crate::integrations::sessions::IntegrationSession; -use crate::integrations::process_io_utils::read_file_with_cursor; -use crate::integrations::utils::{serialize_num_to_str, deserialize_str_to_num}; - -#[derive(Deserialize, Serialize, Clone, PartialEq, Default, Debug)] -pub struct SettingsMCP { - #[serde(rename = "command", default)] - pub mcp_command: String, - #[serde(default, rename = "env")] - pub mcp_env: HashMap, - #[serde(default, rename = "url")] - pub mcp_url: String, - #[serde(default = "default_headers", rename = "headers")] - pub mcp_headers: HashMap, - #[serde(default = "default_init_timeout", serialize_with = "serialize_num_to_str", deserialize_with = "deserialize_str_to_num")] - pub init_timeout: u64, - #[serde(default = "default_request_timeout", serialize_with = "serialize_num_to_str", deserialize_with = "deserialize_str_to_num")] - pub request_timeout: u64, -} - -fn default_init_timeout() -> u64 { 60 } -fn default_request_timeout() -> u64 { 30 } -fn default_headers() -> HashMap { - HashMap::from([ - ("User-Agent".to_string(), "Refact.ai (+https://github.com/smallcloudai/refact)".to_string()), - ("Accept".to_string(), "text/event-stream".to_string()), - ("Content-Type".to_string(), "application/json".to_string()), - ]) -} - -pub struct ToolMCP { - pub common: IntegrationCommon, - pub config_path: String, - pub mcp_client: Arc>>>, - pub mcp_tool: McpTool, - pub request_timeout: u64, -} - -#[derive(Default)] -pub struct IntegrationMCP { - pub gcx_option: Option>>, // need default to zero, to have access to all the virtual functions and then set it up - pub cfg: SettingsMCP, - pub common: IntegrationCommon, - pub config_path: String, -} - -pub struct SessionMCP { - pub debug_name: String, - pub config_path: String, // to check if expired or not - pub launched_cfg: SettingsMCP, // a copy to compare against IntegrationMCP::cfg, to see if anything has changed - pub mcp_client: Option>>>>, - pub mcp_tools: Vec, - pub startup_task_handles: Option<(Arc>>>, AbortHandle)>, - pub logs: Arc>>, // Store log messages - pub stderr_file_path: Option, // Path to the temporary file for stderr - pub stderr_cursor: Arc>, // Position in the file where we last read from -} - -impl IntegrationSession for SessionMCP { - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } - - fn is_expired(&self) -> bool { - !std::path::Path::new(&self.config_path).exists() - } - - fn try_stop(&mut self, self_arc: Arc>>) -> Box + Send> { - Box::new(async move { - let (debug_name, client, logs, startup_task_handles, stderr_file) = { - let mut session_locked = self_arc.lock().await; - let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); - ( - session_downcasted.debug_name.clone(), - session_downcasted.mcp_client.clone(), - session_downcasted.logs.clone(), - session_downcasted.startup_task_handles.clone(), - session_downcasted.stderr_file_path.clone(), - ) - }; - - if let Some((_, abort_handle)) = startup_task_handles { - _add_log_entry(logs.clone(), "Aborted startup task".to_string()).await; - abort_handle.abort(); - } - - if let Some(client) = client { - _session_kill_process(&debug_name, client, logs).await; - } - if let Some(stderr_file) = &stderr_file { - if let Err(e) = tokio::fs::remove_file(stderr_file).await { - tracing::error!("Failed to remove {}: {}", stderr_file.to_string_lossy(), e); - } - } - - "".to_string() - }) - } -} - -async fn _add_log_entry(session_logs: Arc>>, entry: String) { - let timestamp = chrono::Local::now().format("%H:%M:%S%.3f").to_string(); - let log_entry = format!("[{}] {}", timestamp, entry); - - let mut session_logs_locked = session_logs.lock().await; - session_logs_locked.extend(log_entry.lines().into_iter().map(|s| s.to_string())); - - if session_logs_locked.len() > 100 { - let excess = session_logs_locked.len() - 100; - session_logs_locked.drain(0..excess); - } -} - -pub async fn update_logs_from_stderr( - stderr_file_path: &PathBuf, - stderr_cursor: Arc>, - session_logs: Arc>> -) -> Result<(), String> { - let (buffer, bytes_read) = read_file_with_cursor(stderr_file_path, stderr_cursor.clone()).await - .map_err_with_prefix("Failed to read file:")?; - if bytes_read > 0 && !buffer.trim().is_empty() { - _add_log_entry(session_logs, buffer.trim().to_string()).await; - } - Ok(()) -} - -async fn _session_kill_process( - debug_name: &str, - mcp_client: Arc>>>, - session_logs: Arc>>, -) { - tracing::info!("Stopping MCP Server for {}", debug_name); - _add_log_entry(session_logs.clone(), "Stopping MCP Server".to_string()).await; - - let client_to_cancel = { - let mut mcp_client_locked = mcp_client.lock().await; - mcp_client_locked.take() - }; - - if let Some(client) = client_to_cancel { - match timeout(Duration::from_secs(3), client.cancel()).await { - Ok(Ok(reason)) => { - let success_msg = format!("MCP server stopped: {:?}", reason); - tracing::info!("{} for {}", success_msg, debug_name); - _add_log_entry(session_logs, success_msg).await; - }, - Ok(Err(e)) => { - let error_msg = format!("Failed to stop MCP: {:?}", e); - tracing::error!("{} for {}", error_msg, debug_name); - _add_log_entry(session_logs, error_msg).await; - }, - Err(_) => { - let error_msg = "MCP server stop operation timed out after 3 seconds".to_string(); - tracing::error!("{} for {}", error_msg, debug_name); - _add_log_entry(session_logs, error_msg).await; - } - } - } -} - -async fn _session_apply_settings( - gcx: Arc>, - config_path: String, - new_cfg: SettingsMCP, -) { - let session_key = format!("{}", config_path); - - let session_arc = { - let mut gcx_write = gcx.write().await; - let session = gcx_write.integration_sessions.get(&session_key).cloned(); - if session.is_none() { - let new_session: Arc>> = Arc::new(AMutex::new(Box::new(SessionMCP { - debug_name: session_key.clone(), - config_path: config_path.clone(), - launched_cfg: new_cfg.clone(), - mcp_client: None, - mcp_tools: Vec::new(), - startup_task_handles: None, - logs: Arc::new(AMutex::new(Vec::new())), - stderr_file_path: None, - stderr_cursor: Arc::new(AMutex::new(0)), - }))); - tracing::info!("MCP START SESSION {:?}", session_key); - gcx_write.integration_sessions.insert(session_key.clone(), new_session.clone()); - new_session - } else { - session.unwrap() - } - }; - - let new_cfg_clone = new_cfg.clone(); - let session_arc_clone = session_arc.clone(); - - { - let mut session_locked = session_arc.lock().await; - let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); - - // If it's same config, and there is an mcp client, or startup task is running, skip - if new_cfg == session_downcasted.launched_cfg { - if session_downcasted.mcp_client.is_some() || session_downcasted.startup_task_handles.as_ref().map_or( - false, |h| !h.1.is_finished() - ) { - return; - } - } - - let startup_task_join_handle = tokio::spawn(async move { - let (mcp_client, logs, debug_name, stderr_file) = { - let mut session_locked = session_arc_clone.lock().await; - let mcp_session = session_locked.as_any_mut().downcast_mut::().unwrap(); - mcp_session.stderr_cursor = Arc::new(AMutex::new(0)); - mcp_session.launched_cfg = new_cfg_clone.clone(); - ( - std::mem::take(&mut mcp_session.mcp_client), - mcp_session.logs.clone(), - mcp_session.debug_name.clone(), - std::mem::take(&mut mcp_session.stderr_file_path), - ) - }; - - let log = async |level: Level, msg: String| { - match level { - Level::ERROR => tracing::error!("{msg} for {debug_name}"), - Level::WARN => tracing::warn!("{msg} for {debug_name}"), - _ => tracing::info!("{msg} for {debug_name}"), - } - _add_log_entry(logs.clone(), msg).await; - }; - - log(Level::INFO, "Applying new settings".to_string()).await; - - if let Some(mcp_client) = mcp_client { - _session_kill_process(&debug_name, mcp_client, logs.clone()).await; - } - if let Some(stderr_file) = &stderr_file { - if let Err(e) = tokio::fs::remove_file(stderr_file).await { - log(Level::ERROR, format!("Failed to remove {}: {}", stderr_file.to_string_lossy(), e)).await; - } - } - - let client = match (new_cfg_clone.mcp_url.trim(), new_cfg_clone.mcp_command.trim()) { - ("", "") => { - log(Level::ERROR, "Url and command are both empty, set up either url for sse protocol, or command for stdio protocol".to_string()).await; - return; - }, - (url, "") => { - let mut header_map = reqwest::header::HeaderMap::new(); - for (k, v) in &new_cfg_clone.mcp_headers { - match (reqwest::header::HeaderName::from_bytes(k.as_bytes()), - reqwest::header::HeaderValue::from_str(v), - ) { - (Ok(name), Ok(value)) => { - header_map.insert(name, value); - } - _ => log(Level::WARN, format!("Invalid header: {}: {}", k, v)).await, - } - } - let reqwest_client = match reqwest::Client::builder().default_headers(header_map).build() { - Ok(reqwest_client) => reqwest_client, - Err(e) => { - log(Level::ERROR, format!("Failed to build reqwest client: {}", e)).await; - return; - } - }; - let sse_client = match ReqwestSseClient::new_with_client(url, reqwest_client).await { - Ok(sse_client) => sse_client, - Err(e) => { - log(Level::ERROR, format!("Failed to init SSE client: {}", e)).await; - return; - }, - }; - let transport = match SseTransport::start_with_client(sse_client).await { - Ok(t) => t, - Err(e) => { - log(Level::ERROR, format!("Failed to init SSE transport: {}", e)).await; - return; - } - }; - match timeout(Duration::from_secs(new_cfg_clone.init_timeout), serve_client((), transport)).await { - Ok(Ok(client)) => client, - Ok(Err(e)) => { - log(Level::ERROR, format!("Failed to init SSE server: {}", e)).await; - return; - }, - Err(_) => { - log(Level::ERROR, format!("Request timed out after {} seconds", new_cfg_clone.init_timeout)).await; - return; - } - } - }, - ("", command) => { - let parsed_args = match shell_words::split(&command) { - Ok(args) => { - if args.is_empty() { - log(Level::ERROR, "Empty command".to_string()).await; - return; - } - args - } - Err(e) => { - log(Level::ERROR, format!("Failed to parse command: {}", e)).await; - return; - } - }; - - let mut command = tokio::process::Command::new(&parsed_args[0]); - command.args(&parsed_args[1..]); - for (key, value) in &new_cfg_clone.mcp_env { - command.env(key, value); - } - - match NamedTempFile::new().map(|f| f.keep()) { - Ok(Ok((file, path))) => { - { - let mut session_locked = session_arc_clone.lock().await; - let mcp_session = session_locked.as_any_mut().downcast_mut::().unwrap(); - - mcp_session.stderr_file_path = Some(path.clone()); - mcp_session.stderr_cursor = Arc::new(AMutex::new(0)); - } - command.stderr(Stdio::from(file)); - }, - Ok(Err(e)) => tracing::error!("Failed to persist stderr file for {debug_name}: {e}"), - Err(e) => tracing::error!("Failed to create stderr file for {debug_name}: {e}"), - } - - let transport = match rmcp::transport::TokioChildProcess::new(command) { - Ok(t) => t, - Err(e) => { - log(Level::ERROR, format!("Failed to init Tokio child process: {}", e)).await; - return; - } - }; - match timeout(Duration::from_secs(new_cfg_clone.init_timeout), serve_client((), transport)).await { - Ok(Ok(client)) => client, - Ok(Err(e)) => { - log(Level::ERROR, format!("Failed to init stdio server: {}", e)).await; - return; - }, - Err(_) => { - log(Level::ERROR, format!("Request timed out after {} seconds", new_cfg_clone.init_timeout)).await; - return; - } - } - }, - (_url, _command) => { - log(Level::ERROR, "Url and command cannot be specified at the same time, set up either url for sse protocol, or command for stdio protocol".to_string()).await; - return; - }, - }; - - log(Level::INFO, "Listing tools".to_string()).await; - - let tools = match timeout(Duration::from_secs(new_cfg_clone.request_timeout), client.list_all_tools()).await { - Ok(Ok(result)) => result, - Ok(Err(tools_error)) => { - log(Level::ERROR, format!("Failed to list tools: {:?}", tools_error)).await; - return; - }, - Err(_) => { - log(Level::ERROR, format!("Request timed out after {} seconds", new_cfg_clone.request_timeout)).await; - return; - } - }; - let tools_len = tools.len(); - - { - let mut session_locked = session_arc_clone.lock().await; - let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); - - session_downcasted.mcp_client = Some(Arc::new(AMutex::new(Some(client)))); - session_downcasted.mcp_tools = tools; - - session_downcasted.mcp_tools.len() - }; - - log(Level::INFO, format!("MCP session setup complete with {tools_len} tools")).await; - }); - - let startup_task_abort_handle = startup_task_join_handle.abort_handle(); - session_downcasted.startup_task_handles = Some( - (Arc::new(AMutex::new(Some(startup_task_join_handle))), startup_task_abort_handle) - ); - } -} - -async fn _session_wait_startup_task( - session_arc: Arc>>, -) { - let startup_task_handles = { - let mut session_locked = session_arc.lock().await; - let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); - session_downcasted.startup_task_handles.clone() - }; - - if let Some((join_handler_arc, _)) = startup_task_handles { - let mut join_handler_locked = join_handler_arc.lock().await; - if let Some(join_handler) = join_handler_locked.take() { - let _ = join_handler.await; - } - } -} - -#[async_trait] -impl IntegrationTrait for IntegrationMCP { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - async fn integr_settings_apply(&mut self, gcx: Arc>, config_path: String, value: &serde_json::Value) -> Result<(), serde_json::Error> { - self.gcx_option = Some(Arc::downgrade(&gcx)); - self.cfg = serde_json::from_value(value.clone())?; - self.common = serde_json::from_value(value.clone())?; - self.config_path = config_path; - _session_apply_settings(gcx.clone(), self.config_path.clone(), self.cfg.clone()).await; // possibly saves coroutine in session - Ok(()) - } - - fn integr_settings_as_json(&self) -> serde_json::Value { - serde_json::to_value(&self.cfg).unwrap() - } - - fn integr_common(&self) -> IntegrationCommon { - self.common.clone() - } - - async fn integr_tools(&self, _integr_name: &str) -> Vec> { - let session_key = format!("{}", self.config_path); - - let gcx = match self.gcx_option.clone() { - Some(gcx_weak) => match gcx_weak.upgrade() { - Some(gcx) => gcx, - None => { - tracing::error!("Error: System is shutting down"); - return vec![]; - } - }, - None => { - tracing::error!("Error: MCP is not set up yet"); - return vec![]; - } - }; - - let session_maybe = gcx.read().await.integration_sessions.get(&session_key).cloned(); - let session = match session_maybe { - Some(session) => session, - None => { - tracing::error!("No session for {:?}, strange (1)", session_key); - return vec![]; - } - }; - - let mut result: Vec> = vec![]; - { - let mut session_locked = session.lock().await; - let session_downcasted: &mut SessionMCP = session_locked.as_any_mut().downcast_mut::().unwrap(); - if session_downcasted.mcp_client.is_none() { - tracing::error!("No mcp_client for {:?}, strange (2)", session_key); - return vec![]; - } - for tool in session_downcasted.mcp_tools.iter() { - result.push(Box::new(ToolMCP { - common: self.common.clone(), - config_path: self.config_path.clone(), - mcp_client: session_downcasted.mcp_client.clone().unwrap(), - mcp_tool: tool.clone(), - request_timeout: self.cfg.request_timeout, - })); - } - } - - result - } - - fn integr_schema(&self) -> &str { - MCP_INTEGRATION_SCHEMA - } -} - -#[async_trait] -impl Tool for ToolMCP { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - async fn tool_execute( - &mut self, - ccx: Arc>, - tool_call_id: &String, - args: &HashMap, - ) -> Result<(bool, Vec), String> { - let session_key = format!("{}", self.config_path); - let (gcx, current_model) = { - let ccx_locked = ccx.lock().await; - (ccx_locked.global_context.clone(), ccx_locked.current_model.clone()) - }; - let (session_maybe, caps_maybe) = { - let gcx_locked = gcx.read().await; - (gcx_locked.integration_sessions.get(&session_key).cloned(), gcx_locked.caps.clone()) - }; - if session_maybe.is_none() { - tracing::error!("No session for {:?}, strange (2)", session_key); - return Err(format!("No session for {:?}", session_key)); - } - let session = session_maybe.unwrap(); - let model_supports_multimodality = caps_maybe.is_some_and(|caps| { - resolve_chat_model(caps, ¤t_model).is_ok_and(|m| m.supports_multimodality) - }); - _session_wait_startup_task(session.clone()).await; - - let json_args = serde_json::json!(args); - tracing::info!("\n\nMCP CALL tool '{}' with arguments: {:?}", self.mcp_tool.name, json_args); - - let session_logs = { - let mut session_locked = session.lock().await; - let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); - session_downcasted.logs.clone() - }; - - _add_log_entry(session_logs.clone(), format!("Executing tool '{}' with arguments: {:?}", self.mcp_tool.name, json_args)).await; - - let result_probably = { - let mcp_client_locked = self.mcp_client.lock().await; - if let Some(client) = &*mcp_client_locked { - match timeout(Duration::from_secs(self.request_timeout), - client.call_tool(CallToolRequestParam { - name: self.mcp_tool.name.clone(), - arguments: match json_args { - serde_json::Value::Object(map) => Some(map), - _ => None, - }, - }) - ).await { - Ok(result) => result, - Err(_) => {Err(rmcp::service::ServiceError::Timeout { - timeout: Duration::from_secs(self.request_timeout), - })}, - } - } else { - return Err("MCP client is not available".to_string()); - } - }; - - let result_message = match result_probably { - Ok(result) => { - if result.is_error.unwrap_or(false) { - let error_msg = format!("Tool execution error: {:?}", result.content); - _add_log_entry(session_logs.clone(), error_msg.clone()).await; - return Err(error_msg); - } - - let mut elements = Vec::new(); - for content in result.content { - match content.raw { - RawContent::Text(text_content) => { - elements.push(MultimodalElement { - m_type: "text".to_string(), - m_content: text_content.text, - }) - } - RawContent::Image(image_content) => { - if model_supports_multimodality { - let mime_type = if image_content.mime_type.starts_with("image/") { - image_content.mime_type - } else { - format!("image/{}", image_content.mime_type) - }; - elements.push(MultimodalElement { - m_type: mime_type, - m_content: image_content.data, - }) - } else { - elements.push(MultimodalElement { - m_type: "text".to_string(), - m_content: "Server returned an image, but model does not support multimodality".to_string(), - }) - } - }, - RawContent::Audio(_) => { - elements.push(MultimodalElement { - m_type: "text".to_string(), - m_content: "Server returned audio, which is not supported".to_string(), - }) - }, - RawContent::Resource(_) => { - elements.push(MultimodalElement { - m_type: "text".to_string(), - m_content: "Server returned resource, which is not supported".to_string(), - }) - }, - } - } - - let content = if elements.iter().all(|el| el.m_type == "text") { - ChatContent::SimpleText( - elements.into_iter().map(|el| el.m_content).collect::>().join("\n\n") - ) - } else { - ChatContent::Multimodal(elements) - }; - - ContextEnum::ChatMessage(ChatMessage { - role: "tool".to_string(), - content, - tool_calls: None, - tool_call_id: tool_call_id.clone(), - ..Default::default() - }) - } - Err(e) => { - let error_msg = format!("Failed to call tool: {:?}", e); - tracing::error!("{}", error_msg); - _add_log_entry(session_logs.clone(), error_msg).await; - return Err(e.to_string()); - } - }; - - Ok((false, vec![result_message])) - } - - fn tool_depends_on(&self) -> Vec { - vec![] - } - - fn tool_description(&self) -> ToolDesc { - // self.mcp_tool.input_schema = Object { - // "properties": Object { - // "a": Object { - // "title": String("A"), - // "type": String("integer") - // }, - // "b": Object { - // "title": String("B"), - // "type": String("integer") - // } - // }, - // "required": Array [ - // String("a"), - // String("b") - // ], - // "title": String("addArguments"), - // "type": String("object") - // } - let mut parameters = vec![]; - let mut parameters_required = vec![]; - - if let Some(serde_json::Value::Object(properties)) = self.mcp_tool.input_schema.get("properties") { - for (name, prop) in properties { - if let serde_json::Value::Object(prop_obj) = prop { - let param_type = prop_obj.get("type").and_then(|v| v.as_str()).unwrap_or("string").to_string(); - let description = prop_obj.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string(); - parameters.push(ToolParam { - name: name.clone(), - param_type, - description, - }); - } - } - } - if let Some(serde_json::Value::Array(required)) = self.mcp_tool.input_schema.get("required") { - for req in required { - if let Some(req_str) = req.as_str() { - parameters_required.push(req_str.to_string()); - } - } - } - - ToolDesc { - name: self.tool_name(), - agentic: true, - experimental: false, - description: self.mcp_tool.description.to_owned().unwrap_or_default().to_string(), - parameters, - parameters_required, - } - } - - fn tool_name(&self) -> String { - let yaml_name = std::path::Path::new(&self.config_path) - .file_stem() - .and_then(|name| name.to_str()) - .unwrap_or("unknown"); - let sanitized_yaml_name = format!("{}_{}", yaml_name, self.mcp_tool.name) - .chars() - .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' }) - .collect::(); - sanitized_yaml_name - } - - async fn command_to_match_against_confirm_deny( - &self, - _ccx: Arc>, - _args: &HashMap, - ) -> Result { - let command = self.mcp_tool.name.clone(); - tracing::info!("MCP command_to_match_against_confirm_deny() returns {:?}", command); - Ok(command.to_string()) - } - - fn confirm_deny_rules(&self) -> Option { - Some(self.common.confirmation.clone()) - } - - fn has_config_path(&self) -> Option { - Some(self.config_path.clone()) - } -} - -pub const MCP_INTEGRATION_SCHEMA: &str = r#" -fields: - command: - f_type: string - f_desc: "The MCP command to execute (for stdio transport), like `npx -y `, `/my/path/venv/python -m `, or `docker run -i --rm `. On Windows, use `npx.cmd` or `npm.cmd` instead of `npx` or `npm`." - env: - f_type: string_to_string_map - f_desc: "Environment variables to pass to the MCP command (for stdio transport)." - url: - f_type: string - f_desc: "The URL of the MCP server (for sse transport), e.g., 'https://api.example.com/mcp/sse'." - headers: - f_type: string_to_string_map - f_desc: "HTTP headers to include in requests to the MCP server (for sse transport)." - f_default: - User-Agent: "Refact.ai (+https://github.com/smallcloudai/refact)" - Accept: text/event-stream - Content-Type: application/json - init_timeout: - f_type: string_short - f_desc: "Timeout in seconds for MCP server initialization." - f_default: "60" - f_extra: true - request_timeout: - f_type: string_short - f_desc: "Timeout in seconds for MCP requests." - f_default: "30" - f_extra: true -description: | - You can add almost any MCP (Model Context Protocol) server here! This supports both local MCP servers (stdio) - and remote MCP servers (sse). You can read more about MCP here: https://www.anthropic.com/news/model-context-protocol - - For servers using stdio protocol, provide the command to execute, and optionally, set the environment variables. - For remote using sse protocol, provide the URL of the server, and optionally, add more headers. -available: - on_your_laptop_possible: true - when_isolated_possible: true -confirmation: - ask_user_default: ["*"] - deny_default: [] -smartlinks: - - sl_label: "Test" - sl_chat: - - role: "user" - content: > - 🔧 Your job is to test %CURRENT_CONFIG%. Tools that this MCP server has created should be visible to you. Don't search anything, it should be visible as - a tools already. Run one and express happiness. If something does wrong, or you don't see the tools, ask user if they want to fix it by rewriting the config. - sl_enable_only_with_tool: true -"#; diff --git a/refact-agent/engine/src/integrations/mcp/integr_mcp.rs b/refact-agent/engine/src/integrations/mcp/integr_mcp.rs new file mode 100644 index 000000000..0c08b8567 --- /dev/null +++ b/refact-agent/engine/src/integrations/mcp/integr_mcp.rs @@ -0,0 +1,357 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Weak; +use std::process::Stdio; +use async_trait::async_trait; +use tokio::sync::Mutex as AMutex; +use tokio::sync::RwLock as ARwLock; +use tokio::time::timeout; +use tokio::time::Duration; +use rmcp::transport::sse::ReqwestSseClient; +use rmcp::transport::SseTransport; +use rmcp::serve_client; +use serde::{Deserialize, Serialize}; +use tempfile::NamedTempFile; + +use crate::global_context::GlobalContext; +use crate::integrations::integr_abstract::{IntegrationTrait, IntegrationCommon}; +use crate::integrations::utils::{serialize_num_to_str, deserialize_str_to_num}; +use super::session_mcp::{SessionMCP, _add_log_entry, _session_kill_process}; +use super::tool_mcp::ToolMCP; +use super::MCP_INTEGRATION_SCHEMA; + +#[derive(Deserialize, Serialize, Clone, PartialEq, Default, Debug)] +pub struct SettingsMCP { + #[serde(rename = "command", default)] + pub mcp_command: String, + #[serde(default, rename = "env")] + pub mcp_env: HashMap, + #[serde(default, rename = "url")] + pub mcp_url: String, + #[serde(default = "default_headers", rename = "headers")] + pub mcp_headers: HashMap, + #[serde(default = "default_init_timeout", serialize_with = "serialize_num_to_str", deserialize_with = "deserialize_str_to_num")] + pub init_timeout: u64, + #[serde(default = "default_request_timeout", serialize_with = "serialize_num_to_str", deserialize_with = "deserialize_str_to_num")] + pub request_timeout: u64, +} + +fn default_init_timeout() -> u64 { 60 } +fn default_request_timeout() -> u64 { 30 } +fn default_headers() -> HashMap { + HashMap::from([ + ("User-Agent".to_string(), "Refact.ai (+https://github.com/smallcloudai/refact)".to_string()), + ("Accept".to_string(), "text/event-stream".to_string()), + ("Content-Type".to_string(), "application/json".to_string()), + ]) +} + +#[derive(Default)] +pub struct IntegrationMCP { + pub gcx_option: Option>>, // need default to zero, to have access to all the virtual functions and then set it up + pub cfg: SettingsMCP, + pub common: IntegrationCommon, + pub config_path: String, +} + +pub async fn _session_apply_settings( + gcx: Arc>, + config_path: String, + new_cfg: SettingsMCP, +) { + let session_key = format!("{}", config_path); + + let session_arc = { + let mut gcx_write = gcx.write().await; + let session = gcx_write.integration_sessions.get(&session_key).cloned(); + if session.is_none() { + let new_session: Arc>> = Arc::new(AMutex::new(Box::new(SessionMCP { + debug_name: session_key.clone(), + config_path: config_path.clone(), + launched_cfg: new_cfg.clone(), + mcp_client: None, + mcp_tools: Vec::new(), + startup_task_handles: None, + logs: Arc::new(AMutex::new(Vec::new())), + stderr_file_path: None, + stderr_cursor: Arc::new(AMutex::new(0)), + }))); + tracing::info!("MCP START SESSION {:?}", session_key); + gcx_write.integration_sessions.insert(session_key.clone(), new_session.clone()); + new_session + } else { + session.unwrap() + } + }; + + let new_cfg_clone = new_cfg.clone(); + let session_arc_clone = session_arc.clone(); + + { + let mut session_locked = session_arc.lock().await; + let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); + + // If it's same config, and there is an mcp client, or startup task is running, skip + if new_cfg == session_downcasted.launched_cfg { + if session_downcasted.mcp_client.is_some() || session_downcasted.startup_task_handles.as_ref().map_or( + false, |h| !h.1.is_finished() + ) { + return; + } + } + + let startup_task_join_handle = tokio::spawn(async move { + let (mcp_client, logs, debug_name, stderr_file) = { + let mut session_locked = session_arc_clone.lock().await; + let mcp_session = session_locked.as_any_mut().downcast_mut::().unwrap(); + mcp_session.stderr_cursor = Arc::new(AMutex::new(0)); + mcp_session.launched_cfg = new_cfg_clone.clone(); + ( + std::mem::take(&mut mcp_session.mcp_client), + mcp_session.logs.clone(), + mcp_session.debug_name.clone(), + std::mem::take(&mut mcp_session.stderr_file_path), + ) + }; + + let log = async |level: tracing::Level, msg: String| { + match level { + tracing::Level::ERROR => tracing::error!("{msg} for {debug_name}"), + tracing::Level::WARN => tracing::warn!("{msg} for {debug_name}"), + _ => tracing::info!("{msg} for {debug_name}"), + } + _add_log_entry(logs.clone(), msg).await; + }; + + log(tracing::Level::INFO, "Applying new settings".to_string()).await; + + if let Some(mcp_client) = mcp_client { + _session_kill_process(&debug_name, mcp_client, logs.clone()).await; + } + if let Some(stderr_file) = &stderr_file { + if let Err(e) = tokio::fs::remove_file(stderr_file).await { + log(tracing::Level::ERROR, format!("Failed to remove {}: {}", stderr_file.to_string_lossy(), e)).await; + } + } + + let client = match (new_cfg_clone.mcp_url.trim(), new_cfg_clone.mcp_command.trim()) { + ("", "") => { + log(tracing::Level::ERROR, "Url and command are both empty, set up either url for sse protocol, or command for stdio protocol".to_string()).await; + return; + }, + (url, "") => { + let mut header_map = reqwest::header::HeaderMap::new(); + for (k, v) in &new_cfg_clone.mcp_headers { + match (reqwest::header::HeaderName::from_bytes(k.as_bytes()), + reqwest::header::HeaderValue::from_str(v), + ) { + (Ok(name), Ok(value)) => { + header_map.insert(name, value); + } + _ => log(tracing::Level::WARN, format!("Invalid header: {}: {}", k, v)).await, + } + } + let reqwest_client = match reqwest::Client::builder().default_headers(header_map).build() { + Ok(reqwest_client) => reqwest_client, + Err(e) => { + log(tracing::Level::ERROR, format!("Failed to build reqwest client: {}", e)).await; + return; + } + }; + let sse_client = match ReqwestSseClient::new_with_client(url, reqwest_client).await { + Ok(sse_client) => sse_client, + Err(e) => { + log(tracing::Level::ERROR, format!("Failed to init SSE client: {}", e)).await; + return; + }, + }; + let transport = match SseTransport::start_with_client(sse_client).await { + Ok(t) => t, + Err(e) => { + log(tracing::Level::ERROR, format!("Failed to init SSE transport: {}", e)).await; + return; + } + }; + match timeout(Duration::from_secs(new_cfg_clone.init_timeout), serve_client((), transport)).await { + Ok(Ok(client)) => client, + Ok(Err(e)) => { + log(tracing::Level::ERROR, format!("Failed to init SSE server: {}", e)).await; + return; + }, + Err(_) => { + log(tracing::Level::ERROR, format!("Request timed out after {} seconds", new_cfg_clone.init_timeout)).await; + return; + } + } + }, + ("", command) => { + let parsed_args = match shell_words::split(&command) { + Ok(args) => { + if args.is_empty() { + log(tracing::Level::ERROR, "Empty command".to_string()).await; + return; + } + args + } + Err(e) => { + log(tracing::Level::ERROR, format!("Failed to parse command: {}", e)).await; + return; + } + }; + + let mut command = tokio::process::Command::new(&parsed_args[0]); + command.args(&parsed_args[1..]); + for (key, value) in &new_cfg_clone.mcp_env { + command.env(key, value); + } + + match NamedTempFile::new().map(|f| f.keep()) { + Ok(Ok((file, path))) => { + { + let mut session_locked = session_arc_clone.lock().await; + let mcp_session = session_locked.as_any_mut().downcast_mut::().unwrap(); + + mcp_session.stderr_file_path = Some(path.clone()); + mcp_session.stderr_cursor = Arc::new(AMutex::new(0)); + } + command.stderr(Stdio::from(file)); + }, + Ok(Err(e)) => tracing::error!("Failed to persist stderr file for {debug_name}: {e}"), + Err(e) => tracing::error!("Failed to create stderr file for {debug_name}: {e}"), + } + + let transport = match rmcp::transport::TokioChildProcess::new(command) { + Ok(t) => t, + Err(e) => { + log(tracing::Level::ERROR, format!("Failed to init Tokio child process: {}", e)).await; + return; + } + }; + match timeout(Duration::from_secs(new_cfg_clone.init_timeout), serve_client((), transport)).await { + Ok(Ok(client)) => client, + Ok(Err(e)) => { + log(tracing::Level::ERROR, format!("Failed to init stdio server: {}", e)).await; + return; + }, + Err(_) => { + log(tracing::Level::ERROR, format!("Request timed out after {} seconds", new_cfg_clone.init_timeout)).await; + return; + } + } + }, + (_url, _command) => { + log(tracing::Level::ERROR, "Url and command cannot be specified at the same time, set up either url for sse protocol, or command for stdio protocol".to_string()).await; + return; + }, + }; + + log(tracing::Level::INFO, "Listing tools".to_string()).await; + + let tools = match timeout(Duration::from_secs(new_cfg_clone.request_timeout), client.list_all_tools()).await { + Ok(Ok(result)) => result, + Ok(Err(tools_error)) => { + log(tracing::Level::ERROR, format!("Failed to list tools: {:?}", tools_error)).await; + return; + }, + Err(_) => { + log(tracing::Level::ERROR, format!("Request timed out after {} seconds", new_cfg_clone.request_timeout)).await; + return; + } + }; + let tools_len = tools.len(); + + { + let mut session_locked = session_arc_clone.lock().await; + let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); + + session_downcasted.mcp_client = Some(Arc::new(AMutex::new(Some(client)))); + session_downcasted.mcp_tools = tools; + + session_downcasted.mcp_tools.len() + }; + + log(tracing::Level::INFO, format!("MCP session setup complete with {tools_len} tools")).await; + }); + + let startup_task_abort_handle = startup_task_join_handle.abort_handle(); + session_downcasted.startup_task_handles = Some( + (Arc::new(AMutex::new(Some(startup_task_join_handle))), startup_task_abort_handle) + ); + } +} + +#[async_trait] +impl IntegrationTrait for IntegrationMCP { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + async fn integr_settings_apply(&mut self, gcx: Arc>, config_path: String, value: &serde_json::Value) -> Result<(), serde_json::Error> { + self.gcx_option = Some(Arc::downgrade(&gcx)); + self.cfg = serde_json::from_value(value.clone())?; + self.common = serde_json::from_value(value.clone())?; + self.config_path = config_path; + _session_apply_settings(gcx.clone(), self.config_path.clone(), self.cfg.clone()).await; // possibly saves coroutine in session + Ok(()) + } + + fn integr_settings_as_json(&self) -> serde_json::Value { + serde_json::to_value(&self.cfg).unwrap() + } + + fn integr_common(&self) -> IntegrationCommon { + self.common.clone() + } + + async fn integr_tools(&self, _integr_name: &str) -> Vec> { + let session_key = format!("{}", self.config_path); + + let gcx = match self.gcx_option.clone() { + Some(gcx_weak) => match gcx_weak.upgrade() { + Some(gcx) => gcx, + None => { + tracing::error!("Error: System is shutting down"); + return vec![]; + } + }, + None => { + tracing::error!("Error: MCP is not set up yet"); + return vec![]; + } + }; + + let session_maybe = gcx.read().await.integration_sessions.get(&session_key).cloned(); + let session = match session_maybe { + Some(session) => session, + None => { + tracing::error!("No session for {:?}, strange (1)", session_key); + return vec![]; + } + }; + + let mut result: Vec> = vec![]; + { + let mut session_locked = session.lock().await; + let session_downcasted: &mut SessionMCP = session_locked.as_any_mut().downcast_mut::().unwrap(); + if session_downcasted.mcp_client.is_none() { + tracing::error!("No mcp_client for {:?}, strange (2)", session_key); + return vec![]; + } + for tool in session_downcasted.mcp_tools.iter() { + result.push(Box::new(ToolMCP { + common: self.common.clone(), + config_path: self.config_path.clone(), + mcp_client: session_downcasted.mcp_client.clone().unwrap(), + mcp_tool: tool.clone(), + request_timeout: self.cfg.request_timeout, + })); + } + } + + result + } + + fn integr_schema(&self) -> &str { + MCP_INTEGRATION_SCHEMA + } +} diff --git a/refact-agent/engine/src/integrations/mcp/mcp_schema.yaml b/refact-agent/engine/src/integrations/mcp/mcp_schema.yaml new file mode 100644 index 000000000..58d9a4674 --- /dev/null +++ b/refact-agent/engine/src/integrations/mcp/mcp_schema.yaml @@ -0,0 +1,47 @@ +fields: + command: + f_type: string + f_desc: "The MCP command to execute (for stdio transport), like `npx -y `, `/my/path/venv/python -m `, or `docker run -i --rm `. On Windows, use `npx.cmd` or `npm.cmd` instead of `npx` or `npm`." + env: + f_type: string_to_string_map + f_desc: "Environment variables to pass to the MCP command (for stdio transport)." + url: + f_type: string + f_desc: "The URL of the MCP server (for sse transport), e.g., 'https://api.example.com/mcp/sse'." + headers: + f_type: string_to_string_map + f_desc: "HTTP headers to include in requests to the MCP server (for sse transport)." + f_default: + User-Agent: "Refact.ai (+https://github.com/smallcloudai/refact)" + Accept: text/event-stream + Content-Type: application/json + init_timeout: + f_type: string_short + f_desc: "Timeout in seconds for MCP server initialization." + f_default: "60" + f_extra: true + request_timeout: + f_type: string_short + f_desc: "Timeout in seconds for MCP requests." + f_default: "30" + f_extra: true +description: | + You can add almost any MCP (Model Context Protocol) server here! This supports both local MCP servers (stdio) + and remote MCP servers (sse). You can read more about MCP here: https://www.anthropic.com/news/model-context-protocol + + For servers using stdio protocol, provide the command to execute, and optionally, set the environment variables. + For remote using sse protocol, provide the URL of the server, and optionally, add more headers. +available: + on_your_laptop_possible: true + when_isolated_possible: true +confirmation: + ask_user_default: ["*"] + deny_default: [] +smartlinks: + - sl_label: "Test" + sl_chat: + - role: "user" + content: > + 🔧 Your job is to test %CURRENT_CONFIG%. Tools that this MCP server has created should be visible to you. Don't search anything, it should be visible as + a tools already. Run one and express happiness. If something does wrong, or you don't see the tools, ask user if they want to fix it by rewriting the config. + sl_enable_only_with_tool: true diff --git a/refact-agent/engine/src/integrations/mcp/mod.rs b/refact-agent/engine/src/integrations/mcp/mod.rs new file mode 100644 index 000000000..a04a2c9cf --- /dev/null +++ b/refact-agent/engine/src/integrations/mcp/mod.rs @@ -0,0 +1,7 @@ +pub mod integr_mcp; +pub mod tool_mcp; +pub mod session_mcp; + +pub use integr_mcp::IntegrationMCP; + +pub const MCP_INTEGRATION_SCHEMA: &str = include_str!("mcp_schema.yaml"); diff --git a/refact-agent/engine/src/integrations/mcp/session_mcp.rs b/refact-agent/engine/src/integrations/mcp/session_mcp.rs new file mode 100644 index 000000000..e1c479620 --- /dev/null +++ b/refact-agent/engine/src/integrations/mcp/session_mcp.rs @@ -0,0 +1,144 @@ +use std::any::Any; +use std::path::PathBuf; +use std::sync::Arc; +use std::future::Future; +use tokio::sync::Mutex as AMutex; +use tokio::task::{AbortHandle, JoinHandle}; +use rmcp::{RoleClient, service::RunningService}; +use rmcp::model::Tool as McpTool; +use tokio::time::{timeout, Duration}; + +use crate::integrations::sessions::IntegrationSession; +use crate::integrations::process_io_utils::read_file_with_cursor; +use super::integr_mcp::SettingsMCP; + +pub struct SessionMCP { + pub debug_name: String, + pub config_path: String, // to check if expired or not + pub launched_cfg: SettingsMCP, // a copy to compare against IntegrationMCP::cfg, to see if anything has changed + pub mcp_client: Option>>>>, + pub mcp_tools: Vec, + pub startup_task_handles: Option<(Arc>>>, AbortHandle)>, + pub logs: Arc>>, // Store log messages + pub stderr_file_path: Option, // Path to the temporary file for stderr + pub stderr_cursor: Arc>, // Position in the file where we last read from +} + +impl IntegrationSession for SessionMCP { + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn is_expired(&self) -> bool { + !std::path::Path::new(&self.config_path).exists() + } + + fn try_stop(&mut self, self_arc: Arc>>) -> Box + Send> { + Box::new(async move { + let (debug_name, client, logs, startup_task_handles, stderr_file) = { + let mut session_locked = self_arc.lock().await; + let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); + ( + session_downcasted.debug_name.clone(), + session_downcasted.mcp_client.clone(), + session_downcasted.logs.clone(), + session_downcasted.startup_task_handles.clone(), + session_downcasted.stderr_file_path.clone(), + ) + }; + + if let Some((_, abort_handle)) = startup_task_handles { + _add_log_entry(logs.clone(), "Aborted startup task".to_string()).await; + abort_handle.abort(); + } + + if let Some(client) = client { + _session_kill_process(&debug_name, client, logs).await; + } + if let Some(stderr_file) = &stderr_file { + if let Err(e) = tokio::fs::remove_file(stderr_file).await { + tracing::error!("Failed to remove {}: {}", stderr_file.to_string_lossy(), e); + } + } + + "".to_string() + }) + } +} + +pub async fn _add_log_entry(session_logs: Arc>>, entry: String) { + let timestamp = chrono::Local::now().format("%H:%M:%S%.3f").to_string(); + let log_entry = format!("[{}] {}", timestamp, entry); + + let mut session_logs_locked = session_logs.lock().await; + session_logs_locked.extend(log_entry.lines().into_iter().map(|s| s.to_string())); + + if session_logs_locked.len() > 100 { + let excess = session_logs_locked.len() - 100; + session_logs_locked.drain(0..excess); + } +} + +pub async fn update_logs_from_stderr( + stderr_file_path: &PathBuf, + stderr_cursor: Arc>, + session_logs: Arc>> +) -> Result<(), String> { + let (buffer, bytes_read) = read_file_with_cursor(stderr_file_path, stderr_cursor.clone()).await + .map_err(|e| format!("Failed to read file: {}", e))?; + if bytes_read > 0 && !buffer.trim().is_empty() { + _add_log_entry(session_logs, buffer.trim().to_string()).await; + } + Ok(()) +} + +pub async fn _session_kill_process( + debug_name: &str, + mcp_client: Arc>>>, + session_logs: Arc>>, +) { + tracing::info!("Stopping MCP Server for {}", debug_name); + _add_log_entry(session_logs.clone(), "Stopping MCP Server".to_string()).await; + + let client_to_cancel = { + let mut mcp_client_locked = mcp_client.lock().await; + mcp_client_locked.take() + }; + + if let Some(client) = client_to_cancel { + match timeout(Duration::from_secs(3), client.cancel()).await { + Ok(Ok(reason)) => { + let success_msg = format!("MCP server stopped: {:?}", reason); + tracing::info!("{} for {}", success_msg, debug_name); + _add_log_entry(session_logs, success_msg).await; + }, + Ok(Err(e)) => { + let error_msg = format!("Failed to stop MCP: {:?}", e); + tracing::error!("{} for {}", error_msg, debug_name); + _add_log_entry(session_logs, error_msg).await; + }, + Err(_) => { + let error_msg = "MCP server stop operation timed out after 3 seconds".to_string(); + tracing::error!("{} for {}", error_msg, debug_name); + _add_log_entry(session_logs, error_msg).await; + } + } + } +} + +pub async fn _session_wait_startup_task( + session_arc: Arc>>, +) { + let startup_task_handles = { + let mut session_locked = session_arc.lock().await; + let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); + session_downcasted.startup_task_handles.clone() + }; + + if let Some((join_handler_arc, _)) = startup_task_handles { + let mut join_handler_locked = join_handler_arc.lock().await; + if let Some(join_handler) = join_handler_locked.take() { + let _ = join_handler.await; + } + } +} diff --git a/refact-agent/engine/src/integrations/mcp/tool_mcp.rs b/refact-agent/engine/src/integrations/mcp/tool_mcp.rs new file mode 100644 index 000000000..a067559dd --- /dev/null +++ b/refact-agent/engine/src/integrations/mcp/tool_mcp.rs @@ -0,0 +1,254 @@ + +use std::collections::HashMap; +use std::sync::Arc; +use async_trait::async_trait; +use rmcp::model::{RawContent, CallToolRequestParam, Tool as McpTool}; +use rmcp::{RoleClient, service::RunningService}; +use tokio::sync::Mutex as AMutex; +use tokio::time::timeout; +use tokio::time::Duration; + +use crate::caps::resolve_chat_model; +use crate::at_commands::at_commands::AtCommandsContext; +use crate::scratchpads::multimodality::MultimodalElement; +use crate::tools::tools_description::{Tool, ToolDesc, ToolParam}; +use crate::call_validation::{ChatMessage, ChatContent, ContextEnum}; +use crate::integrations::integr_abstract::{IntegrationCommon, IntegrationConfirmation}; +use super::session_mcp::{_add_log_entry, _session_wait_startup_task}; + +pub struct ToolMCP { + pub common: IntegrationCommon, + pub config_path: String, + pub mcp_client: Arc>>>, + pub mcp_tool: McpTool, + pub request_timeout: u64, +} + +#[async_trait] +impl Tool for ToolMCP { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + async fn tool_execute( + &mut self, + ccx: Arc>, + tool_call_id: &String, + args: &HashMap, + ) -> Result<(bool, Vec), String> { + let session_key = format!("{}", self.config_path); + let (gcx, current_model) = { + let ccx_locked = ccx.lock().await; + (ccx_locked.global_context.clone(), ccx_locked.current_model.clone()) + }; + let (session_maybe, caps_maybe) = { + let gcx_locked = gcx.read().await; + (gcx_locked.integration_sessions.get(&session_key).cloned(), gcx_locked.caps.clone()) + }; + if session_maybe.is_none() { + tracing::error!("No session for {:?}, strange (2)", session_key); + return Err(format!("No session for {:?}", session_key)); + } + let session = session_maybe.unwrap(); + let model_supports_multimodality = caps_maybe.is_some_and(|caps| { + resolve_chat_model(caps, ¤t_model).is_ok_and(|m| m.supports_multimodality) + }); + _session_wait_startup_task(session.clone()).await; + + let json_args = serde_json::json!(args); + tracing::info!("\n\nMCP CALL tool '{}' with arguments: {:?}", self.mcp_tool.name, json_args); + + let session_logs = { + let mut session_locked = session.lock().await; + let session_downcasted = session_locked.as_any_mut().downcast_mut::().unwrap(); + session_downcasted.logs.clone() + }; + + _add_log_entry(session_logs.clone(), format!("Executing tool '{}' with arguments: {:?}", self.mcp_tool.name, json_args)).await; + + let result_probably = { + let mcp_client_locked = self.mcp_client.lock().await; + if let Some(client) = &*mcp_client_locked { + match timeout(Duration::from_secs(self.request_timeout), + client.call_tool(CallToolRequestParam { + name: self.mcp_tool.name.clone(), + arguments: match json_args { + serde_json::Value::Object(map) => Some(map), + _ => None, + }, + }) + ).await { + Ok(result) => result, + Err(_) => {Err(rmcp::service::ServiceError::Timeout { + timeout: Duration::from_secs(self.request_timeout), + })}, + } + } else { + return Err("MCP client is not available".to_string()); + } + }; + + let result_message = match result_probably { + Ok(result) => { + if result.is_error.unwrap_or(false) { + let error_msg = format!("Tool execution error: {:?}", result.content); + _add_log_entry(session_logs.clone(), error_msg.clone()).await; + return Err(error_msg); + } + + let mut elements = Vec::new(); + for content in result.content { + match content.raw { + RawContent::Text(text_content) => { + elements.push(MultimodalElement { + m_type: "text".to_string(), + m_content: text_content.text, + }) + } + RawContent::Image(image_content) => { + if model_supports_multimodality { + let mime_type = if image_content.mime_type.starts_with("image/") { + image_content.mime_type + } else { + format!("image/{}", image_content.mime_type) + }; + elements.push(MultimodalElement { + m_type: mime_type, + m_content: image_content.data, + }) + } else { + elements.push(MultimodalElement { + m_type: "text".to_string(), + m_content: "Server returned an image, but model does not support multimodality".to_string(), + }) + } + }, + RawContent::Audio(_) => { + elements.push(MultimodalElement { + m_type: "text".to_string(), + m_content: "Server returned audio, which is not supported".to_string(), + }) + }, + RawContent::Resource(_) => { + elements.push(MultimodalElement { + m_type: "text".to_string(), + m_content: "Server returned resource, which is not supported".to_string(), + }) + }, + } + } + + let content = if elements.iter().all(|el| el.m_type == "text") { + ChatContent::SimpleText( + elements.into_iter().map(|el| el.m_content).collect::>().join("\n\n") + ) + } else { + ChatContent::Multimodal(elements) + }; + + ContextEnum::ChatMessage(ChatMessage { + role: "tool".to_string(), + content, + tool_calls: None, + tool_call_id: tool_call_id.clone(), + ..Default::default() + }) + } + Err(e) => { + let error_msg = format!("Failed to call tool: {:?}", e); + tracing::error!("{}", error_msg); + _add_log_entry(session_logs.clone(), error_msg).await; + return Err(e.to_string()); + } + }; + + Ok((false, vec![result_message])) + } + + fn tool_depends_on(&self) -> Vec { + vec![] + } + + fn tool_description(&self) -> ToolDesc { + // self.mcp_tool.input_schema = Object { + // "properties": Object { + // "a": Object { + // "title": String("A"), + // "type": String("integer") + // }, + // "b": Object { + // "title": String("B"), + // "type": String("integer") + // } + // }, + // "required": Array [ + // String("a"), + // String("b") + // ], + // "title": String("addArguments"), + // "type": String("object") + // } + let mut parameters = vec![]; + let mut parameters_required = vec![]; + + if let Some(serde_json::Value::Object(properties)) = self.mcp_tool.input_schema.get("properties") { + for (name, prop) in properties { + if let serde_json::Value::Object(prop_obj) = prop { + let param_type = prop_obj.get("type").and_then(|v| v.as_str()).unwrap_or("string").to_string(); + let description = prop_obj.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string(); + parameters.push(ToolParam { + name: name.clone(), + param_type, + description, + }); + } + } + } + if let Some(serde_json::Value::Array(required)) = self.mcp_tool.input_schema.get("required") { + for req in required { + if let Some(req_str) = req.as_str() { + parameters_required.push(req_str.to_string()); + } + } + } + + ToolDesc { + name: self.tool_name(), + agentic: true, + experimental: false, + description: self.mcp_tool.description.to_owned().unwrap_or_default().to_string(), + parameters, + parameters_required, + } + } + + fn tool_name(&self) -> String { + let yaml_name = std::path::Path::new(&self.config_path) + .file_stem() + .and_then(|name| name.to_str()) + .unwrap_or("unknown"); + let sanitized_yaml_name = format!("{}_{}", yaml_name, self.mcp_tool.name) + .chars() + .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' }) + .collect::(); + sanitized_yaml_name + } + + async fn command_to_match_against_confirm_deny( + &self, + _ccx: Arc>, + _args: &HashMap, + ) -> Result { + let command = self.mcp_tool.name.clone(); + tracing::info!("MCP command_to_match_against_confirm_deny() returns {:?}", command); + Ok(command.to_string()) + } + + fn confirm_deny_rules(&self) -> Option { + Some(self.common.confirmation.clone()) + } + + fn has_config_path(&self) -> Option { + Some(self.config_path.clone()) + } +} diff --git a/refact-agent/engine/src/integrations/mod.rs b/refact-agent/engine/src/integrations/mod.rs index 39a0463be..5264ac6cc 100644 --- a/refact-agent/engine/src/integrations/mod.rs +++ b/refact-agent/engine/src/integrations/mod.rs @@ -18,7 +18,7 @@ pub mod integr_mysql; pub mod integr_cmdline; pub mod integr_cmdline_service; pub mod integr_shell; -pub mod integr_mcp; +pub mod mcp; pub mod process_io_utils; pub mod docker; @@ -52,7 +52,7 @@ pub fn integration_from_name(n: &str) -> Result) }, mcp if mcp.starts_with("mcp_") => { - Ok(Box::new(integr_mcp::IntegrationMCP {..Default::default()}) as Box) + Ok(Box::new(mcp::IntegrationMCP {..Default::default()}) as Box) }, "isolation" => Ok(Box::new(docker::integr_isolation::IntegrationIsolation {..Default::default()}) as Box), _ => Err(format!("Unknown integration name: {}", n)), From 9b82286a15367366a1c0eebb844581491d57526b Mon Sep 17 00:00:00 2001 From: alashchev17 Date: Fri, 9 May 2025 17:27:51 +0200 Subject: [PATCH 14/15] feat: added headers support --- .../IntegrationForm/FormFields.tsx | 4 + .../IntegrationForm/IntegrationForm.tsx | 5 ++ ...ntVariablesTable.tsx => KeyValueTable.tsx} | 34 ++++---- .../IntegrationsView/IntegrationsView.tsx | 2 + .../hooks/useFormAvailability.ts | 10 +++ .../IntegrationsView/hooks/useIntegrations.ts | 14 +++- .../Integrations/IntegrationFormField.tsx | 79 +++++++++++++++---- .../gui/src/services/refact/integrations.ts | 12 ++- refact-agent/gui/src/services/refact/types.ts | 8 ++ 9 files changed, 132 insertions(+), 36 deletions(-) rename refact-agent/gui/src/components/IntegrationsView/IntegrationsTable/{EnvironmentVariablesTable.tsx => KeyValueTable.tsx} (90%) diff --git a/refact-agent/gui/src/components/IntegrationsView/IntegrationForm/FormFields.tsx b/refact-agent/gui/src/components/IntegrationsView/IntegrationForm/FormFields.tsx index 24eefa8ba..8249991e0 100644 --- a/refact-agent/gui/src/components/IntegrationsView/IntegrationForm/FormFields.tsx +++ b/refact-agent/gui/src/components/IntegrationsView/IntegrationForm/FormFields.tsx @@ -24,6 +24,7 @@ type FormFieldsProps = { onToolParameters: (data: ToolParameterEntity[]) => void; onArguments: (updatedArgs: string[]) => void; onEnvs: (updatedEnvs: Record) => void; + onHeaders: (updatedHeaders: Record) => void; }; export const FormFields: FC = ({ @@ -34,6 +35,7 @@ export const FormFields: FC = ({ onToolParameters, onArguments, onEnvs, + onHeaders, }) => { const { integr_config_path, @@ -56,6 +58,7 @@ export const FormFields: FC = ({ onToolParameters={onToolParameters} onArguments={onArguments} onEnvs={onEnvs} + onHeaders={onHeaders} /> ))} {Object.keys(extraFields).map((fieldKey) => ( @@ -71,6 +74,7 @@ export const FormFields: FC = ({ onToolParameters={onToolParameters} onArguments={onArguments} onEnvs={onEnvs} + onHeaders={onHeaders} /> ))} diff --git a/refact-agent/gui/src/components/IntegrationsView/IntegrationForm/IntegrationForm.tsx b/refact-agent/gui/src/components/IntegrationsView/IntegrationForm/IntegrationForm.tsx index 9dd353a05..57d1ace1d 100644 --- a/refact-agent/gui/src/components/IntegrationsView/IntegrationForm/IntegrationForm.tsx +++ b/refact-agent/gui/src/components/IntegrationsView/IntegrationForm/IntegrationForm.tsx @@ -48,6 +48,7 @@ type IntegrationFormProps = { setMCPEnvironmentVariables: React.Dispatch< React.SetStateAction> >; + setHeaders: React.Dispatch>>; setToolParameters: React.Dispatch< React.SetStateAction >; @@ -76,6 +77,7 @@ export const IntegrationForm: FC = ({ setConfirmationRules, setMCPArguments, setMCPEnvironmentVariables, + setHeaders, setToolParameters, handleSwitchIntegration, }) => { @@ -93,12 +95,14 @@ export const IntegrationForm: FC = ({ handleToolParameters, handleMCPArguments, handleMCPEnvironmentVariables, + handleHeaders, } = useFormAvailability({ setAvailabilityValues, setConfirmationRules, setToolParameters, setMCPArguments, setMCPEnvironmentVariables, + setHeaders, }); // Set initial values from integration data @@ -199,6 +203,7 @@ export const IntegrationForm: FC = ({ onToolParameters={handleToolParameters} onArguments={handleMCPArguments} onEnvs={handleMCPEnvironmentVariables} + onHeaders={handleHeaders} /> diff --git a/refact-agent/gui/src/components/IntegrationsView/IntegrationsTable/EnvironmentVariablesTable.tsx b/refact-agent/gui/src/components/IntegrationsView/IntegrationsTable/KeyValueTable.tsx similarity index 90% rename from refact-agent/gui/src/components/IntegrationsView/IntegrationsTable/EnvironmentVariablesTable.tsx rename to refact-agent/gui/src/components/IntegrationsView/IntegrationsTable/KeyValueTable.tsx index 7ac98d507..112f855d9 100644 --- a/refact-agent/gui/src/components/IntegrationsView/IntegrationsTable/EnvironmentVariablesTable.tsx +++ b/refact-agent/gui/src/components/IntegrationsView/IntegrationsTable/KeyValueTable.tsx @@ -13,21 +13,25 @@ import styles from "./ConfirmationTable.module.css"; import { debugIntegrations } from "../../../debugConfig"; import { MCPEnvs } from "../../../services/refact"; -type EnvironmentVariablesTableProps = { - initialData: MCPEnvs; - onMCPEnvironmentVariables: (data: MCPEnvs) => void; +type KeyValueTableProps = { + initialData: Record; + onChange: (data: Record) => void; + columnNames?: string[]; + emptyMessage?: string; }; -type EnvVarRow = { +type KeyValueRow = { key: string; value: string; originalKey: string; order: number; }; -export const EnvironmentVariablesTable: FC = ({ +export const KeyValueTable: FC = ({ initialData, - onMCPEnvironmentVariables, + onChange, + columnNames = ["Key", "Value"], + emptyMessage, }) => { const [nextOrder, setNextOrder] = useState( () => Object.keys(initialData).length, @@ -106,14 +110,14 @@ export const EnvironmentVariablesTable: FC = ({ }; useEffect(() => { - onMCPEnvironmentVariables(data); - }, [data, onMCPEnvironmentVariables]); + onChange(data); + }, [data, onChange]); const tableData = useMemo( () => Object.entries(data) .map( - ([key, value]): EnvVarRow => ({ + ([key, value]): KeyValueRow => ({ key, value, originalKey: key, @@ -125,14 +129,14 @@ export const EnvironmentVariablesTable: FC = ({ ); useEffect(() => { - debugIntegrations(`[DEBUG MCP]: envs table data: `, tableData); + debugIntegrations(`[DEBUG]: KeyValueTable data changed: `, tableData); }, [tableData]); - const columns = useMemo[]>( + const columns = useMemo[]>( () => [ { id: "key", - header: "Environment Variable", + header: columnNames[0], cell: ({ row }) => ( = ({ }, { id: "value", - header: "Value", + header: columnNames[1], cell: ({ row }) => ( = ({ )) ) : ( - - No environment variables set yet - + {emptyMessage} )} diff --git a/refact-agent/gui/src/components/IntegrationsView/IntegrationsView.tsx b/refact-agent/gui/src/components/IntegrationsView/IntegrationsView.tsx index 58d3a75c6..228681ec9 100644 --- a/refact-agent/gui/src/components/IntegrationsView/IntegrationsView.tsx +++ b/refact-agent/gui/src/components/IntegrationsView/IntegrationsView.tsx @@ -65,6 +65,7 @@ export const IntegrationsView: FC = ({ setToolParameters, setMCPArguments, setMCPEnvironmentVariables, + setHeaders, isDisabledIntegrationForm, isApplyingIntegrationForm, isDeletingIntegration, @@ -126,6 +127,7 @@ export const IntegrationsView: FC = ({ setConfirmationRules={setConfirmationRules} setMCPArguments={setMCPArguments} setMCPEnvironmentVariables={setMCPEnvironmentVariables} + setHeaders={setHeaders} setToolParameters={setToolParameters} handleSwitchIntegration={handleNavigateToIntegrationSetup} /> diff --git a/refact-agent/gui/src/components/IntegrationsView/hooks/useFormAvailability.ts b/refact-agent/gui/src/components/IntegrationsView/hooks/useFormAvailability.ts index b38d10caf..1bb44ca4f 100644 --- a/refact-agent/gui/src/components/IntegrationsView/hooks/useFormAvailability.ts +++ b/refact-agent/gui/src/components/IntegrationsView/hooks/useFormAvailability.ts @@ -17,6 +17,7 @@ type UseFormAvailabilityProps = { setMCPEnvironmentVariables: React.Dispatch< React.SetStateAction> >; + setHeaders: React.Dispatch>>; }; export const useFormAvailability = ({ @@ -25,6 +26,7 @@ export const useFormAvailability = ({ setToolParameters, setMCPArguments, setMCPEnvironmentVariables, + setHeaders, }: UseFormAvailabilityProps) => { const handleAvailabilityChange = useCallback( (fieldName: string, value: boolean) => { @@ -64,11 +66,19 @@ export const useFormAvailability = ({ [setMCPEnvironmentVariables], ); + const handleHeaders = useCallback( + (updatedHeaders: Record) => { + setHeaders(updatedHeaders); + }, + [setHeaders], + ); + return { handleAvailabilityChange, handleConfirmationChange, handleToolParameters, handleMCPArguments, handleMCPEnvironmentVariables, + handleHeaders, }; }; diff --git a/refact-agent/gui/src/components/IntegrationsView/hooks/useIntegrations.ts b/refact-agent/gui/src/components/IntegrationsView/hooks/useIntegrations.ts index aaabbea02..fa3432d69 100644 --- a/refact-agent/gui/src/components/IntegrationsView/hooks/useIntegrations.ts +++ b/refact-agent/gui/src/components/IntegrationsView/hooks/useIntegrations.ts @@ -38,6 +38,7 @@ import { IntegrationWithIconRecordAndAddress, IntegrationWithIconResponse, isDetailMessage, + isDictionary, isMCPArgumentsArray, isMCPEnvironmentsDict, isNotConfiguredIntegrationWithIconRecord, @@ -285,6 +286,8 @@ export const useIntegrations = ({ const [MCPEnvironmentVariables, setMCPEnvironmentVariables] = useState({}); + const [headers, setHeaders] = useState>({}); + const [toolParameters, setToolParameters] = useState< ToolParameterEntity[] | null >(null); @@ -414,6 +417,10 @@ export const useIntegrations = ({ ? !isEqual(currentIntegrationValues.env, MCPEnvironmentVariables) : false; + const headersChanged = isDictionary(currentIntegrationValues.env) + ? !isEqual(currentIntegrationValues.headers, headers) + : false; + const confirmationRulesChanged = !isEqual( confirmationRules, currentIntegrationValues.confirmation, @@ -423,7 +430,8 @@ export const useIntegrations = ({ confirmationRulesChanged || toolParametersChanged || MCPArgumentsChanged || - MCPEnvironmentVariablesChanged; + MCPEnvironmentVariablesChanged || + headersChanged; // Manually collecting data from the form const formElement = document.getElementById( @@ -488,6 +496,7 @@ export const useIntegrations = ({ currentIntegration, MCPArguments, MCPEnvironmentVariables, + headers, ]); const handleSetCurrentIntegrationSchema = ( @@ -575,6 +584,7 @@ export const useIntegrations = ({ if (currentIntegration.integr_name.includes("mcp")) { formValues.env = MCPEnvironmentVariables; formValues.args = MCPArguments; + formValues.headers = headers; } if (!currentIntegrationSchema.confirmation.not_applicable) { formValues.confirmation = confirmationRules; @@ -616,6 +626,7 @@ export const useIntegrations = ({ toolParameters, MCPArguments, MCPEnvironmentVariables, + headers, ], ); @@ -960,6 +971,7 @@ export const useIntegrations = ({ setToolParameters, setMCPArguments, setMCPEnvironmentVariables, + setHeaders, isDisabledIntegrationForm, isApplyingIntegrationForm, isDeletingIntegration, diff --git a/refact-agent/gui/src/features/Integrations/IntegrationFormField.tsx b/refact-agent/gui/src/features/Integrations/IntegrationFormField.tsx index 770305f41..d395a4e94 100644 --- a/refact-agent/gui/src/features/Integrations/IntegrationFormField.tsx +++ b/refact-agent/gui/src/features/Integrations/IntegrationFormField.tsx @@ -14,8 +14,8 @@ import styles from "./IntegrationFormField.module.css"; import { areToolParameters, + isDictionary, isMCPArgumentsArray, - isMCPEnvironmentsDict, type Integration, type IntegrationField, type IntegrationPrimitive, @@ -23,40 +23,57 @@ import { type ToolParameterEntity, } from "../../services/refact"; import { ArgumentsTable } from "../../components/IntegrationsView/IntegrationsTable/ArgumentsTable"; -import { EnvironmentVariablesTable } from "../../components/IntegrationsView/IntegrationsTable/EnvironmentVariablesTable"; +import { KeyValueTable } from "../../components/IntegrationsView/IntegrationsTable/KeyValueTable"; -type FieldType = "string" | "bool" | "int" | "tool" | "output"; +type FieldType = + | "string" + | "string_to_string_map" + | "bool" + | "int" + | "tool" + | "output"; // Helper functions const isFieldType = (value: string): value is FieldType => { - return ["string", "bool", "int", "tool", "output"].includes(value); + return [ + "string_to_string_map", + "string", + "bool", + "int", + "tool", + "output", + ].includes(value); }; const getDefaultValue = ({ field, values, fieldKey, - f_type, + // f_type, + f_type_raw, }: { field: IntegrationField>; values: Integration["integr_values"]; fieldKey: string; f_type: FieldType; -}): string | number | boolean | undefined => { + f_type_raw: string; +}): string | number | boolean | Record | undefined => { // First check if we have a value in the current values if (values && fieldKey in values) { return values[fieldKey]?.toString(); } // Otherwise use the default value based on type - switch (f_type) { + switch (f_type_raw) { case "int": return Number(field.f_default); case "bool": return Boolean(field.f_default); - case "tool": - case "output": + case "tool_parameters": + case "output_filter": return JSON.stringify(field.f_default); + case "string_to_string_map": + return field.f_default as Record; default: return field.f_default?.toString(); } @@ -74,12 +91,13 @@ type IntegrationFormFieldProps = { onToolParameters: (data: ToolParameterEntity[]) => void; onArguments: (updatedArgs: string[]) => void; onEnvs: (updatedEnvs: Record) => void; + onHeaders: (updatedHeaders: Record) => void; }; type CommonFieldProps = { id: string; name: string; - defaultValue?: string | number | boolean; + defaultValue?: string | number | boolean | Record; placeholder?: string; }; @@ -93,6 +111,7 @@ const FieldContent: FC<{ onToolParameters: (data: ToolParameterEntity[]) => void; onArguments: (updatedArgs: string[]) => void; onEnvs: (updatedEnvs: Record) => void; + onHeaders: (updatedHeaders: Record) => void; }> = ({ f_type, commonProps, @@ -102,6 +121,7 @@ const FieldContent: FC<{ onToolParameters, onArguments, onEnvs, + onHeaders, }) => { switch (f_type) { case "bool": { @@ -152,15 +172,32 @@ const FieldContent: FC<{ ); } if (f_size === "to_string_map") { - const valuesForTable = values?.[fieldKey]; - const tableData = isMCPEnvironmentsDict(valuesForTable) - ? valuesForTable - : {}; + const valuesForTable = values?.[fieldKey] ?? commonProps.defaultValue; + const tableData = isDictionary(valuesForTable) ? valuesForTable : {}; + + const columnsMapToArray: Record = { + env: ["Environment Variable", "Value"], + headers: ["Header Name", "Value"], + }; + const emptyMessageMap: Record = { + env: "No environment variables specified yet", + headers: "No headers specified yet", + }; + + const changeHandlersMap: Record< + string, + (updatedField: Record) => void + > = { + env: onEnvs, + headers: onHeaders, + }; return ( - ); } @@ -226,13 +263,20 @@ export const IntegrationFormField: FC = ({ onToolParameters, onArguments, onEnvs, + onHeaders, }) => { const splittedType = field.f_type.toString().split("_"); const [f_type_raw, ...rest] = splittedType; const f_size = rest.join("_"); const f_type = isFieldType(f_type_raw) ? f_type_raw : "string"; - const defaultValue = getDefaultValue({ field, values, fieldKey, f_type }); + const defaultValue = getDefaultValue({ + field, + values, + fieldKey, + f_type, + f_type_raw: field.f_type as string, + }); const commonProps = { id: fieldKey, @@ -273,6 +317,7 @@ export const IntegrationFormField: FC = ({ onToolParameters={onToolParameters} onArguments={onArguments} onEnvs={onEnvs} + onHeaders={onHeaders} /> {field.f_desc && ( diff --git a/refact-agent/gui/src/services/refact/integrations.ts b/refact-agent/gui/src/services/refact/integrations.ts index a2a347f2e..41f9b3adc 100644 --- a/refact-agent/gui/src/services/refact/integrations.ts +++ b/refact-agent/gui/src/services/refact/integrations.ts @@ -500,7 +500,7 @@ export type IntegrationField = { f_type: T; f_desc?: string; f_placeholder?: T; // should match f_type - f_default?: T; + f_default?: T | Record; f_label?: string; f_extra?: boolean; // rather the field is hidden by default or not smartlinks?: SmartLink[]; @@ -535,7 +535,15 @@ function isIntegrationField( if ("f_placeholder" in json && !isPrimitive(json.f_placeholder)) { return false; } - if ("f_default" in json && !isPrimitive(json.f_default)) { + if ( + "f_default" in json && + json.f_default !== undefined && + !( + isPrimitive(json.f_default) || + (typeof json.f_default === "object" && + Object.values(json.f_default).every(isPrimitive)) + ) + ) { return false; } if ("smartlinks" in json && !Array.isArray(json.smartlinks)) { diff --git a/refact-agent/gui/src/services/refact/types.ts b/refact-agent/gui/src/services/refact/types.ts index 8a45b0108..852cb452b 100644 --- a/refact-agent/gui/src/services/refact/types.ts +++ b/refact-agent/gui/src/services/refact/types.ts @@ -753,3 +753,11 @@ export function isMCPEnvironmentsDict(json: unknown): json is MCPEnvs { return Object.values(json).every((value) => typeof value === "string"); } + +export function isDictionary(json: unknown): json is Record { + if (!json) return false; + if (typeof json !== "object") return false; + if (Array.isArray(json)) return false; + + return Object.values(json).every((value) => typeof value === "string"); +} From d40a6925377a5a92d239ab930714051925d84de9 Mon Sep 17 00:00:00 2001 From: alashchev17 Date: Mon, 12 May 2025 18:46:58 +0200 Subject: [PATCH 15/15] fix: typechecking headers instead of envs --- .../src/components/IntegrationsView/hooks/useIntegrations.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/refact-agent/gui/src/components/IntegrationsView/hooks/useIntegrations.ts b/refact-agent/gui/src/components/IntegrationsView/hooks/useIntegrations.ts index fa3432d69..05be005a2 100644 --- a/refact-agent/gui/src/components/IntegrationsView/hooks/useIntegrations.ts +++ b/refact-agent/gui/src/components/IntegrationsView/hooks/useIntegrations.ts @@ -417,7 +417,7 @@ export const useIntegrations = ({ ? !isEqual(currentIntegrationValues.env, MCPEnvironmentVariables) : false; - const headersChanged = isDictionary(currentIntegrationValues.env) + const headersChanged = isDictionary(currentIntegrationValues.headers) ? !isEqual(currentIntegrationValues.headers, headers) : false;