diff --git a/connection/connection_test.go b/connection/connection_test.go index ffd483d24db..24534ec8aea 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -2,14 +2,18 @@ package connection import ( "context" + "crypto/rand" "fmt" "io" - "math/rand" + "math/big" "net/http" "time" + pkgerrors "github.com/pkg/errors" "github.com/rs/zerolog" + cfdsession "github.com/cloudflare/cloudflared/session" + "github.com/cloudflare/cloudflared/stream" "github.com/cloudflare/cloudflared/tracing" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -77,7 +81,7 @@ func (moc *mockOriginProxy) ProxyHTTP( return wsFlakyEndpoint(w, req) default: originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found")) - return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path) + return fmt.Errorf("unknown websocket endpoint %s", req.URL.Path) } } switch req.URL.Path { @@ -95,7 +99,6 @@ func (moc *mockOriginProxy) ProxyHTTP( originRespEndpoint(w, http.StatusNotFound, []byte("page not found")) } return nil - } func (moc *mockOriginProxy) ProxyTCP( @@ -103,6 +106,10 @@ func (moc *mockOriginProxy) ProxyTCP( rwa ReadWriteAcker, r *TCPRequest, ) error { + if r.CfTraceID == "flow-rate-limited" { + return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "tcp flow rate limited") + } + return nil } @@ -178,7 +185,8 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error { wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log) - closedAfter := time.Millisecond * time.Duration(rand.Intn(50)) + rInt, _ := rand.Int(rand.Reader, big.NewInt(50)) + closedAfter := time.Millisecond * time.Duration(rInt.Int64()) originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)} stream.Pipe(wsConn, originConn, &log) cancel() diff --git a/connection/header.go b/connection/header.go index 516c5df6db9..269e56a7830 100644 --- a/connection/header.go +++ b/connection/header.go @@ -22,8 +22,9 @@ var ( var ( // pre-generate possible values for res - responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared") - responseMetaHeaderOrigin = mustInitRespMetaHeader("origin") + responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared", false) + responseMetaHeaderCfdFlowRateLimited = mustInitRespMetaHeader("cloudflared", true) + responseMetaHeaderOrigin = mustInitRespMetaHeader("origin", false) ) // HTTPHeader is a custom header struct that expects only ever one value for the header. @@ -34,11 +35,12 @@ type HTTPHeader struct { } type responseMetaHeader struct { - Source string `json:"src"` + Source string `json:"src"` + FlowRateLimited bool `json:"flow_rate_limited,omitempty"` } -func mustInitRespMetaHeader(src string) string { - header, err := json.Marshal(responseMetaHeader{Source: src}) +func mustInitRespMetaHeader(src string, flowRateLimited bool) string { + header, err := json.Marshal(responseMetaHeader{Source: src, FlowRateLimited: flowRateLimited}) if err != nil { panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err)) } @@ -112,7 +114,7 @@ func SerializeHeaders(h1Headers http.Header) string { func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) { const unableToDeserializeErr = "Unable to deserialize headers" - var deserialized []HTTPHeader + deserialized := make([]HTTPHeader, 0) for _, serializedPair := range strings.Split(serializedHeaders, ";") { if len(serializedPair) == 0 { continue diff --git a/connection/http2.go b/connection/http2.go index aee9d9dab83..daf395901e1 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -16,6 +16,8 @@ import ( "github.com/rs/zerolog" "golang.org/x/net/http2" + cfdsession "github.com/cloudflare/cloudflared/session" + "github.com/cloudflare/cloudflared/tracing" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) @@ -156,7 +158,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.log.Error().Err(requestErr).Msg("failed to serve incoming request") // WriteErrorResponse will return false if status was already written. we need to abort handler. - if !respWriter.WriteErrorResponse() { + if !respWriter.WriteErrorResponse(requestErr) { c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent") panic(http.ErrAbortHandler) } @@ -209,8 +211,9 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l w: w, log: log, } - respWriter.WriteErrorResponse() - return nil, fmt.Errorf("%T doesn't implement http.Flusher", w) + err := fmt.Errorf("%T doesn't implement http.Flusher", w) + respWriter.WriteErrorResponse(err) + return nil, err } return &http2RespWriter{ @@ -295,7 +298,7 @@ func (rp *http2RespWriter) WriteHeader(status int) { rp.log.Warn().Msg("WriteHeader after hijack") return } - rp.WriteRespHeaders(status, rp.respHeaders) + _ = rp.WriteRespHeaders(status, rp.respHeaders) } func (rp *http2RespWriter) hijacked() bool { @@ -328,12 +331,16 @@ func (rp *http2RespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return conn, readWriter, nil } -func (rp *http2RespWriter) WriteErrorResponse() bool { +func (rp *http2RespWriter) WriteErrorResponse(err error) bool { if rp.statusWritten { return false } - rp.setResponseMetaHeader(responseMetaHeaderCfd) + if errors.Is(err, cfdsession.ErrTooManyActiveSessions) { + rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited) + } else { + rp.setResponseMetaHeader(responseMetaHeaderCfd) + } rp.w.WriteHeader(http.StatusBadGateway) rp.statusWritten = true diff --git a/connection/http2_test.go b/connection/http2_test.go index 92665688966..d2045600a39 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -20,6 +20,8 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/net/http2" + "github.com/cloudflare/cloudflared/tracing" + "github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) @@ -65,19 +67,18 @@ func TestHTTP2ConfigurationSet(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - http2Conn.Serve(ctx) + _ = http2Conn.Serve(ctx) }() edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) require.NoError(t, err) - endpoint := fmt.Sprintf("http://localhost:8080/ok") reqBody := []byte(`{ "version": 2, "config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}} `) reader := bytes.NewReader(reqBody) - req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, reader) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://localhost:8080/ok", reader) require.NoError(t, err) req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate) @@ -85,11 +86,11 @@ func TestHTTP2ConfigurationSet(t *testing.T) { require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) bdy, err := io.ReadAll(resp.Body) + defer resp.Body.Close() require.NoError(t, err) assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy)) cancel() wg.Wait() - } func TestServeHTTP(t *testing.T) { @@ -134,7 +135,7 @@ func TestServeHTTP(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - http2Conn.Serve(ctx) + _ = http2Conn.Serve(ctx) }() edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) @@ -153,6 +154,7 @@ func TestServeHTTP(t *testing.T) { require.NoError(t, err) require.Equal(t, test.expectedBody, respBody) } + _ = resp.Body.Close() if test.isProxyError { require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader)) } else { @@ -281,10 +283,11 @@ func TestServeWS(t *testing.T) { respBody, err := wsutil.ReadServerBinary(respWriter.RespBody()) require.NoError(t, err) - require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) + require.Equal(t, data, respBody, "expect %s, got %s", string(data), string(respBody)) cancel() resp := respWriter.Result() + defer resp.Body.Close() // http2RespWriter should rewrite status 101 to 200 require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader)) @@ -304,7 +307,7 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) { serverDone := make(chan struct{}) go func() { defer close(serverDone) - cfdHTTP2Conn.Serve(ctx) + _ = cfdHTTP2Conn.Serve(ctx) }() edgeTransport := http2.Transport{} @@ -319,13 +322,16 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) { readPipe, writePipe := io.Pipe() reqCtx, reqCancel := context.WithCancel(ctx) req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe) - require.NoError(t, err) + assert.NoError(t, err) + req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade) resp, err := edgeHTTP2Conn.RoundTrip(req) - require.NoError(t, err) + assert.NoError(t, err) + _ = resp.Body.Close() + // http2RespWriter should rewrite status 101 to 200 - require.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, http.StatusOK, resp.StatusCode) wg.Add(1) go func() { @@ -378,7 +384,7 @@ func TestServeControlStream(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - http2Conn.Serve(ctx) + _ = http2Conn.Serve(ctx) }() req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) @@ -391,7 +397,8 @@ func TestServeControlStream(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - edgeHTTP2Conn.RoundTrip(req) + // nolint: bodyclose + _, _ = edgeHTTP2Conn.RoundTrip(req) }() <-rpcClientFactory.registered @@ -431,7 +438,7 @@ func TestFailRegistration(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - http2Conn.Serve(ctx) + _ = http2Conn.Serve(ctx) }() req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) @@ -442,9 +449,10 @@ func TestFailRegistration(t *testing.T) { require.NoError(t, err) resp, err := edgeHTTP2Conn.RoundTrip(req) require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusBadGateway, resp.StatusCode) - assert.NotNil(t, http2Conn.controlStreamErr) + require.Error(t, http2Conn.controlStreamErr) cancel() wg.Wait() } @@ -481,7 +489,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - http2Conn.Serve(ctx) + _ = http2Conn.Serve(ctx) }() req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) @@ -494,6 +502,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) { wg.Add(1) go func() { defer wg.Done() + // nolint: bodyclose _, _ = edgeHTTP2Conn.RoundTrip(req) }() @@ -524,6 +533,36 @@ func TestGracefulShutdownHTTP2(t *testing.T) { }) } +func TestServeTCP_RateLimited(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + http2Conn, edgeConn := newTestHTTP2Connection() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _ = http2Conn.Serve(ctx) + }() + + edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil) + require.NoError(t, err) + req.Header.Set(InternalTCPProxySrcHeader, "tcp") + req.Header.Set(tracing.TracerContextName, "flow-rate-limited") + + resp, err := edgeHTTP2Conn.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusBadGateway, resp.StatusCode) + require.Equal(t, responseMetaHeaderCfdFlowRateLimited, resp.Header.Get(ResponseMetaHeader)) + + cancel() + wg.Wait() +} + func benchmarkServeHTTP(b *testing.B, test testRequest) { http2Conn, edgeConn := newTestHTTP2Connection() @@ -532,7 +571,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) { wg.Add(1) go func() { defer wg.Done() - http2Conn.Serve(ctx) + _ = http2Conn.Serve(ctx) }() endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint) diff --git a/connection/quic_connection.go b/connection/quic_connection.go index 7a20e15aec8..7f22b2a4171 100644 --- a/connection/quic_connection.go +++ b/connection/quic_connection.go @@ -17,6 +17,8 @@ import ( "github.com/rs/zerolog" "golang.org/x/sync/errgroup" + cfdsession "github.com/cloudflare/cloudflared/session" + cfdquic "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -108,7 +110,6 @@ func (q *quicConnection) Serve(ctx context.Context) error { } cancel() return err - }) errGroup.Go(func() error { defer cancel() @@ -129,7 +130,7 @@ func (q *quicConnection) serveControlStream(ctx context.Context, controlStream q // Close the connection with no errors specified. func (q *quicConnection) Close() { - q.conn.CloseWithError(0, "") + _ = q.conn.CloseWithError(0, "") } func (q *quicConnection) acceptStream(ctx context.Context) error { @@ -182,7 +183,13 @@ func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.R return err } - if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil { + var metadata []pogs.Metadata + // Check the type of error that was throw and add metadata that will help identify it on OTD. + if errors.Is(err, cfdsession.ErrTooManyActiveSessions) { + metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedKey) + } + + if writeRespErr := stream.WriteConnectResponseData(err, metadata...); writeRespErr != nil { return writeRespErr } } @@ -278,7 +285,7 @@ func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) func (hrw *httpResponseAdapter) Write(p []byte) (int, error) { // Make sure to send WriteHeader response if not called yet if !hrw.connectResponseSent { - hrw.WriteRespHeaders(http.StatusOK, hrw.headers) + _ = hrw.WriteRespHeaders(http.StatusOK, hrw.headers) } return hrw.RequestServerStream.Write(p) } @@ -291,7 +298,7 @@ func (hrw *httpResponseAdapter) Header() http.Header { func (hrw *httpResponseAdapter) Flush() {} func (hrw *httpResponseAdapter) WriteHeader(status int) { - hrw.WriteRespHeaders(status, hrw.headers) + _ = hrw.WriteRespHeaders(status, hrw.headers) } func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { @@ -304,7 +311,7 @@ func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { } func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { - hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) + _ = hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) } func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error { diff --git a/connection/quic_connection_test.go b/connection/quic_connection_test.go index 80a4c09fea2..b9db4e67204 100644 --- a/connection/quic_connection_test.go +++ b/connection/quic_connection_test.go @@ -8,6 +8,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/pem" + "errors" "fmt" "io" "math/big" @@ -21,7 +22,7 @@ import ( "github.com/gobwas/ws/wsutil" "github.com/google/uuid" - "github.com/pkg/errors" + pkgerrors "github.com/pkg/errors" "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -506,6 +507,10 @@ func TestBuildHTTPRequest(t *testing.T) { } func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error { + if tcpRequest.Dest == "rate-limit-me" { + return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "failed tcp stream") + } + _ = rwa.AckConnection("") _, _ = io.Copy(rwa, rwa) return nil @@ -597,6 +602,59 @@ func TestCreateUDPConnReuseSourcePort(t *testing.T) { } } +// TestTCPProxy_FlowRateLimited tests if the pogs.ConnectResponse returns the expected error and metadata, when a +// new flow is rate limited. +func TestTCPProxy_FlowRateLimited(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + // Start a UDP Listener for QUIC. + udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + + udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr) + require.NoError(t, err) + defer udpListener.Close() + + quicTransport := &quic.Transport{Conn: udpListener, ConnectionIDLength: 16} + quicListener, err := quicTransport.Listen(testTLSServerConfig, testQUICConfig) + require.NoError(t, err) + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + + session, err := quicListener.Accept(ctx) + assert.NoError(t, err) + + quicStream, err := session.OpenStreamSync(context.Background()) + assert.NoError(t, err) + stream := cfdquic.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log) + + reqClientStream := rpcquic.RequestClientStream{ReadWriteCloser: stream} + err = reqClientStream.WriteConnectRequestData("rate-limit-me", pogs.ConnectionTypeTCP) + assert.NoError(t, err) + + response, err := reqClientStream.ReadConnectResponseData() + assert.NoError(t, err) + + // Got Rate Limited + assert.NotEmpty(t, response.Error) + assert.Contains(t, response.Metadata, pogs.ErrorFlowConnectRateLimitedKey) + }() + + tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(0)) + + connDone := make(chan struct{}) + go func() { + defer close(connDone) + _ = tunnelConn.Serve(ctx) + }() + + <-serverDone + cancel() + <-connDone +} + func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPort) { logger := zerolog.Nop() conn, err := createUDPConnForConnIndex(0, nil, edgeIP, &logger) diff --git a/orchestration/orchestrator.go b/orchestration/orchestrator.go index a800db301c1..1dd25f77a3b 100644 --- a/orchestration/orchestrator.go +++ b/orchestration/orchestrator.go @@ -141,7 +141,7 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil { return errors.Wrap(err, "failed to start origin") } - proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.config.WriteTimeout, o.log) + proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.sessionLimiter, o.config.WriteTimeout, o.log) o.proxy.Store(proxy) o.config.Ingress = &ingressRules o.config.WarpRouting = warpRouting diff --git a/proxy/proxy.go b/proxy/proxy.go index dd999f87098..dbcbbb18a52 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -9,10 +9,14 @@ import ( "time" "github.com/pkg/errors" + pkgerrors "github.com/pkg/errors" "github.com/rs/zerolog" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + "github.com/cloudflare/cloudflared/management" + cfdsession "github.com/cloudflare/cloudflared/session" + "github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/cfio" "github.com/cloudflare/cloudflared/connection" @@ -30,11 +34,11 @@ const ( // Proxy represents a means to Proxy between cloudflared and the origin services. type Proxy struct { - ingressRules ingress.Ingress - warpRouting *ingress.WarpRoutingService - management *ingress.ManagementService - tags []pogs.Tag - log *zerolog.Logger + ingressRules ingress.Ingress + warpRouting *ingress.WarpRoutingService + tags []pogs.Tag + sessionLimiter cfdsession.Limiter + log *zerolog.Logger } // NewOriginProxy returns a new instance of the Proxy struct. @@ -42,13 +46,15 @@ func NewOriginProxy( ingressRules ingress.Ingress, warpRouting ingress.WarpRoutingConfig, tags []pogs.Tag, + sessionLimiter cfdsession.Limiter, writeTimeout time.Duration, log *zerolog.Logger, ) *Proxy { proxy := &Proxy{ - ingressRules: ingressRules, - tags: tags, - log: log, + ingressRules: ingressRules, + tags: tags, + sessionLimiter: sessionLimiter, + log: log, } proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout) @@ -64,7 +70,7 @@ func (p *Proxy) applyIngressMiddleware(rule *ingress.Rule, r *http.Request, w co } if result.ShouldFilterRequest { - w.WriteRespHeaders(result.StatusCode, nil) + _ = w.WriteRespHeaders(result.StatusCode, nil) return fmt.Errorf("request filtered by middleware handler (%s) due to: %s", handler.Name(), result.Reason), true } } @@ -152,10 +158,18 @@ func (p *Proxy) ProxyTCP( return err } + logger := newTCPLogger(p.log, req) + + // Try to start a new session + if err := p.sessionLimiter.Acquire(management.TCP.String()); err != nil { + logger.Warn().Msg("Too many concurrent sessions being handled, rejecting tcp proxy") + return pkgerrors.Wrap(err, "failed to start tcp session due to rate limiting") + } + defer p.sessionLimiter.Release() + serveCtx, cancel := context.WithCancel(ctx) defer cancel() - logger := newTCPLogger(p.log, req) tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger) logger.Debug().Msg("tcp proxy stream started") diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index c1c60fc4f4f..9c3a42b133d 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -21,8 +21,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/urfave/cli/v2" + "go.uber.org/mock/gomock" "golang.org/x/sync/errgroup" + "github.com/cloudflare/cloudflared/mocks" + + cfdsession "github.com/cloudflare/cloudflared/session" + "github.com/cloudflare/cloudflared/cfio" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" @@ -71,11 +76,6 @@ func (w *mockHTTPRespWriter) Read(data []byte) (int, error) { return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader") } -// respHeaders is a test function to read respHeaders -func (w *mockHTTPRespWriter) headers() http.Header { - return w.Header() -} - func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { panic("Hijack not implemented") } @@ -113,7 +113,7 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) { return w.reader.Read(data) } -func (m *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (w *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { panic("Hijack not implemented") } @@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) { require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log) + proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log) t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxySSE", testProxySSE(proxy)) @@ -246,7 +246,7 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) { _ = responseWriter.Close() close(finished) - errGroup.Wait() + _ = errGroup.Wait() } } @@ -267,7 +267,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) { defer wg.Done() log := zerolog.Nop() err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false) - require.Equal(t, err.Error(), "context canceled") + require.Equal(t, "context canceled", err.Error()) require.Equal(t, http.StatusOK, responseWriter.Code) }() @@ -275,7 +275,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) { for i := 0; i < pushCount; i++ { line := responseWriter.ReadBytes() expect := fmt.Sprintf("%d\n\n", i) - require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line)) + require.Equal(t, []byte(expect), line, "Expect to read %v, got %v", expect, line) } cancel() @@ -290,7 +290,9 @@ func TestProxySSEAllData(t *testing.T) { responseWriter := newMockSSERespWriter() // responseWriter uses an unbuffered channel, so we call in a different go-routine - go cfio.Copy(responseWriter, eyeballReader) + go func() { + _, _ = cfio.Copy(responseWriter, eyeballReader) + }() result := string(<-responseWriter.writeNotification) require.Equal(t, "data\r\r", result) @@ -366,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat ctx, cancel := context.WithCancel(context.Background()) require.NoError(t, ingress.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log) + proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log) for _, test := range tests { responseWriter := newMockHTTPRespWriter() @@ -414,23 +416,18 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log) + proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log) responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) - assert.NoError(t, err) + require.NoError(t, err) - assert.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)) + require.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)) } type replayer struct { sync.RWMutex - writeDone chan struct{} - rw *bytes.Buffer -} - -func newReplayer(buffer *bytes.Buffer) { - + rw *bytes.Buffer } func (r *replayer) Read(p []byte) (int, error) { @@ -471,7 +468,7 @@ func (r *replayer) Bytes() []byte { // eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access). func TestConnections(t *testing.T) { logger := logger.Create(nil) - replayer := &replayer{rw: &bytes.Buffer{}} + replayer := &replayer{rw: bytes.NewBuffer([]byte{})} type args struct { ingressServiceScheme string originService func(*testing.T, net.Listener) @@ -486,6 +483,9 @@ func TestConnections(t *testing.T) { // requestheaders to be sent in the call to proxy.Proxy requestHeaders http.Header + + // sessionLimiterResponse is the response of the cfdsession.Limiter#Acquire method call + sessionLimiterResponse error } type want struct { @@ -663,6 +663,25 @@ func TestConnections(t *testing.T) { err: true, }, }, + { + name: "tcp-* proxy rate limited flow", + args: args{ + ingressServiceScheme: "tcp://", + originService: runEchoTCPService, + eyeballResponseWriter: newTCPRespWriter(replayer), + eyeballRequestBody: newTCPRequestBody([]byte("rate-limited")), + warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)), + connectionType: connection.TypeTCP, + requestHeaders: map[string][]string{ + "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, + }, + sessionLimiterResponse: cfdsession.ErrTooManyActiveSessions, + }, + want: want{ + message: []byte{}, + err: true, + }, + }, } for _, test := range tests { @@ -674,8 +693,16 @@ func TestConnections(t *testing.T) { test.args.originService(t, ln) ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) - ingressRule.StartOrigins(logger, ctx.Done()) - proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger) + _ = ingressRule.StartOrigins(logger, ctx.Done()) + + // Mock session limiter + ctrl := gomock.NewController(t) + defer ctrl.Finish() + sessionLimiter := mocks.NewMockLimiter(ctrl) + sessionLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.sessionLimiterResponse) + sessionLimiter.EXPECT().Release().AnyTimes() + + proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, sessionLimiter, time.Duration(0), logger) proxy.warpRouting = test.args.warpRoutingService dest := ln.Addr().String() @@ -693,7 +720,7 @@ func TestConnections(t *testing.T) { respWriter = newTCPRespWriter(pipedReqBody.pipedConn) go func() { resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String()) - replayer.Write(resp) + _, _ = replayer.Write(resp) }() } if test.args.connectionType == connection.TypeTCP { @@ -705,9 +732,9 @@ func TestConnections(t *testing.T) { } cancel() - assert.Equal(t, test.want.err, err != nil) - assert.Equal(t, test.want.message, replayer.Bytes()) - assert.Equal(t, test.want.headers, respWriter.Header()) + require.Equal(t, test.want.err, err != nil) + require.Equal(t, test.want.message, replayer.Bytes()) + require.Equal(t, test.want.headers, respWriter.Header()) replayer.rw.Reset() }) } @@ -720,7 +747,9 @@ type requestBody struct { func newWSRequestBody(data []byte) *requestBody { pr, pw := io.Pipe() - go wsutil.WriteClientBinary(pw, data) + go func() { + _ = wsutil.WriteClientBinary(pw, data) + }() return &requestBody{ pr: pr, pw: pw, @@ -728,7 +757,9 @@ func newWSRequestBody(data []byte) *requestBody { } func newTCPRequestBody(data []byte) *requestBody { pr, pw := io.Pipe() - go pw.Write(data) + go func() { + _, _ = pw.Write(data) + }() return &requestBody{ pr: pr, pw: pw, @@ -740,8 +771,8 @@ func (r *requestBody) Read(p []byte) (n int, err error) { } func (r *requestBody) Close() error { - r.pw.Close() - r.pr.Close() + _ = r.pw.Close() + _ = r.pr.Close() return nil } @@ -774,6 +805,7 @@ func (p *pipedRequestBody) roundtrip(addr string) []byte { panic(err) } defer conn.Close() + defer resp.Body.Close() if resp.StatusCode != http.StatusSwitchingProtocols { panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode)) @@ -917,7 +949,9 @@ func runEchoTCPService(t *testing.T, l net.Listener) { go func() { for { conn, err := l.Accept() - require.NoError(t, err) + if err != nil { + panic(err) + } defer conn.Close() for { @@ -971,12 +1005,15 @@ func runEchoWSService(t *testing.T, l net.Listener) { } } + // nolint: gosec server := http.Server{ Handler: http.HandlerFunc(ws), } go func() { err := server.Serve(l) - require.NoError(t, err) + if err != nil { + panic(err) + } }() } diff --git a/tunnelrpc/pogs/quic_metadata_protocol.go b/tunnelrpc/pogs/quic_metadata_protocol.go index cfbfe845031..d73c973208d 100644 --- a/tunnelrpc/pogs/quic_metadata_protocol.go +++ b/tunnelrpc/pogs/quic_metadata_protocol.go @@ -18,6 +18,11 @@ const ( ConnectionTypeTCP ) +var ( + // ErrorFlowConnectRateLimitedKey is the Metadata entry that allows to know if a request was rate limited on connect. + ErrorFlowConnectRateLimitedKey = Metadata{Key: "FlowConnectRateLimited", Val: "true"} +) + func (c ConnectionType) String() string { switch c { case ConnectionTypeHTTP: diff --git a/tunnelrpc/quic/request_server_stream.go b/tunnelrpc/quic/request_server_stream.go index c0aee434da2..93b42b82889 100644 --- a/tunnelrpc/quic/request_server_stream.go +++ b/tunnelrpc/quic/request_server_stream.go @@ -37,7 +37,8 @@ func (rss *RequestServerStream) WriteConnectResponseData(respErr error, metadata var connectResponse *pogs.ConnectResponse if respErr != nil { connectResponse = &pogs.ConnectResponse{ - Error: respErr.Error(), + Error: respErr.Error(), + Metadata: metadata, } } else { connectResponse = &pogs.ConnectResponse{ diff --git a/tunnelrpc/quic/request_server_stream_test.go b/tunnelrpc/quic/request_server_stream_test.go index 0be170495bf..e6972263d27 100644 --- a/tunnelrpc/quic/request_server_stream_test.go +++ b/tunnelrpc/quic/request_server_stream_test.go @@ -98,12 +98,7 @@ func TestConnectResponseMeta(t *testing.T) { reqClientStream := RequestClientStream{noopCloser{b}} respMeta, err := reqClientStream.ReadConnectResponseData() require.NoError(t, err) - - if respMeta.Error == "" { - assert.Equal(t, test.metadata, respMeta.Metadata) - } else { - assert.Equal(t, 0, len(respMeta.Metadata)) - } + require.Equal(t, test.metadata, respMeta.Metadata) }) } } @@ -153,21 +148,21 @@ func TestRegisterUdpSession(t *testing.T) { }() rpcClientStream, err := NewCloudflaredClient(context.Background(), clientStream, 5*time.Second) - assert.NoError(t, err) + require.NoError(t, err) reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext) - assert.NoError(t, err) - assert.NoError(t, reg.Err) + require.NoError(t, err) + require.NoError(t, reg.Err) - // Different sessionID, the RPC server should reject the registraion + // Different sessionID, the RPC server should reject the registration reg, err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext) - assert.NoError(t, err) - assert.Error(t, reg.Err) + require.NoError(t, err) + require.Error(t, reg.Err) - assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage)) + require.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage)) - // Different sessionID, the RPC server should reject the unregistraion - assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage)) + // Different sessionID, the RPC server should reject the unregistration + require.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage)) rpcClientStream.Close() <-sessionRegisteredChan @@ -200,10 +195,10 @@ func TestManageConfiguration(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() rpcClientStream, err := NewCloudflaredClient(ctx, clientStream, 5*time.Second) - assert.NoError(t, err) + require.NoError(t, err) result, err := rpcClientStream.UpdateConfiguration(ctx, version, config) - assert.NoError(t, err) + require.NoError(t, err) require.Equal(t, version, result.LastAppliedVersion) require.NoError(t, result.Err)