Skip to content

Commit

Permalink
TUN-8861: Add session limiter to UDP session manager
Browse files Browse the repository at this point in the history
## Summary
In order to make cloudflared behavior more predictable and
prevent an exhaustion of resources, we have decided to add
session limits that can be configured by the user. This first
commit introduces the session limiter and adds it to the UDP
handling path. For now the limiter is set to run only in
unlimited mode.
  • Loading branch information
jcsf committed Jan 20, 2025
1 parent 8918b67 commit bf4954e
Show file tree
Hide file tree
Showing 66 changed files with 3,409 additions and 1,184 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,7 @@ fmt-check:
.PHONY: lint
lint:
@golangci-lint run

.PHONY: mocks
mocks:
go generate mocks/mockgen.go
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ For example, as of January 2023 Cloudflare will support cloudflared version 2023
- [capnpc-go](https://pkg.go.dev/zombiezen.com/go/capnproto2/capnpc-go)
- [goimports](https://pkg.go.dev/golang.org/x/tools/cmd/goimports)
- [golangci-lint](https://github.com/golangci/golangci-lint)
- [gomocks](https://pkg.go.dev/go.uber.org/mock)

### Build
To build cloudflared locally run `make cloudflared`
Expand All @@ -76,3 +77,6 @@ To locally run the tests run `make test`

### Linting
To format the code and keep a good code quality use `make fmt` and `make lint`

### Mocks
After changes on interfaces you might need to regenerate the mocks, so run `make mock`
45 changes: 30 additions & 15 deletions connection/quic_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/net/nettest"

cfdsession "github.com/cloudflare/cloudflared/session"

"github.com/cloudflare/cloudflared/datagramsession"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/packet"
Expand All @@ -53,7 +55,8 @@ var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
func TestQUICServer(t *testing.T) {
// This is simply a sample websocket frame message.
wsBuf := &bytes.Buffer{}
wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
err := wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
require.NoError(t, err)

var tests = []struct {
desc string
Expand Down Expand Up @@ -158,17 +161,19 @@ func TestQUICServer(t *testing.T) {

serverDone := make(chan struct{})
go func() {
// nolint: testifylint
quicServer(
ctx, t, quicListener, test.dest, test.connectionType, test.metadata, test.message, test.expectedResponse,
)
close(serverDone)
}()

// nolint: gosec
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(i))

connDone := make(chan struct{})
go func() {
tunnelConn.Serve(ctx)
_ = tunnelConn.Serve(ctx)
close(connDone)
}()

Expand Down Expand Up @@ -254,14 +259,14 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, tr *tracing.T
case "/ok":
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
case "/slow_echo_body":
time.Sleep(5)
time.Sleep(5 * time.Nanosecond)
fallthrough
case "/echo_body":
resp := &http.Response{
StatusCode: http.StatusOK,
}
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
io.Copy(w, r.Body)
_, _ = io.Copy(w, r.Body)
case "/error":
return fmt.Errorf("Failed to proxy to origin")
default:
Expand Down Expand Up @@ -493,16 +498,16 @@ func TestBuildHTTPRequest(t *testing.T) {
test := test // capture range variable
t.Run(test.name, func(t *testing.T) {
req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body, 0, &log)
assert.NoError(t, err)
require.NoError(t, err)
test.req = test.req.WithContext(req.Context())
assert.Equal(t, test.req, req.Request)
require.Equal(t, test.req, req.Request)
})
}
}

func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
rwa.AckConnection("")
io.Copy(rwa, rwa)
_ = rwa.AckConnection("")
_, _ = io.Copy(rwa, rwa)
return nil
}

Expand All @@ -520,16 +525,19 @@ func TestServeUDPSession(t *testing.T) {
edgeQUICSessionChan := make(chan quic.Connection)
go func() {
earlyListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig)
require.NoError(t, err)
assert.NoError(t, err)

edgeQUICSession, err := earlyListener.Accept(ctx)
require.NoError(t, err)
assert.NoError(t, err)

edgeQUICSessionChan <- edgeQUICSession
}()

// Random index to avoid reusing port
tunnelConn, datagramConn := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), 28)
go tunnelConn.Serve(ctx)
go func() {
_ = tunnelConn.Serve(ctx)
}()

edgeQUICSession := <-edgeQUICSessionChan

Expand All @@ -545,14 +553,14 @@ func TestNopCloserReadWriterCloseBeforeEOF(t *testing.T) {

n, err := readerWriter.Read(buffer)
require.NoError(t, err)
require.Equal(t, n, 5)
require.Equal(t, 5, n)

// close
require.NoError(t, readerWriter.Close())

// read should get error
n, err = readerWriter.Read(buffer)
require.Equal(t, n, 0)
require.Equal(t, 0, n)
require.Equal(t, err, fmt.Errorf("closed by handler"))
}

Expand All @@ -562,7 +570,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {

n, err := readerWriter.Read(buffer)
require.NoError(t, err)
require.Equal(t, n, 9)
require.Equal(t, 9, n)

// force another read to read eof
_, err = readerWriter.Read(buffer)
Expand All @@ -573,7 +581,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {

// read should get EOF still
n, err = readerWriter.Read(buffer)
require.Equal(t, n, 0)
require.Equal(t, 0, n)
require.Equal(t, err, io.EOF)
}

Expand Down Expand Up @@ -669,6 +677,7 @@ func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQ
unregisterReason: expectedReason,
calledUnregisterChan: unregisterFromEdgeChan,
}
// nolint: testifylint
go runRPCServer(ctx, edgeQUICSession, sessionRPCServer, nil, t)

<-unregisterFromEdgeChan
Expand Down Expand Up @@ -729,6 +738,7 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI

func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) (TunnelConnection, *datagramV2Connection) {
tlsClientConfig := &tls.Config{
// nolint: gosec
InsecureSkipVerify: true,
NextProtos: []string{"argotunnel"},
}
Expand All @@ -747,6 +757,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
index,
&log,
)
require.NoError(t, err)

// Start a session manager for the connection
sessionDemuxChan := make(chan *packet.Session, 4)
Expand All @@ -757,7 +768,9 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)

datagramConn := &datagramV2Connection{
conn,
index,
sessionManager,
cfdsession.NewLimiter(0),
datagramMuxer,
packetRouter,
15 * time.Second,
Expand Down Expand Up @@ -796,6 +809,7 @@ func (m *mockReaderNoopWriter) Close() error {

// GenerateTLSConfig sets up a bare-bones TLS config for a QUIC server
func GenerateTLSConfig() *tls.Config {
// nolint: gosec
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
Expand All @@ -812,6 +826,7 @@ func GenerateTLSConfig() *tls.Config {
if err != nil {
panic(err)
}
// nolint: gosec
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{"argotunnel"},
Expand Down
45 changes: 35 additions & 10 deletions connection/quic_datagram_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ import (
"time"

"github.com/google/uuid"
pkgerrors "github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"

cfdsession "github.com/cloudflare/cloudflared/session"

"github.com/cloudflare/cloudflared/datagramsession"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management"
Expand All @@ -38,10 +41,14 @@ type DatagramSessionHandler interface {
}

type datagramV2Connection struct {
conn quic.Connection
conn quic.Connection
index uint8

// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
sessionManager datagramsession.Manager
// sessionLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
sessionLimiter cfdsession.Limiter

// datagramMuxer mux/demux datagrams from quic connection
datagramMuxer *cfdquic.DatagramMuxerV2
packetRouter *ingress.PacketRouter
Expand All @@ -58,6 +65,7 @@ func NewDatagramV2Connection(ctx context.Context,
index uint8,
rpcTimeout time.Duration,
streamWriteTimeout time.Duration,
sessionLimiter cfdsession.Limiter,
logger *zerolog.Logger,
) DatagramSessionHandler {
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
Expand All @@ -66,13 +74,15 @@ func NewDatagramV2Connection(ctx context.Context,
packetRouter := ingress.NewPacketRouter(icmpRouter, datagramMuxer, index, logger)

return &datagramV2Connection{
conn,
sessionManager,
datagramMuxer,
packetRouter,
rpcTimeout,
streamWriteTimeout,
logger,
conn: conn,
index: index,
sessionManager: sessionManager,
sessionLimiter: sessionLimiter,
datagramMuxer: datagramMuxer,
packetRouter: packetRouter,
rpcTimeout: rpcTimeout,
streamWriteTimeout: streamWriteTimeout,
logger: logger,
}
}

Expand Down Expand Up @@ -109,12 +119,23 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)),
))
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()

// Try to start a new session
if err := q.sessionLimiter.Acquire(management.UDP.String()); err != nil {
log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)

err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
tracing.EndWithErrorStatus(registerSpan, err)
return nil, err
}

// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
originProxy, err := ingress.DialUDP(dstIP, dstPort)
if err != nil {
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
tracing.EndWithErrorStatus(registerSpan, err)
q.sessionLimiter.Release()
return nil, err
}
registerSpan.SetAttributes(
Expand All @@ -127,10 +148,14 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
originProxy.Close()
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
tracing.EndWithErrorStatus(registerSpan, err)
q.sessionLimiter.Release()
return nil, err
}

go q.serveUDPSession(session, closeAfterIdleHint)
go func() {
defer q.sessionLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
q.serveUDPSession(session, closeAfterIdleHint)
}()

log.Debug().
Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).
Expand Down Expand Up @@ -170,7 +195,7 @@ func (q *datagramV2Connection) serveUDPSession(session *datagramsession.Session,

// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge
func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) {
q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
_ = q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
quicStream, err := q.conn.OpenStream()
if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection
Expand Down
Loading

0 comments on commit bf4954e

Please sign in to comment.