Skip to content

feat: Add request with context support. #226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
package io.modelcontextprotocol.server.transport;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.*;
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -100,6 +95,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv

private McpServerSession.Factory sessionFactory;

private McpContextFactory mcpContextFactory;

/**
* Map of active client sessions, keyed by session ID.
*/
Expand Down Expand Up @@ -169,6 +166,11 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) {
this.sessionFactory = sessionFactory;
}

@Override
public void setMcpContextFactory(final McpContextFactory mcpContextFactory) {
this.mcpContextFactory = mcpContextFactory;
}

/**
* Broadcasts a JSON-RPC message to all connected clients through their SSE
* connections. The message is serialized to JSON and sent as a server-sent event to
Expand Down Expand Up @@ -261,7 +263,7 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
.body(Flux.<ServerSentEvent<?>>create(sink -> {
WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink);

McpServerSession session = sessionFactory.create(sessionTransport);
McpServerSession session = sessionFactory.create(sessionTransport, createContext(request));
String sessionId = session.getId();

logger.debug("Created new SSE connection for session: {}", sessionId);
Expand All @@ -280,6 +282,18 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
}), ServerSentEvent.class);
}

private McpContext createContext(final ServerRequest request) {
// create a context form the request
McpContext context;
if (mcpContextFactory != null) {
context = mcpContextFactory.create(request);
}
else {
context = McpContext.empty();
}
return context;
}

/**
* Handles incoming JSON-RPC messages from clients. Deserializes the message and
* processes it through the configured message handler.
Expand Down Expand Up @@ -314,14 +328,16 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
return request.bodyToMono(String.class).flatMap(body -> {
try {
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
return session.handle(message).flatMap(response -> ServerResponse.ok().build()).onErrorResume(error -> {
logger.error("Error processing message: {}", error.getMessage());
// TODO: instead of signalling the error, just respond with 200 OK
// - the error is signalled on the SSE connection
// return ServerResponse.ok().build();
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
.bodyValue(new McpError(error.getMessage()));
});
return session.handle(message, createContext(request))
.flatMap(response -> ServerResponse.ok().build())
.onErrorResume(error -> {
logger.error("Error processing message: {}", error.getMessage());
// TODO: instead of signalling the error, just respond with 200 OK
// - the error is signalled on the SSE connection
// return ServerResponse.ok().build();
return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
.bodyValue(new McpError(error.getMessage()));
});
}
catch (IllegalArgumentException | IOException e) {
logger.error("Failed to deserialize message: {}", e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import io.modelcontextprotocol.server.TestUtil;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
import io.modelcontextprotocol.spec.McpContext;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.*;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -767,9 +767,11 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
));

AtomicReference<CompleteRequest> samplingRequest = new AtomicReference<>();
BiFunction<McpSyncServerExchange, CompleteRequest, CompleteResult> completionHandler = (mcpSyncServerExchange,
request) -> {
samplingRequest.set(request);
AtomicReference<McpContext> mcpContext = new AtomicReference<>();
BiFunction<McpSyncServerExchange, McpServerFeatures.RequestWithContext<CompleteRequest>, CompleteResult> completionHandler = (
mcpSyncServerExchange, reqWithContext) -> {
samplingRequest.set(reqWithContext.request());
mcpContext.set(reqWithContext.mcpContext());
return completionResponse;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,12 @@

import java.io.IOException;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.*;
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -97,6 +92,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi

private McpServerSession.Factory sessionFactory;

private McpContextFactory mcpContextFactory;

/**
* Map of active client sessions, keyed by session ID.
*/
Expand Down Expand Up @@ -169,6 +166,11 @@ public void setSessionFactory(McpServerSession.Factory sessionFactory) {
this.sessionFactory = sessionFactory;
}

@Override
public void setMcpContextFactory(McpContextFactory mcpContextFactory) {
this.mcpContextFactory = mcpContextFactory;
}

/**
* Broadcasts a notification to all connected clients through their SSE connections.
* The message is serialized to JSON and sent as an SSE event with type "message". If
Expand Down Expand Up @@ -263,7 +265,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
});

WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder);
McpServerSession session = sessionFactory.create(sessionTransport);
McpServerSession session = sessionFactory.create(sessionTransport, createContext(request));
this.sessions.put(sessionId, session);

try {
Expand All @@ -284,6 +286,18 @@ private ServerResponse handleSseConnection(ServerRequest request) {
}
}

private McpContext createContext(final ServerRequest request) {
// create a context form the request
McpContext context;
if (mcpContextFactory != null) {
context = mcpContextFactory.create(request);
}
else {
context = McpContext.empty();
}
return context;
}

/**
* Handles incoming JSON-RPC messages from clients. This method:
* <ul>
Expand Down Expand Up @@ -316,7 +330,8 @@ private ServerResponse handleMessage(ServerRequest request) {
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);

// Process the message through the session's handle method
session.handle(message).block(); // Block for WebMVC compatibility
session.handle(message, createContext(request)).block(); // Block for WebMVC
// compatibility

return ServerResponse.ok().build();
}
Expand Down
Loading