From fc31ad9ff44283e229ed98eb3de5e1d1505c4134 Mon Sep 17 00:00:00 2001 From: Sylvain Muller Date: Sat, 25 Jan 2025 18:59:18 +0100 Subject: [PATCH] Improve client ip resolver performance for edge-case (#53) * feat(clientip): improve some resolver performance for edge-case such as header not present or when chaining multiple strategy with single ip header. * feat(http_const): change HeaderXRealIP location for better clarity --------- Co-authored-by: tigerwill90 --- clientip/clientip.go | 43 +++++++++++++++++++++------------------ clientip/clientip_test.go | 4 ++-- http_consts.go | 2 +- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/clientip/clientip.go b/clientip/clientip.go index 58d48ba..afff7fe 100644 --- a/clientip/clientip.go +++ b/clientip/clientip.go @@ -32,9 +32,11 @@ var ( ErrRightmostTrustedRange = errors.New("rightmost trusted range resolver") ) -// Avoid allocating those errors each time since it may happen a lot on adversary header. +// Avoid allocating those errors each time since it may happen a lot on adversary header or when using multiple single ip +// header resolver in order to find the right match. var ( errLeftmostNonPrivate = fmt.Errorf("%w: unable to find a valid or non-private IP", ErrLeftmostNonPrivate) + errSingleIPHeader = fmt.Errorf("%w: header not found", ErrSingleIPHeader) ) // TrustedIPRange returns a set of trusted IP ranges. @@ -150,7 +152,7 @@ func (s SingleIPHeader) ClientIP(c fox.Context) (*net.IPAddr, error) { // in theory it should be the newest value.) ipStr := lastHeader(c.Request().Header, s.headerName) if ipStr == "" { - return nil, fmt.Errorf("%w: header %q not found", ErrSingleIPHeader, s.headerName) + return nil, errSingleIPHeader } return ParseIPAddr(ipStr) @@ -192,13 +194,14 @@ func NewLeftmostNonPrivate(key HeaderKey, limit uint, opts ...BlacklistRangeOpti // ClientIP derives the client IP using the [LeftmostNonPrivate] resolver. The returned [net.IPAddr] may contain a // zone identifier. If no valid IP can be derived, an error returned. func (s LeftmostNonPrivate) ClientIP(c fox.Context) (*net.IPAddr, error) { - for ip := range iterutil.Take(ipAddrSeq(c.Request().Header, s.headerName), s.limit) { - if ip != nil && !isIPContainedInRanges(ip.IP, s.blacklistedRanges) { - // This is the leftmost valid, non-private IP - return ip, nil + if values, ok := c.Request().Header[s.headerName]; ok && len(values) > 0 { + for ip := range iterutil.Take(ipAddrSeq(values, s.headerName), s.limit) { + if ip != nil && !isIPContainedInRanges(ip.IP, s.blacklistedRanges) { + // This is the leftmost valid, non-private IP + return ip, nil + } } } - // We failed to find any valid, non-private IP return nil, errLeftmostNonPrivate } @@ -232,12 +235,13 @@ func NewRightmostNonPrivate(key HeaderKey, opts ...TrustedRangeOption) (Rightmos // ClientIP derives the client IP using the [RightmostNonPrivate] resolver. The returned [net.IPAddr] may contain a // zone identifier. If no valid IP can be derived, an error returned. func (s RightmostNonPrivate) ClientIP(c fox.Context) (*net.IPAddr, error) { - for ip := range backwardIpAddrSeq(c.Request().Header, s.headerName) { - if ip != nil && !isIPContainedInRanges(ip.IP, s.trustedRanges) { - return ip, nil + if values, ok := c.Request().Header[s.headerName]; ok && len(values) > 0 { + for ip := range backwardIpAddrSeq(values, s.headerName) { + if ip != nil && !isIPContainedInRanges(ip.IP, s.trustedRanges) { + return ip, nil + } } } - // We failed to find any valid, non-private IP return nil, fmt.Errorf("%w: unable to find a valid or non-private IP", ErrRightmostNonPrivate) } @@ -267,7 +271,7 @@ func NewRightmostTrustedCount(key HeaderKey, trustedCount uint) (RightmostTruste // ClientIP derives the client IP using the [RightmostTrustedCount] resolver. The returned [net.IPAddr] may contain a // zone identifier. If no valid IP can be derived, an error returned. func (s RightmostTrustedCount) ClientIP(c fox.Context) (*net.IPAddr, error) { - ip, ok := iterutil.At(backwardIpAddrSeq(c.Request().Header, s.headerName), s.trustedCount-1) + ip, ok := iterutil.At(backwardIpAddrSeq(c.Request().Header[s.headerName], s.headerName), s.trustedCount-1) if !ok { // This is a misconfiguration error. There were fewer IPs than we expected. return nil, fmt.Errorf("%w: expected at least %d IP(s)", ErrRightmostTrustedCount, s.trustedCount) @@ -317,7 +321,7 @@ func (s RightmostTrustedRange) ClientIP(c fox.Context) (*net.IPAddr, error) { return nil, fmt.Errorf("%w: unable to resolve trusted ip range: %w", ErrRightmostTrustedRange, err) } - for ip := range backwardIpAddrSeq(c.Request().Header, s.headerName) { + for ip := range backwardIpAddrSeq(c.Request().Header[s.headerName], s.headerName) { if ip != nil && isIPContainedInRanges(ip.IP, trustedRange) { // This IP is trusted continue @@ -467,10 +471,9 @@ func lastHeader(headers http.Header, headerName string) string { // backwardIpAddrSeq returns a range iterator over the X-Forwarded-For or Forwarded header // values, in reverse order. Any invalid IPs will result in nil elements. headerName must already -// be canonicalized. -func backwardIpAddrSeq(headers http.Header, headerName string) iter.Seq[*net.IPAddr] { +// be in canonical form. +func backwardIpAddrSeq(values []string, headerName string) iter.Seq[*net.IPAddr] { return func(yield func(*net.IPAddr) bool) { - values := headers[headerName] for i := len(values) - 1; i >= 0; i-- { for rawListItem := range iterutil.BackwardSplitStringSeq(values[i], ",") { // The IPs are often comma-space separated, so we'll need to trim the string @@ -495,12 +498,12 @@ func backwardIpAddrSeq(headers http.Header, headerName string) iter.Seq[*net.IPA // ipAddrSeq returns a range iterator over the X-Forwarded-For or Forwarded header // values, in order. Any invalid IPs will result in nil elements. headerName must already -// be canonicalized. -func ipAddrSeq(headers http.Header, headerName string) iter.Seq[*net.IPAddr] { +// be in canonical form. +func ipAddrSeq(values []string, headerName string) iter.Seq[*net.IPAddr] { return func(yield func(*net.IPAddr) bool) { - for _, h := range headers[headerName] { + for _, v := range values { // We now have a sequence of comma-separated list items. - for rawListItem := range iterutil.SplitStringSeq(h, ",") { + for rawListItem := range iterutil.SplitStringSeq(v, ",") { // The IPs are often comma-space separated, so we'll need to trim the string rawListItem = strings.TrimSpace(rawListItem) diff --git a/clientip/clientip_test.go b/clientip/clientip_test.go index d63cad9..bb26489 100644 --- a/clientip/clientip_test.go +++ b/clientip/clientip_test.go @@ -218,7 +218,7 @@ func TestChain_ClientIP(t *testing.T) { assert.ErrorIs(t, err, ErrSingleIPHeader) assert.ErrorIs(t, err, ErrRemoteAddress) assert.ErrorIs(t, err, ErrInvalidIpAddress) - assert.ErrorContains(t, err, "header \"Cf-Connecting-Ip\" not found") + assert.ErrorContains(t, err, "header not found") } func TestAddressesAndRangesToIPNets(t *testing.T) { @@ -721,7 +721,7 @@ func Test_forwardedHeaderRFCDeviations(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := slices.Collect(ipAddrSeq(tt.args.headers, tt.args.headerName)) + got := slices.Collect(ipAddrSeq(tt.args.headers[tt.args.headerName], tt.args.headerName)) assert.Equal(t, tt.want, got) }) } diff --git a/http_consts.go b/http_consts.go index 41971c5..28329d3 100644 --- a/http_consts.go +++ b/http_consts.go @@ -51,7 +51,6 @@ const ( HeaderXForwardedProto = "X-Forwarded-Proto" HeaderXForwardedProtocol = "X-Forwarded-Protocol" HeaderXForwardedSsl = "X-Forwarded-Ssl" - HeaderXRealIP = "X-Real-Ip" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" HeaderXRequestID = "X-Request-Id" @@ -92,6 +91,7 @@ const ( HeaderXAzureSocketIP = "X-Azure-SocketIP" HeaderXAppengineRemoteAddr = "X-Appengine-Remote-Addr" HeaderFlyClientIP = "Fly-Client-IP" + HeaderXRealIP = "X-Real-Ip" ) // nolint:gosec