Skip to content

Commit

Permalink
TUN-8236: Add write timeout to quic and tcp connections
Browse files Browse the repository at this point in the history
## Summary
To prevent bad eyeballs and severs to be able to exhaust the quic
control flows we are adding the possibility of having a timeout
for a write operation to be acknowledged. This will prevent hanging
connections from exhausting the quic control flows, creating a DDoS.
  • Loading branch information
jcsf committed Feb 15, 2024
1 parent 56aeb6b commit 76badfa
Show file tree
Hide file tree
Showing 18 changed files with 146 additions and 54 deletions.
10 changes: 10 additions & 0 deletions cmd/cloudflared/tunnel/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ const (
// udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge
udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout"

// writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
writeStreamTimeout = "write-stream-timeout"

// quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
Expand Down Expand Up @@ -696,6 +699,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Value: 5 * time.Second,
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: writeStreamTimeout,
EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"},
Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.",
Value: 0 * time.Second,
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: quicDisablePathMTUDiscovery,
EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"},
Expand Down
2 changes: 2 additions & 0 deletions cmd/cloudflared/tunnel/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ func prepareTunnelConfig(
FeatureSelector: featureSelector,
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
WriteStreamTimeout: c.Duration(writeStreamTimeout),
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
}
packetConfig, err := newPacketConfig(c, log)
Expand All @@ -259,6 +260,7 @@ func prepareTunnelConfig(
Ingress: &ingressRules,
WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting),
ConfigurationFlags: parseConfigFlags(c),
WriteTimeout: c.Duration(writeStreamTimeout),
}
return tunnelConfig, orchestratorConfig, nil
}
Expand Down
7 changes: 5 additions & 2 deletions connection/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type QUICConnection struct {
connIndex uint8

udpUnregisterTimeout time.Duration
streamWriteTimeout time.Duration
}

// NewQUICConnection returns a new instance of QUICConnection.
Expand All @@ -82,6 +83,7 @@ func NewQUICConnection(
logger *zerolog.Logger,
packetRouterConfig *ingress.GlobalRouterConfig,
udpUnregisterTimeout time.Duration,
streamWriteTimeout time.Duration,
) (*QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
if err != nil {
Expand Down Expand Up @@ -117,6 +119,7 @@ func NewQUICConnection(
connOptions: connOptions,
connIndex: connIndex,
udpUnregisterTimeout: udpUnregisterTimeout,
streamWriteTimeout: streamWriteTimeout,
}, nil
}

Expand Down Expand Up @@ -195,7 +198,7 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {

func (q *QUICConnection) runStream(quicStream quic.Stream) {
ctx := quicStream.Context()
stream := quicpogs.NewSafeStreamCloser(quicStream)
stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()

// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
Expand Down Expand Up @@ -373,7 +376,7 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI
return
}

stream := quicpogs.NewSafeStreamCloser(quicStream)
stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger)
if err != nil {
Expand Down
4 changes: 3 additions & 1 deletion connection/quic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var (
KeepAlivePeriod: 5 * time.Second,
EnableDatagrams: true,
}
defaultQUICTimeout = 30 * time.Second
)

var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
Expand Down Expand Up @@ -197,7 +198,7 @@ func quicServer(

quicStream, err := session.OpenStreamSync(context.Background())
require.NoError(t, err)
stream := quicpogs.NewSafeStreamCloser(quicStream)
stream := quicpogs.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)

reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
Expand Down Expand Up @@ -726,6 +727,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU
&log,
nil,
5*time.Second,
0*time.Second,
)
require.NoError(t, err)
return qc
Expand Down
7 changes: 7 additions & 0 deletions ingress/constants_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package ingress

import "github.com/cloudflare/cloudflared/logger"

var (
TestLogger = logger.Create(nil)
)
26 changes: 22 additions & 4 deletions ingress/origin_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"io"
"net"
"time"

"github.com/rs/zerolog"

Expand Down Expand Up @@ -31,15 +32,32 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze

// tcpConnection is an OriginConnection that directly streams to raw TCP.
type tcpConnection struct {
conn net.Conn
net.Conn
writeTimeout time.Duration
logger *zerolog.Logger
}

func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
stream.Pipe(tunnelConn, tc.conn, log)
func (tc *tcpConnection) Stream(_ context.Context, tunnelConn io.ReadWriter, _ *zerolog.Logger) {
stream.Pipe(tunnelConn, tc, tc.logger)
}

func (tc *tcpConnection) Write(b []byte) (int, error) {
if tc.writeTimeout > 0 {
if err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil {
tc.logger.Err(err).Msg("Error setting write deadline for TCP connection")
}
}

nBytes, err := tc.Conn.Write(b)
if err != nil {
tc.logger.Err(err).Msg("Error writing to the TCP connection")
}

return nBytes, err
}

func (tc *tcpConnection) Close() {
tc.conn.Close()
tc.Conn.Close()
}

// tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
Expand Down
15 changes: 7 additions & 8 deletions ingress/origin_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"golang.org/x/net/proxy"
"golang.org/x/sync/errgroup"

"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/websocket"
Expand All @@ -31,15 +30,15 @@ const (
)

var (
testLogger = logger.Create(nil)
testMessage = []byte("TestStreamOriginConnection")
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
)

func TestStreamTCPConnection(t *testing.T) {
cfdConn, originConn := net.Pipe()
tcpConn := tcpConnection{
conn: cfdConn,
Conn: cfdConn,
writeTimeout: 30 * time.Second,
}

eyeballConn, edgeConn := net.Pipe()
Expand All @@ -66,7 +65,7 @@ func TestStreamTCPConnection(t *testing.T) {
return nil
})

tcpConn.Stream(ctx, edgeConn, testLogger)
tcpConn.Stream(ctx, edgeConn, TestLogger)
require.NoError(t, errGroup.Wait())
}

Expand All @@ -93,7 +92,7 @@ func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
return nil
})

tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
require.NoError(t, errGroup.Wait())
}

Expand Down Expand Up @@ -147,7 +146,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {

errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
return nil
})

Expand All @@ -159,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
require.NoError(t, err)
defer wsForwarderInConn.Close()

stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, TestLogger)
return nil
})

Expand Down Expand Up @@ -209,7 +208,7 @@ func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
originConn.Close()
}()
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger)
tcpOverWSConn.Stream(ctx, eyeballConn, TestLogger)
})
server := httptest.NewServer(handler)
defer server.Close()
Expand Down
14 changes: 9 additions & 5 deletions ingress/origin_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"fmt"
"net/http"

"github.com/rs/zerolog"
)

// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
Expand All @@ -14,7 +16,7 @@ type HTTPOriginProxy interface {

// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface {
EstablishConnection(ctx context.Context, dest string) (OriginConnection, error)
EstablishConnection(ctx context.Context, dest string, log *zerolog.Logger) (OriginConnection, error)
}

// HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests.
Expand Down Expand Up @@ -62,19 +64,21 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return resp, nil
}

func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) {
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) {
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
if err != nil {
return nil, err
}

originConn := &tcpConnection{
conn: conn,
Conn: conn,
writeTimeout: o.writeTimeout,
logger: logger,
}
return originConn, nil
}

func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) {
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, _ *zerolog.Logger) (OriginConnection, error) {
var err error
if !o.isBastion {
dest = o.dest
Expand All @@ -92,6 +96,6 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string)

}

func (o *socksProxyOverWSService) EstablishConnection(_ctx context.Context, _dest string) (OriginConnection, error) {
func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
return o.conn, nil
}
10 changes: 5 additions & 5 deletions ingress/origin_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
require.NoError(t, err)

// Origin not listening for new connection, should return an error
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String())
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger)
require.Error(t, err)
}

Expand Down Expand Up @@ -87,7 +87,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
t.Run(test.testCase, func(t *testing.T) {
if test.expectErr {
bastionHost, _ := carrier.ResolveBastionDest(test.req)
_, err := test.service.EstablishConnection(context.Background(), bastionHost)
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err)
}
})
Expand All @@ -99,7 +99,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
_, err := service.EstablishConnection(context.Background(), bastionHost)
_, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err)
}
}
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
url: originURL,
}
shutdownC := make(chan struct{})
require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))

req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err)
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
url: originURL,
}
shutdownC := make(chan struct{})
require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))

// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
protos := []string{"https", "http", "dne"}
Expand Down
11 changes: 7 additions & 4 deletions ingress/origin_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,17 @@ func (o httpService) MarshalJSON() ([]byte, error) {
// rawTCPService dials TCP to the destination specified by the client
// It's used by warp routing
type rawTCPService struct {
name string
dialer net.Dialer
name string
dialer net.Dialer
writeTimeout time.Duration
logger *zerolog.Logger
}

func (o *rawTCPService) String() string {
return o.name
}

func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
func (o *rawTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, _ OriginRequestConfig) error {
return nil
}

Expand Down Expand Up @@ -285,13 +287,14 @@ type WarpRoutingService struct {
Proxy StreamBasedOriginProxy
}

func NewWarpRoutingService(config WarpRoutingConfig) *WarpRoutingService {
func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService {
svc := &rawTCPService{
name: ServiceWarpRouting,
dialer: net.Dialer{
Timeout: config.ConnectTimeout.Duration,
KeepAlive: config.TCPKeepAlive.Duration,
},
writeTimeout: writeTimeout,
}

return &WarpRoutingService{Proxy: svc}
Expand Down
6 changes: 4 additions & 2 deletions orchestration/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package orchestration

import (
"encoding/json"
"time"

"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/ingress"
Expand All @@ -19,8 +20,9 @@ type newLocalConfig struct {

// Config is the original config as read and parsed by cloudflared.
type Config struct {
Ingress *ingress.Ingress
WarpRouting ingress.WarpRoutingConfig
Ingress *ingress.Ingress
WarpRouting ingress.WarpRoutingConfig
WriteTimeout time.Duration

// Extra settings used to configure this instance but that are not eligible for remotely management
// ie. (--protocol, --loglevel, ...)
Expand Down
Loading

0 comments on commit 76badfa

Please sign in to comment.