Skip to content

Feat/add streamable http #126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion client/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"sync/atomic"
Expand Down Expand Up @@ -220,7 +221,7 @@ func (client *Client) sendNotification4Initialized(ctx context.Context) error {
// Responsible for request and response assembly
func (client *Client) callServer(ctx context.Context, method protocol.Method, params protocol.ClientRequest) (json.RawMessage, error) {
if !client.ready.Load() && (method != protocol.Initialize && method != protocol.Ping) {
return nil, fmt.Errorf("client not ready")
return nil, errors.New("callServer: client not ready")
}

requestID := strconv.FormatInt(atomic.AddInt64(&client.requestID, 1), 10)
Expand Down
4 changes: 3 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client
import (
"context"
"fmt"
"sync"
"time"

cmap "github.com/orcaman/concurrent-map/v2"
Expand Down Expand Up @@ -47,7 +48,8 @@ type Client struct {

requestID int64

ready *pkg.AtomicBool
ready *pkg.AtomicBool
initializationMu sync.Mutex

clientInfo *protocol.Implementation
clientCapabilities *protocol.ClientCapabilities
Expand Down
34 changes: 29 additions & 5 deletions client/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package client
import (
"context"
"encoding/json"
"errors"
"fmt"

"github.com/ThinkInAIXYZ/go-mcp/pkg"
"github.com/ThinkInAIXYZ/go-mcp/protocol"
)

Expand All @@ -20,8 +22,13 @@ func (client *Client) sendMsgWithRequest(ctx context.Context, requestID protocol
return err
}

if err := client.transport.Send(ctx, message); err != nil {
return fmt.Errorf("sendRequest: transport send: %w", err)
if err = client.transport.Send(ctx, message); err != nil {
if !errors.Is(err, pkg.ErrSessionClosed) {
return fmt.Errorf("sendRequest: transport send: %w", err)
}
if err = client.againInitialization(ctx); err != nil {
return err
}
}
return nil
}
Expand All @@ -38,7 +45,7 @@ func (client *Client) sendMsgWithResponse(ctx context.Context, requestID protoco
return err
}

if err := client.transport.Send(ctx, message); err != nil {
if err = client.transport.Send(ctx, message); err != nil {
return fmt.Errorf("sendResponse: transport send: %w", err)
}
return nil
Expand All @@ -52,7 +59,7 @@ func (client *Client) sendMsgWithNotification(ctx context.Context, method protoc
return err
}

if err := client.transport.Send(ctx, message); err != nil {
if err = client.transport.Send(ctx, message); err != nil {
return fmt.Errorf("sendNotification: transport send: %w", err)
}
return nil
Expand All @@ -70,8 +77,25 @@ func (client *Client) sendMsgWithError(ctx context.Context, requestID protocol.R
return err
}

if err := client.transport.Send(ctx, message); err != nil {
if err = client.transport.Send(ctx, message); err != nil {
return fmt.Errorf("sendResponse: transport send: %w", err)
}
return nil
}

func (client *Client) againInitialization(ctx context.Context) error {
client.ready.Store(false)

client.initializationMu.Lock()
defer client.initializationMu.Unlock()

if client.ready.Load() {
return nil
}

if _, err := client.initialization(ctx, protocol.NewInitializeRequest(*client.clientInfo, *client.clientCapabilities)); err != nil {
return err
}
client.ready.Store(true)
return nil
}
14 changes: 10 additions & 4 deletions examples/current_time_server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,27 @@ func getTransport() (t transport.ServerTransport) {
addr = "127.0.0.1:8080"
)

flag.StringVar(&mode, "transport", "stdio", "The transport to use, should be \"stdio\" or \"sse\"")
flag.StringVar(&mode, "transport", "stdio", "The transport to use, should be \"stdio\" or \"sse\" or \"streamable_http\"")
flag.Parse()

if mode == "stdio" {
switch mode {
case "stdio":
log.Println("start current time mcp server with stdio transport")
t = transport.NewStdioServerTransport()
} else {
case "sse":
log.Printf("start current time mcp server with sse transport, listen %s", addr)
t, _ = transport.NewSSEServerTransport(addr)
case "streamable_http":
log.Printf("start current time mcp server with streamable_http transport, listen %s", addr)
t = transport.NewStreamableHTTPServerTransport(addr)
default:
panic(fmt.Errorf("unknown mode: %s", mode))
}

return t
}

func currentTime(request *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
req := new(currentTimeReq)
if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil {
return nil, err
Expand Down
19 changes: 13 additions & 6 deletions examples/everything/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,32 @@ func main() {
}

func getTransport() (t transport.ServerTransport) {
mode := ""
port := ""
flag.StringVar(&mode, "transport", "stdio", "The transport to use, should be \"stdio\" or \"sse\"")
var mode, port, stateMode string
flag.StringVar(&mode, "transport", "streamable_http", "The transport to use, should be \"stdio\" or \"sse\" or \"streamable_http\"")
flag.StringVar(&port, "port", "8080", "sse server address")
flag.StringVar(&stateMode, "state_mode", "stateless", "streamable_http server state mode, should be \"stateless\" or \"stateful\"")
flag.Parse()

if mode == "stdio" {
switch mode {
case "stdio":
log.Println("start current time mcp server with stdio transport")
t = transport.NewStdioServerTransport()
} else {
case "sse":
addr := fmt.Sprintf("127.0.0.1:%s", port)
log.Printf("start current time mcp server with sse transport, listen %s", addr)
t, _ = transport.NewSSEServerTransport(addr)
case "streamable_http":
addr := fmt.Sprintf("127.0.0.1:%s", port)
log.Printf("start current time mcp server with streamable_http transport, listen %s", addr)
t = transport.NewStreamableHTTPServerTransport(addr, transport.WithStreamableHTTPServerTransportOptionStateMode(transport.StateMode(stateMode)))
default:
panic(fmt.Errorf("unknown mode: %s", mode))
}

return t
}

func currentTime(request *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
req := new(currentTimeReq)
if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion examples/http_handler/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func main() {
}
}

func currentTime(request *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
func currentTime(_ context.Context, request *protocol.CallToolRequest) (*protocol.CallToolResult, error) {
req := new(currentTimeReq)
if err := protocol.VerifyAndUnmarshal(request.RawArguments, &req); err != nil {
return nil, err
Expand Down
18 changes: 18 additions & 0 deletions pkg/atomic_bool.go → pkg/atomic.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,21 @@ func (b *AtomicBool) Store(value bool) {
func (b *AtomicBool) Load() bool {
return b.b.Load().(bool)
}

type AtomicString struct {
b atomic.Value
}

func NewAtomicString() *AtomicString {
b := &AtomicString{}
b.b.Store("")
return b
}

func (b *AtomicString) Store(value string) {
b.b.Store(value)
}

func (b *AtomicString) Load() string {
return b.b.Load().(string)
}
26 changes: 26 additions & 0 deletions pkg/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package pkg

import (
"context"
"time"
)

type CancelShieldContext struct {
context.Context
}

func NewCancelShieldContext(ctx context.Context) context.Context {
return CancelShieldContext{Context: ctx}
}

func (v CancelShieldContext) Deadline() (deadline time.Time, ok bool) {
return
}

func (v CancelShieldContext) Done() <-chan struct{} {
return nil
}

func (v CancelShieldContext) Err() error {
return nil
}
1 change: 1 addition & 0 deletions pkg/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ var (
ErrJSONUnmarshal = errors.New("json unmarshal error")
ErrSessionHasNotInitialized = errors.New("the session has not been initialized")
ErrLackSession = errors.New("lack session")
ErrSessionClosed = errors.New("session closed")
ErrSendEOF = errors.New("send EOF")
)

Expand Down
9 changes: 4 additions & 5 deletions protocol/schema_generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func compareInputSchema(a, b *InputSchema) bool {
return false
}

// 比较Required字段
// compare required field
if len(a.Required) != len(b.Required) {
return false
}
Expand Down Expand Up @@ -331,12 +331,11 @@ func compareProperty(a, b *Property) bool {
return false
}

// 比较Items字段
// compare Items field
if !compareProperty(a.Items, b.Items) {
return false
}

// 比较Properties字段
// compare Properties field
if len(a.Properties) != len(b.Properties) {
return false
}
Expand All @@ -350,7 +349,7 @@ func compareProperty(a, b *Property) bool {
}
}

// 比较Required字段
// compare Required field比
if len(a.Required) != len(b.Required) {
return false
}
Expand Down
1 change: 0 additions & 1 deletion protocol/schema_validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ func Test_Validate(t *testing.T) {
},
Required: []string{"string"},
}}, false},
// 嵌套匿名结构体测试
{"nested anonymous struct", args{data: map[string]any{
"user": map[string]any{
"name": "test",
Expand Down
7 changes: 6 additions & 1 deletion protocol/types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package protocol

const Version = "2024-11-05"
const Version = "2025-03-26"

var SupportedVersion = map[string]struct{}{
"2024-11-05": {},
"2025-03-26": {},
}

// Method represents the JSON-RPC method name
type Method string
Expand Down
4 changes: 2 additions & 2 deletions server/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (server *Server) SendNotification4ResourcesUpdated(ctx context.Context, not
func (server *Server) callClient(ctx context.Context, sessionID string, method protocol.Method, params protocol.ServerRequest) (json.RawMessage, error) {
session, ok := server.sessionManager.GetSession(sessionID)
if !ok {
return nil, pkg.ErrLackSession
return nil, fmt.Errorf("callClient: %w", pkg.ErrLackSession)
}

requestID := strconv.FormatInt(session.IncRequestID(), 10)
Expand All @@ -107,7 +107,7 @@ func (server *Server) callClient(ctx context.Context, sessionID string, method p
defer session.GetReqID2respChan().Remove(requestID)

if err := server.sendMsgWithRequest(ctx, sessionID, requestID, method, params); err != nil {
return nil, err
return nil, fmt.Errorf("callClient: %w", err)
}

select {
Expand Down
Loading
Loading