Skip to content

Commit

Permalink
TUN-8861: Add session limiter to TCP session manager
Browse files Browse the repository at this point in the history
## Summary
In order to make cloudflared behavior more predictable and
prevent an exhaustion of resources, we have decided to add
session limits that can be configured by the user. This commit
adds the session limiter to the HTTP/TCP handling path.
For now the limiter is set to run only in unlimited mode.
  • Loading branch information
jcsf committed Jan 20, 2025
1 parent bf4954e commit 8bfe111
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 102 deletions.
16 changes: 12 additions & 4 deletions connection/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@ package connection

import (
"context"
"crypto/rand"
"fmt"
"io"
"math/rand"
"math/big"
"net/http"
"time"

pkgerrors "github.com/pkg/errors"
"github.com/rs/zerolog"

cfdsession "github.com/cloudflare/cloudflared/session"

"github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
Expand Down Expand Up @@ -77,7 +81,7 @@ func (moc *mockOriginProxy) ProxyHTTP(
return wsFlakyEndpoint(w, req)
default:
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
return fmt.Errorf("unknown websocket endpoint %s", req.URL.Path)
}
}
switch req.URL.Path {
Expand All @@ -95,14 +99,17 @@ func (moc *mockOriginProxy) ProxyHTTP(
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
}
return nil

}

func (moc *mockOriginProxy) ProxyTCP(
ctx context.Context,
rwa ReadWriteAcker,
r *TCPRequest,
) error {
if r.CfTraceID == "flow-rate-limited" {
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "tcp flow rate limited")
}

return nil
}

Expand Down Expand Up @@ -178,7 +185,8 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {

wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)

closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
rInt, _ := rand.Int(rand.Reader, big.NewInt(50))
closedAfter := time.Millisecond * time.Duration(rInt.Int64())
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
stream.Pipe(wsConn, originConn, &log)
cancel()
Expand Down
14 changes: 8 additions & 6 deletions connection/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ var (

var (
// pre-generate possible values for res
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared", false)
responseMetaHeaderCfdFlowRateLimited = mustInitRespMetaHeader("cloudflared", true)
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin", false)
)

// HTTPHeader is a custom header struct that expects only ever one value for the header.
Expand All @@ -34,11 +35,12 @@ type HTTPHeader struct {
}

type responseMetaHeader struct {
Source string `json:"src"`
Source string `json:"src"`
FlowRateLimited bool `json:"flow_rate_limited,omitempty"`
}

func mustInitRespMetaHeader(src string) string {
header, err := json.Marshal(responseMetaHeader{Source: src})
func mustInitRespMetaHeader(src string, flowRateLimited bool) string {
header, err := json.Marshal(responseMetaHeader{Source: src, FlowRateLimited: flowRateLimited})
if err != nil {
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
}
Expand Down Expand Up @@ -112,7 +114,7 @@ func SerializeHeaders(h1Headers http.Header) string {
func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
const unableToDeserializeErr = "Unable to deserialize headers"

var deserialized []HTTPHeader
deserialized := make([]HTTPHeader, 0)
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
if len(serializedPair) == 0 {
continue
Expand Down
19 changes: 13 additions & 6 deletions connection/http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/rs/zerolog"
"golang.org/x/net/http2"

cfdsession "github.com/cloudflare/cloudflared/session"

"github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
Expand Down Expand Up @@ -156,7 +158,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.log.Error().Err(requestErr).Msg("failed to serve incoming request")

// WriteErrorResponse will return false if status was already written. we need to abort handler.
if !respWriter.WriteErrorResponse() {
if !respWriter.WriteErrorResponse(requestErr) {
c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent")
panic(http.ErrAbortHandler)
}
Expand Down Expand Up @@ -209,8 +211,9 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l
w: w,
log: log,
}
respWriter.WriteErrorResponse()
return nil, fmt.Errorf("%T doesn't implement http.Flusher", w)
err := fmt.Errorf("%T doesn't implement http.Flusher", w)
respWriter.WriteErrorResponse(err)
return nil, err
}

return &http2RespWriter{
Expand Down Expand Up @@ -295,7 +298,7 @@ func (rp *http2RespWriter) WriteHeader(status int) {
rp.log.Warn().Msg("WriteHeader after hijack")
return
}
rp.WriteRespHeaders(status, rp.respHeaders)
_ = rp.WriteRespHeaders(status, rp.respHeaders)
}

func (rp *http2RespWriter) hijacked() bool {
Expand Down Expand Up @@ -328,12 +331,16 @@ func (rp *http2RespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return conn, readWriter, nil
}

func (rp *http2RespWriter) WriteErrorResponse() bool {
func (rp *http2RespWriter) WriteErrorResponse(err error) bool {
if rp.statusWritten {
return false
}

rp.setResponseMetaHeader(responseMetaHeaderCfd)
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
} else {
rp.setResponseMetaHeader(responseMetaHeaderCfd)
}
rp.w.WriteHeader(http.StatusBadGateway)
rp.statusWritten = true

Expand Down
71 changes: 55 additions & 16 deletions connection/http2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"

"github.com/cloudflare/cloudflared/tracing"

"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
Expand Down Expand Up @@ -65,31 +67,30 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()

edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)

endpoint := fmt.Sprintf("http://localhost:8080/ok")
reqBody := []byte(`{
"version": 2,
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
`)
reader := bytes.NewReader(reqBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, reader)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://localhost:8080/ok", reader)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)

resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
bdy, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
require.NoError(t, err)
assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
cancel()
wg.Wait()

}

func TestServeHTTP(t *testing.T) {
Expand Down Expand Up @@ -134,7 +135,7 @@ func TestServeHTTP(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()

edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
Expand All @@ -153,6 +154,7 @@ func TestServeHTTP(t *testing.T) {
require.NoError(t, err)
require.Equal(t, test.expectedBody, respBody)
}
_ = resp.Body.Close()
if test.isProxyError {
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
} else {
Expand Down Expand Up @@ -281,10 +283,11 @@ func TestServeWS(t *testing.T) {

respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
require.NoError(t, err)
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
require.Equal(t, data, respBody, "expect %s, got %s", string(data), string(respBody))

cancel()
resp := respWriter.Result()
defer resp.Body.Close()
// http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
Expand All @@ -304,7 +307,7 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
serverDone := make(chan struct{})
go func() {
defer close(serverDone)
cfdHTTP2Conn.Serve(ctx)
_ = cfdHTTP2Conn.Serve(ctx)
}()

edgeTransport := http2.Transport{}
Expand All @@ -319,13 +322,16 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
readPipe, writePipe := io.Pipe()
reqCtx, reqCancel := context.WithCancel(ctx)
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
require.NoError(t, err)
assert.NoError(t, err)

req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)

resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
assert.NoError(t, err)
_ = resp.Body.Close()

// http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, http.StatusOK, resp.StatusCode)

wg.Add(1)
go func() {
Expand Down Expand Up @@ -378,7 +384,7 @@ func TestServeControlStream(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
Expand All @@ -391,7 +397,8 @@ func TestServeControlStream(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
edgeHTTP2Conn.RoundTrip(req)
// nolint: bodyclose
_, _ = edgeHTTP2Conn.RoundTrip(req)
}()

<-rpcClientFactory.registered
Expand Down Expand Up @@ -431,7 +438,7 @@ func TestFailRegistration(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
Expand All @@ -442,9 +449,10 @@ func TestFailRegistration(t *testing.T) {
require.NoError(t, err)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadGateway, resp.StatusCode)

assert.NotNil(t, http2Conn.controlStreamErr)
require.Error(t, http2Conn.controlStreamErr)
cancel()
wg.Wait()
}
Expand Down Expand Up @@ -481,7 +489,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
Expand All @@ -494,6 +502,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
// nolint: bodyclose
_, _ = edgeHTTP2Conn.RoundTrip(req)
}()

Expand Down Expand Up @@ -524,6 +533,36 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
})
}

func TestServeTCP_RateLimited(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
http2Conn, edgeConn := newTestHTTP2Connection()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()

edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil)
require.NoError(t, err)
req.Header.Set(InternalTCPProxySrcHeader, "tcp")
req.Header.Set(tracing.TracerContextName, "flow-rate-limited")

resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()

require.Equal(t, http.StatusBadGateway, resp.StatusCode)
require.Equal(t, responseMetaHeaderCfdFlowRateLimited, resp.Header.Get(ResponseMetaHeader))

cancel()
wg.Wait()
}

func benchmarkServeHTTP(b *testing.B, test testRequest) {
http2Conn, edgeConn := newTestHTTP2Connection()

Expand All @@ -532,7 +571,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()

endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
Expand Down
Loading

0 comments on commit 8bfe111

Please sign in to comment.