diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index a7287fb0efa..146755c4a60 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -126,7 +126,7 @@ var ( routeFailMsg = fmt.Sprintf("failed to provision routing, please create it manually via Cloudflare dashboard or UI; "+ "most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+ "any existing DNS records for this hostname.", overwriteDNSFlag) - errDeprecatedClassicTunnel = fmt.Errorf("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)") + errDeprecatedClassicTunnel = errors.New("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)") // TODO: TUN-8756 the list below denotes the flags that do not possess any kind of sensitive information // however this approach is not maintainble in the long-term. nonSecretFlagsList = []string{ @@ -214,6 +214,7 @@ var ( "protocol", "overwrite-dns", "help", + "max-active-flows", } ) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 15c54954d42..2a0ac4ab915 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -38,7 +38,7 @@ const ( var ( secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag} - configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"} + configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address", "max-active-flows"} ) func logClientOptions(c *cli.Context, log *zerolog.Logger) { diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index cfbffcc9791..e86a02c671b 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -230,6 +230,11 @@ var ( Usage: "Network diagnostics won't be performed", Value: false, } + maxActiveFlowsFlag = &cli.Uint64Flag{ + Name: "max-active-flows", + Usage: "Overrides the remote configuration for max active private network flows (TCP/UDP) that this cloudflared instance supports", + EnvVars: []string{"TUNNEL_MAX_ACTIVE_FLOWS"}, + } ) func buildCreateCommand() *cli.Command { @@ -705,6 +710,7 @@ func buildRunCommand() *cli.Command { tunnelTokenFlag, icmpv4SrcFlag, icmpv6SrcFlag, + maxActiveFlowsFlag, } flags = append(flags, configureProxyFlags(false)...) return &cli.Command{ diff --git a/orchestration/orchestrator.go b/orchestration/orchestrator.go index 4c44143ea69..fc259d94d64 100644 --- a/orchestration/orchestrator.go +++ b/orchestration/orchestrator.go @@ -4,16 +4,16 @@ import ( "context" "encoding/json" "fmt" + "strconv" "sync" "sync/atomic" - "github.com/pkg/errors" + pkgerrors "github.com/pkg/errors" "github.com/rs/zerolog" - cfdflow "github.com/cloudflare/cloudflared/flow" - "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" + cfdflow "github.com/cloudflare/cloudflared/flow" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/proxy" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -117,6 +117,30 @@ func (o *Orchestrator) UpdateConfig(version int32, config []byte) *pogs.UpdateCo } } +// overrideRemoteWarpRoutingWithLocalValues overrides the ingress.WarpRoutingConfig that comes from the remote with +// the local values if there is any. +func (o *Orchestrator) overrideRemoteWarpRoutingWithLocalValues(remoteWarpRouting *ingress.WarpRoutingConfig) error { + return o.overrideMaxActiveFlows(o.config.ConfigurationFlags["max-active-flows"], remoteWarpRouting) +} + +// overrideMaxActiveFlows checks the local configuration flags, and if a value is found for the flags.MaxActiveFlows +// overrides the value that comes on the remote ingress.WarpRoutingConfig with the local value. +func (o *Orchestrator) overrideMaxActiveFlows(maxActiveFlowsLocalConfig string, remoteWarpRouting *ingress.WarpRoutingConfig) error { + // If max active flows isn't defined locally just use the remote value + if maxActiveFlowsLocalConfig == "" { + return nil + } + + maxActiveFlowsLocalOverride, err := strconv.ParseUint(maxActiveFlowsLocalConfig, 10, 64) + if err != nil { + return pkgerrors.Wrapf(err, "failed to parse %s", "max-active-flows") + } + + // Override the value that comes from the remote with the local value + remoteWarpRouting.MaxActiveFlows = maxActiveFlowsLocalOverride + return nil +} + // The caller is responsible to make sure there is no concurrent access func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting ingress.WarpRoutingConfig) error { select { @@ -125,6 +149,11 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i default: } + // Overrides the local values, onto the remote values of the warp routing configuration + if err := o.overrideRemoteWarpRoutingWithLocalValues(&warpRouting); err != nil { + return pkgerrors.Wrap(err, "failed to merge local overrides into warp routing configuration") + } + // Assign the internal ingress rules to the parsed ingress ingressRules.InternalRules = o.internalRules @@ -139,7 +168,7 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i // The downside is minimized because none of the ingress.OriginService implementation have that requirement proxyShutdownC := make(chan struct{}) if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil { - return errors.Wrap(err, "failed to start origin") + return pkgerrors.Wrap(err, "failed to start origin") } // Update the flow limit since the configuration might have changed diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go index eb2c6f72be8..fdb9ce34d3e 100644 --- a/orchestration/orchestrator_test.go +++ b/orchestration/orchestrator_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" gows "github.com/gorilla/websocket" "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudflare/cloudflared/config" @@ -106,25 +107,25 @@ func TestUpdateConfiguration(t *testing.T) { require.Len(t, configV2.Ingress.Rules, 3) // originRequest of this ingress rule overrides global default require.Equal(t, config.CustomDuration{Duration: time.Second * 10}, configV2.Ingress.Rules[0].Config.ConnectTimeout) - require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoTLSVerify) + require.True(t, configV2.Ingress.Rules[0].Config.NoTLSVerify) // Inherited from global default - require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoHappyEyeballs) + require.True(t, configV2.Ingress.Rules[0].Config.NoHappyEyeballs) // Validate ingress rule 1 require.Equal(t, "jira.tunnel.org", configV2.Ingress.Rules[1].Hostname) require.True(t, configV2.Ingress.Rules[1].Matches("jira.tunnel.org", "/users")) require.Equal(t, "http://172.32.20.6:80", configV2.Ingress.Rules[1].Service.String()) // originRequest of this ingress rule overrides global default require.Equal(t, config.CustomDuration{Duration: time.Second * 30}, configV2.Ingress.Rules[1].Config.ConnectTimeout) - require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoTLSVerify) + require.True(t, configV2.Ingress.Rules[1].Config.NoTLSVerify) // Inherited from global default - require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoHappyEyeballs) + require.True(t, configV2.Ingress.Rules[1].Config.NoHappyEyeballs) // Validate ingress rule 2, it's the catch-all rule require.True(t, configV2.Ingress.Rules[2].Matches("blogs.tunnel.io", "/2022/02/10")) // Inherited from global default require.Equal(t, config.CustomDuration{Duration: time.Second * 90}, configV2.Ingress.Rules[2].Config.ConnectTimeout) - require.Equal(t, false, configV2.Ingress.Rules[2].Config.NoTLSVerify) - require.Equal(t, true, configV2.Ingress.Rules[2].Config.NoHappyEyeballs) - require.Equal(t, configV2.WarpRouting.ConnectTimeout.Duration, 10*time.Second) + require.False(t, configV2.Ingress.Rules[2].Config.NoTLSVerify) + require.True(t, configV2.Ingress.Rules[2].Config.NoHappyEyeballs) + require.Equal(t, 10*time.Second, configV2.WarpRouting.ConnectTimeout.Duration) originProxyV2, err := orchestrator.GetOriginProxy() require.NoError(t, err) @@ -317,7 +318,7 @@ func TestConcurrentUpdateAndRead(t *testing.T) { go func(i int, originProxy connection.OriginProxy) { defer wg.Done() resp, err := proxyHTTP(originProxy, hostname) - require.NoError(t, err, "proxyHTTP %d failed %v", i, err) + assert.NoError(t, err, "proxyHTTP %d failed %v", i, err) defer resp.Body.Close() var warpRoutingDisabled bool @@ -326,16 +327,16 @@ func TestConcurrentUpdateAndRead(t *testing.T) { // v1 proxy, warp enabled case 200: body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, t.Name(), string(body)) + assert.NoError(t, err) + assert.Equal(t, t.Name(), string(body)) warpRoutingDisabled = false // v2 proxy, warp disabled case 204: - require.Greater(t, i, concurrentRequests/4) + assert.Greater(t, i, concurrentRequests/4) warpRoutingDisabled = true // v3 proxy, warp enabled case 418: - require.Greater(t, i, concurrentRequests/2) + assert.Greater(t, i, concurrentRequests/2) warpRoutingDisabled = false } @@ -358,11 +359,10 @@ func TestConcurrentUpdateAndRead(t *testing.T) { err = proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), w, pr) if warpRoutingDisabled { - require.Error(t, err, "expect proxyTCP %d to return error", i) + assert.Error(t, err, "expect proxyTCP %d to return error", i) } else { - require.NoError(t, err, "proxyTCP %d failed %v", i, err) + assert.NoError(t, err, "proxyTCP %d failed %v", i, err) } - }(i, originProxy) if i == concurrentRequests/4 { @@ -388,6 +388,57 @@ func TestConcurrentUpdateAndRead(t *testing.T) { wg.Wait() } +// TestOverrideWarpRoutingConfigWithLocalValues tests that if a value is defined in the Config.ConfigurationFlags, +// it will override the value that comes from the remote result. +func TestOverrideWarpRoutingConfigWithLocalValues(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + assertMaxActiveFlows := func(orchestrator *Orchestrator, expectedValue uint64) { + configJson, err := orchestrator.GetConfigJSON() + require.NoError(t, err) + var result map[string]interface{} + err = json.Unmarshal(configJson, &result) + require.NoError(t, err) + warpRouting := result["warp-routing"].(map[string]interface{}) + require.EqualValues(t, expectedValue, warpRouting["maxActiveFlows"]) + } + + remoteValue := uint64(100) + remoteIngress := ingress.Ingress{} + remoteWarpConfig := ingress.WarpRoutingConfig{ + MaxActiveFlows: remoteValue, + } + remoteConfig := &Config{ + Ingress: &remoteIngress, + WarpRouting: remoteWarpConfig, + ConfigurationFlags: map[string]string{}, + } + orchestrator, err := NewOrchestrator(ctx, remoteConfig, testTags, []ingress.Rule{}, &testLogger) + require.NoError(t, err) + + assertMaxActiveFlows(orchestrator, remoteValue) + + // Add a local override for the maxActiveFlows + localValue := uint64(500) + remoteConfig.ConfigurationFlags["max-active-flows"] = fmt.Sprintf("%d", localValue) + // Force a configuration refresh + err = orchestrator.updateIngress(remoteIngress, remoteWarpConfig) + require.NoError(t, err) + + // Check the value being used is the local one + assertMaxActiveFlows(orchestrator, localValue) + + // Remove local override for the maxActiveFlows + delete(remoteConfig.ConfigurationFlags, "max-active-flows") + // Force a configuration refresh + err = orchestrator.updateIngress(remoteIngress, remoteWarpConfig) + require.NoError(t, err) + + // Check the value being used is now the remote again + assertMaxActiveFlows(orchestrator, remoteValue) +} + func proxyHTTP(originProxy connection.OriginProxy, hostname string) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", hostname), nil) if err != nil { @@ -409,15 +460,16 @@ func proxyHTTP(originProxy connection.OriginProxy, hostname string) (*http.Respo return w.Result(), nil } +// nolint: testifylint // this is used inside go routines so it can't use `require.` func tcpEyeball(t *testing.T, reqWriter io.WriteCloser, body string, respReadWriter *respReadWriteFlusher) { writeN, err := reqWriter.Write([]byte(body)) - require.NoError(t, err) + assert.NoError(t, err) readBuffer := make([]byte, writeN) n, err := respReadWriter.Read(readBuffer) - require.NoError(t, err) - require.Equal(t, body, string(readBuffer[:n])) - require.Equal(t, writeN, n) + assert.NoError(t, err) + assert.Equal(t, body, string(readBuffer[:n])) + assert.Equal(t, writeN, n) } func proxyTCP(ctx context.Context, originProxy connection.OriginProxy, originAddr string, w http.ResponseWriter, reqBody io.ReadCloser) error { @@ -458,14 +510,15 @@ func serveTCPOrigin(t *testing.T, tcpOrigin net.Listener, wg *sync.WaitGroup) { } } +// nolint: testifylint // this is used inside go routines so it can't use `require.` func echoTCP(t *testing.T, conn net.Conn) { readBuf := make([]byte, 1000) readN, err := conn.Read(readBuf) - require.NoError(t, err) + assert.NoError(t, err) writeN, err := conn.Write(readBuf[:readN]) - require.NoError(t, err) - require.Equal(t, readN, writeN) + assert.NoError(t, err) + assert.Equal(t, readN, writeN) } type validateHostHandler struct { @@ -479,16 +532,17 @@ func (vhh *validateHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request return } w.WriteHeader(http.StatusOK) - w.Write([]byte(vhh.body)) + _, _ = w.Write([]byte(vhh.body)) } +// nolint: testifylint // this is used inside go routines so it can't use `require.` func updateWithValidation(t *testing.T, orchestrator *Orchestrator, version int32, config []byte) { resp := orchestrator.UpdateConfig(version, config) - require.NoError(t, resp.Err) - require.Equal(t, version, resp.LastAppliedVersion) + assert.NoError(t, resp.Err) + assert.Equal(t, version, resp.LastAppliedVersion) } -// TestClosePreviousProxies makes sure proxies started in the pervious configuration version are shutdown +// TestClosePreviousProxies makes sure proxies started in the previous configuration version are shutdown func TestClosePreviousProxies(t *testing.T) { var ( hostname = "hello.tunnel1.org" @@ -532,6 +586,7 @@ func TestClosePreviousProxies(t *testing.T) { originProxyV1, err := orchestrator.GetOriginProxy() require.NoError(t, err) + // nolint: bodyclose resp, err := proxyHTTP(originProxyV1, hostname) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -540,12 +595,14 @@ func TestClosePreviousProxies(t *testing.T) { originProxyV2, err := orchestrator.GetOriginProxy() require.NoError(t, err) + // nolint: bodyclose resp, err = proxyHTTP(originProxyV2, hostname) require.NoError(t, err) require.Equal(t, http.StatusTeapot, resp.StatusCode) // The hello-world server in config v1 should have been stopped. We wait a bit since it's closed asynchronously. time.Sleep(time.Millisecond * 10) + // nolint: bodyclose resp, err = proxyHTTP(originProxyV1, hostname) require.Error(t, err) require.Nil(t, resp) @@ -557,6 +614,7 @@ func TestClosePreviousProxies(t *testing.T) { require.NoError(t, err) require.NotEqual(t, originProxyV1, originProxyV3) + // nolint: bodyclose resp, err = proxyHTTP(originProxyV3, hostname) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -566,6 +624,7 @@ func TestClosePreviousProxies(t *testing.T) { // Wait for proxies to shutdown time.Sleep(time.Millisecond * 10) + // nolint: bodyclose resp, err = proxyHTTP(originProxyV3, hostname) require.Error(t, err) require.Nil(t, resp) @@ -622,7 +681,7 @@ func TestPersistentConnection(t *testing.T) { go func() { defer wg.Done() conn, err := tcpOrigin.Accept() - require.NoError(t, err) + assert.NoError(t, err) defer conn.Close() // Expect 3 TCP messages @@ -630,26 +689,26 @@ func TestPersistentConnection(t *testing.T) { echoTCP(t, conn) } }() - // Simulate cloudflared recieving a TCP connection + // Simulate cloudflared receiving a TCP connection go func() { defer wg.Done() - require.NoError(t, proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader)) + assert.NoError(t, proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader)) }() - // Simulate cloudflared recieving a WS connection + // Simulate cloudflared receiving a WS connection go func() { defer wg.Done() req, err := http.NewRequest(http.MethodGet, hostname, wsReqReader) - require.NoError(t, err) + assert.NoError(t, err) // ProxyHTTP will add Connection, Upgrade and Sec-Websocket-Version headers req.Header.Add("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") log := zerolog.Nop() respWriter, err := connection.NewHTTP2RespWriter(req, wsRespReadWriter, connection.TypeWebsocket, &log) - require.NoError(t, err) + assert.NoError(t, err) err = originProxy.ProxyHTTP(respWriter, tracing.NewTracedHTTPRequest(req, 0, &log), true) - require.NoError(t, err) + assert.NoError(t, err) }() // Simulate eyeball WS and TCP connections