Skip to content

Commit

Permalink
Improve client ip resolver performance for edge-case (#53)
Browse files Browse the repository at this point in the history
* feat(clientip): improve some resolver performance for edge-case such as header not present or when chaining multiple strategy with single ip header.

* feat(http_const): change HeaderXRealIP location for better clarity

---------

Co-authored-by: tigerwill90 <[email protected]>
  • Loading branch information
tigerwill90 and tigerwill90 authored Jan 25, 2025
1 parent e3f0f52 commit fc31ad9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
43 changes: 23 additions & 20 deletions clientip/clientip.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ var (
ErrRightmostTrustedRange = errors.New("rightmost trusted range resolver")
)

// Avoid allocating those errors each time since it may happen a lot on adversary header.
// Avoid allocating those errors each time since it may happen a lot on adversary header or when using multiple single ip
// header resolver in order to find the right match.
var (
errLeftmostNonPrivate = fmt.Errorf("%w: unable to find a valid or non-private IP", ErrLeftmostNonPrivate)
errSingleIPHeader = fmt.Errorf("%w: header not found", ErrSingleIPHeader)
)

// TrustedIPRange returns a set of trusted IP ranges.
Expand Down Expand Up @@ -150,7 +152,7 @@ func (s SingleIPHeader) ClientIP(c fox.Context) (*net.IPAddr, error) {
// in theory it should be the newest value.)
ipStr := lastHeader(c.Request().Header, s.headerName)
if ipStr == "" {
return nil, fmt.Errorf("%w: header %q not found", ErrSingleIPHeader, s.headerName)
return nil, errSingleIPHeader
}

return ParseIPAddr(ipStr)
Expand Down Expand Up @@ -192,13 +194,14 @@ func NewLeftmostNonPrivate(key HeaderKey, limit uint, opts ...BlacklistRangeOpti
// ClientIP derives the client IP using the [LeftmostNonPrivate] resolver. The returned [net.IPAddr] may contain a
// zone identifier. If no valid IP can be derived, an error returned.
func (s LeftmostNonPrivate) ClientIP(c fox.Context) (*net.IPAddr, error) {
for ip := range iterutil.Take(ipAddrSeq(c.Request().Header, s.headerName), s.limit) {
if ip != nil && !isIPContainedInRanges(ip.IP, s.blacklistedRanges) {
// This is the leftmost valid, non-private IP
return ip, nil
if values, ok := c.Request().Header[s.headerName]; ok && len(values) > 0 {
for ip := range iterutil.Take(ipAddrSeq(values, s.headerName), s.limit) {
if ip != nil && !isIPContainedInRanges(ip.IP, s.blacklistedRanges) {
// This is the leftmost valid, non-private IP
return ip, nil
}
}
}

// We failed to find any valid, non-private IP
return nil, errLeftmostNonPrivate
}
Expand Down Expand Up @@ -232,12 +235,13 @@ func NewRightmostNonPrivate(key HeaderKey, opts ...TrustedRangeOption) (Rightmos
// ClientIP derives the client IP using the [RightmostNonPrivate] resolver. The returned [net.IPAddr] may contain a
// zone identifier. If no valid IP can be derived, an error returned.
func (s RightmostNonPrivate) ClientIP(c fox.Context) (*net.IPAddr, error) {
for ip := range backwardIpAddrSeq(c.Request().Header, s.headerName) {
if ip != nil && !isIPContainedInRanges(ip.IP, s.trustedRanges) {
return ip, nil
if values, ok := c.Request().Header[s.headerName]; ok && len(values) > 0 {
for ip := range backwardIpAddrSeq(values, s.headerName) {
if ip != nil && !isIPContainedInRanges(ip.IP, s.trustedRanges) {
return ip, nil
}
}
}

// We failed to find any valid, non-private IP
return nil, fmt.Errorf("%w: unable to find a valid or non-private IP", ErrRightmostNonPrivate)
}
Expand Down Expand Up @@ -267,7 +271,7 @@ func NewRightmostTrustedCount(key HeaderKey, trustedCount uint) (RightmostTruste
// ClientIP derives the client IP using the [RightmostTrustedCount] resolver. The returned [net.IPAddr] may contain a
// zone identifier. If no valid IP can be derived, an error returned.
func (s RightmostTrustedCount) ClientIP(c fox.Context) (*net.IPAddr, error) {
ip, ok := iterutil.At(backwardIpAddrSeq(c.Request().Header, s.headerName), s.trustedCount-1)
ip, ok := iterutil.At(backwardIpAddrSeq(c.Request().Header[s.headerName], s.headerName), s.trustedCount-1)
if !ok {
// This is a misconfiguration error. There were fewer IPs than we expected.
return nil, fmt.Errorf("%w: expected at least %d IP(s)", ErrRightmostTrustedCount, s.trustedCount)
Expand Down Expand Up @@ -317,7 +321,7 @@ func (s RightmostTrustedRange) ClientIP(c fox.Context) (*net.IPAddr, error) {
return nil, fmt.Errorf("%w: unable to resolve trusted ip range: %w", ErrRightmostTrustedRange, err)
}

for ip := range backwardIpAddrSeq(c.Request().Header, s.headerName) {
for ip := range backwardIpAddrSeq(c.Request().Header[s.headerName], s.headerName) {
if ip != nil && isIPContainedInRanges(ip.IP, trustedRange) {
// This IP is trusted
continue
Expand Down Expand Up @@ -467,10 +471,9 @@ func lastHeader(headers http.Header, headerName string) string {

// backwardIpAddrSeq returns a range iterator over the X-Forwarded-For or Forwarded header
// values, in reverse order. Any invalid IPs will result in nil elements. headerName must already
// be canonicalized.
func backwardIpAddrSeq(headers http.Header, headerName string) iter.Seq[*net.IPAddr] {
// be in canonical form.
func backwardIpAddrSeq(values []string, headerName string) iter.Seq[*net.IPAddr] {
return func(yield func(*net.IPAddr) bool) {
values := headers[headerName]
for i := len(values) - 1; i >= 0; i-- {
for rawListItem := range iterutil.BackwardSplitStringSeq(values[i], ",") {
// The IPs are often comma-space separated, so we'll need to trim the string
Expand All @@ -495,12 +498,12 @@ func backwardIpAddrSeq(headers http.Header, headerName string) iter.Seq[*net.IPA

// ipAddrSeq returns a range iterator over the X-Forwarded-For or Forwarded header
// values, in order. Any invalid IPs will result in nil elements. headerName must already
// be canonicalized.
func ipAddrSeq(headers http.Header, headerName string) iter.Seq[*net.IPAddr] {
// be in canonical form.
func ipAddrSeq(values []string, headerName string) iter.Seq[*net.IPAddr] {
return func(yield func(*net.IPAddr) bool) {
for _, h := range headers[headerName] {
for _, v := range values {
// We now have a sequence of comma-separated list items.
for rawListItem := range iterutil.SplitStringSeq(h, ",") {
for rawListItem := range iterutil.SplitStringSeq(v, ",") {
// The IPs are often comma-space separated, so we'll need to trim the string
rawListItem = strings.TrimSpace(rawListItem)

Expand Down
4 changes: 2 additions & 2 deletions clientip/clientip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func TestChain_ClientIP(t *testing.T) {
assert.ErrorIs(t, err, ErrSingleIPHeader)
assert.ErrorIs(t, err, ErrRemoteAddress)
assert.ErrorIs(t, err, ErrInvalidIpAddress)
assert.ErrorContains(t, err, "header \"Cf-Connecting-Ip\" not found")
assert.ErrorContains(t, err, "header not found")
}

func TestAddressesAndRangesToIPNets(t *testing.T) {
Expand Down Expand Up @@ -721,7 +721,7 @@ func Test_forwardedHeaderRFCDeviations(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := slices.Collect(ipAddrSeq(tt.args.headers, tt.args.headerName))
got := slices.Collect(ipAddrSeq(tt.args.headers[tt.args.headerName], tt.args.headerName))
assert.Equal(t, tt.want, got)
})
}
Expand Down
2 changes: 1 addition & 1 deletion http_consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ const (
HeaderXForwardedProto = "X-Forwarded-Proto"
HeaderXForwardedProtocol = "X-Forwarded-Protocol"
HeaderXForwardedSsl = "X-Forwarded-Ssl"
HeaderXRealIP = "X-Real-Ip"
HeaderXUrlScheme = "X-Url-Scheme"
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
HeaderXRequestID = "X-Request-Id"
Expand Down Expand Up @@ -92,6 +91,7 @@ const (
HeaderXAzureSocketIP = "X-Azure-SocketIP"
HeaderXAppengineRemoteAddr = "X-Appengine-Remote-Addr"
HeaderFlyClientIP = "Fly-Client-IP"
HeaderXRealIP = "X-Real-Ip"
)

// nolint:gosec
Expand Down

0 comments on commit fc31ad9

Please sign in to comment.