From d00c2156c1d4a4e33cae0b325aa40c24e12b6d83 Mon Sep 17 00:00:00 2001 From: Alone88 Date: Tue, 8 Apr 2025 11:39:12 +0800 Subject: [PATCH 1/4] feat: add custom sse router --- examples/custom_sse_pattern/main.go | 57 ++++++++++++++ server/sse.go | 118 +++++++++++++++++++++++++--- 2 files changed, 163 insertions(+), 12 deletions(-) create mode 100644 examples/custom_sse_pattern/main.go diff --git a/examples/custom_sse_pattern/main.go b/examples/custom_sse_pattern/main.go new file mode 100644 index 00000000..58f5788c --- /dev/null +++ b/examples/custom_sse_pattern/main.go @@ -0,0 +1,57 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// Custom context function for SSE connections +func customContextFunc(ctx context.Context, r *http.Request) context.Context { + params := server.GetRouteParams(ctx) + log.Printf("SSE Connection Established - Route Parameters: %+v", params) + log.Printf("Request Path: %s", r.URL.Path) + return ctx +} + +// Message handler for simulating message sending +func messageHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Get channel parameter from context + channel := server.GetRouteParam(ctx, "channel") + log.Printf("Processing Message - Channel Parameter: %s", channel) + + if channel == "" { + return mcp.NewToolResultText("Failed to get channel parameter"), nil + } + + message := fmt.Sprintf("Message sent to channel: %s", channel) + return mcp.NewToolResultText(message), nil +} + +func main() { + // Create MCP Server + mcpServer := server.NewMCPServer("test-server", "1.0.0") + + // Register test tool + mcpServer.AddTool(mcp.NewTool("send_message"), messageHandler) + + // Create SSE Server with custom route pattern + sseServer := server.NewSSEServer(mcpServer, + server.WithBaseURL("http://localhost:8080"), + server.WithSSEPattern("/:channel/sse"), + server.WithSSEContextFunc(customContextFunc), + ) + + // Start server + log.Printf("Server started on port :8080") + log.Printf("Test URL: http://localhost:8080/test/sse") + log.Printf("Test URL: http://localhost:8080/news/sse") + + if err := sseServer.Start(":8080"); err != nil { + log.Fatalf("Server error: %v", err) + } +} diff --git a/server/sse.go b/server/sse.go index 6e6a13fe..d27173b5 100644 --- a/server/sse.go +++ b/server/sse.go @@ -24,6 +24,7 @@ type sseSession struct { sessionID string notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool + routeParams RouteParams // Store route parameters in session } // SSEContextFunc is a function that takes an existing context and the current @@ -31,6 +32,28 @@ type sseSession struct { // content. This can be used to inject context values from headers, for example. type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context +// RouteParamsKey is the key type for storing route parameters in context +type RouteParamsKey struct{} + +// RouteParams stores path parameters +type RouteParams map[string]string + +// GetRouteParam retrieves a route parameter from context +func GetRouteParam(ctx context.Context, key string) string { + if params, ok := ctx.Value(RouteParamsKey{}).(RouteParams); ok { + return params[key] + } + return "" +} + +// GetRouteParams retrieves all route parameters from context +func GetRouteParams(ctx context.Context) RouteParams { + if params, ok := ctx.Value(RouteParamsKey{}).(RouteParams); ok { + return params + } + return RouteParams{} +} + func (s *sseSession) SessionID() string { return s.sessionID } @@ -58,6 +81,7 @@ type SSEServer struct { messageEndpoint string useFullURLForMessageEndpoint bool sseEndpoint string + ssePattern string sessions sync.Map srv *http.Server contextFunc SSEContextFunc @@ -123,6 +147,13 @@ func WithSSEEndpoint(endpoint string) SSEOption { } } +// WithSSEPattern sets the SSE endpoint pattern with route parameters +func WithSSEPattern(pattern string) SSEOption { + return func(s *SSEServer) { + s.ssePattern = pattern + } +} + // WithHTTPServer sets the HTTP server instance func WithHTTPServer(srv *http.Server) SSEOption { return func(s *SSEServer) { @@ -130,7 +161,7 @@ func WithHTTPServer(srv *http.Server) SSEOption { } } -// WithContextFunc sets a function that will be called to customise the context +// WithSSEContextFunc sets a function that will be called to customise the context // to the server using the incoming request. func WithSSEContextFunc(fn SSEContextFunc) SSEOption { return func(s *SSEServer) { @@ -222,12 +253,21 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { eventQueue: make(chan string, 100), // Buffer for events sessionID: sessionID, notificationChannel: make(chan mcp.JSONRPCNotification, 100), + routeParams: GetRouteParams(r.Context()), // Store route parameters from context } s.sessions.Store(sessionID, session) defer s.sessions.Delete(sessionID) - if err := s.server.RegisterSession(r.Context(), session); err != nil { + // Create base context with session + ctx := s.server.WithContext(r.Context(), session) + + // Apply custom context function if set + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + if err := s.server.RegisterSession(ctx, session); err != nil { http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError) return } @@ -249,7 +289,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } case <-session.done: return - case <-r.Context().Done(): + case <-ctx.Done(): return } } @@ -266,7 +306,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { // Write the event to the response fmt.Fprint(w, event) flusher.Flush() - case <-r.Context().Done(): + case <-ctx.Done(): close(session.done) return } @@ -304,8 +344,15 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { } session := sessionI.(*sseSession) - // Set the client context before handling the message + // Create base context with session ctx := s.server.WithContext(r.Context(), session) + + // Add stored route parameters to context + if len(session.routeParams) > 0 { + ctx = context.WithValue(ctx, RouteParamsKey{}, session.routeParams) + } + + // Apply custom context function if set if s.contextFunc != nil { ctx = s.contextFunc(ctx, r) } @@ -317,7 +364,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { return } - // Process message through MCPServer + // Process message through MCPServer with the context containing route parameters response := s.server.HandleMessage(ctx, rawMessage) // Only send response if there is one (not for notifications) @@ -384,6 +431,7 @@ func (s *SSEServer) SendEventToSession( return fmt.Errorf("event queue full") } } + func (s *SSEServer) GetUrlPath(input string) (string, error) { parse, err := url.Parse(input) if err != nil { @@ -395,6 +443,7 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) { func (s *SSEServer) CompleteSseEndpoint() string { return s.baseURL + s.basePath + s.sseEndpoint } + func (s *SSEServer) CompleteSsePath() string { path, err := s.GetUrlPath(s.CompleteSseEndpoint()) if err != nil { @@ -406,6 +455,7 @@ func (s *SSEServer) CompleteSsePath() string { func (s *SSEServer) CompleteMessageEndpoint() string { return s.baseURL + s.basePath + s.messageEndpoint } + func (s *SSEServer) CompleteMessagePath() string { path, err := s.GetUrlPath(s.CompleteMessageEndpoint()) if err != nil { @@ -417,17 +467,61 @@ func (s *SSEServer) CompleteMessagePath() string { // ServeHTTP implements the http.Handler interface. func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := r.URL.Path - // Use exact path matching rather than Contains - ssePath := s.CompleteSsePath() - if ssePath != "" && path == ssePath { - s.handleSSE(w, r) - return - } messagePath := s.CompleteMessagePath() + + // Handle message endpoint if messagePath != "" && path == messagePath { s.handleMessage(w, r) return } + // Handle SSE endpoint with route parameters + if s.ssePattern != "" { + // Try pattern matching if pattern is set + fullPattern := s.basePath + s.ssePattern + matches, params := matchPath(fullPattern, path) + if matches { + // Create new context with route parameters + ctx := context.WithValue(r.Context(), RouteParamsKey{}, params) + s.handleSSE(w, r.WithContext(ctx)) + return + } + // If pattern is set but doesn't match, return 404 + http.NotFound(w, r) + return + } + + // If no pattern is set, use the default SSE endpoint + ssePath := s.CompleteSsePath() + if ssePath != "" && path == ssePath { + s.handleSSE(w, r) + return + } + http.NotFound(w, r) } + +// matchPath checks if the given path matches the pattern and extracts parameters +// pattern format: /user/:id/profile/:type +func matchPath(pattern, path string) (bool, RouteParams) { + patternParts := strings.Split(strings.Trim(pattern, "/"), "/") + pathParts := strings.Split(strings.Trim(path, "/"), "/") + + if len(patternParts) != len(pathParts) { + return false, nil + } + + params := make(RouteParams) + for i, part := range patternParts { + if strings.HasPrefix(part, ":") { + // This is a parameter + paramName := strings.TrimPrefix(part, ":") + params[paramName] = pathParts[i] + } else if part != pathParts[i] { + // Static part doesn't match + return false, nil + } + } + + return true, params +} From b19b21e55eb09a51132cd6057bdf9041197b7fc8 Mon Sep 17 00:00:00 2001 From: Alone88 Date: Tue, 8 Apr 2025 11:39:40 +0800 Subject: [PATCH 2/4] feat: add custom sse router test --- server/sse_test.go | 133 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/server/sse_test.go b/server/sse_test.go index 111c5845..f46a9077 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -739,4 +739,137 @@ func TestSSEServer(t *testing.T) { } } }) + + t.Run("Can handle custom route parameters", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0", + WithResourceCapabilities(true, true), + ) + + // Add a test tool that uses route parameters + mcpServer.AddTool(mcp.NewTool("test_route"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + channel := GetRouteParam(ctx, "channel") + if channel == "" { + return nil, fmt.Errorf("channel parameter not found") + } + return mcp.NewToolResultText(fmt.Sprintf("Channel: %s", channel)), nil + }) + + // Create SSE server with custom route pattern + testServer := NewTestServer(mcpServer, + WithSSEPattern("/:channel/sse"), + WithSSEContextFunc(func(ctx context.Context, r *http.Request) context.Context { + return ctx + }), + ) + defer testServer.Close() + + // Connect to SSE endpoint with channel parameter + sseResp, err := http.Get(fmt.Sprintf("%s/test-channel/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event + buf := make([]byte, 1024) + n, err := sseResp.Body.Read(buf) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + + endpointEvent := string(buf[:n]) + if !strings.Contains(endpointEvent, "event: endpoint") { + t.Fatalf("Expected endpoint event, got: %s", endpointEvent) + } + + // Extract message endpoint URL + messageURL := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + + // Send initialize request + initRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]interface{}{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + + requestBody, err := json.Marshal(initRequest) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + resp, err := http.Post( + messageURL, + "application/json", + bytes.NewBuffer(requestBody), + ) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + + // Call the test tool + toolRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "test_route", + }, + } + + requestBody, err = json.Marshal(toolRequest) + if err != nil { + t.Fatalf("Failed to marshal tool request: %v", err) + } + + resp, err = http.Post( + messageURL, + "application/json", + bytes.NewBuffer(requestBody), + ) + if err != nil { + t.Fatalf("Failed to send tool request: %v", err) + } + defer resp.Body.Close() + + // Verify response + var response map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + result, ok := response["result"].(map[string]interface{}) + if !ok { + t.Fatalf("Expected result object, got: %v", response) + } + + content, ok := result["content"].([]interface{}) + if !ok || len(content) == 0 { + t.Fatalf("Expected content array, got: %v", result) + } + + textObj, ok := content[0].(map[string]interface{}) + if !ok { + t.Fatalf("Expected text object, got: %v", content[0]) + } + + text, ok := textObj["text"].(string) + if !ok { + t.Fatalf("Expected text string, got: %v", textObj["text"]) + } + + expectedText := "Channel: test-channel" + if text != expectedText { + t.Errorf("Expected text %q, got %q", expectedText, text) + } + }) } From b14de59e2e9f006b3f1da2c0776c4e54d74f390e Mon Sep 17 00:00:00 2001 From: Alone88 Date: Fri, 11 Apr 2025 23:56:47 +0800 Subject: [PATCH 3/4] Update server/sse.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- server/sse.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/sse.go b/server/sse.go index da56d9c2..2bfc4215 100644 --- a/server/sse.go +++ b/server/sse.go @@ -85,7 +85,7 @@ type SSEServer struct { ssePattern string sessions sync.Map srv *http.Server - contextFunc SSEContextFun + contextFunc SSEContextFunc keepAlive bool keepAliveInterval time.Duration } From 9266ee5576f4c4c15f01858569f9a3cf7c2f86b4 Mon Sep 17 00:00:00 2001 From: Alone88 Date: Sun, 27 Apr 2025 01:56:02 +0000 Subject: [PATCH 4/4] fix: go fmt style --- server/sse.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/server/sse.go b/server/sse.go index c129d53e..1d740088 100644 --- a/server/sse.go +++ b/server/sse.go @@ -77,18 +77,18 @@ var _ ClientSession = (*sseSession)(nil) // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. type SSEServer struct { - server *MCPServer - baseURL string - basePath string - useFullURLForMessageEndpoint bool - messageEndpoint string - sseEndpoint string - ssePattern string - sessions sync.Map - srv *http.Server - contextFunc SSEContextFunc - keepAlive bool - keepAliveInterval time.Duration + server *MCPServer + baseURL string + basePath string + useFullURLForMessageEndpoint bool + messageEndpoint string + sseEndpoint string + ssePattern string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc + keepAlive bool + keepAliveInterval time.Duration mu sync.RWMutex }