Skip to content

feat: add support for JSON-RPC batch requests and responses #191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
*
* @author Christian Tzolov
* @author Dariusz Jędrzejczyk
* @author Jihoon Kim
*/
public class McpClientSession implements McpSession {

Expand Down Expand Up @@ -136,6 +137,14 @@ private void handle(McpSchema.JSONRPCMessage message) {
sink.success(response);
}
}
else if (message instanceof McpSchema.JSONRPCBatchResponse batchResponse) {
logger.debug("Received Batch Response: {}", batchResponse);
batchResponse.responses().forEach(jsonrpcMessage -> {
if (jsonrpcMessage instanceof McpSchema.JSONRPCResponse response) {
this.handle(response);
}
});
}
else if (message instanceof McpSchema.JSONRPCRequest request) {
logger.debug("Received request: {}", request);
handleIncomingRequest(request).onErrorResume(error -> {
Expand All @@ -145,6 +154,17 @@ else if (message instanceof McpSchema.JSONRPCRequest request) {
return this.transport.sendMessage(errorResponse).then(Mono.empty());
}).flatMap(this.transport::sendMessage).subscribe();
}
else if (message instanceof McpSchema.JSONRPCBatchRequest batchRequest) {
logger.debug("Received Batch Request: {}", batchRequest);
batchRequest.messages().forEach(jsonrpcMessage -> {
if (jsonrpcMessage instanceof McpSchema.JSONRPCRequest request) {
this.handle(request);
}
else if (jsonrpcMessage instanceof McpSchema.JSONRPCNotification notification) {
this.handle(notification);
}
});
}
else if (message instanceof McpSchema.JSONRPCNotification notification) {
logger.debug("Received notification: {}", notification);
handleIncomingNotification(notification)
Expand Down
80 changes: 69 additions & 11 deletions mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.JsonTypeInfo.As;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
Expand All @@ -29,6 +30,7 @@
* Context Protocol Schema</a>.
*
* @author Christian Tzolov
* @author Jihoon Kim
*/
public final class McpSchema {

Expand Down Expand Up @@ -140,42 +142,78 @@ public sealed interface Request
};

/**
* Deserializes a JSON string into a JSONRPCMessage object.
* Deserializes a JSON string into a JSONRPCMessage object. Handles both single and
* batch JSON-RPC messages.
* @param objectMapper The ObjectMapper instance to use for deserialization
* @param jsonText The JSON string to deserialize
* @return A JSONRPCMessage instance using either the {@link JSONRPCRequest},
* {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes.
* @return A JSONRPCMessage instance, either a {@link JSONRPCRequest},
* {@link JSONRPCNotification}, {@link JSONRPCResponse}, or
* {@link JSONRPCBatchRequest}, or {@link JSONRPCBatchResponse} based on the JSON
* structure.
* @throws IOException If there's an error during deserialization
* @throws IllegalArgumentException If the JSON structure doesn't match any known
* message type
*/
public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, String jsonText)
throws IOException {

logger.debug("Received JSON message: {}", jsonText);

var map = objectMapper.readValue(jsonText, MAP_TYPE_REF);
JsonNode rootNode = objectMapper.readTree(jsonText);

// Check if it's a batch request/response
if (rootNode.isArray()) {
// Batch processing
List<JSONRPCMessage> messages = new ArrayList<>();
for (JsonNode node : rootNode) {
Map<String, Object> map = objectMapper.convertValue(node, MAP_TYPE_REF);
messages.add(convertToJsonRpcMessage(map, objectMapper));
}

// Determine message type based on specific JSON structure
// If it's a batch response, return JSONRPCBatchResponse
if (messages.get(0) instanceof JSONRPCResponse) {
return new JSONRPCBatchResponse(messages);
}
// If it's a batch request, return JSONRPCBatchRequest
else {
return new JSONRPCBatchRequest(messages);
}
}

// Single message processing
Map<String, Object> map = objectMapper.readValue(jsonText, MAP_TYPE_REF);
return convertToJsonRpcMessage(map, objectMapper);
}

/**
* Converts a map into a specific JSON-RPC message type. Based on the map's structure,
* this method determines whether the message is a {@link JSONRPCRequest},
* {@link JSONRPCNotification}, or {@link JSONRPCResponse}.
* @param map The map representing the JSON structure
* @param objectMapper The ObjectMapper instance to use for deserialization
* @return The corresponding JSONRPCMessage instance (could be {@link JSONRPCRequest},
* {@link JSONRPCNotification}, or {@link JSONRPCResponse})
* @throws IllegalArgumentException If the map structure doesn't match any known
* message type
*/
private static JSONRPCMessage convertToJsonRpcMessage(Map<String, Object> map, ObjectMapper objectMapper) {
if (map.containsKey("method") && map.containsKey("id")) {
return objectMapper.convertValue(map, JSONRPCRequest.class);
}
else if (map.containsKey("method") && !map.containsKey("id")) {
else if (map.containsKey("method")) {
return objectMapper.convertValue(map, JSONRPCNotification.class);
}
else if (map.containsKey("result") || map.containsKey("error")) {
return objectMapper.convertValue(map, JSONRPCResponse.class);
}

throw new IllegalArgumentException("Cannot deserialize JSONRPCMessage: " + jsonText);
throw new IllegalArgumentException("Unknown JSON-RPC message type: " + map);
}

// ---------------------------
// JSON-RPC Message Types
// ---------------------------
public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse {

String jsonrpc();
public sealed interface JSONRPCMessage
permits JSONRPCRequest, JSONRPCBatchRequest, JSONRPCNotification, JSONRPCResponse, JSONRPCBatchResponse {

}

Expand All @@ -188,6 +226,26 @@ public record JSONRPCRequest( // @formatter:off
@JsonProperty("params") Object params) implements JSONRPCMessage {
} // @formatter:on

public record JSONRPCBatchRequest(List<JSONRPCMessage> messages) implements JSONRPCMessage {
public JSONRPCBatchRequest {
boolean valid = messages.stream()
.allMatch(message -> message instanceof JSONRPCRequest || message instanceof JSONRPCNotification);
if (!valid) {
throw new IllegalArgumentException(
"Only JSONRPCRequest or JSONRPCNotification are allowed in batch request.");
}
}
}

public record JSONRPCBatchResponse(List<JSONRPCMessage> responses) implements JSONRPCMessage {
public JSONRPCBatchResponse {
boolean valid = responses.stream().allMatch(response -> response instanceof JSONRPCResponse);
if (!valid) {
throw new IllegalArgumentException("Only JSONRPCResponse are allowed in batch response.");
}
}
}

@JsonInclude(JsonInclude.Include.NON_ABSENT)
@JsonIgnoreProperties(ignoreUnknown = true)
public record JSONRPCNotification( // @formatter:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;
import reactor.core.publisher.Sinks;
Expand Down Expand Up @@ -167,6 +168,13 @@ public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
}
return Mono.empty();
}
else if (message instanceof McpSchema.JSONRPCBatchResponse batchResponse) {
logger.debug("Received Batch Response: {}", batchResponse);
return Flux.fromIterable(batchResponse.responses())
.filter(jsonrpcMessage -> jsonrpcMessage instanceof McpSchema.JSONRPCResponse)
.flatMap(this::handle)
.then();
}
else if (message instanceof McpSchema.JSONRPCRequest request) {
logger.debug("Received request: {}", request);
return handleIncomingRequest(request).onErrorResume(error -> {
Expand All @@ -177,6 +185,21 @@ else if (message instanceof McpSchema.JSONRPCRequest request) {
return this.transport.sendMessage(errorResponse).then(Mono.empty());
}).flatMap(this.transport::sendMessage);
}
else if (message instanceof McpSchema.JSONRPCBatchRequest batchRequest) {
logger.debug("Received Batch Request: {}", batchRequest);
return Flux.fromIterable(batchRequest.messages()).flatMap(jsonrpcMessage -> {
if (jsonrpcMessage instanceof McpSchema.JSONRPCRequest request) {
return this.handle(request);
}
else if (jsonrpcMessage instanceof McpSchema.JSONRPCNotification notification) {
return this.handle(notification);
}
else {
logger.warn("Unsupported message in batch request: {}", jsonrpcMessage);
return Mono.empty();
}
}).then();
}
else if (message instanceof McpSchema.JSONRPCNotification notification) {
// TODO handle errors for communication to without initialization
// happening first
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ public McpSchema.JSONRPCMessage getLastSentMessage() {
return !sent.isEmpty() ? sent.get(sent.size() - 1) : null;
}

public McpSchema.JSONRPCBatchResponse getSentMessagesAsBatchResponse() {
return new McpSchema.JSONRPCBatchResponse(sent);
}

private volatile boolean connected = false;

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
package io.modelcontextprotocol.spec;

import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.MockMcpClientTransport;
Expand All @@ -26,6 +29,7 @@
* request-response correlation, and notification processing.
*
* @author Christian Tzolov
* @author Jihoon Kim
*/
class McpClientSessionTests {

Expand Down Expand Up @@ -155,6 +159,40 @@ void testRequestHandling() {
assertThat(response.error()).isNull();
}

@Test
void testBatchRequestHandling() {
String echoMessage1 = "Hello MCP 1!";
String echoMessage2 = "Hello MCP 2!";

// Request handler: echoes the input
Map<String, McpClientSession.RequestHandler<?>> requestHandlers = Map.of(ECHO_METHOD,
params -> Mono.just(params));
transport = new MockMcpClientTransport();
session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of());

// Simulate incoming batch request
McpSchema.JSONRPCRequest request1 = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD,
"batch-id-1", echoMessage1);
McpSchema.JSONRPCRequest request2 = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD,
"batch-id-2", echoMessage2);
McpSchema.JSONRPCBatchRequest batchRequest = new McpSchema.JSONRPCBatchRequest(List.of(request1, request2));
transport.simulateIncomingMessage(batchRequest);

// Wait for async processing
McpSchema.JSONRPCBatchResponse batchResponse = transport.getSentMessagesAsBatchResponse();
List<McpSchema.JSONRPCMessage> responses = batchResponse.responses();

assertThat(responses).hasSize(2);
assertThat(responses).allMatch(resp -> resp instanceof McpSchema.JSONRPCResponse);

Map<Object, McpSchema.JSONRPCResponse> responseMap = responses.stream()
.map(resp -> (McpSchema.JSONRPCResponse) resp)
.collect(Collectors.toMap(McpSchema.JSONRPCResponse::id, Function.identity()));

assertThat(responseMap.get("batch-id-1").result()).isEqualTo(echoMessage1);
assertThat(responseMap.get("batch-id-2").result()).isEqualTo(echoMessage2);
}

@Test
void testNotificationHandling() {
Sinks.One<Object> receivedParams = Sinks.one();
Expand Down
Loading