From e1e5c9665c2ee22cbaf02539e556089ba6b569b4 Mon Sep 17 00:00:00 2001 From: He-Pin Date: Thu, 8 May 2025 14:41:46 +0800 Subject: [PATCH] feat: Add request with context support. Signed-off-by: He-Pin --- .../WebFluxSseServerTransportProvider.java | 46 +++++++++----- .../WebFluxSseIntegrationTests.java | 10 ++-- .../WebMvcSseServerTransportProvider.java | 31 +++++++--- .../server/McpAsyncServer.java | 60 ++++++++++--------- .../server/McpServer.java | 4 +- .../server/McpServerFeatures.java | 24 ++++---- ...HttpServletSseServerTransportProvider.java | 30 +++++++--- .../StdioServerTransportProvider.java | 29 ++++++--- .../modelcontextprotocol/spec/McpContext.java | 29 +++++++++ .../spec/McpContextFactory.java | 9 +++ .../spec/McpContextKey.java | 10 ++++ .../spec/McpServerSession.java | 48 ++++++++++----- .../spec/McpServerTransportProvider.java | 6 ++ .../MockMcpServerTransportProvider.java | 15 +++-- .../StdioServerTransportProviderTests.java | 11 ++-- 15 files changed, 250 insertions(+), 112 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpContext.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpContextFactory.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpContextKey.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..0ecd58f3 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,16 +1,11 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransport; -import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.*; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -100,6 +95,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv private McpServerSession.Factory sessionFactory; + private McpContextFactory mcpContextFactory; + /** * Map of active client sessions, keyed by session ID. */ @@ -169,6 +166,11 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.sessionFactory = sessionFactory; } + @Override + public void setMcpContextFactory(final McpContextFactory mcpContextFactory) { + this.mcpContextFactory = mcpContextFactory; + } + /** * Broadcasts a JSON-RPC message to all connected clients through their SSE * connections. The message is serialized to JSON and sent as a server-sent event to @@ -261,7 +263,7 @@ private Mono handleSseConnection(ServerRequest request) { .body(Flux.>create(sink -> { WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(sessionTransport, createContext(request)); String sessionId = session.getId(); logger.debug("Created new SSE connection for session: {}", sessionId); @@ -280,6 +282,18 @@ private Mono handleSseConnection(ServerRequest request) { }), ServerSentEvent.class); } + private McpContext createContext(final ServerRequest request) { + // create a context form the request + McpContext context; + if (mcpContextFactory != null) { + context = mcpContextFactory.create(request); + } + else { + context = McpContext.empty(); + } + return context; + } + /** * Handles incoming JSON-RPC messages from clients. Deserializes the message and * processes it through the configured message handler. @@ -314,14 +328,16 @@ private Mono handleMessage(ServerRequest request) { return request.bodyToMono(String.class).flatMap(body -> { try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> { - logger.error("Error processing message: {}", error.getMessage()); - // TODO: instead of signalling the error, just respond with 200 OK - // - the error is signalled on the SSE connection - // return ServerResponse.ok().build(); - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) - .bodyValue(new McpError(error.getMessage())); - }); + return session.handle(message, createContext(request)) + .flatMap(response -> ServerResponse.ok().build()) + .onErrorResume(error -> { + logger.error("Error processing message: {}", error.getMessage()); + // TODO: instead of signalling the error, just respond with 200 OK + // - the error is signalled on the SSE connection + // return ServerResponse.ok().build(); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError(error.getMessage())); + }); } catch (IllegalArgumentException | IOException e) { logger.error("Failed to deserialize message: {}", e.getMessage()); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2ba04746..71bb2fef 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -23,10 +23,10 @@ import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.*; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; @@ -767,9 +767,11 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { )); AtomicReference samplingRequest = new AtomicReference<>(); - BiFunction completionHandler = (mcpSyncServerExchange, - request) -> { - samplingRequest.set(request); + AtomicReference mcpContext = new AtomicReference<>(); + BiFunction, CompleteResult> completionHandler = ( + mcpSyncServerExchange, reqWithContext) -> { + samplingRequest.set(reqWithContext.request()); + mcpContext.set(reqWithContext.mcpContext()); return completionResponse; }; diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..4447caa6 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -6,17 +6,12 @@ import java.io.IOException; import java.time.Duration; -import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerTransport; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.*; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -97,6 +92,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private McpServerSession.Factory sessionFactory; + private McpContextFactory mcpContextFactory; + /** * Map of active client sessions, keyed by session ID. */ @@ -169,6 +166,11 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.sessionFactory = sessionFactory; } + @Override + public void setMcpContextFactory(McpContextFactory mcpContextFactory) { + this.mcpContextFactory = mcpContextFactory; + } + /** * Broadcasts a notification to all connected clients through their SSE connections. * The message is serialized to JSON and sent as an SSE event with type "message". If @@ -263,7 +265,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { }); WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(sessionTransport, createContext(request)); this.sessions.put(sessionId, session); try { @@ -284,6 +286,18 @@ private ServerResponse handleSseConnection(ServerRequest request) { } } + private McpContext createContext(final ServerRequest request) { + // create a context form the request + McpContext context; + if (mcpContextFactory != null) { + context = mcpContextFactory.create(request); + } + else { + context = McpContext.empty(); + } + return context; + } + /** * Handles incoming JSON-RPC messages from clients. This method: *
    @@ -316,7 +330,8 @@ private ServerResponse handleMessage(ServerRequest request) { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); // Process the message through the session's handle method - session.handle(message).block(); // Block for WebMVC compatibility + session.handle(message, createContext(request)).block(); // Block for WebMVC + // compatibility return ServerResponse.ok().build(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 3c112ad7..719391e5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -17,17 +17,13 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientSession; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.*; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; import io.modelcontextprotocol.util.Utils; @@ -302,7 +298,7 @@ private static class AsyncServerImpl extends McpAsyncServer { // Initialize request handlers for standard MCP methods // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params, context) -> Mono.just(Map.of())); // Add tools API handlers if the tool capability is enabled if (this.serverCapabilities.tools() != null) { @@ -335,7 +331,8 @@ private static class AsyncServerImpl extends McpAsyncServer { Map notificationHandlers = new HashMap<>(); - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, + (exchange, params, ctx) -> Mono.empty()); List, Mono>> rootsChangeConsumers = features .rootsChangeConsumers(); @@ -350,16 +347,17 @@ private static class AsyncServerImpl extends McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider + .setSessionFactory((transport, mcpContext) -> new McpServerSession(UUID.randomUUID().toString(), + requestTimeout, transport, this::asyncInitializeRequestHandler, + this::asyncInitNotificationHandler, requestHandlers, notificationHandlers)); } // --------------------------------------- // Lifecycle Management // --------------------------------------- private Mono asyncInitializeRequestHandler( - McpSchema.InitializeRequest initializeRequest) { + McpSchema.InitializeRequest initializeRequest, McpContext mcpContext) { return Mono.defer(() -> { logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", initializeRequest.protocolVersion(), initializeRequest.capabilities(), @@ -387,6 +385,10 @@ private Mono asyncInitializeRequestHandler( }); } + private Mono asyncInitNotificationHandler(McpContext mcpContext) { + return Mono.empty(); + } + public McpSchema.ServerCapabilities getServerCapabilities() { return this.serverCapabilities; } @@ -407,7 +409,7 @@ public void close() { private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( List, Mono>> rootsChangeConsumers) { - return (exchange, params) -> exchange.listRoots() + return (exchange, params, ctx) -> exchange.listRoots() .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) .onErrorResume(error -> { @@ -482,7 +484,7 @@ public Mono notifyToolsListChanged() { } private McpServerSession.RequestHandler toolsListRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); return Mono.just(new McpSchema.ListToolsResult(tools, null)); @@ -490,7 +492,7 @@ private McpServerSession.RequestHandler toolsListRequ } private McpServerSession.RequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, new TypeReference() { }); @@ -502,8 +504,8 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( if (toolSpecification.isEmpty()) { return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } - - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + var reqWithContext = new McpServerFeatures.RequestWithContext<>(callToolRequest, context); + return toolSpecification.map(tool -> tool.call().apply(exchange, reqWithContext)) .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -563,7 +565,7 @@ public Mono notifyResourcesListChanged() { } private McpServerSession.RequestHandler resourcesListRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { var resourceList = this.resources.values() .stream() .map(McpServerFeatures.AsyncResourceSpecification::resource) @@ -573,7 +575,7 @@ private McpServerSession.RequestHandler resources } private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono + return (exchange, params, context) -> Mono .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); } @@ -597,7 +599,7 @@ private List getResourceTemplates() { } private McpServerSession.RequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, new TypeReference() { }); @@ -610,8 +612,8 @@ private McpServerSession.RequestHandler resourcesR .matches(resourceUri)) .findFirst() .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); - - return specification.readHandler().apply(exchange, resourceRequest); + var reqWithContext = new McpServerFeatures.RequestWithContext<>(resourceRequest, context); + return specification.readHandler().apply(exchange, reqWithContext); }; } @@ -679,7 +681,7 @@ public Mono notifyPromptsListChanged() { } private McpServerSession.RequestHandler promptsListRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, // new TypeReference() { @@ -695,7 +697,7 @@ private McpServerSession.RequestHandler promptsList } private McpServerSession.RequestHandler promptsGetRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, new TypeReference() { }); @@ -705,8 +707,8 @@ private McpServerSession.RequestHandler promptsGetReq if (specification == null) { return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); } - - return specification.promptHandler().apply(exchange, promptRequest); + var reqWithContext = new McpServerFeatures.RequestWithContext<>(promptRequest, context); + return specification.promptHandler().apply(exchange, reqWithContext); }; } @@ -730,7 +732,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN } private McpServerSession.RequestHandler setLoggerRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { return Mono.defer(() -> { SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, @@ -749,7 +751,7 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { } private McpServerSession.RequestHandler completionCompleteRequestHandler() { - return (exchange, params) -> { + return (exchange, params, context) -> { McpSchema.CompleteRequest request = parseCompletionParams(params); if (request.ref() == null) { @@ -801,8 +803,8 @@ private McpServerSession.RequestHandler completionComp if (specification == null) { return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); } - - return specification.completionHandler().apply(exchange, request); + var reqWithContext = new McpServerFeatures.RequestWithContext<>(request, context); + return specification.completionHandler().apply(exchange, reqWithContext); }; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d6ec2cc3..083beb60 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -323,7 +323,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi * @throws IllegalArgumentException if tool or handler is null */ public AsyncSpecification tool(McpSchema.Tool tool, - BiFunction, Mono> handler) { + BiFunction, Mono> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -814,7 +814,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * @throws IllegalArgumentException if tool or handler is null */ public SyncSpecification tool(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> handler) { + BiFunction, McpSchema.CallToolResult> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 8311f5d4..122857aa 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -11,6 +11,7 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; +import io.modelcontextprotocol.spec.McpContext; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; @@ -237,7 +238,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * connected client. The second arguments is a map of tool arguments. */ public record AsyncToolSpecification(McpSchema.Tool tool, - BiFunction, Mono> call) { + BiFunction, Mono> call) { static AsyncToolSpecification fromSync(SyncToolSpecification tool) { // FIXME: This is temporary, proper validation should be implemented @@ -245,12 +246,15 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { return null; } return new AsyncToolSpecification(tool.tool(), - (exchange, map) -> Mono - .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), map)) + (exchange, reqWithContext) -> Mono + .fromCallable(() -> tool.call().apply(new McpSyncServerExchange(exchange), reqWithContext)) .subscribeOn(Schedulers.boundedElastic())); } } + public record RequestWithContext(Request request, McpContext mcpContext) { + } + /** * Specification of a resource with its asynchronous handler function. Resources * provide context to AI models by exposing data such as: @@ -279,7 +283,7 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { * {@link io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest}. */ public record AsyncResourceSpecification(McpSchema.Resource resource, - BiFunction> readHandler) { + BiFunction, Mono> readHandler) { static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { // FIXME: This is temporary, proper validation should be implemented @@ -325,7 +329,7 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource) { * {@link io.modelcontextprotocol.spec.McpSchema.GetPromptRequest}. */ public record AsyncPromptSpecification(McpSchema.Prompt prompt, - BiFunction> promptHandler) { + BiFunction, Mono> promptHandler) { static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { // FIXME: This is temporary, proper validation should be implemented @@ -356,7 +360,7 @@ static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt) { * argument is a {@link io.modelcontextprotocol.spec.McpSchema.CompleteRequest}. */ public record AsyncCompletionSpecification(McpSchema.CompleteReference referenceKey, - BiFunction> completionHandler) { + BiFunction, Mono> completionHandler) { /** * Converts a synchronous {@link SyncCompletionSpecification} into an @@ -413,7 +417,7 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet * client. The second arguments is a map of arguments passed to the tool. */ public record SyncToolSpecification(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> call) { + BiFunction, McpSchema.CallToolResult> call) { } /** @@ -445,7 +449,7 @@ public record SyncToolSpecification(McpSchema.Tool tool, * {@link io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest}. */ public record SyncResourceSpecification(McpSchema.Resource resource, - BiFunction readHandler) { + BiFunction, McpSchema.ReadResourceResult> readHandler) { } /** @@ -480,7 +484,7 @@ public record SyncResourceSpecification(McpSchema.Resource resource, * {@link io.modelcontextprotocol.spec.McpSchema.GetPromptRequest}. */ public record SyncPromptSpecification(McpSchema.Prompt prompt, - BiFunction promptHandler) { + BiFunction, McpSchema.GetPromptResult> promptHandler) { } /** @@ -493,7 +497,7 @@ public record SyncPromptSpecification(McpSchema.Prompt prompt, * is a {@link io.modelcontextprotocol.spec.McpSchema.CompleteRequest}. */ public record SyncCompletionSpecification(McpSchema.CompleteReference referenceKey, - BiFunction completionHandler) { + BiFunction, McpSchema.CompleteResult> completionHandler) { } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff47..4534b32a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -13,11 +13,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransport; -import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.*; import io.modelcontextprotocol.util.Assert; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; @@ -103,6 +99,8 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Session factory for creating new sessions */ private McpServerSession.Factory sessionFactory; + private McpContextFactory mcpContextFactory; + /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE * endpoint. @@ -153,6 +151,11 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) { this.sessionFactory = sessionFactory; } + @Override + public void setMcpContextFactory(final McpContextFactory mcpContextFactory) { + this.mcpContextFactory = mcpContextFactory; + } + /** * Broadcasts a notification to all connected clients. * @param method The method name for the notification @@ -219,13 +222,25 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) writer); // Create a new session using the session factory - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(sessionTransport, createContext(request)); this.sessions.put(sessionId, session); // Send initial endpoint event this.sendEvent(writer, ENDPOINT_EVENT_TYPE, this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } + private McpContext createContext(final HttpServletRequest request) { + // create a context form the request + McpContext context; + if (mcpContextFactory != null) { + context = mcpContextFactory.create(request); + } + else { + context = McpContext.empty(); + } + return context; + } + /** * Handles POST requests for client messages. *

    @@ -289,7 +304,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); // Process the message through the session's handle method - session.handle(message).block(); // Block for Servlet compatibility + session.handle(message, createContext(request)).block(); // Block for Servlet + // compatibility response.setStatus(HttpServletResponse.SC_OK); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 819da977..58057419 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -18,12 +18,8 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.*; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransport; -import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,6 +52,11 @@ public class StdioServerTransportProvider implements McpServerTransportProvider private final Sinks.One inboundReady = Sinks.one(); + // Create a single session for the stdio connection + private final StdioMcpSessionTransport transport = new StdioMcpSessionTransport(); + + private McpContextFactory mcpContextFactory; + /** * Creates a new StdioServerTransportProvider with a default ObjectMapper and System * streams. @@ -92,12 +93,22 @@ public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream input @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { - // Create a single session for the stdio connection - var transport = new StdioMcpSessionTransport(); - this.session = sessionFactory.create(transport); + this.session = sessionFactory.create(transport, createContext()); transport.initProcessing(); } + protected McpContext createContext() { + if (this.mcpContextFactory != null) { + return this.mcpContextFactory.create(null); + } + return McpContext.empty(); + } + + @Override + public void setMcpContextFactory(final McpContextFactory mcpContextFactory) { + this.mcpContextFactory = mcpContextFactory; + } + @Override public Mono notifyClients(String method, Object params) { if (this.session == null) { @@ -186,7 +197,7 @@ private void initProcessing() { } private void handleIncomingMessages() { - this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + this.inboundSink.asFlux().flatMap(message -> session.handle(message, createContext())).doOnTerminate(() -> { // The outbound processing will dispose its scheduler upon completion this.outboundSink.tryEmitComplete(); this.inboundScheduler.dispose(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpContext.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpContext.java new file mode 100644 index 00000000..7eed7417 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpContext.java @@ -0,0 +1,29 @@ +package io.modelcontextprotocol.spec; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class McpContext { + + private final Map, Object> contextMap = new ConcurrentHashMap<>(); + + @SuppressWarnings("unchecked") + T read(McpContextKey contextKey, T defaultValue) { + return (T) contextMap.getOrDefault(contextKey, defaultValue); + } + + McpContext write(McpContextKey contextKey, T value) { + contextMap.put(contextKey, value); + return this; + } + + McpContext remove(McpContextKey contextKey) { + contextMap.remove(contextKey); + return this; + } + + public static McpContext empty() { + return new McpContext(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpContextFactory.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpContextFactory.java new file mode 100644 index 00000000..922860c7 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpContextFactory.java @@ -0,0 +1,9 @@ +package io.modelcontextprotocol.spec; + +public interface McpContextFactory { + + default McpContext create(Object contextParam) { + return new McpContext(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpContextKey.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpContextKey.java new file mode 100644 index 00000000..33741606 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpContextKey.java @@ -0,0 +1,10 @@ +package io.modelcontextprotocol.spec; + +/** + * A Marker interface for the Model Context Protocol (MCP) context key. + */ +public interface McpContextKey { + + String getName(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d85..cc827e29 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -48,6 +48,8 @@ public class McpServerSession implements McpSession { private final AtomicReference clientInfo = new AtomicReference<>(); + private final AtomicReference initContext = new AtomicReference<>(); + private static final int STATE_UNINITIALIZED = 0; private static final int STATE_INITIALIZING = 1; @@ -98,10 +100,13 @@ public String getId() { * Spec * @param clientCapabilities the capabilities the connected client provides * @param clientInfo the information about the connected client + * @param mcpContext the context of the request */ - public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + McpContext mcpContext) { this.clientCapabilities.lazySet(clientCapabilities); this.clientInfo.lazySet(clientInfo); + this.initContext.lazySet(mcpContext); } private String generateRequestId() { @@ -151,9 +156,10 @@ public Mono sendNotification(String method, Object params) { * {@link io.modelcontextprotocol.server.McpSyncServer}) via * {@link McpServerSession.Factory} that the server creates. * @param message the incoming JSON-RPC message + * @param mcpContext the context of the request * @return a Mono that completes when the message is processed */ - public Mono handle(McpSchema.JSONRPCMessage message) { + public Mono handle(McpSchema.JSONRPCMessage message, McpContext mcpContext) { return Mono.defer(() -> { // TODO handle errors for communication to without initialization happening // first @@ -170,7 +176,7 @@ public Mono handle(McpSchema.JSONRPCMessage message) { } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { + return handleIncomingRequest(request, mcpContext).onErrorResume(error -> { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); @@ -183,7 +189,7 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { // happening first logger.debug("Received notification: {}", notification); // TODO: in case of error, should the POST request be signalled? - return handleIncomingNotification(notification) + return handleIncomingNotification(notification, mcpContext) .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); } else { @@ -196,9 +202,11 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { /** * Handles an incoming JSON-RPC request by routing it to the appropriate handler. * @param request The incoming JSON-RPC request + * @param mcpContext The context of the request * @return A Mono containing the JSON-RPC response */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request, + McpContext mcpContext) { return Mono.defer(() -> { Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { @@ -208,8 +216,8 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR }); this.state.lazySet(STATE_INITIALIZING); - this.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); - resultMono = this.initRequestHandler.handle(initializeRequest); + this.init(initializeRequest.capabilities(), initializeRequest.clientInfo(), mcpContext); + resultMono = this.initRequestHandler.handle(initializeRequest, mcpContext); } else { // TODO handle errors for communication to this session without @@ -222,7 +230,8 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR error.message(), error.data()))); } - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + resultMono = this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(exchange, request.params(), mcpContext)); } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) @@ -236,14 +245,15 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR /** * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. * @param notification The incoming JSON-RPC notification + * @param mcpContext The context of the request * @return A Mono that completes when the notification is processed */ - private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification, McpContext mcpContext) { return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); - return this.initNotificationHandler.handle(); + return this.initNotificationHandler.handle(mcpContext); } var handler = notificationHandlers.get(notification.method()); @@ -251,7 +261,8 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti logger.error("No handler registered for notification method: {}", notification.method()); return Mono.empty(); } - return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + return this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(exchange, notification.params(), mcpContext)); }); } @@ -280,9 +291,10 @@ public interface InitRequestHandler { /** * Handles the initialization request. * @param initializeRequest the initialization request by the client + * @param mcpContext the context of the request * @return a Mono that will emit the result of the initialization */ - Mono handle(McpSchema.InitializeRequest initializeRequest); + Mono handle(McpSchema.InitializeRequest initializeRequest, McpContext mcpContext); } @@ -293,9 +305,10 @@ public interface InitNotificationHandler { /** * Specifies an action to take upon successful initialization. + * @param mcpContext the context of the request * @return a Mono that will complete when the initialization is acted upon. */ - Mono handle(); + Mono handle(McpContext mcpContext); } @@ -309,9 +322,10 @@ public interface NotificationHandler { * @param exchange the exchange associated with the client that allows calling * back to the connected client or inspecting its capabilities. * @param params the parameters of the notification. + * @param mcpContext the context of the request * @return a Mono that completes once the notification is handled. */ - Mono handle(McpAsyncServerExchange exchange, Object params); + Mono handle(McpAsyncServerExchange exchange, Object params, McpContext mcpContext); } @@ -328,9 +342,10 @@ public interface RequestHandler { * @param exchange the exchange associated with the client that allows calling * back to the connected client or inspecting its capabilities. * @param params the parameters of the request. + * @param mcpContext the context of the request * @return a Mono that will emit the response to the request. */ - Mono handle(McpAsyncServerExchange exchange, Object params); + Mono handle(McpAsyncServerExchange exchange, Object params, McpContext mcpContext); } @@ -344,9 +359,10 @@ public interface Factory { /** * Creates a new 1:1 representation of the client-server interaction. * @param sessionTransport the transport to use for communication with the client. + * @param mcpContext the context of the request * @return a new server session. */ - McpServerSession create(McpServerTransport sessionTransport); + McpServerSession create(McpServerTransport sessionTransport, McpContext mcpContext); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index 5fdbd7ab..e7886a18 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -39,6 +39,12 @@ public interface McpServerTransportProvider { */ void setSessionFactory(McpServerSession.Factory sessionFactory); + /** + * Set the MCP context factory that will be used to create the MCP context. This + * method must be called before `setSessionFactory`. + */ + void setMcpContextFactory(McpContextFactory mcpContextFactory); + /** * Sends a notification to all connected clients. * @param method the name of the notification method to be called on the clients diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 20a8c0cf..8f43ae0e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -17,10 +17,8 @@ import java.util.Map; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.*; import io.modelcontextprotocol.spec.McpServerSession.Factory; -import io.modelcontextprotocol.spec.McpServerTransportProvider; import reactor.core.publisher.Mono; /** @@ -32,6 +30,8 @@ public class MockMcpServerTransportProvider implements McpServerTransportProvide private final MockMcpServerTransport transport; + private McpContextFactory mcpContextFactory; + public MockMcpServerTransportProvider(MockMcpServerTransport transport) { this.transport = transport; } @@ -42,8 +42,13 @@ public MockMcpServerTransport getTransport() { @Override public void setSessionFactory(Factory sessionFactory) { + var ctx = mcpContextFactory != null ? mcpContextFactory.create(null) : McpContext.empty(); + session = sessionFactory.create(transport, ctx); + } - session = sessionFactory.create(transport); + @Override + public void setMcpContextFactory(McpContextFactory mcpContextFactory) { + this.mcpContextFactory = mcpContextFactory; } @Override @@ -57,7 +62,7 @@ public Mono closeGracefully() { } public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { - session.handle(message).subscribe(); + session.handle(message, McpContext.empty()).subscribe(); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 14987b5a..cc4386fc 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -15,10 +15,7 @@ import java.util.concurrent.atomic.AtomicReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.*; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -71,7 +68,7 @@ void setUp() { sessionFactory = mock(McpServerSession.Factory.class); // Configure mock behavior - when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); + when(sessionFactory.create(any(McpServerTransport.class), any(McpContext.class))).thenReturn(mockSession); when(mockSession.closeGracefully()).thenReturn(Mono.empty()); when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); @@ -110,9 +107,9 @@ void shouldHandleIncomingMessages() throws Exception { AtomicReference capturedMessage = new AtomicReference<>(); CountDownLatch messageLatch = new CountDownLatch(1); - McpServerSession.Factory realSessionFactory = transport -> { + McpServerSession.Factory realSessionFactory = (transport, ctx) -> { McpServerSession session = mock(McpServerSession.class); - when(session.handle(any())).thenAnswer(invocation -> { + when(session.handle(any(), any(McpContext.class))).thenAnswer(invocation -> { capturedMessage.set(invocation.getArgument(0)); messageLatch.countDown(); return Mono.empty();