From f55df186e8657d8271b695a1afd5d8fae831bfad Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 30 Oct 2021 14:51:12 +0200 Subject: [PATCH] limit the number of concurrent incoming streams --- mux.go | 6 ++++++ session.go | 23 +++++++++++++++++----- session_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 5 deletions(-) diff --git a/mux.go b/mux.go index 49bd51c..be13e11 100644 --- a/mux.go +++ b/mux.go @@ -31,6 +31,11 @@ type Config struct { // an expectation that things will move along quickly. ConnectionWriteTimeout time.Duration + // MaxIncomingStreams is maximum number of concurrent incoming streams + // that we accept. If the peer tries to open more streams, those will be + // reset immediately. + MaxIncomingStreams uint32 + // InitialStreamWindowSize is used to control the initial // window size that we allow for a stream. InitialStreamWindowSize uint32 @@ -65,6 +70,7 @@ func DefaultConfig() *Config { EnableKeepAlive: true, KeepAliveInterval: 30 * time.Second, ConnectionWriteTimeout: 10 * time.Second, + MaxIncomingStreams: 1000, InitialStreamWindowSize: initialStreamWindow, MaxStreamWindowSize: maxStreamWindow, LogOutput: os.Stderr, diff --git a/session.go b/session.go index 218d345..1027715 100644 --- a/session.go +++ b/session.go @@ -15,7 +15,7 @@ import ( "sync/atomic" "time" - "github.com/libp2p/go-buffer-pool" + pool "github.com/libp2p/go-buffer-pool" ) // Session is used to wrap a reliable ordered connection and to @@ -55,9 +55,10 @@ type Session struct { // streams maps a stream id to a stream, and inflight has an entry // for any outgoing stream that has not yet been established. Both are // protected by streamLock. - streams map[uint32]*Stream - inflight map[uint32]struct{} - streamLock sync.Mutex + numIncomingStreams uint32 + streams map[uint32]*Stream + inflight map[uint32]struct{} + streamLock sync.Mutex // synCh acts like a semaphore. It is sized to the AcceptBacklog which // is assumed to be symmetric between the client and server. This allows @@ -735,6 +736,15 @@ func (s *Session) incomingStream(id uint32) error { return ErrDuplicateStream } + if s.numIncomingStreams >= s.config.MaxIncomingStreams { + // too many active streams at the same time + s.logger.Printf("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset") + delete(s.streams, id) + hdr := encode(typeWindowUpdate, flagRST, id, 0) + return s.sendMsg(hdr, nil, nil) + } + + s.numIncomingStreams++ // Register the stream s.streams[id] = stream @@ -744,7 +754,7 @@ func (s *Session) incomingStream(id uint32) error { return nil default: // Backlog exceeded! RST the stream - s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") + s.logger.Printf("[WARN] yamux: backlog exceeded, forcing stream reset") delete(s.streams, id) hdr := encode(typeWindowUpdate, flagRST, id, 0) return s.sendMsg(hdr, nil, nil) @@ -764,6 +774,9 @@ func (s *Session) closeStream(id uint32) { } delete(s.inflight, id) } + if s.client == (id%2 == 0) { + s.numIncomingStreams-- + } delete(s.streams, id) s.streamLock.Unlock() } diff --git a/session_test.go b/session_test.go index b857b14..085e7f8 100644 --- a/session_test.go +++ b/session_test.go @@ -1732,3 +1732,55 @@ func TestInitialStreamWindow(t *testing.T) { } } } + +func TestMaxIncomingStreams(t *testing.T) { + const maxIncomingStreams = 5 + conn1, conn2 := testConn() + client, err := Client(conn1, DefaultConfig()) + require.NoError(t, err) + defer client.Close() + + conf := DefaultConfig() + conf.MaxIncomingStreams = maxIncomingStreams + server, err := Server(conn2, conf) + require.NoError(t, err) + defer server.Close() + + strChan := make(chan *Stream, maxIncomingStreams) + go func() { + defer close(strChan) + for { + str, err := server.AcceptStream() + if err != nil { + return + } + _, err = str.Write([]byte("foobar")) + require.NoError(t, err) + strChan <- str + } + }() + + for i := 0; i < maxIncomingStreams; i++ { + str, err := client.OpenStream(context.Background()) + require.NoError(t, err) + _, err = str.Read(make([]byte, 6)) + require.NoError(t, err) + require.NoError(t, str.CloseWrite()) + } + // The server now has maxIncomingStreams incoming streams. + // It will now reset the next stream that is opened. + str, err := client.OpenStream(context.Background()) + require.NoError(t, err) + str.SetDeadline(time.Now().Add(time.Second)) + _, err = str.Read([]byte{0}) + require.EqualError(t, err, "stream reset") + + // Now close one of the streams. + // This should then allow the client to open a new stream. + require.NoError(t, (<-strChan).Close()) + str, err = client.OpenStream(context.Background()) + require.NoError(t, err) + str.SetDeadline(time.Now().Add(time.Second)) + _, err = str.Read([]byte{0}) + require.NoError(t, err) +}