Skip to content

feat(sse): Add support for dynamic route parameters in SSE server #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions examples/custom_sse_pattern/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package main

import (
"context"
"fmt"
"log"
"net/http"

"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)

// Custom context function for SSE connections
func customContextFunc(ctx context.Context, r *http.Request) context.Context {
params := server.GetRouteParams(ctx)
log.Printf("SSE Connection Established - Route Parameters: %+v", params)
log.Printf("Request Path: %s", r.URL.Path)
return ctx
}

// Message handler for simulating message sending
func messageHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Get channel parameter from context
channel := server.GetRouteParam(ctx, "channel")
log.Printf("Processing Message - Channel Parameter: %s", channel)

if channel == "" {
return mcp.NewToolResultText("Failed to get channel parameter"), nil
}

message := fmt.Sprintf("Message sent to channel: %s", channel)
return mcp.NewToolResultText(message), nil
}

func main() {
// Create MCP Server
mcpServer := server.NewMCPServer("test-server", "1.0.0")

// Register test tool
mcpServer.AddTool(mcp.NewTool("send_message"), messageHandler)

// Create SSE Server with custom route pattern
sseServer := server.NewSSEServer(mcpServer,
server.WithBaseURL("http://localhost:8080"),
server.WithSSEPattern("/:channel/sse"),
server.WithSSEContextFunc(customContextFunc),
)

// Start server
log.Printf("Server started on port :8080")
log.Printf("Test URL: http://localhost:8080/test/sse")
log.Printf("Test URL: http://localhost:8080/news/sse")

if err := sseServer.Start(":8080"); err != nil {
log.Fatalf("Server error: %v", err)
}
}
121 changes: 105 additions & 16 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,36 @@ type sseSession struct {
requestID atomic.Int64
notificationChannel chan mcp.JSONRPCNotification
initialized atomic.Bool
routeParams RouteParams // Store route parameters in session
}

// SSEContextFunc is a function that takes an existing context and the current
// request and returns a potentially modified context based on the request
// content. This can be used to inject context values from headers, for example.
type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context

// RouteParamsKey is the key type for storing route parameters in context
type RouteParamsKey struct{}

// RouteParams stores path parameters
type RouteParams map[string]string

// GetRouteParam retrieves a route parameter from context
func GetRouteParam(ctx context.Context, key string) string {
if params, ok := ctx.Value(RouteParamsKey{}).(RouteParams); ok {
return params[key]
}
return ""
}

// GetRouteParams retrieves all route parameters from context
func GetRouteParams(ctx context.Context) RouteParams {
if params, ok := ctx.Value(RouteParamsKey{}).(RouteParams); ok {
return params
}
return RouteParams{}
}

func (s *sseSession) SessionID() string {
return s.sessionID
}
Expand Down Expand Up @@ -60,12 +83,12 @@ type SSEServer struct {
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
ssePattern string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc

keepAlive bool
keepAliveInterval time.Duration
keepAlive bool
keepAliveInterval time.Duration

mu sync.RWMutex
}
Expand Down Expand Up @@ -130,6 +153,13 @@ func WithSSEEndpoint(endpoint string) SSEOption {
}
}

// WithSSEPattern sets the SSE endpoint pattern with route parameters
func WithSSEPattern(pattern string) SSEOption {
return func(s *SSEServer) {
s.ssePattern = pattern
}
}

// WithHTTPServer sets the HTTP server instance
func WithHTTPServer(srv *http.Server) SSEOption {
return func(s *SSEServer) {
Expand All @@ -150,8 +180,7 @@ func WithKeepAlive(keepAlive bool) SSEOption {
}
}

// WithContextFunc sets a function that will be called to customise the context
// to the server using the incoming request.
// WithSSEContextFunc sets a function that will be called to customise the context
func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
return func(s *SSEServer) {
s.contextFunc = fn
Expand Down Expand Up @@ -247,12 +276,21 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
eventQueue: make(chan string, 100), // Buffer for events
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
routeParams: GetRouteParams(r.Context()), // Store route parameters from context
}

s.sessions.Store(sessionID, session)
defer s.sessions.Delete(sessionID)

if err := s.server.RegisterSession(r.Context(), session); err != nil {
// Create base context with session
ctx := s.server.WithContext(r.Context(), session)

// Apply custom context function if set
if s.contextFunc != nil {
ctx = s.contextFunc(ctx, r)
}

if err := s.server.RegisterSession(ctx, session); err != nil {
http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError)
return
}
Expand All @@ -274,7 +312,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
}
case <-session.done:
return
case <-r.Context().Done():
case <-ctx.Done():
return
}
}
Expand Down Expand Up @@ -318,7 +356,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
// Write the event to the response
fmt.Fprint(w, event)
flusher.Flush()
case <-r.Context().Done():
case <-ctx.Done():
close(session.done)
return
case <-session.done:
Expand Down Expand Up @@ -357,8 +395,15 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
}
session := sessionI.(*sseSession)

// Set the client context before handling the message
// Create base context with session
ctx := s.server.WithContext(r.Context(), session)

// Add stored route parameters to context
if len(session.routeParams) > 0 {
ctx = context.WithValue(ctx, RouteParamsKey{}, session.routeParams)
}

// Apply custom context function if set
if s.contextFunc != nil {
ctx = s.contextFunc(ctx, r)
}
Expand All @@ -370,7 +415,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
return
}

// Process message through MCPServer
// Process message through MCPServer with the context containing route parameters
response := s.server.HandleMessage(ctx, rawMessage)

// Only send response if there is one (not for notifications)
Expand Down Expand Up @@ -473,17 +518,61 @@ func (s *SSEServer) CompleteMessagePath() string {
// ServeHTTP implements the http.Handler interface.
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// Use exact path matching rather than Contains
ssePath := s.CompleteSsePath()
if ssePath != "" && path == ssePath {
s.handleSSE(w, r)
return
}
messagePath := s.CompleteMessagePath()

// Handle message endpoint
if messagePath != "" && path == messagePath {
s.handleMessage(w, r)
return
}

// Handle SSE endpoint with route parameters
if s.ssePattern != "" {
// Try pattern matching if pattern is set
fullPattern := s.basePath + s.ssePattern
matches, params := matchPath(fullPattern, path)
if matches {
// Create new context with route parameters
ctx := context.WithValue(r.Context(), RouteParamsKey{}, params)
s.handleSSE(w, r.WithContext(ctx))
return
}
// If pattern is set but doesn't match, return 404
http.NotFound(w, r)
return
}

// If no pattern is set, use the default SSE endpoint
ssePath := s.CompleteSsePath()
if ssePath != "" && path == ssePath {
s.handleSSE(w, r)
return
}

http.NotFound(w, r)
}

// matchPath checks if the given path matches the pattern and extracts parameters
// pattern format: /user/:id/profile/:type
func matchPath(pattern, path string) (bool, RouteParams) {
patternParts := strings.Split(strings.Trim(pattern, "/"), "/")
pathParts := strings.Split(strings.Trim(path, "/"), "/")

if len(patternParts) != len(pathParts) {
return false, nil
}

params := make(RouteParams)
for i, part := range patternParts {
if strings.HasPrefix(part, ":") {
// This is a parameter
paramName := strings.TrimPrefix(part, ":")
params[paramName] = pathParts[i]
} else if part != pathParts[i] {
// Static part doesn't match
return false, nil
}
}

return true, params
}
Loading