diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 6eca3475..c19fb579 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -34,6 +34,7 @@ * * @author Christian Tzolov * @author Dariusz Jędrzejczyk + * @author Jihoon Kim */ public class McpClientSession implements McpSession { @@ -136,6 +137,14 @@ private void handle(McpSchema.JSONRPCMessage message) { sink.success(response); } } + else if (message instanceof McpSchema.JSONRPCBatchResponse batchResponse) { + logger.debug("Received Batch Response: {}", batchResponse); + batchResponse.responses().forEach(jsonrpcMessage -> { + if (jsonrpcMessage instanceof McpSchema.JSONRPCResponse response) { + this.handle(response); + } + }); + } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); handleIncomingRequest(request).onErrorResume(error -> { @@ -145,6 +154,17 @@ else if (message instanceof McpSchema.JSONRPCRequest request) { return this.transport.sendMessage(errorResponse).then(Mono.empty()); }).flatMap(this.transport::sendMessage).subscribe(); } + else if (message instanceof McpSchema.JSONRPCBatchRequest batchRequest) { + logger.debug("Received Batch Request: {}", batchRequest); + batchRequest.messages().forEach(jsonrpcMessage -> { + if (jsonrpcMessage instanceof McpSchema.JSONRPCRequest request) { + this.handle(request); + } + else if (jsonrpcMessage instanceof McpSchema.JSONRPCNotification notification) { + this.handle(notification); + } + }); + } else if (message instanceof McpSchema.JSONRPCNotification notification) { logger.debug("Received notification: {}", notification); handleIncomingNotification(notification) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e77edb3b..94ce12b8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo.As; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; @@ -29,6 +30,7 @@ * Context Protocol Schema. * * @author Christian Tzolov + * @author Jihoon Kim */ public final class McpSchema { @@ -140,42 +142,78 @@ public sealed interface Request }; /** - * Deserializes a JSON string into a JSONRPCMessage object. + * Deserializes a JSON string into a JSONRPCMessage object. Handles both single and + * batch JSON-RPC messages. * @param objectMapper The ObjectMapper instance to use for deserialization * @param jsonText The JSON string to deserialize - * @return A JSONRPCMessage instance using either the {@link JSONRPCRequest}, - * {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes. + * @return A JSONRPCMessage instance, either a {@link JSONRPCRequest}, + * {@link JSONRPCNotification}, {@link JSONRPCResponse}, or + * {@link JSONRPCBatchRequest}, or {@link JSONRPCBatchResponse} based on the JSON + * structure. * @throws IOException If there's an error during deserialization * @throws IllegalArgumentException If the JSON structure doesn't match any known * message type */ public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, String jsonText) throws IOException { - logger.debug("Received JSON message: {}", jsonText); - var map = objectMapper.readValue(jsonText, MAP_TYPE_REF); + JsonNode rootNode = objectMapper.readTree(jsonText); + + // Check if it's a batch request/response + if (rootNode.isArray()) { + // Batch processing + List messages = new ArrayList<>(); + for (JsonNode node : rootNode) { + Map map = objectMapper.convertValue(node, MAP_TYPE_REF); + messages.add(convertToJsonRpcMessage(map, objectMapper)); + } - // Determine message type based on specific JSON structure + // If it's a batch response, return JSONRPCBatchResponse + if (messages.get(0) instanceof JSONRPCResponse) { + return new JSONRPCBatchResponse(messages); + } + // If it's a batch request, return JSONRPCBatchRequest + else { + return new JSONRPCBatchRequest(messages); + } + } + + // Single message processing + Map map = objectMapper.readValue(jsonText, MAP_TYPE_REF); + return convertToJsonRpcMessage(map, objectMapper); + } + + /** + * Converts a map into a specific JSON-RPC message type. Based on the map's structure, + * this method determines whether the message is a {@link JSONRPCRequest}, + * {@link JSONRPCNotification}, or {@link JSONRPCResponse}. + * @param map The map representing the JSON structure + * @param objectMapper The ObjectMapper instance to use for deserialization + * @return The corresponding JSONRPCMessage instance (could be {@link JSONRPCRequest}, + * {@link JSONRPCNotification}, or {@link JSONRPCResponse}) + * @throws IllegalArgumentException If the map structure doesn't match any known + * message type + */ + private static JSONRPCMessage convertToJsonRpcMessage(Map map, ObjectMapper objectMapper) { if (map.containsKey("method") && map.containsKey("id")) { return objectMapper.convertValue(map, JSONRPCRequest.class); } - else if (map.containsKey("method") && !map.containsKey("id")) { + else if (map.containsKey("method")) { return objectMapper.convertValue(map, JSONRPCNotification.class); } else if (map.containsKey("result") || map.containsKey("error")) { return objectMapper.convertValue(map, JSONRPCResponse.class); } - throw new IllegalArgumentException("Cannot deserialize JSONRPCMessage: " + jsonText); + throw new IllegalArgumentException("Unknown JSON-RPC message type: " + map); } // --------------------------- // JSON-RPC Message Types // --------------------------- - public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { - - String jsonrpc(); + public sealed interface JSONRPCMessage + permits JSONRPCRequest, JSONRPCBatchRequest, JSONRPCNotification, JSONRPCResponse, JSONRPCBatchResponse { } @@ -188,6 +226,26 @@ public record JSONRPCRequest( // @formatter:off @JsonProperty("params") Object params) implements JSONRPCMessage { } // @formatter:on + public record JSONRPCBatchRequest(List messages) implements JSONRPCMessage { + public JSONRPCBatchRequest { + boolean valid = messages.stream() + .allMatch(message -> message instanceof JSONRPCRequest || message instanceof JSONRPCNotification); + if (!valid) { + throw new IllegalArgumentException( + "Only JSONRPCRequest or JSONRPCNotification are allowed in batch request."); + } + } + } + + public record JSONRPCBatchResponse(List responses) implements JSONRPCMessage { + public JSONRPCBatchResponse { + boolean valid = responses.stream().allMatch(response -> response instanceof JSONRPCResponse); + if (!valid) { + throw new IllegalArgumentException("Only JSONRPCResponse are allowed in batch response."); + } + } + } + @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record JSONRPCNotification( // @formatter:off diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 64315095..83efd344 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -11,6 +11,7 @@ import io.modelcontextprotocol.server.McpAsyncServerExchange; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; import reactor.core.publisher.Sinks; @@ -167,6 +168,13 @@ public Mono handle(McpSchema.JSONRPCMessage message) { } return Mono.empty(); } + else if (message instanceof McpSchema.JSONRPCBatchResponse batchResponse) { + logger.debug("Received Batch Response: {}", batchResponse); + return Flux.fromIterable(batchResponse.responses()) + .filter(jsonrpcMessage -> jsonrpcMessage instanceof McpSchema.JSONRPCResponse) + .flatMap(this::handle) + .then(); + } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); return handleIncomingRequest(request).onErrorResume(error -> { @@ -177,6 +185,21 @@ else if (message instanceof McpSchema.JSONRPCRequest request) { return this.transport.sendMessage(errorResponse).then(Mono.empty()); }).flatMap(this.transport::sendMessage); } + else if (message instanceof McpSchema.JSONRPCBatchRequest batchRequest) { + logger.debug("Received Batch Request: {}", batchRequest); + return Flux.fromIterable(batchRequest.messages()).flatMap(jsonrpcMessage -> { + if (jsonrpcMessage instanceof McpSchema.JSONRPCRequest request) { + return this.handle(request); + } + else if (jsonrpcMessage instanceof McpSchema.JSONRPCNotification notification) { + return this.handle(notification); + } + else { + logger.warn("Unsupported message in batch request: {}", jsonrpcMessage); + return Mono.empty(); + } + }).then(); + } else if (message instanceof McpSchema.JSONRPCNotification notification) { // TODO handle errors for communication to without initialization // happening first diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java index 482d0aac..6dea96c8 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java @@ -63,6 +63,10 @@ public McpSchema.JSONRPCMessage getLastSentMessage() { return !sent.isEmpty() ? sent.get(sent.size() - 1) : null; } + public McpSchema.JSONRPCBatchResponse getSentMessagesAsBatchResponse() { + return new McpSchema.JSONRPCBatchResponse(sent); + } + private volatile boolean connected = false; @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index f72be43e..00360098 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -5,7 +5,10 @@ package io.modelcontextprotocol.spec; import java.time.Duration; +import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.MockMcpClientTransport; @@ -26,6 +29,7 @@ * request-response correlation, and notification processing. * * @author Christian Tzolov + * @author Jihoon Kim */ class McpClientSessionTests { @@ -155,6 +159,40 @@ void testRequestHandling() { assertThat(response.error()).isNull(); } + @Test + void testBatchRequestHandling() { + String echoMessage1 = "Hello MCP 1!"; + String echoMessage2 = "Hello MCP 2!"; + + // Request handler: echoes the input + Map> requestHandlers = Map.of(ECHO_METHOD, + params -> Mono.just(params)); + transport = new MockMcpClientTransport(); + session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); + + // Simulate incoming batch request + McpSchema.JSONRPCRequest request1 = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, + "batch-id-1", echoMessage1); + McpSchema.JSONRPCRequest request2 = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, + "batch-id-2", echoMessage2); + McpSchema.JSONRPCBatchRequest batchRequest = new McpSchema.JSONRPCBatchRequest(List.of(request1, request2)); + transport.simulateIncomingMessage(batchRequest); + + // Wait for async processing + McpSchema.JSONRPCBatchResponse batchResponse = transport.getSentMessagesAsBatchResponse(); + List responses = batchResponse.responses(); + + assertThat(responses).hasSize(2); + assertThat(responses).allMatch(resp -> resp instanceof McpSchema.JSONRPCResponse); + + Map responseMap = responses.stream() + .map(resp -> (McpSchema.JSONRPCResponse) resp) + .collect(Collectors.toMap(McpSchema.JSONRPCResponse::id, Function.identity())); + + assertThat(responseMap.get("batch-id-1").result()).isEqualTo(echoMessage1); + assertThat(responseMap.get("batch-id-2").result()).isEqualTo(echoMessage2); + } + @Test void testNotificationHandling() { Sinks.One receivedParams = Sinks.one(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index a41fc095..c892c997 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -22,6 +22,7 @@ /** * @author Christian Tzolov + * @author Jihoon Kim */ public class McpSchemaTests { @@ -208,6 +209,152 @@ void testJSONRPCResponseWithError() throws Exception { {"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid request"}}""")); } + @Test + void testJSONRPCBatchRequest() throws Exception { + Map params1 = Map.of("key1", "value1"); + Map params2 = Map.of("key2", "value2"); + + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", 1, + params1); + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + "method2", params2); + + // Serialize the request and notification to JSON + String batchRequestJson = mapper.writeValueAsString(List.of(request, notification)); + + // Use the deserializeJsonRpcMessage method + McpSchema.JSONRPCBatchRequest batchRequest = (McpSchema.JSONRPCBatchRequest) McpSchema + .deserializeJsonRpcMessage(mapper, batchRequestJson); + + // Assertions + assertThat(batchRequest.messages()).hasSize(2); + assertThat(batchRequest.messages().get(0)).isInstanceOf(McpSchema.JSONRPCRequest.class); + assertThat(batchRequest.messages().get(1)).isInstanceOf(McpSchema.JSONRPCNotification.class); + } + + @Test + void testJSONRPCBatchResponse() throws Exception { + // The JSON string for batch response containing both results and errors + String batchResponseJson = """ + [ + {"jsonrpc": "2.0", "result": 7, "id": "1"}, + {"jsonrpc": "2.0", "result": 19, "id": "2"}, + {"jsonrpc": "2.0", "result": ["hello", 5], "id": "9"} + ] + """; + + // Deserialize the batch response JSON string + McpSchema.JSONRPCBatchResponse batchResponse = (McpSchema.JSONRPCBatchResponse) McpSchema + .deserializeJsonRpcMessage(mapper, batchResponseJson); + + // Assertions + assertThat(batchResponse.responses()).hasSize(3); + assertThat(batchResponse.responses().get(0)).isInstanceOf(McpSchema.JSONRPCResponse.class); + assertThat(batchResponse.responses().get(1)).isInstanceOf(McpSchema.JSONRPCResponse.class); + assertThat(batchResponse.responses().get(2)).isInstanceOf(McpSchema.JSONRPCResponse.class); + } + + @Test + void testJSONRPCBatchResponseWithError() throws Exception { + // The JSON string for batch response containing both results and errors + String batchResponseJson = """ + [ + {"jsonrpc": "2.0", "result": 7, "id": "1"}, + {"jsonrpc": "2.0", "result": 19, "id": "2"}, + {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid Request"}, "id": null}, + {"jsonrpc": "2.0", "error": {"code": -32601, "message": "Method not found"}, "id": "5"}, + {"jsonrpc": "2.0", "result": ["hello", 5], "id": "9"} + ] + """; + + // Deserialize the batch response JSON string + McpSchema.JSONRPCBatchResponse batchResponse = (McpSchema.JSONRPCBatchResponse) McpSchema + .deserializeJsonRpcMessage(mapper, batchResponseJson); + + // Assertions + assertThat(batchResponse.responses()).hasSize(5); + + // Check the first response (id: "1") with result + McpSchema.JSONRPCResponse firstResponse = (McpSchema.JSONRPCResponse) batchResponse.responses().get(0); + assertThat(firstResponse.error()).isNull(); // Ensure error is null + assertThat(firstResponse.result()).isEqualTo(7); + + // Check the second response (id: "2") with result + McpSchema.JSONRPCResponse secondResponse = (McpSchema.JSONRPCResponse) batchResponse.responses().get(1); + assertThat(secondResponse.error()).isNull(); // Ensure error is null + assertThat(secondResponse.result()).isEqualTo(19); + + // Check the third response (id: null) with error + McpSchema.JSONRPCResponse thirdResponse = (McpSchema.JSONRPCResponse) batchResponse.responses().get(2); + assertThat(thirdResponse.result()).isNull(); // Ensure result is null + assertThat(thirdResponse.error()).isNotNull(); // Ensure error is not null + assertThat(thirdResponse.error().code()).isEqualTo(-32600); + assertThat(thirdResponse.error().message()).isEqualTo("Invalid Request"); + + // Check the fourth response (id: "5") with error + McpSchema.JSONRPCResponse fourthResponse = (McpSchema.JSONRPCResponse) batchResponse.responses().get(3); + assertThat(fourthResponse.result()).isNull(); // Ensure result is null + assertThat(fourthResponse.error()).isNotNull(); // Ensure error is not null + assertThat(fourthResponse.error().code()).isEqualTo(-32601); + assertThat(fourthResponse.error().message()).isEqualTo("Method not found"); + + // Check the fifth response (id: "9") with result + McpSchema.JSONRPCResponse fifthResponse = (McpSchema.JSONRPCResponse) batchResponse.responses().get(4); + assertThat(fifthResponse.error()).isNull(); // Ensure error is null + assertThat(fifthResponse.result()).isEqualTo(List.of("hello", 5)); // Ensure + // result + // matches + } + + @Test + void testValidJSONRPCBatchRequest() { + // Create valid messages: JSONRPCRequest and JSONRPCNotification + McpSchema.JSONRPCRequest validRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", + 1, Map.of("key", "value")); + McpSchema.JSONRPCNotification validNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + "notification_method", Map.of("key", "value")); + + // Create a valid batch request + McpSchema.JSONRPCBatchRequest validBatchRequest = new McpSchema.JSONRPCBatchRequest( + List.of(validRequest, validNotification)); + } + + @Test + void testInvalidJSONRPCBatchRequest() { + // Create an invalid message: a JSONRPCResponse which is not allowed in a request + McpSchema.JSONRPCResponse invalidResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, + Map.of("result", "value"), null); + + // Test that an exception is thrown when trying to create a batch request with + // invalid messages + assertThatThrownBy(() -> new McpSchema.JSONRPCBatchRequest(List.of(invalidResponse))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Only JSONRPCRequest or JSONRPCNotification are allowed in batch request."); + } + + @Test + void testValidJSONRPCBatchResponse() { + // Create a valid response: JSONRPCResponse + McpSchema.JSONRPCResponse validResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, + Map.of("result", "value"), null); + + // Create a valid batch response + McpSchema.JSONRPCBatchResponse validBatchResponse = new McpSchema.JSONRPCBatchResponse(List.of(validResponse)); + } + + @Test + void testInvalidJSONRPCBatchResponse() { + // Create an invalid message: JSONRPCRequest which is not allowed in a response + McpSchema.JSONRPCRequest invalidRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", + 1, Map.of("key", "value")); + + // Test that an exception is thrown when trying to create a batch response with + // invalid messages + assertThatThrownBy(() -> new McpSchema.JSONRPCBatchResponse(List.of(invalidRequest))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Only JSONRPCResponse are allowed in batch response."); + } + // Initialization Tests @Test