From 82150cabdde8b70c317f3e72866ea5672a2252ae Mon Sep 17 00:00:00 2001 From: taobaorun Date: Wed, 13 Aug 2025 11:44:18 +0800 Subject: [PATCH 1/2] if Tool Impl or else throws McpError, use its values for the RPC error --- .../WebFluxStreamableIntegrationTests.java | 81 +++++++++++++++--- .../WebMvcStreamableIntegrationTests.java | 82 ++++++++++++++++--- ...stractMcpClientServerIntegrationTests.java | 62 +++++++------- .../spec/McpStreamableServerSession.java | 11 ++- 4 files changed, 180 insertions(+), 56 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java index 9eba0e57c..9f7021938 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -4,29 +4,39 @@ package io.modelcontextprotocol; -import java.time.Duration; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Timeout; -import org.springframework.http.server.reactive.HttpHandler; -import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.server.RouterFunctions; - import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServer.AsyncSpecification; import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; +import java.time.Duration; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + @Timeout(15) class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { @@ -88,4 +98,53 @@ public void after() { } } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallThrowMcpError(String clientType) { + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("toolThrowMcpError") + .description("toolThrowMcpError description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + throw new McpError( + new McpSchema.JSONRPCResponse.JSONRPCError(50000, "test exception message", Map.of("a", "b"))); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThatThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("toolThrowMcpError", Map.of()))) + .isInstanceOf(McpError.class) + .hasMessage("test exception message") + .satisfies(ex -> { + McpError mcpError = (McpError) ex; + assertThat(mcpError.getJsonRpcError()).isNotNull(); + assertThat(mcpError.getJsonRpcError().code()).isEqualTo(50000); + }); + + } + + mcpServer.close(); + } + } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java index 16012e7d9..a0367ca27 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -3,32 +3,39 @@ */ package io.modelcontextprotocol.server; -import static org.assertj.core.api.Assertions.assertThat; - -import java.time.Duration; - +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; +import reactor.core.scheduler.Schedulers; -import com.fasterxml.jackson.databind.ObjectMapper; +import java.time.Duration; +import java.util.Map; -import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.server.McpServer.AsyncSpecification; -import io.modelcontextprotocol.server.McpServer.SyncSpecification; -import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; -import reactor.core.scheduler.Schedulers; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; @Timeout(15) class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { @@ -139,4 +146,53 @@ protected void prepareClients(int port, String mcpEndpoint) { .requestTimeout(Duration.ofHours(10))); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallThrowMcpError(String clientType) { + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("toolThrowMcpError") + .description("toolThrowMcpError description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + throw new McpError( + new McpSchema.JSONRPCResponse.JSONRPCError(50000, "test exception message", Map.of("a", "b"))); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + assertThatThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("toolThrowMcpError", Map.of()))) + .isInstanceOf(McpError.class) + .hasMessage("test exception message") + .satisfies(ex -> { + McpError mcpError = (McpError) ex; + assertThat(mcpError.getJsonRpcError()).isNotNull(); + assertThat(mcpError.getJsonRpcError().code()).isEqualTo(50000); + }); + + } + + mcpServer.close(); + } + } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index 8e041d91e..07001b7a4 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -4,33 +4,6 @@ package io.modelcontextprotocol; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertWith; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; - -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.stream.Collectors; - -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; @@ -56,9 +29,36 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + public abstract class AbstractMcpClientServerIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @@ -759,7 +759,8 @@ void testToolCallSuccess(String clientType) { HttpResponse response = HttpClient.newHttpClient() .send(HttpRequest.newBuilder() .uri(URI.create( - "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README" + + ".md")) .GET() .build(), HttpResponse.BodyHandlers.ofString()); String responseBody = response.body(); @@ -844,7 +845,8 @@ void testToolListChangeHandlingSuccess(String clientType) { HttpResponse response = HttpClient.newHttpClient() .send(HttpRequest.newBuilder() .uri(URI.create( - "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README" + + ".md")) .GET() .build(), HttpResponse.BodyHandlers.ofString()); String responseBody = response.body(); @@ -1056,7 +1058,7 @@ void testLoggingNotification(String clientType) throws InterruptedException { @ValueSource(strings = { "httpclient", "webflux" }) void testProgressNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress - // token + // token CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); // Create a list to store received logging notifications List receivedNotifications = new CopyOnWriteArrayList<>(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index ef7967c1e..098906c0d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -177,9 +177,16 @@ public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStr .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, null)) .onErrorResume(e -> { + McpSchema.JSONRPCResponse.JSONRPCError error; + if (e instanceof McpError mcpError && mcpError.getJsonRpcError() != null) { + error = mcpError.getJsonRpcError(); + } + else { + error = new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + e.getMessage(), null); + } var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - e.getMessage(), null)); + null, error); return Mono.just(errorResponse); }) .flatMap(transport::sendMessage) From 5984017f7862d3291a142e16709f399b635c7bff Mon Sep 17 00:00:00 2001 From: taobaorun Date: Wed, 13 Aug 2025 13:18:46 +0800 Subject: [PATCH 2/2] Remove unnecessary changes --- ...stractMcpClientServerIntegrationTests.java | 62 +++++++++---------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index 07001b7a4..8e041d91e 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -4,6 +4,33 @@ package io.modelcontextprotocol; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; @@ -29,36 +56,9 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; import net.javacrumbs.jsonunit.core.Option; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.stream.Collectors; - -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.assertWith; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; - public abstract class AbstractMcpClientServerIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); @@ -759,8 +759,7 @@ void testToolCallSuccess(String clientType) { HttpResponse response = HttpClient.newHttpClient() .send(HttpRequest.newBuilder() .uri(URI.create( - "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README" - + ".md")) + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) .GET() .build(), HttpResponse.BodyHandlers.ofString()); String responseBody = response.body(); @@ -845,8 +844,7 @@ void testToolListChangeHandlingSuccess(String clientType) { HttpResponse response = HttpClient.newHttpClient() .send(HttpRequest.newBuilder() .uri(URI.create( - "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README" - + ".md")) + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) .GET() .build(), HttpResponse.BodyHandlers.ofString()); String responseBody = response.body(); @@ -1058,7 +1056,7 @@ void testLoggingNotification(String clientType) throws InterruptedException { @ValueSource(strings = { "httpclient", "webflux" }) void testProgressNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress - // token + // token CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); // Create a list to store received logging notifications List receivedNotifications = new CopyOnWriteArrayList<>();