diff --git a/client/call.go b/client/call.go index 10db8f1..3d42099 100644 --- a/client/call.go +++ b/client/call.go @@ -3,6 +3,7 @@ package client import ( "context" "encoding/json" + "errors" "fmt" "strconv" "sync/atomic" @@ -220,7 +221,7 @@ func (client *Client) sendNotification4Initialized(ctx context.Context) error { // Responsible for request and response assembly func (client *Client) callServer(ctx context.Context, method protocol.Method, params protocol.ClientRequest) (json.RawMessage, error) { if !client.ready.Load() && (method != protocol.Initialize && method != protocol.Ping) { - return nil, fmt.Errorf("client not ready") + return nil, errors.New("callServer: client not ready") } requestID := strconv.FormatInt(atomic.AddInt64(&client.requestID, 1), 10) diff --git a/client/client.go b/client/client.go index 1a4c053..5a2262b 100644 --- a/client/client.go +++ b/client/client.go @@ -3,6 +3,7 @@ package client import ( "context" "fmt" + "sync" "time" cmap "github.com/orcaman/concurrent-map/v2" @@ -47,7 +48,8 @@ type Client struct { requestID int64 - ready *pkg.AtomicBool + ready *pkg.AtomicBool + initializationMu sync.Mutex clientInfo *protocol.Implementation clientCapabilities *protocol.ClientCapabilities diff --git a/client/send.go b/client/send.go index 5814aab..53be744 100644 --- a/client/send.go +++ b/client/send.go @@ -3,8 +3,10 @@ package client import ( "context" "encoding/json" + "errors" "fmt" + "github.com/ThinkInAIXYZ/go-mcp/pkg" "github.com/ThinkInAIXYZ/go-mcp/protocol" ) @@ -20,8 +22,13 @@ func (client *Client) sendMsgWithRequest(ctx context.Context, requestID protocol return err } - if err := client.transport.Send(ctx, message); err != nil { - return fmt.Errorf("sendRequest: transport send: %w", err) + if err = client.transport.Send(ctx, message); err != nil { + if !errors.Is(err, pkg.ErrSessionClosed) { + return fmt.Errorf("sendRequest: transport send: %w", err) + } + if err = client.againInitialization(ctx); err != nil { + return err + } } return nil } @@ -38,7 +45,7 @@ func (client *Client) sendMsgWithResponse(ctx context.Context, requestID protoco return err } - if err := client.transport.Send(ctx, message); err != nil { + if err = client.transport.Send(ctx, message); err != nil { return fmt.Errorf("sendResponse: transport send: %w", err) } return nil @@ -52,7 +59,7 @@ func (client *Client) sendMsgWithNotification(ctx context.Context, method protoc return err } - if err := client.transport.Send(ctx, message); err != nil { + if err = client.transport.Send(ctx, message); err != nil { return fmt.Errorf("sendNotification: transport send: %w", err) } return nil @@ -70,8 +77,25 @@ func (client *Client) sendMsgWithError(ctx context.Context, requestID protocol.R return err } - if err := client.transport.Send(ctx, message); err != nil { + if err = client.transport.Send(ctx, message); err != nil { return fmt.Errorf("sendResponse: transport send: %w", err) } return nil } + +func (client *Client) againInitialization(ctx context.Context) error { + client.ready.Store(false) + + client.initializationMu.Lock() + defer client.initializationMu.Unlock() + + if client.ready.Load() { + return nil + } + + if _, err := client.initialization(ctx, protocol.NewInitializeRequest(*client.clientInfo, *client.clientCapabilities)); err != nil { + return err + } + client.ready.Store(true) + return nil +} diff --git a/examples/current_time_server/main.go b/examples/current_time_server/main.go index 7a54f7d..3a03737 100644 --- a/examples/current_time_server/main.go +++ b/examples/current_time_server/main.go @@ -69,21 +69,27 @@ func getTransport() (t transport.ServerTransport) { addr = "127.0.0.1:8080" ) - flag.StringVar(&mode, "transport", "stdio", "The transport to use, should be \"stdio\" or \"sse\"") + flag.StringVar(&mode, "transport", "stdio", "The transport to use, should be \"stdio\" or \"sse\" or \"streamable_http\"") flag.Parse() - if mode == "stdio" { + switch mode { + case "stdio": log.Println("start current time mcp server with stdio transport") t = transport.NewStdioServerTransport() - } else { + case "sse": log.Printf("start current time mcp server with sse transport, listen %s", addr) t, _ = transport.NewSSEServerTransport(addr) + case "streamable_http": + log.Printf("start current time mcp server with streamable_http transport, listen %s", addr) + t = transport.NewStreamableHTTPServerTransport(addr) + default: + panic(fmt.Errorf("unknown mode: %s", mode)) } return t } -func currentTime(request *protocol.CallToolRequest) (*protocol.CallToolResult, error) { +func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) { req := new(currentTimeReq) if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil { return nil, err diff --git a/examples/everything/main.go b/examples/everything/main.go index 4e0b864..6c3f38a 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -63,25 +63,32 @@ func main() { } func getTransport() (t transport.ServerTransport) { - mode := "" - port := "" - flag.StringVar(&mode, "transport", "stdio", "The transport to use, should be \"stdio\" or \"sse\"") + var mode, port, stateMode string + flag.StringVar(&mode, "transport", "streamable_http", "The transport to use, should be \"stdio\" or \"sse\" or \"streamable_http\"") flag.StringVar(&port, "port", "8080", "sse server address") + flag.StringVar(&stateMode, "state_mode", "stateless", "streamable_http server state mode, should be \"stateless\" or \"stateful\"") flag.Parse() - if mode == "stdio" { + switch mode { + case "stdio": log.Println("start current time mcp server with stdio transport") t = transport.NewStdioServerTransport() - } else { + case "sse": addr := fmt.Sprintf("127.0.0.1:%s", port) log.Printf("start current time mcp server with sse transport, listen %s", addr) t, _ = transport.NewSSEServerTransport(addr) + case "streamable_http": + addr := fmt.Sprintf("127.0.0.1:%s", port) + log.Printf("start current time mcp server with streamable_http transport, listen %s", addr) + t = transport.NewStreamableHTTPServerTransport(addr, transport.WithStreamableHTTPServerTransportOptionStateMode(transport.StateMode(stateMode))) + default: + panic(fmt.Errorf("unknown mode: %s", mode)) } return t } -func currentTime(request *protocol.CallToolRequest) (*protocol.CallToolResult, error) { +func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) { req := new(currentTimeReq) if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil { return nil, err diff --git a/examples/http_handler/main.go b/examples/http_handler/main.go index 199fa5a..ae054d2 100644 --- a/examples/http_handler/main.go +++ b/examples/http_handler/main.go @@ -93,7 +93,7 @@ func main() { } } -func currentTime(request *protocol.CallToolRequest) (*protocol.CallToolResult, error) { +func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) { req := new(currentTimeReq) if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil { return nil, err diff --git a/pkg/atomic_bool.go b/pkg/atomic.go similarity index 51% rename from pkg/atomic_bool.go rename to pkg/atomic.go index 6401d2f..b01cd39 100644 --- a/pkg/atomic_bool.go +++ b/pkg/atomic.go @@ -19,3 +19,21 @@ func (b *AtomicBool) Store(value bool) { func (b *AtomicBool) Load() bool { return b.b.Load().(bool) } + +type AtomicString struct { + b atomic.Value +} + +func NewAtomicString() *AtomicString { + b := &AtomicString{} + b.b.Store("") + return b +} + +func (b *AtomicString) Store(value string) { + b.b.Store(value) +} + +func (b *AtomicString) Load() string { + return b.b.Load().(string) +} diff --git a/pkg/context.go b/pkg/context.go new file mode 100644 index 0000000..009bcc2 --- /dev/null +++ b/pkg/context.go @@ -0,0 +1,26 @@ +package pkg + +import ( + "context" + "time" +) + +type CancelShieldContext struct { + context.Context +} + +func NewCancelShieldContext(ctx context.Context) context.Context { + return CancelShieldContext{Context: ctx} +} + +func (v CancelShieldContext) Deadline() (deadline time.Time, ok bool) { + return +} + +func (v CancelShieldContext) Done() <-chan struct{} { + return nil +} + +func (v CancelShieldContext) Err() error { + return nil +} diff --git a/pkg/errors.go b/pkg/errors.go index 3fe9eec..392d1f3 100644 --- a/pkg/errors.go +++ b/pkg/errors.go @@ -14,6 +14,7 @@ var ( ErrJSONUnmarshal = errors.New("json unmarshal error") ErrSessionHasNotInitialized = errors.New("the session has not been initialized") ErrLackSession = errors.New("lack session") + ErrSessionClosed = errors.New("session closed") ErrSendEOF = errors.New("send EOF") ) diff --git a/protocol/schema_generate_test.go b/protocol/schema_generate_test.go index 14ca6f2..1046b22 100644 --- a/protocol/schema_generate_test.go +++ b/protocol/schema_generate_test.go @@ -283,7 +283,7 @@ func compareInputSchema(a, b *InputSchema) bool { return false } - // 比较Required字段 + // compare required field if len(a.Required) != len(b.Required) { return false } @@ -331,12 +331,11 @@ func compareProperty(a, b *Property) bool { return false } - // 比较Items字段 + // compare Items field if !compareProperty(a.Items, b.Items) { return false } - - // 比较Properties字段 + // compare Properties field if len(a.Properties) != len(b.Properties) { return false } @@ -350,7 +349,7 @@ func compareProperty(a, b *Property) bool { } } - // 比较Required字段 + // compare Required field比 if len(a.Required) != len(b.Required) { return false } diff --git a/protocol/schema_validate_test.go b/protocol/schema_validate_test.go index df81aae..af6ef50 100644 --- a/protocol/schema_validate_test.go +++ b/protocol/schema_validate_test.go @@ -160,7 +160,6 @@ func Test_Validate(t *testing.T) { }, Required: []string{"string"}, }}, false}, - // 嵌套匿名结构体测试 {"nested anonymous struct", args{data: map[string]any{ "user": map[string]any{ "name": "test", diff --git a/protocol/types.go b/protocol/types.go index 8f4b1f3..fca1dcb 100644 --- a/protocol/types.go +++ b/protocol/types.go @@ -1,6 +1,11 @@ package protocol -const Version = "2024-11-05" +const Version = "2025-03-26" + +var SupportedVersion = map[string]struct{}{ + "2024-11-05": {}, + "2025-03-26": {}, +} // Method represents the JSON-RPC method name type Method string diff --git a/server/call.go b/server/call.go index 2e0535a..8a1bc5b 100644 --- a/server/call.go +++ b/server/call.go @@ -98,7 +98,7 @@ func (server *Server) SendNotification4ResourcesUpdated(ctx context.Context, not func (server *Server) callClient(ctx context.Context, sessionID string, method protocol.Method, params protocol.ServerRequest) (json.RawMessage, error) { session, ok := server.sessionManager.GetSession(sessionID) if !ok { - return nil, pkg.ErrLackSession + return nil, fmt.Errorf("callClient: %w", pkg.ErrLackSession) } requestID := strconv.FormatInt(session.IncRequestID(), 10) @@ -107,7 +107,7 @@ func (server *Server) callClient(ctx context.Context, sessionID string, method p defer session.GetReqID2respChan().Remove(requestID) if err := server.sendMsgWithRequest(ctx, sessionID, requestID, method, params); err != nil { - return nil, err + return nil, fmt.Errorf("callClient: %w", err) } select { diff --git a/server/handle.go b/server/handle.go index 86c7f77..bb48199 100644 --- a/server/handle.go +++ b/server/handle.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "fmt" @@ -8,33 +9,42 @@ import ( "github.com/ThinkInAIXYZ/go-mcp/pkg" "github.com/ThinkInAIXYZ/go-mcp/protocol" + "github.com/ThinkInAIXYZ/go-mcp/transport" ) func (server *Server) handleRequestWithPing() (*protocol.PingResult, error) { return protocol.NewPingResult(), nil } -func (server *Server) handleRequestWithInitialize(sessionID string, rawParams json.RawMessage) (*protocol.InitializeResult, error) { +func (server *Server) handleRequestWithInitialize(ctx context.Context, sessionID string, rawParams json.RawMessage) (*protocol.InitializeResult, error) { var request *protocol.InitializeRequest if err := pkg.JSONUnmarshal(rawParams, &request); err != nil { return nil, err } - if request.ProtocolVersion != protocol.Version { + if _, ok := protocol.SupportedVersion[request.ProtocolVersion]; !ok { return nil, fmt.Errorf("protocol version not supported, supported version is %v", protocol.Version) } + protocolVersion := request.ProtocolVersion - s, ok := server.sessionManager.GetSession(sessionID) - if !ok { - return nil, pkg.ErrLackSession + if midVar, ok := ctx.Value(transport.SessionIDForReturnKey{}).(*transport.SessionIDForReturn); ok { + sessionID = server.sessionManager.CreateSession() + midVar.SessionID = sessionID + } + + if sessionID != "" { + s, ok := server.sessionManager.GetSession(sessionID) + if !ok { + return nil, pkg.ErrLackSession + } + s.SetClientInfo(&request.ClientInfo, &request.Capabilities) + s.SetReceivedInitRequest() } - s.SetClientInfo(&request.ClientInfo, &request.Capabilities) - s.SetReceivedInitRequest() return &protocol.InitializeResult{ ServerInfo: *server.serverInfo, Capabilities: *server.capabilities, - ProtocolVersion: protocol.Version, + ProtocolVersion: protocolVersion, Instructions: server.instructions, }, nil } @@ -62,7 +72,7 @@ func (server *Server) handleRequestWithListPrompts(rawParams json.RawMessage) (* }, nil } -func (server *Server) handleRequestWithGetPrompt(rawParams json.RawMessage) (*protocol.GetPromptResult, error) { +func (server *Server) handleRequestWithGetPrompt(ctx context.Context, rawParams json.RawMessage) (*protocol.GetPromptResult, error) { if server.capabilities.Prompts == nil { return nil, pkg.ErrServerNotSupport } @@ -76,7 +86,7 @@ func (server *Server) handleRequestWithGetPrompt(rawParams json.RawMessage) (*pr if !ok { return nil, fmt.Errorf("missing prompt, promptName=%s", request.Name) } - return entry.handler(request) + return entry.handler(ctx, request) } func (server *Server) handleRequestWithListResources(rawParams json.RawMessage) (*protocol.ListResourcesResult, error) { @@ -125,7 +135,7 @@ func (server *Server) handleRequestWithListResourceTemplates(rawParams json.RawM }, nil } -func (server *Server) handleRequestWithReadResource(rawParams json.RawMessage) (*protocol.ReadResourceResult, error) { +func (server *Server) handleRequestWithReadResource(ctx context.Context, rawParams json.RawMessage) (*protocol.ReadResourceResult, error) { if server.capabilities.Resources == nil { return nil, pkg.ErrServerNotSupport } @@ -156,7 +166,7 @@ func (server *Server) handleRequestWithReadResource(rawParams json.RawMessage) ( if handler == nil { return nil, fmt.Errorf("missing resource, resourceName=%s", request.URI) } - return handler(request) + return handler(ctx, request) } func matchesTemplate(uri string, template *uritemplate.Template) bool { @@ -220,7 +230,7 @@ func (server *Server) handleRequestWithListTools(rawParams json.RawMessage) (*pr return &protocol.ListToolsResult{Tools: tools}, nil } -func (server *Server) handleRequestWithCallTool(rawParams json.RawMessage) (*protocol.CallToolResult, error) { +func (server *Server) handleRequestWithCallTool(ctx context.Context, rawParams json.RawMessage) (*protocol.CallToolResult, error) { if server.capabilities.Tools == nil { return nil, pkg.ErrServerNotSupport } @@ -235,10 +245,14 @@ func (server *Server) handleRequestWithCallTool(rawParams json.RawMessage) (*pro return nil, fmt.Errorf("missing tool, toolName=%s", request.Name) } - return entry.handler(request) + return entry.handler(ctx, request) } func (server *Server) handleNotifyWithInitialized(sessionID string, rawParams json.RawMessage) error { + if sessionID == "" { + return nil + } + param := &protocol.InitializedNotification{} if len(rawParams) > 0 { if err := pkg.JSONUnmarshal(rawParams, param); err != nil { diff --git a/server/receive.go b/server/receive.go index 55cdb59..e8e8d54 100644 --- a/server/receive.go +++ b/server/receive.go @@ -2,6 +2,7 @@ package server import ( "context" + "encoding/json" "errors" "fmt" @@ -11,89 +12,83 @@ import ( "github.com/ThinkInAIXYZ/go-mcp/protocol" ) -func (server *Server) receive(_ context.Context, sessionID string, msg []byte) error { - if !server.sessionManager.IsExistSession(sessionID) { - return pkg.ErrLackSession +func (server *Server) receive(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error) { + if sessionID != "" && !server.sessionManager.IsActiveSession(sessionID) { + if server.sessionManager.IsClosedSession(sessionID) { + return nil, pkg.ErrSessionClosed + } + return nil, pkg.ErrLackSession } if !gjson.GetBytes(msg, "id").Exists() { notify := &protocol.JSONRPCNotification{} if err := pkg.JSONUnmarshal(msg, ¬ify); err != nil { - return err + return nil, err } - if notify.Method == protocol.NotificationInitialized { - if err := server.receiveNotify(sessionID, notify); err != nil { - notify.RawParams = nil // simplified log - server.logger.Errorf("receive notify:%+v error: %s", notify, err.Error()) - } - return nil + if err := server.receiveNotify(sessionID, notify); err != nil { + notify.RawParams = nil // simplified log + server.logger.Errorf("receive notify:%+v error: %s", notify, err.Error()) + return nil, err } - go func() { - defer pkg.Recover() - - if err := server.receiveNotify(sessionID, notify); err != nil { - notify.RawParams = nil // simplified log - server.logger.Errorf("receive notify:%+v error: %s", notify, err.Error()) - return - } - }() - return nil + return nil, nil } - // 判断 request和response + // case request or response if !gjson.GetBytes(msg, "method").Exists() { resp := &protocol.JSONRPCResponse{} if err := pkg.JSONUnmarshal(msg, &resp); err != nil { - return err + return nil, err + } + + if err := server.receiveResponse(sessionID, resp); err != nil { + resp.RawResult = nil // simplified log + server.logger.Errorf("receive response:%+v error: %s", resp, err.Error()) + return nil, err } - go func() { - defer pkg.Recover() - - if err := server.receiveResponse(sessionID, resp); err != nil { - resp.RawResult = nil // simplified log - server.logger.Errorf("receive response:%+v error: %s", resp, err.Error()) - return - } - }() - return nil + return nil, nil } req := &protocol.JSONRPCRequest{} if err := pkg.JSONUnmarshal(msg, &req); err != nil { - return err + return nil, err } if !req.IsValid() { - return pkg.ErrRequestInvalid + return nil, pkg.ErrRequestInvalid + } + + if sessionID != "" && req.Method != protocol.Initialize && req.Method != protocol.Ping { + if s, ok := server.sessionManager.GetSession(sessionID); !ok { + return nil, pkg.ErrLackSession + } else if !s.GetReady() { + return nil, pkg.ErrSessionHasNotInitialized + } } + server.inFlyRequest.Add(1) + if server.inShutdown.Load() { - defer server.inFlyRequest.Done() - return errors.New("server already shutdown") + server.inFlyRequest.Done() + return nil, errors.New("server already shutdown") } - go func() { + ch := make(chan []byte, 1) + go func(ctx context.Context) { defer pkg.Recover() defer server.inFlyRequest.Done() + defer close(ch) - if err := server.receiveRequest(sessionID, req); err != nil { - req.RawParams = nil // simplified log - server.logger.Errorf("receive request:%+v error: %s", req, err.Error()) + resp := server.receiveRequest(ctx, sessionID, req) + message, err := json.Marshal(resp) + if err != nil { + server.logger.Errorf("receive json marshal response:%+v error: %s", resp, err.Error()) return } - }() - - return nil + ch <- message + }(pkg.NewCancelShieldContext(ctx)) + return ch, nil } -func (server *Server) receiveRequest(sessionID string, request *protocol.JSONRPCRequest) error { - if request.Method != protocol.Initialize && request.Method != protocol.Ping { - if s, ok := server.sessionManager.GetSession(sessionID); !ok { - return pkg.ErrLackSession - } else if !s.GetReady() { - return pkg.ErrSessionHasNotInitialized - } - } - +func (server *Server) receiveRequest(ctx context.Context, sessionID string, request *protocol.JSONRPCRequest) *protocol.JSONRPCResponse { if request.Method != protocol.Ping { server.sessionManager.UpdateSessionLastActiveAt(sessionID) } @@ -107,17 +102,17 @@ func (server *Server) receiveRequest(sessionID string, request *protocol.JSONRPC case protocol.Ping: result, err = server.handleRequestWithPing() case protocol.Initialize: - result, err = server.handleRequestWithInitialize(sessionID, request.RawParams) + result, err = server.handleRequestWithInitialize(ctx, sessionID, request.RawParams) case protocol.PromptsList: result, err = server.handleRequestWithListPrompts(request.RawParams) case protocol.PromptsGet: - result, err = server.handleRequestWithGetPrompt(request.RawParams) + result, err = server.handleRequestWithGetPrompt(ctx, request.RawParams) case protocol.ResourcesList: result, err = server.handleRequestWithListResources(request.RawParams) case protocol.ResourceListTemplates: result, err = server.handleRequestWithListResourceTemplates(request.RawParams) case protocol.ResourcesRead: - result, err = server.handleRequestWithReadResource(request.RawParams) + result, err = server.handleRequestWithReadResource(ctx, request.RawParams) case protocol.ResourcesSubscribe: result, err = server.handleRequestWithSubscribeResourceChange(sessionID, request.RawParams) case protocol.ResourcesUnsubscribe: @@ -125,33 +120,35 @@ func (server *Server) receiveRequest(sessionID string, request *protocol.JSONRPC case protocol.ToolsList: result, err = server.handleRequestWithListTools(request.RawParams) case protocol.ToolsCall: - result, err = server.handleRequestWithCallTool(request.RawParams) + result, err = server.handleRequestWithCallTool(ctx, request.RawParams) default: err = fmt.Errorf("%w: method=%s", pkg.ErrMethodNotSupport, request.Method) } - ctx := context.Background() - if err != nil { + var code int switch { case errors.Is(err, pkg.ErrMethodNotSupport): - return server.sendMsgWithError(ctx, sessionID, request.ID, protocol.MethodNotFound, err.Error()) + code = protocol.MethodNotFound case errors.Is(err, pkg.ErrRequestInvalid): - return server.sendMsgWithError(ctx, sessionID, request.ID, protocol.InvalidRequest, err.Error()) + code = protocol.InvalidRequest case errors.Is(err, pkg.ErrJSONUnmarshal): - return server.sendMsgWithError(ctx, sessionID, request.ID, protocol.ParseError, err.Error()) + code = protocol.ParseError default: - return server.sendMsgWithError(ctx, sessionID, request.ID, protocol.InternalError, err.Error()) + code = protocol.InternalError } + return protocol.NewJSONRPCErrorResponse(request.ID, code, err.Error()) } - return server.sendMsgWithResponse(ctx, sessionID, request.ID, result) + return protocol.NewJSONRPCSuccessResponse(request.ID, result) } func (server *Server) receiveNotify(sessionID string, notify *protocol.JSONRPCNotification) error { - if s, ok := server.sessionManager.GetSession(sessionID); !ok { - return pkg.ErrLackSession - } else if !s.GetReady() && notify.Method != protocol.NotificationInitialized { - return pkg.ErrSessionHasNotInitialized + if sessionID != "" { + if s, ok := server.sessionManager.GetSession(sessionID); !ok { + return pkg.ErrLackSession + } else if notify.Method != protocol.NotificationInitialized && !s.GetReady() { + return pkg.ErrSessionHasNotInitialized + } } switch notify.Method { diff --git a/server/send.go b/server/send.go index a47cc0b..f7d19ee 100644 --- a/server/send.go +++ b/server/send.go @@ -28,20 +28,6 @@ func (server *Server) sendMsgWithRequest(ctx context.Context, sessionID string, return nil } -func (server *Server) sendMsgWithResponse(ctx context.Context, sessionID string, requestID protocol.RequestID, result protocol.ServerResponse) error { - resp := protocol.NewJSONRPCSuccessResponse(requestID, result) - - message, err := json.Marshal(resp) - if err != nil { - return err - } - - if err := server.transport.Send(ctx, sessionID, message); err != nil { - return fmt.Errorf("sendResponse: transport send: %w", err) - } - return nil -} - func (server *Server) sendMsgWithNotification(ctx context.Context, sessionID string, method protocol.Method, params protocol.ServerNotify) error { notify := protocol.NewJSONRPCNotification(method, params) @@ -55,21 +41,3 @@ func (server *Server) sendMsgWithNotification(ctx context.Context, sessionID str } return nil } - -func (server *Server) sendMsgWithError(ctx context.Context, sessionID string, requestID protocol.RequestID, code int, msg string) error { - if requestID == nil { - return fmt.Errorf("requestID can't is nil") - } - - resp := protocol.NewJSONRPCErrorResponse(requestID, code, msg) - - message, err := json.Marshal(resp) - if err != nil { - return err - } - - if err := server.transport.Send(ctx, sessionID, message); err != nil { - return fmt.Errorf("sendResponse: transport send: %w", err) - } - return nil -} diff --git a/server/server.go b/server/server.go index 4af102d..95a074f 100644 --- a/server/server.go +++ b/server/server.go @@ -85,6 +85,8 @@ func NewServer(t transport.ServerTransport, opts ...Option) (*Server, error) { opt(server) } + server.sessionManager.SetLogger(server.logger) + t.SetSessionManager(server.sessionManager) return server, nil @@ -108,7 +110,7 @@ type toolEntry struct { handler ToolHandlerFunc } -type ToolHandlerFunc func(*protocol.CallToolRequest) (*protocol.CallToolResult, error) +type ToolHandlerFunc func(context.Context, *protocol.CallToolRequest) (*protocol.CallToolResult, error) func (server *Server) RegisterTool(tool *protocol.Tool, toolHandler ToolHandlerFunc) { server.tools.Store(tool.Name, &toolEntry{tool: tool, handler: toolHandler}) @@ -135,7 +137,7 @@ type promptEntry struct { handler PromptHandlerFunc } -type PromptHandlerFunc func(*protocol.GetPromptRequest) (*protocol.GetPromptResult, error) +type PromptHandlerFunc func(context.Context, *protocol.GetPromptRequest) (*protocol.GetPromptResult, error) func (server *Server) RegisterPrompt(prompt *protocol.Prompt, promptHandler PromptHandlerFunc) { server.prompts.Store(prompt.Name, &promptEntry{prompt: prompt, handler: promptHandler}) @@ -162,7 +164,7 @@ type resourceEntry struct { handler ResourceHandlerFunc } -type ResourceHandlerFunc func(*protocol.ReadResourceRequest) (*protocol.ReadResourceResult, error) +type ResourceHandlerFunc func(context.Context, *protocol.ReadResourceRequest) (*protocol.ReadResourceResult, error) func (server *Server) RegisterResource(resource *protocol.Resource, resourceHandler ResourceHandlerFunc) { server.resources.Store(resource.URI, &resourceEntry{resource: resource, handler: resourceHandler}) @@ -236,7 +238,7 @@ func (server *Server) sessionDetection(ctx context.Context, sessionID string) er return nil } - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() if _, err := server.Ping(setSessionIDToCtx(ctx, sessionID), protocol.NewPingRequest()); err != nil { diff --git a/server/server_test.go b/server/server_test.go index eaa7f1e..7dc7905 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "bufio" + "context" "encoding/json" "io" "reflect" @@ -63,7 +64,7 @@ func TestServerHandle(t *testing.T) { Type: "text", Text: "pong", } - server.RegisterTool(testTool, func(_ *protocol.CallToolRequest) (*protocol.CallToolResult, error) { + server.RegisterTool(testTool, func(_ context.Context, _ *protocol.CallToolRequest) (*protocol.CallToolResult, error) { return &protocol.CallToolResult{ Content: []protocol.Content{testToolCallContent}, }, nil @@ -84,7 +85,7 @@ func TestServerHandle(t *testing.T) { testPromptGetResponse := &protocol.GetPromptResult{ Description: "test_prompt_description", } - server.RegisterPrompt(testPrompt, func(*protocol.GetPromptRequest) (*protocol.GetPromptResult, error) { + server.RegisterPrompt(testPrompt, func(context.Context, *protocol.GetPromptRequest) (*protocol.GetPromptResult, error) { return testPromptGetResponse, nil }) @@ -99,7 +100,7 @@ func TestServerHandle(t *testing.T) { MimeType: testResource.MimeType, Text: "test", } - server.RegisterResource(testResource, func(*protocol.ReadResourceRequest) (*protocol.ReadResourceResult, error) { + server.RegisterResource(testResource, func(context.Context, *protocol.ReadResourceRequest) (*protocol.ReadResourceResult, error) { return &protocol.ReadResourceResult{ Contents: []protocol.ResourceContents{ testResourceContent, @@ -112,7 +113,7 @@ func TestServerHandle(t *testing.T) { URITemplate: "file:///{path}", Name: "test", } - if err := server.RegisterResourceTemplate(testResourceTemplate, func(*protocol.ReadResourceRequest) (*protocol.ReadResourceResult, error) { + if err := server.RegisterResourceTemplate(testResourceTemplate, func(context.Context, *protocol.ReadResourceRequest) (*protocol.ReadResourceResult, error) { return &protocol.ReadResourceResult{ Contents: []protocol.ResourceContents{ testResourceContent, @@ -361,7 +362,7 @@ func TestServerNotify(t *testing.T) { name: "test_tools_changed_notify", method: protocol.NotificationToolsListChanged, f: func() { - server.RegisterTool(testTool, func(_ *protocol.CallToolRequest) (*protocol.CallToolResult, error) { + server.RegisterTool(testTool, func(context.Context, *protocol.CallToolRequest) (*protocol.CallToolResult, error) { return &protocol.CallToolResult{ Content: []protocol.Content{testToolCallContent}, }, nil @@ -373,7 +374,7 @@ func TestServerNotify(t *testing.T) { name: "test_prompts_changed_notify", method: protocol.NotificationPromptsListChanged, f: func() { - server.RegisterPrompt(testPrompt, func(*protocol.GetPromptRequest) (*protocol.GetPromptResult, error) { + server.RegisterPrompt(testPrompt, func(context.Context, *protocol.GetPromptRequest) (*protocol.GetPromptResult, error) { return testPromptGetResponse, nil }) }, @@ -383,7 +384,7 @@ func TestServerNotify(t *testing.T) { name: "test_resources_changed_notify", method: protocol.NotificationResourcesListChanged, f: func() { - server.RegisterResource(testResource, func(*protocol.ReadResourceRequest) (*protocol.ReadResourceResult, error) { + server.RegisterResource(testResource, func(context.Context, *protocol.ReadResourceRequest) (*protocol.ReadResourceResult, error) { return &protocol.ReadResourceResult{ Contents: []protocol.ResourceContents{ testResourceContent, @@ -397,7 +398,7 @@ func TestServerNotify(t *testing.T) { name: "test_resources_template_changed_notify", method: protocol.NotificationResourcesListChanged, f: func() { - if err := server.RegisterResourceTemplate(testResourceTemplate, func(*protocol.ReadResourceRequest) (*protocol.ReadResourceResult, error) { + if err := server.RegisterResourceTemplate(testResourceTemplate, func(context.Context, *protocol.ReadResourceRequest) (*protocol.ReadResourceResult, error) { return &protocol.ReadResourceResult{ Contents: []protocol.ResourceContents{ testResourceContent, diff --git a/server/session/manager.go b/server/session/manager.go index ff8174e..d2db05a 100644 --- a/server/session/manager.go +++ b/server/session/manager.go @@ -4,14 +4,19 @@ import ( "context" "time" + "github.com/google/uuid" + "github.com/ThinkInAIXYZ/go-mcp/pkg" ) type Manager struct { - sessions pkg.SyncMap[*State] + activeSessions pkg.SyncMap[*State] + closedSessions pkg.SyncMap[struct{}] stopHeartbeat chan struct{} + logger pkg.Logger + detection func(ctx context.Context, sessionID string) error maxIdleTime time.Duration } @@ -20,6 +25,7 @@ func NewManager(detection func(ctx context.Context, sessionID string) error) *Ma return &Manager{ detection: detection, stopHeartbeat: make(chan struct{}), + logger: pkg.DefaultLogger, } } @@ -27,42 +33,65 @@ func (m *Manager) SetMaxIdleTime(d time.Duration) { m.maxIdleTime = d } -func (m *Manager) CreateSession(sessionID string) { +func (m *Manager) SetLogger(logger pkg.Logger) { + m.logger = logger +} + +func (m *Manager) CreateSession() string { + sessionID := uuid.NewString() state := NewState() - m.sessions.Store(sessionID, state) + m.activeSessions.Store(sessionID, state) + return sessionID } -func (m *Manager) IsExistSession(sessionID string) bool { - _, has := m.sessions.Load(sessionID) +func (m *Manager) IsActiveSession(sessionID string) bool { + _, has := m.activeSessions.Load(sessionID) + return has +} + +func (m *Manager) IsClosedSession(sessionID string) bool { + _, has := m.closedSessions.Load(sessionID) return has } func (m *Manager) GetSession(sessionID string) (*State, bool) { - state, has := m.sessions.Load(sessionID) + if sessionID == "" { + return nil, false + } + state, has := m.activeSessions.Load(sessionID) if !has { return nil, false } return state, true } -func (m *Manager) SendMessage(ctx context.Context, sessionID string, message []byte) error { +func (m *Manager) OpenMessageQueueForSend(sessionID string) error { + state, has := m.GetSession(sessionID) + if !has { + return pkg.ErrLackSession + } + state.openMessageQueueForSend() + return nil +} + +func (m *Manager) EnqueueMessageForSend(ctx context.Context, sessionID string, message []byte) error { state, has := m.GetSession(sessionID) if !has { return pkg.ErrLackSession } - return state.sendMessage(ctx, message) + return state.enqueueMessage(ctx, message) } -func (m *Manager) GetMessageForSend(ctx context.Context, sessionID string) ([]byte, error) { +func (m *Manager) DequeueMessageForSend(ctx context.Context, sessionID string) ([]byte, error) { state, has := m.GetSession(sessionID) if !has { return nil, pkg.ErrLackSession } - return state.getMessageForSend(ctx) + return state.dequeueMessage(ctx) } func (m *Manager) UpdateSessionLastActiveAt(sessionID string) { - state, ok := m.sessions.Load(sessionID) + state, ok := m.activeSessions.Load(sessionID) if !ok { return } @@ -70,21 +99,18 @@ func (m *Manager) UpdateSessionLastActiveAt(sessionID string) { } func (m *Manager) CloseSession(sessionID string) { - state, ok := m.sessions.LoadAndDelete(sessionID) + state, ok := m.activeSessions.LoadAndDelete(sessionID) if !ok { return } state.Close() + m.closedSessions.Store(sessionID, struct{}{}) } func (m *Manager) CloseAllSessions() { - m.sessions.Range(func(sessionID string, _ *State) bool { + m.activeSessions.Range(func(sessionID string, _ *State) bool { // Here we load the session again to prevent concurrency conflicts with CloseSession, which may cause repeated close chan - state, ok := m.sessions.LoadAndDelete(sessionID) - if !ok { - return true - } - state.Close() + m.CloseSession(sessionID) return true }) } @@ -99,17 +125,20 @@ func (m *Manager) StartHeartbeatAndCleanInvalidSessions() { return case <-ticker.C: now := time.Now() - m.sessions.Range(func(sessionID string, state *State) bool { + m.activeSessions.Range(func(sessionID string, state *State) bool { if m.maxIdleTime != 0 && now.Sub(state.lastActiveAt) > m.maxIdleTime { + m.logger.Infof("session expire, session id: %v", sessionID) m.CloseSession(sessionID) return true } + var err error for i := 0; i < 3; i++ { - if err := m.detection(context.Background(), sessionID); err == nil { + if err = m.detection(context.Background(), sessionID); err == nil { return true } } + m.logger.Infof("session detection fail, session id: %v, fail reason: %+v", sessionID, err) m.CloseSession(sessionID) return true }) @@ -122,12 +151,12 @@ func (m *Manager) StopHeartbeat() { } func (m *Manager) RangeSessions(f func(sessionID string, state *State) bool) { - m.sessions.Range(f) + m.activeSessions.Range(f) } func (m *Manager) IsEmpty() bool { isEmpty := true - m.sessions.Range(func(string, *State) bool { + m.activeSessions.Range(func(string, *State) bool { isEmpty = false return false }) diff --git a/server/session/state.go b/server/session/state.go index fd3ca08..4f2dab7 100644 --- a/server/session/state.go +++ b/server/session/state.go @@ -13,6 +13,8 @@ import ( "github.com/ThinkInAIXYZ/go-mcp/protocol" ) +var ErrQueueNotOpened = errors.New("queue has not been opened") + type State struct { lastActiveAt time.Time @@ -38,7 +40,6 @@ type State struct { func NewState() *State { return &State{ lastActiveAt: time.Now(), - sendChan: make(chan []byte, 64), reqID2respChan: cmap.New[chan *protocol.JSONRPCResponse](), subscribedResources: cmap.New[struct{}](), receivedInitRequest: pkg.NewAtomicBool(), @@ -85,14 +86,26 @@ func (s *State) Close() { defer s.mu.Unlock() s.closed.Store(true) - close(s.sendChan) + + if s.sendChan != nil { + close(s.sendChan) + } } func (s *State) updateLastActiveAt() { s.lastActiveAt = time.Now() } -func (s *State) sendMessage(ctx context.Context, message []byte) error { +func (s *State) openMessageQueueForSend() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.sendChan == nil { + s.sendChan = make(chan []byte, 64) + } +} + +func (s *State) enqueueMessage(ctx context.Context, message []byte) error { s.mu.RLock() defer s.mu.RUnlock() @@ -100,6 +113,10 @@ func (s *State) sendMessage(ctx context.Context, message []byte) error { return errors.New("session already closed") } + if s.sendChan == nil { + return ErrQueueNotOpened + } + select { case s.sendChan <- message: return nil @@ -108,7 +125,14 @@ func (s *State) sendMessage(ctx context.Context, message []byte) error { } } -func (s *State) getMessageForSend(ctx context.Context) ([]byte, error) { +func (s *State) dequeueMessage(ctx context.Context) ([]byte, error) { + s.mu.RLock() + if s.sendChan == nil { + s.mu.RUnlock() + return nil, ErrQueueNotOpened + } + s.mu.RUnlock() + select { case <-ctx.Done(): return nil, ctx.Err() diff --git a/tests/streamable_http_test.go b/tests/streamable_http_test.go new file mode 100644 index 0000000..6269994 --- /dev/null +++ b/tests/streamable_http_test.go @@ -0,0 +1,55 @@ +package tests + +import ( + "fmt" + "os" + "os/exec" + "strconv" + "testing" + + "github.com/ThinkInAIXYZ/go-mcp/transport" +) + +func TestStreamableHTTPWithStateless(t *testing.T) { + port, err := getAvailablePort() + if err != nil { + t.Fatalf("Failed to get available port: %v", err) + } + + transportClient, err := transport.NewStreamableHTTPClientTransport(fmt.Sprintf("http://127.0.0.1:%d/mcp", port)) + if err != nil { + t.Fatalf("Failed to create transport client: %v", err) + } + + test(t, func() error { return runStreamableHTTPServer(port, transport.Stateless) }, transportClient) +} + +func TestStreamableHTTPWithStateful(t *testing.T) { + port, err := getAvailablePort() + if err != nil { + t.Fatalf("Failed to get available port: %v", err) + } + + transportClient, err := transport.NewStreamableHTTPClientTransport(fmt.Sprintf("http://127.0.0.1:%d/mcp", port)) + if err != nil { + t.Fatalf("Failed to create transport client: %v", err) + } + + test(t, func() error { return runStreamableHTTPServer(port, transport.Stateful) }, transportClient) +} + +func runStreamableHTTPServer(port int, stateful transport.StateMode) error { + mockServerTrPath, err := compileMockStdioServerTr() + if err != nil { + return err + } + fmt.Println(mockServerTrPath) + + defer func(name string) { + if err := os.Remove(name); err != nil { + fmt.Printf("failed to remove mock server: %v\n", err) + } + }(mockServerTrPath) + + return exec.Command(mockServerTrPath, "-transport", "streamable_http", "-port", strconv.Itoa(port), "-state_mode", string(stateful)).Run() +} diff --git a/transport/mock_server.go b/transport/mock_server.go index 86a20b2..67e6ee0 100644 --- a/transport/mock_server.go +++ b/transport/mock_server.go @@ -10,13 +10,13 @@ import ( "github.com/ThinkInAIXYZ/go-mcp/pkg" ) -const mockSessionID = "mock" - type mockServerTransport struct { receiver serverReceiver in io.ReadCloser out io.Writer + sessionID string + sessionManager sessionManager logger pkg.Logger @@ -39,7 +39,7 @@ func (t *mockServerTransport) Run() error { ctx, cancel := context.WithCancel(context.Background()) t.cancel = cancel - t.sessionManager.CreateSession(mockSessionID) + t.sessionID = t.sessionManager.CreateSession() t.receive(ctx) @@ -87,10 +87,28 @@ func (t *mockServerTransport) receive(ctx context.Context) { case <-ctx.Done(): return default: - if err := t.receiver.Receive(ctx, mockSessionID, s.Bytes()); err != nil { + outputMsgCh, err := t.receiver.Receive(ctx, t.sessionID, s.Bytes()) + if err != nil { t.logger.Errorf("receiver failed: %v", err) continue } + + if outputMsgCh == nil { + continue + } + + go func() { + defer pkg.Recover() + + msg := <-outputMsgCh + if len(msg) == 0 { + t.logger.Errorf("handle request fail") + return + } + if err := t.Send(context.Background(), t.sessionID, msg); err != nil { + t.logger.Errorf("Failed to send message: %v", err) + } + }() } } diff --git a/transport/sse_client.go b/transport/sse_client.go index 44934ea..ff10826 100644 --- a/transport/sse_client.go +++ b/transport/sse_client.go @@ -85,7 +85,7 @@ func (t *sseClientTransport) Start() error { go func() { defer pkg.Recover() - req, err := http.NewRequest(http.MethodGet, t.serverURL.String(), nil) + req, err := http.NewRequestWithContext(t.ctx, http.MethodGet, t.serverURL.String(), nil) if err != nil { errChan <- fmt.Errorf("failed to create request: %w", err) return @@ -107,17 +107,6 @@ func (t *sseClientTransport) Start() error { return } - go func() { - defer pkg.Recover() - - <-t.ctx.Done() - - if err := resp.Body.Close(); err != nil { - t.logger.Errorf("failed to close SSE stream body: %w", err) - return - } - }() - t.readSSE(resp.Body) close(t.sseConnectClose) diff --git a/transport/sse_client_test.go b/transport/sse_client_test.go deleted file mode 100644 index bc90318..0000000 --- a/transport/sse_client_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package transport - -import ( - "net/url" - "testing" - - "github.com/ThinkInAIXYZ/go-mcp/pkg" -) - -func Test_sseClientTransport_handleSSEEvent(t1 *testing.T) { - type fields struct { - serverURL *url.URL - logger pkg.Logger - } - type args struct { - event string - data string - } - tests := []struct { - name string - fields fields - args args - want string - }{ - { - name: "1", - fields: fields{ - serverURL: func() *url.URL { - uri, err := url.Parse("https://api.baidu.com/mcp") - if err != nil { - panic(err) - } - return uri - }(), - logger: pkg.DefaultLogger, - }, - args: args{ - event: "endpoint", - data: "/sse/messages", - }, - want: "https://api.baidu.com/sse/messages", - }, - { - name: "2", - fields: fields{ - serverURL: func() *url.URL { - uri, err := url.Parse("https://api.baidu.com/mcp") - if err != nil { - panic(err) - } - return uri - }(), - logger: pkg.DefaultLogger, - }, - args: args{ - event: "endpoint", - data: "https://api.google.com/sse/messages", - }, - want: "https://api.google.com/sse/messages", - }, - } - for _, tt := range tests { - t1.Run(tt.name, func(t1 *testing.T) { - t := &sseClientTransport{ - serverURL: tt.fields.serverURL, - logger: tt.fields.logger, - endpointChan: make(chan struct{}), - } - t.handleSSEEvent(tt.args.event, tt.args.data) - if t.messageEndpoint.String() != tt.want { - t1.Errorf("handleSSEEvent() = %v, want %v", t.messageEndpoint.String(), tt.want) - } - }) - } -} diff --git a/transport/sse_server.go b/transport/sse_server.go index c32a1a8..912d657 100644 --- a/transport/sse_server.go +++ b/transport/sse_server.go @@ -12,8 +12,6 @@ import ( "sync" "time" - "github.com/google/uuid" - "github.com/ThinkInAIXYZ/go-mcp/pkg" ) @@ -150,6 +148,7 @@ func NewSSEServerTransport(addr string, opts ...SSEServerTransportOption) (Serve func NewSSEServerTransportAndHandler(messageEndpointURL string, opts ...SSEServerTransportAndHandlerOption, ) (ServerTransport, *SSEHandler, error) { //nolint:whitespace + ctx, cancel := context.WithCancel(context.Background()) t := &sseServerTransport{ @@ -183,11 +182,10 @@ func (t *sseServerTransport) Send(ctx context.Context, sessionID string, msg Mes select { case <-t.ctx.Done(): - return ctx.Err() + return t.ctx.Err() default: + return t.sessionManager.EnqueueMessageForSend(ctx, sessionID, msg) } - - return t.sessionManager.SendMessage(ctx, sessionID, msg) } func (t *sseServerTransport) SetReceiver(receiver serverReceiver) { @@ -200,10 +198,10 @@ func (t *sseServerTransport) SetSessionManager(manager sessionManager) { // handleSSE handles incoming SSE connections from clients and sends messages to them. func (t *sseServerTransport) handleSSE(w http.ResponseWriter, r *http.Request) { - defer pkg.Recover() + defer pkg.RecoverWithFunc(func(_ any) { + t.writeError(w, http.StatusInternalServerError, "Internal server error") + }) - //nolint:govet // Ignore error since we're just logging - requestCtx := r.Context() // Set headers for SSE w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -212,14 +210,13 @@ func (t *sseServerTransport) handleSSE(w http.ResponseWriter, r *http.Request) { // Create flush-supporting writer flusher, ok := w.(http.Flusher) if !ok { - http.Error(w, "Streaming not supported", http.StatusInternalServerError) + t.writeError(w, http.StatusInternalServerError, "Streaming not supported") return } w.WriteHeader(http.StatusOK) // Create an SSE connection - sessionID := uuid.New().String() - t.sessionManager.CreateSession(sessionID) + sessionID := t.sessionManager.CreateSession() defer t.sessionManager.CloseSession(sessionID) uri := fmt.Sprintf("%s?sessionID=%s", t.messageEndpointURL, sessionID) @@ -230,12 +227,18 @@ func (t *sseServerTransport) handleSSE(w http.ResponseWriter, r *http.Request) { } flusher.Flush() + if err := t.sessionManager.OpenMessageQueueForSend(sessionID); err != nil { + t.logger.Errorf("handleSSE sessionID=%s OpenMessageQueueForSend fail: %v", sessionID, err) + return + } + for { - msg, err := t.sessionManager.GetMessageForSend(requestCtx, sessionID) + msg, err := t.sessionManager.DequeueMessageForSend(r.Context(), sessionID) if err != nil { - if !errors.Is(err, pkg.ErrSendEOF) { - t.logger.Debugf("sse connect request err: %+v, sessionID=%s", err.Error(), sessionID) + if errors.Is(err, pkg.ErrSendEOF) { + return } + t.logger.Debugf("sse connect dequeueMessage err: %+v, sessionID=%s", err.Error(), sessionID) return } @@ -267,29 +270,45 @@ func (t *sseServerTransport) handleMessage(w http.ResponseWriter, r *http.Reques return } - ctx := r.Context() // Parse message as raw JSON - bs, err := io.ReadAll(r.Body) + inputMsg, err := io.ReadAll(r.Body) if err != nil { t.writeError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request: %v", err)) return } - if err = t.receiver.Receive(ctx, sessionID, bs); err != nil { + + ctx := pkg.NewCancelShieldContext(r.Context()) + outputMsgCh, err := t.receiver.Receive(ctx, sessionID, inputMsg) + if err != nil { t.writeError(w, http.StatusBadRequest, fmt.Sprintf("Failed to receive: %v", err)) return } - // Process message through MCPServer - // For notifications, just send 202 Accepted with no body - t.logger.Debugf("Received message: %s", string(bs)) - // ref: https://github.com/encode/httpx/blob/master/httpx/_status_codes.py#L8 - // in official httpx, 2xx is success + t.logger.Debugf("Received message: %s", string(inputMsg)) w.WriteHeader(http.StatusAccepted) + + if outputMsgCh == nil { + return + } + + go func() { + defer pkg.Recover() + + msg := <-outputMsgCh + if len(msg) == 0 { + t.logger.Errorf("handle request fail") + return + } + if err := t.Send(context.Background(), sessionID, msg); err != nil { + t.logger.Errorf("Failed to send message: %v", err) + } + }() } // writeError writes a JSON-RPC error response with the given error details. func (t *sseServerTransport) writeError(w http.ResponseWriter, code int, message string) { - t.logger.Errorf("sseServerTransport writeError: code: %d, message: %s", code, message) + t.logger.Errorf("sseServerTransport Error: code: %d, message: %s", code, message) + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(code) if _, err := w.Write([]byte(message)); err != nil { diff --git a/transport/sse_server_test.go b/transport/sse_server_test.go deleted file mode 100644 index 81a0366..0000000 --- a/transport/sse_server_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package transport - -import ( - "net/url" - "testing" -) - -func Test_joinPath(t *testing.T) { - type args struct { - u *url.URL - elem []string - } - tests := []struct { - name string - args args - want string - }{ - { - name: "1", - args: args{ - u: func() *url.URL { - uri, err := url.Parse("https://google.com/api/v1") - if err != nil { - panic(err) - } - return uri - }(), - elem: []string{"/test"}, - }, - want: "https://google.com/api/v1/test", - }, - { - name: "2", - args: args{ - u: func() *url.URL { - uri, err := url.Parse("/api/v1") - if err != nil { - panic(err) - } - return uri - }(), - elem: []string{"/test"}, - }, - want: "/api/v1/test", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - joinPath(tt.args.u, tt.args.elem...) - if got := tt.args.u.String(); got != tt.want { - t.Errorf("joinPath() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/transport/sse_test.go b/transport/sse_test.go index 07313a0..33e6ee5 100644 --- a/transport/sse_test.go +++ b/transport/sse_test.go @@ -5,8 +5,11 @@ import ( "log" "net" "net/http" + "net/url" "testing" "time" + + "github.com/ThinkInAIXYZ/go-mcp/pkg" ) func TestSSE(t *testing.T) { @@ -23,13 +26,13 @@ func TestSSE(t *testing.T) { } serverAddr := fmt.Sprintf("127.0.0.1:%d", port) - clientURL := fmt.Sprintf("http://%s/sse", serverAddr) + serverURL := fmt.Sprintf("http://%s/sse", serverAddr) if svr, err = NewSSEServerTransport(serverAddr); err != nil { t.Fatalf("NewSSEServerTransport failed: %v", err) } - if client, err = NewSSEClientTransport(clientURL); err != nil { + if client, err = NewSSEClientTransport(serverURL); err != nil { t.Fatalf("NewSSEClientTransport failed: %v", err) } @@ -101,3 +104,119 @@ func getAvailablePort() (int, error) { port := addr.Addr().(*net.TCPAddr).Port return port, nil } + +func Test_joinPath(t *testing.T) { + type args struct { + u *url.URL + elem []string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "1", + args: args{ + u: func() *url.URL { + uri, err := url.Parse("https://google.com/api/v1") + if err != nil { + panic(err) + } + return uri + }(), + elem: []string{"/test"}, + }, + want: "https://google.com/api/v1/test", + }, + { + name: "2", + args: args{ + u: func() *url.URL { + uri, err := url.Parse("/api/v1") + if err != nil { + panic(err) + } + return uri + }(), + elem: []string{"/test"}, + }, + want: "/api/v1/test", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + joinPath(tt.args.u, tt.args.elem...) + if got := tt.args.u.String(); got != tt.want { + t.Errorf("joinPath() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_sseClientTransport_handleSSEEvent(t1 *testing.T) { + type fields struct { + serverURL *url.URL + logger pkg.Logger + } + type args struct { + event string + data string + } + tests := []struct { + name string + fields fields + args args + want string + }{ + { + name: "1", + fields: fields{ + serverURL: func() *url.URL { + uri, err := url.Parse("https://api.baidu.com/mcp") + if err != nil { + panic(err) + } + return uri + }(), + logger: pkg.DefaultLogger, + }, + args: args{ + event: "endpoint", + data: "/sse/messages", + }, + want: "https://api.baidu.com/sse/messages", + }, + { + name: "2", + fields: fields{ + serverURL: func() *url.URL { + uri, err := url.Parse("https://api.baidu.com/mcp") + if err != nil { + panic(err) + } + return uri + }(), + logger: pkg.DefaultLogger, + }, + args: args{ + event: "endpoint", + data: "https://api.google.com/sse/messages", + }, + want: "https://api.google.com/sse/messages", + }, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + t := &sseClientTransport{ + serverURL: tt.fields.serverURL, + logger: tt.fields.logger, + endpointChan: make(chan struct{}), + } + t.handleSSEEvent(tt.args.event, tt.args.data) + if t.messageEndpoint.String() != tt.want { + t1.Errorf("handleSSEEvent() = %v, want %v", t.messageEndpoint.String(), tt.want) + } + }) + } +} diff --git a/transport/stdio_server.go b/transport/stdio_server.go index 8edfabe..6e73429 100644 --- a/transport/stdio_server.go +++ b/transport/stdio_server.go @@ -12,8 +12,6 @@ import ( "github.com/ThinkInAIXYZ/go-mcp/pkg" ) -const stdioSessionID = "stdio" - type StdioServerTransportOption func(*stdioServerTransport) func WithStdioServerOptionLogger(log pkg.Logger) StdioServerTransportOption { @@ -28,6 +26,7 @@ type stdioServerTransport struct { writer io.Writer sessionManager sessionManager + sessionID string logger pkg.Logger @@ -54,7 +53,7 @@ func (t *stdioServerTransport) Run() error { ctx, cancel := context.WithCancel(context.Background()) t.cancel = cancel - t.sessionManager.CreateSession(stdioSessionID) + t.sessionID = t.sessionManager.CreateSession() t.receive(ctx) @@ -108,10 +107,29 @@ func (t *stdioServerTransport) receive(ctx context.Context) { t.logger.Debugf("skipping empty message") continue } - if err := t.receiver.Receive(ctx, stdioSessionID, s.Bytes()); err != nil { + + outputMsgCh, err := t.receiver.Receive(ctx, t.sessionID, s.Bytes()) + if err != nil { t.logger.Errorf("receiver failed: %v", err) continue } + + if outputMsgCh == nil { + continue + } + + go func() { + defer pkg.Recover() + + msg := <-outputMsgCh + if len(msg) == 0 { + t.logger.Errorf("handle request fail") + return + } + if err := t.Send(context.Background(), t.sessionID, msg); err != nil { + t.logger.Errorf("Failed to send message: %v", err) + } + }() } } diff --git a/transport/streamable_http_client.go b/transport/streamable_http_client.go new file mode 100644 index 0000000..19a093d --- /dev/null +++ b/transport/streamable_http_client.go @@ -0,0 +1,282 @@ +package transport + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/ThinkInAIXYZ/go-mcp/pkg" +) + +const sessionIDHeader = "Mcp-Session-Id" + +// const eventIDHeader = "Last-Event-ID" + +type StreamableHTTPClientTransportOption func(*streamableHTTPClientTransport) + +func WithStreamableHTTPClientOptionReceiveTimeout(timeout time.Duration) StreamableHTTPClientTransportOption { + return func(t *streamableHTTPClientTransport) { + t.receiveTimeout = timeout + } +} + +func WithStreamableHTTPClientOptionHTTPClient(client *http.Client) StreamableHTTPClientTransportOption { + return func(t *streamableHTTPClientTransport) { + t.client = client + } +} + +func WithStreamableHTTPClientOptionLogger(log pkg.Logger) StreamableHTTPClientTransportOption { + return func(t *streamableHTTPClientTransport) { + t.logger = log + } +} + +type streamableHTTPClientTransport struct { + ctx context.Context + cancel context.CancelFunc + + serverURL *url.URL + receiver clientReceiver + sessionID *pkg.AtomicString + + // options + logger pkg.Logger + receiveTimeout time.Duration + client *http.Client + + sseInFlyConnect sync.WaitGroup +} + +func NewStreamableHTTPClientTransport(serverURL string, opts ...StreamableHTTPClientTransportOption) (ClientTransport, error) { + parsedURL, err := url.Parse(serverURL) + if err != nil { + return nil, fmt.Errorf("failed to parse server URL: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + + t := &streamableHTTPClientTransport{ + ctx: ctx, + cancel: cancel, + serverURL: parsedURL, + sessionID: pkg.NewAtomicString(), + logger: pkg.DefaultLogger, + receiveTimeout: time.Second * 30, + client: http.DefaultClient, + } + + for _, opt := range opts { + opt(t) + } + + return t, nil +} + +func (t *streamableHTTPClientTransport) Start() error { + // Start a GET stream for server-initiated messages + t.sseInFlyConnect.Add(1) + go func() { + defer pkg.Recover() + defer t.sseInFlyConnect.Done() + + t.startSSEStream() + }() + return nil +} + +func (t *streamableHTTPClientTransport) Send(ctx context.Context, msg Message) error { + req, err := http.NewRequestWithContext(t.ctx, http.MethodPost, t.serverURL.String(), bytes.NewReader(msg)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + if sessionID := t.sessionID.Load(); sessionID != "" { + req.Header.Set(sessionIDHeader, sessionID) + } + + resp, err := t.client.Do(req) //nolint:bodyclose + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + if resp.Header.Get("Content-Type") != "text/event-stream" { + defer resp.Body.Close() + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + if req.Header.Get(sessionIDHeader) != "" && resp.StatusCode == http.StatusNotFound { + return pkg.ErrSessionClosed + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + return fmt.Errorf("unexpected status code: %d, status: %s, body=%s", resp.StatusCode, resp.Status, body) + } + + // Handle session ID if provided in response + if respSessionID := resp.Header.Get(sessionIDHeader); respSessionID != "" { + t.sessionID.Store(respSessionID) + } + + // Handle different response types + switch resp.Header.Get("Content-Type") { + case "text/event-stream": + go func() { + defer pkg.Recover() + + t.sseInFlyConnect.Add(1) + defer t.sseInFlyConnect.Done() + + t.handleSSEStream(resp.Body) + }() + return nil + case "application/json": + if resp.StatusCode == http.StatusAccepted { // Handle immediate JSON response + return nil + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + if err = t.receiver.Receive(ctx, body); err != nil { + return fmt.Errorf("failed to process response: %w", err) + } + return nil + default: + return fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) + } +} + +func (t *streamableHTTPClientTransport) startSSEStream() { + for { + select { + case <-t.ctx.Done(): + return + case <-time.After(time.Second): + sessionID := t.sessionID.Load() + if sessionID == "" { + continue // Try again after 1 second, waiting for the POST request to initialize the SessionID to complete + } + + req, err := http.NewRequestWithContext(t.ctx, http.MethodGet, t.serverURL.String(), nil) + if err != nil { + t.logger.Errorf("failed to create SSE request: %v", err) + return + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set(sessionIDHeader, sessionID) + + resp, err := t.client.Do(req) + if err != nil { + t.logger.Errorf("failed to connect to SSE stream: %v", err) + continue + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + resp.Body.Close() + + switch resp.StatusCode { + case http.StatusMethodNotAllowed: + t.logger.Infof("server does not support SSE streaming") + return + case http.StatusNotFound: + t.logger.Infof("%+v", pkg.ErrSessionClosed) + continue // Try again after 1 second, waiting for the POST request again to initialize the SessionID to complete + default: + t.logger.Infof("unexpected status code: %d, status: %s", resp.StatusCode, resp.Status) + return + } + } + + t.handleSSEStream(resp.Body) + } + } +} + +func (t *streamableHTTPClientTransport) handleSSEStream(reader io.ReadCloser) { + defer reader.Close() + + br := bufio.NewReader(reader) + var data string + + for { + line, err := br.ReadString('\n') + if err != nil { + if err == io.EOF { + // Process any pending event before exit + if data != "" { + t.processSSEEvent(data) + } + break + } + select { + case <-t.ctx.Done(): + return + default: + t.logger.Errorf("SSE stream error: %v", err) + return + } + } + + line = strings.TrimRight(line, "\r\n") + + if line == "" { + // Empty line means end of event + if data != "" { + t.processSSEEvent(data) + _, data = "", "" + } + continue + } + + if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } +} + +func (t *streamableHTTPClientTransport) processSSEEvent(data string) { + ctx, cancel := context.WithTimeout(t.ctx, t.receiveTimeout) + defer cancel() + + if err := t.receiver.Receive(ctx, []byte(data)); err != nil { + t.logger.Errorf("Error processing SSE event: %v", err) + } +} + +func (t *streamableHTTPClientTransport) SetReceiver(receiver clientReceiver) { + t.receiver = receiver +} + +func (t *streamableHTTPClientTransport) Close() error { + t.cancel() + + t.sseInFlyConnect.Wait() + + if sessionID := t.sessionID.Load(); sessionID != "" { + req, err := http.NewRequest(http.MethodDelete, t.serverURL.String(), nil) + if err != nil { + return err + } + req.Header.Set(sessionIDHeader, sessionID) + resp, err := t.client.Do(req) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + defer resp.Body.Close() + } + + return nil +} diff --git a/transport/streamable_http_server.go b/transport/streamable_http_server.go new file mode 100644 index 0000000..6f66738 --- /dev/null +++ b/transport/streamable_http_server.go @@ -0,0 +1,373 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/ThinkInAIXYZ/go-mcp/pkg" + "github.com/ThinkInAIXYZ/go-mcp/protocol" +) + +type StateMode string + +const ( + Stateful StateMode = "stateful" + Stateless StateMode = "stateless" +) + +type SessionIDForReturnKey struct{} + +type SessionIDForReturn struct { + SessionID string +} + +type StreamableHTTPServerTransportOption func(*streamableHTTPServerTransport) + +func WithStreamableHTTPServerTransportOptionLogger(logger pkg.Logger) StreamableHTTPServerTransportOption { + return func(t *streamableHTTPServerTransport) { + t.logger = logger + } +} + +func WithStreamableHTTPServerTransportOptionEndpoint(endpoint string) StreamableHTTPServerTransportOption { + return func(t *streamableHTTPServerTransport) { + t.mcpEndpoint = endpoint + } +} + +func WithStreamableHTTPServerTransportOptionStateMode(mode StateMode) StreamableHTTPServerTransportOption { + return func(t *streamableHTTPServerTransport) { + t.stateMode = mode + } +} + +type StreamableHTTPServerTransportAndHandlerOption func(*streamableHTTPServerTransport) + +func WithStreamableHTTPServerTransportAndHandlerOptionLogger(logger pkg.Logger) StreamableHTTPServerTransportAndHandlerOption { + return func(t *streamableHTTPServerTransport) { + t.logger = logger + } +} + +func WithStreamableHTTPServerTransportAndHandlerOptionStateMode(mode StateMode) StreamableHTTPServerTransportAndHandlerOption { + return func(t *streamableHTTPServerTransport) { + t.stateMode = mode + } +} + +type streamableHTTPServerTransport struct { + // ctx is the context that controls the lifecycle of the server + ctx context.Context + cancel context.CancelFunc + + httpSvr *http.Server + + stateMode StateMode + + inFlySend sync.WaitGroup + + receiver serverReceiver + + sessionManager sessionManager + + // options + logger pkg.Logger + mcpEndpoint string // The single MCP endpoint path +} + +type StreamableHTTPHandler struct { + transport *streamableHTTPServerTransport +} + +// HandleMCP handles incoming MCP requests +func (h *StreamableHTTPHandler) HandleMCP() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.transport.handleMCPEndpoint(w, r) + }) +} + +// NewStreamableHTTPServerTransportAndHandler returns transport without starting the HTTP server, +// and returns a Handler for users to start their own HTTP server externally +// eg: +// transport, handler, _ := NewStreamableHTTPServerTransportAndHandler() +// http.Handle("/mcp", handler.HandleMCP()) +// http.ListenAndServe(":8080", nil) +func NewStreamableHTTPServerTransportAndHandler( + opts ...StreamableHTTPServerTransportAndHandlerOption, +) (ServerTransport, *StreamableHTTPHandler, error) { //nolint:whitespace + + ctx, cancel := context.WithCancel(context.Background()) + + t := &streamableHTTPServerTransport{ + ctx: ctx, + cancel: cancel, + stateMode: Stateless, + logger: pkg.DefaultLogger, + } + + for _, opt := range opts { + opt(t) + } + + return t, &StreamableHTTPHandler{transport: t}, nil +} + +func NewStreamableHTTPServerTransport(addr string, opts ...StreamableHTTPServerTransportOption) ServerTransport { + ctx, cancel := context.WithCancel(context.Background()) + + t := &streamableHTTPServerTransport{ + ctx: ctx, + cancel: cancel, + stateMode: Stateless, + logger: pkg.DefaultLogger, + mcpEndpoint: "/mcp", // Default MCP endpoint + } + + for _, opt := range opts { + opt(t) + } + + mux := http.NewServeMux() + mux.HandleFunc(t.mcpEndpoint, t.handleMCPEndpoint) + + t.httpSvr = &http.Server{ + Addr: addr, + Handler: mux, + IdleTimeout: time.Minute, + } + + return t +} + +func (t *streamableHTTPServerTransport) Run() error { + if t.httpSvr == nil { + <-t.ctx.Done() + return nil + } + + if err := t.httpSvr.ListenAndServe(); err != nil { + return fmt.Errorf("failed to start HTTP server: %w", err) + } + return nil +} + +func (t *streamableHTTPServerTransport) Send(ctx context.Context, sessionID string, msg Message) error { + t.inFlySend.Add(1) + defer t.inFlySend.Done() + + select { + case <-t.ctx.Done(): + return t.ctx.Err() + default: + return t.sessionManager.EnqueueMessageForSend(ctx, sessionID, msg) + } +} + +func (t *streamableHTTPServerTransport) SetReceiver(receiver serverReceiver) { + t.receiver = receiver +} + +func (t *streamableHTTPServerTransport) SetSessionManager(manager sessionManager) { + t.sessionManager = manager +} + +func (t *streamableHTTPServerTransport) handleMCPEndpoint(w http.ResponseWriter, r *http.Request) { + defer pkg.RecoverWithFunc(func(_ any) { + t.writeError(w, http.StatusInternalServerError, "Internal server error") + }) + + switch r.Method { + case http.MethodPost: + t.handlePost(w, r) + case http.MethodGet: + t.handleGet(w, r) + case http.MethodDelete: + t.handleDelete(w, r) + default: + t.writeError(w, http.StatusMethodNotAllowed, "Method not allowed") + } +} + +func (t *streamableHTTPServerTransport) handlePost(w http.ResponseWriter, r *http.Request) { + // Validate Accept header + accept := r.Header.Get("Accept") + if accept == "" { + t.writeError(w, http.StatusBadRequest, "Missing Accept header") + return + } + + // Read and process the message + bs, err := io.ReadAll(r.Body) + if err != nil { + t.writeError(w, http.StatusBadRequest, fmt.Sprintf("Invalid request: %v", err)) + return + } + + // Disconnection SHOULD NOT be interpreted as the client canceling its request. + // To cancel, the client SHOULD explicitly send an MCP CancelledNotification. + ctx := pkg.NewCancelShieldContext(r.Context()) + + // For InitializeRequest HTTP response + if t.stateMode == Stateful { + ctx = context.WithValue(ctx, SessionIDForReturnKey{}, &SessionIDForReturn{}) + } + + outputMsgCh, err := t.receiver.Receive(ctx, r.Header.Get(sessionIDHeader), bs) + if err != nil { + if errors.Is(err, pkg.ErrSessionClosed) { + t.writeError(w, http.StatusNotFound, fmt.Sprintf("Failed to receive: %v", err)) + return + } + t.writeError(w, http.StatusBadRequest, fmt.Sprintf("Failed to receive: %v", err)) + return + } + + if outputMsgCh == nil { // reply response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + return + } + + msg := <-outputMsgCh + if len(msg) == 0 { + t.writeError(w, http.StatusInternalServerError, "handle request fail") + return + } + + if t.stateMode == Stateful { + if sid := ctx.Value(SessionIDForReturnKey{}).(*SessionIDForReturn); sid.SessionID != "" { // in server.handleRequestWithInitialize assign + w.Header().Set(sessionIDHeader, sid.SessionID) + } + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + if _, err = w.Write(msg); err != nil { + t.logger.Errorf("streamableHTTPServerTransport post write: %+v", err) + return + } +} + +func (t *streamableHTTPServerTransport) handleGet(w http.ResponseWriter, r *http.Request) { + defer pkg.RecoverWithFunc(func(_ any) { + t.writeError(w, http.StatusInternalServerError, "Internal server error") + }) + + if t.stateMode == Stateless { + t.writeError(w, http.StatusMethodNotAllowed, "server is stateless, not support sse connection") + return + } + + if !strings.Contains(r.Header.Get("Accept"), "text/event-stream") { + t.writeError(w, http.StatusBadRequest, "Must accept text/event-stream") + return + } + + // Set headers for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + // Create flush-supporting writer + flusher, ok := w.(http.Flusher) + if !ok { + t.writeError(w, http.StatusInternalServerError, "Streaming not supported") + return + } + sessionID := r.Header.Get(sessionIDHeader) + if sessionID == "" { + t.writeError(w, http.StatusBadRequest, "Missing Session ID") + flusher.Flush() + return + } + if err := t.sessionManager.OpenMessageQueueForSend(sessionID); err != nil { + t.writeError(w, http.StatusBadRequest, err.Error()) + flusher.Flush() + return + } + w.WriteHeader(http.StatusOK) + flusher.Flush() + + for { + msg, err := t.sessionManager.DequeueMessageForSend(r.Context(), sessionID) + if err != nil { + if errors.Is(err, pkg.ErrSendEOF) { + return + } + t.logger.Debugf("sse connect dequeueMessage err: %+v, sessionID=%s", err.Error(), sessionID) + return + } + + t.logger.Debugf("Sending message: %s", string(msg)) + + if _, err = fmt.Fprintf(w, "data: %s\n\n", msg); err != nil { + t.logger.Errorf("Failed to write message: %v", err) + continue + } + flusher.Flush() + } +} + +func (t *streamableHTTPServerTransport) handleDelete(w http.ResponseWriter, r *http.Request) { + sessionID := r.Header.Get("Mcp-Session-Id") + if sessionID == "" { + t.writeError(w, http.StatusBadRequest, "Missing session ID") + return + } + + t.sessionManager.CloseSession(sessionID) + w.WriteHeader(http.StatusOK) +} + +func (t *streamableHTTPServerTransport) writeError(w http.ResponseWriter, code int, message string) { + if code == http.StatusMethodNotAllowed { + t.logger.Infof("streamableHTTPServerTransport response: code: %d, message: %s", code, message) + } else { + t.logger.Errorf("streamableHTTPServerTransport Error: code: %d, message: %s", code, message) + } + + resp := protocol.NewJSONRPCErrorResponse(nil, protocol.InternalError, message) + bytes, err := json.Marshal(resp) + if err != nil { + t.logger.Errorf("streamableHTTPServerTransport writeError json.Marshal: %v", err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + if _, err := w.Write(bytes); err != nil { + t.logger.Errorf("streamableHTTPServerTransport writeError Write: %v", err) + } +} + +func (t *streamableHTTPServerTransport) Shutdown(userCtx context.Context, serverCtx context.Context) error { + shutdownFunc := func() { + <-serverCtx.Done() + + t.cancel() + + t.inFlySend.Wait() + + t.sessionManager.CloseAllSessions() + } + + if t.httpSvr == nil { + shutdownFunc() + return nil + } + + t.httpSvr.RegisterOnShutdown(shutdownFunc) + + if err := t.httpSvr.Shutdown(userCtx); err != nil { + return fmt.Errorf("failed to shutdown HTTP server: %w", err) + } + + return nil +} diff --git a/transport/streamable_http_test.go b/transport/streamable_http_test.go new file mode 100644 index 0000000..1044299 --- /dev/null +++ b/transport/streamable_http_test.go @@ -0,0 +1,31 @@ +package transport + +import ( + "fmt" + "testing" +) + +func TestStreamableHTTP(t *testing.T) { + var ( + err error + svr ServerTransport + client ClientTransport + ) + + // Get an available port + port, err := getAvailablePort() + if err != nil { + t.Fatalf("Failed to get available port: %v", err) + } + + serverAddr := fmt.Sprintf("127.0.0.1:%d", port) + serverURL := fmt.Sprintf("http://%s/mcp", serverAddr) + + svr = NewStreamableHTTPServerTransport(serverAddr) + + if client, err = NewStreamableHTTPClientTransport(serverURL); err != nil { + t.Fatalf("NewStreamableHTTPClientTransport failed: %v", err) + } + + testTransport(t, client, svr) +} diff --git a/transport/transport.go b/transport/transport.go index c20c592..e7ef5a4 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -67,19 +67,20 @@ type ServerTransport interface { } type serverReceiver interface { - Receive(ctx context.Context, sessionID string, msg []byte) error + Receive(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error) } -type ServerReceiverF func(ctx context.Context, sessionID string, msg []byte) error +type ServerReceiverF func(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error) -func (f ServerReceiverF) Receive(ctx context.Context, sessionID string, msg []byte) error { +func (f ServerReceiverF) Receive(ctx context.Context, sessionID string, msg []byte) (<-chan []byte, error) { return f(ctx, sessionID, msg) } type sessionManager interface { - CreateSession(sessionID string) - SendMessage(ctx context.Context, sessionID string, message []byte) error - GetMessageForSend(ctx context.Context, sessionID string) ([]byte, error) + CreateSession() string + OpenMessageQueueForSend(sessionID string) error + EnqueueMessageForSend(ctx context.Context, sessionID string, message []byte) error + DequeueMessageForSend(ctx context.Context, sessionID string) ([]byte, error) CloseSession(sessionID string) CloseAllSessions() } diff --git a/transport/transport_test.go b/transport/transport_test.go index a7c80c3..e3c3003 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/ThinkInAIXYZ/go-mcp/pkg" ) @@ -17,8 +19,19 @@ func newMockSessionManager() *mockSessionManager { return &mockSessionManager{} } -func (m *mockSessionManager) CreateSession(sessionID string) { +func (m *mockSessionManager) CreateSession() string { + sessionID := uuid.NewString() + m.Store(sessionID, nil) + return sessionID +} + +func (m *mockSessionManager) OpenMessageQueueForSend(sessionID string) error { + _, ok := m.Load(sessionID) + if !ok { + return pkg.ErrLackSession + } m.Store(sessionID, make(chan []byte)) + return nil } func (m *mockSessionManager) IsExistSession(sessionID string) bool { @@ -26,7 +39,7 @@ func (m *mockSessionManager) IsExistSession(sessionID string) bool { return has } -func (m *mockSessionManager) SendMessage(ctx context.Context, sessionID string, message []byte) error { +func (m *mockSessionManager) EnqueueMessageForSend(ctx context.Context, sessionID string, message []byte) error { ch, has := m.Load(sessionID) if !has { return pkg.ErrLackSession @@ -40,7 +53,7 @@ func (m *mockSessionManager) SendMessage(ctx context.Context, sessionID string, } } -func (m *mockSessionManager) GetMessageForSend(ctx context.Context, sessionID string) ([]byte, error) { +func (m *mockSessionManager) DequeueMessageForSend(ctx context.Context, sessionID string) ([]byte, error) { ch, has := m.Load(sessionID) if !has { return nil, pkg.ErrLackSession @@ -75,15 +88,18 @@ func (m *mockSessionManager) CloseAllSessions() { } func testTransport(t *testing.T, client ClientTransport, server ServerTransport) { - msgWithServer := "hello" + testMsg := "hello server" expectedMsgWithServerCh := make(chan string, 1) - server.SetReceiver(ServerReceiverF(func(_ context.Context, _ string, msg []byte) error { + server.SetReceiver(ServerReceiverF(func(_ context.Context, _ string, msg []byte) (<-chan []byte, error) { expectedMsgWithServerCh <- string(msg) - return nil + msgCh := make(chan []byte, 1) + go func() { + msgCh <- msg + }() + return msgCh, nil })) server.SetSessionManager(newMockSessionManager()) - msgWithClient := "hello" expectedMsgWithClientCh := make(chan string, 1) client.SetReceiver(ClientReceiverF(func(_ context.Context, msg []byte) error { expectedMsgWithClientCh <- string(msg) @@ -129,25 +145,15 @@ func testTransport(t *testing.T, client ClientTransport, server ServerTransport) } }() - if err := client.Send(context.Background(), Message(msgWithServer)); err != nil { + if err := client.Send(context.Background(), Message(testMsg)); err != nil { t.Fatalf("client.Send() failed: %v", err) } expectedMsg := <-expectedMsgWithServerCh - if !reflect.DeepEqual(expectedMsg, msgWithServer) { - t.Fatalf("client.Send() got %v, want %v", expectedMsg, msgWithServer) - } - - sessionID := "" - if cli, ok := client.(*sseClientTransport); ok { - sessionID = cli.messageEndpoint.Query().Get("sessionID") + if !reflect.DeepEqual(expectedMsg, testMsg) { + t.Fatalf("client.Send() got %v, want %v", expectedMsg, testMsg) } - - if err := server.Send(context.Background(), sessionID, Message(msgWithClient)); err != nil { - t.Fatalf("server.Send() failed: %v", err) - } - expectedMsg = <-expectedMsgWithClientCh - if !reflect.DeepEqual(expectedMsg, msgWithClient) { - t.Fatalf("server.Send() failed: got %v, want %v", expectedMsg, msgWithClient) + if !reflect.DeepEqual(expectedMsg, testMsg) { + t.Fatalf("server.Send() failed: got %v, want %v", expectedMsg, testMsg) } }