diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index aae7bf6..ae2c5d2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -9,16 +9,16 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [ '>=1.19' ] + go: [ '>=1.21' ] steps: - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} cache: false - name: Check out code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Run tests run: go test -v -coverprofile=coverage.txt -covermode=atomic ./... @@ -27,24 +27,25 @@ jobs: run: go test -v -race -run TestDataRace -count=10 ./... - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: - flags: coverage.txt + files: ./coverage.txt + token: ${{ secrets.CODECOV_TOKEN }} lint: name: Lint Fox runs-on: ubuntu-latest strategy: matrix: - go: [ '>=1.19' ] + go: [ '>=1.21' ] steps: - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} cache: false - name: Check out code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Run linter - uses: golangci/golangci-lint-action@v3 \ No newline at end of file + uses: golangci/golangci-lint-action@v6 diff --git a/README.md b/README.md index 73981f5..9e32d1f 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,8 @@ or missing trailing slash, at no extra cost. **Automatic OPTIONS replies:** Inspired from [httprouter](https://github.com/julienschmidt/httprouter), the router has built-in native support for [OPTIONS requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS). +**Client IP Derivation:** Accurately determine the "real" client IP address using best practices tailored to your network topology. + Of course, you can also register custom `NotFound` and `MethodNotAllowed` handlers. ## Getting started @@ -484,6 +486,55 @@ f := fox.New( ) ```` +## Client IP Derivation +The `WithClientIPStrategy` option allows you to set up strategies to resolve the client IP address based on your +use case and network topology. Accurately determining the client IP is hard, particularly in environments with proxies or +load balancers. For example, the leftmost IP in the `X-Forwarded-For` header is commonly used and is often regarded as the +"closest to the client" and "most real," but it can be easily spoofed. Therefore, you should absolutely avoid using it +for any security-related purposes, such as request throttling. + +The strategy used must be chosen and tuned for your network configuration. This should result in the strategy never returning +an error and if it does, it should be treated as an application issue or a misconfiguration, rather than defaulting to an +untrustworthy IP. + +The sub-package `github.com/tigerwill90/fox/strategy` provides a set of best practices strategies that should cover most use cases. + +````go +f := fox.New( + fox.DefaultOptions(), + fox.WithClientIPStrategy( + // We are behind one or many trusted proxies that have all private-space IP addresses. + strategy.NewRightmostNonPrivate(fox.HeaderXForwardedFor), + ), +) + +f.MustHandle(http.MethodGet, "/foo/bar", func(c fox.Context) { + ipAddr, err := c.ClientIP() + if err != nil { + // If the current strategy is not able to derive the client IP, an error + // will be returned rather than falling back on an untrustworthy IP. It + // should be treated as an application issue or a misconfiguration. + panic(err) + } + fmt.Println(ipAddr.String()) +}) +```` + +It is also possible to create a chain with multiple strategies that attempt to derive the client IP, stopping when the first one succeeds. + +````go +f := fox.New( + fox.DefaultOptions(), + fox.WithClientIPStrategy(strategy.NewChain( + strategy.NewLeftmostNonPrivate(fox.HeaderXForwardedFor), + strategy.NewRemoteAddr(), + )), +) +```` + +Note that there is no "sane" default strategy, so calling `Context.ClientIP` without a strategy configured will return an `ErrNoClientIPStrategy`. + +See this [blog post](https://adam-p.ca/blog/2022/03/x-forwarded-for/) for general guidance on choosing a strategy that fit your needs. ## Benchmark The primary goal of Fox is to be a lightweight, high performance router which allow routes modification at runtime. The following benchmarks attempt to compare Fox to various popular alternatives, including both fully-featured web frameworks @@ -623,7 +674,7 @@ BenchmarkPat_GithubAll 424 2899405 ns/op 1843501 - [x] [Update route syntax](https://github.com/tigerwill90/fox/pull/10#issue-1643728309) @v0.6.0 - [x] [Route overlapping](https://github.com/tigerwill90/fox/pull/9#issue-1642887919) @v0.7.0 - [x] [Route overlapping (catch-all and params)](https://github.com/tigerwill90/fox/pull/24#issue-1784686061) @v0.10.0 -- [x] [Ignore trailing slash](https://github.com/tigerwill90/fox/pull/32) @v0.14.0 +- [x] [Ignore trailing slash](https://github.com/tigerwill90/fox/pull/32), [Builtin Logger Middleware](https://github.com/tigerwill90/fox/pull/33), [Client IP Derivation](https://github.com/tigerwill90/fox/pull/33) @v0.14.0 - [ ] Improving performance and polishing ## Contributions diff --git a/context.go b/context.go index e82ae30..39d0a7c 100644 --- a/context.go +++ b/context.go @@ -5,11 +5,14 @@ package fox import ( - netcontext "context" + "context" "fmt" "io" + "net" "net/http" "net/url" + "slices" + "strings" ) // ContextCloser extends Context for manually created instances, adding a Close method @@ -24,27 +27,30 @@ type ContextCloser interface { // duration of the HandlerFunc execution, as the underlying implementation may be reused a soon as the handler return. // (see Clone method). type Context interface { - // Ctx returns the context associated with the current request. - Ctx() netcontext.Context // Request returns the current *http.Request. Request() *http.Request // SetRequest sets the *http.Request. SetRequest(r *http.Request) - // Writer method returns a custom ResponseWriter implementation. The returned ResponseWriter object implements additional - // http.Flusher, http.Hijacker, io.ReaderFrom interfaces for HTTP/1.x requests and http.Flusher, http.Pusher interfaces - // for HTTP/2 requests. These additional interfaces provide extra functionality and are used by underlying HTTP protocols - // for specific tasks. - // - // In actual workload scenarios, the custom ResponseWriter satisfies interfaces for HTTP/1.x and HTTP/2 protocols, - // however, if testing with e.g. httptest.Recorder, only the http.Flusher is available to the underlying ResponseWriter. - // Therefore, while asserting interfaces like http.Hijacker will not fail, invoking Hijack method will panic if the - // underlying ResponseWriter does not implement this interface. - // - // To facilitate testing with e.g. httptest.Recorder, use the WrapTestContextFlusher helper function which only exposes the - // http.Flusher interface for the ResponseWriter. + // Writer method returns a custom ResponseWriter implementation. Writer() ResponseWriter // SetWriter sets the ResponseWriter. SetWriter(w ResponseWriter) + // RemoteIP parses the IP from Request.RemoteAddr, normalizes it, and returns an IP address. The returned *net.IPAddr + // may contain a zone identifier. RemoteIP never returns nil, even if parsing the IP fails. + RemoteIP() *net.IPAddr + // ClientIP returns the "real" client IP address based on the configured ClientIPStrategy. + // The strategy is set using the WithClientIPStrategy option. There is no sane default, so if no strategy is configured, + // the method returns the ErrNoClientIPStrategy. + // + // The strategy used must be chosen and tuned for your network configuration. This should result + // in the strategy never returning an error -- i.e., never failing to find a candidate for the "real" IP. + // Consequently, getting an error result should be treated as an application error, perhaps even + // worthy of panicking. + // + // The returned *net.IPAddr may contain a zone identifier. + // + // This api is EXPERIMENTAL and is likely to change in future release. + ClientIP() (*net.IPAddr, error) // Path returns the registered path for the handler. Path() string // Params returns a Params slice containing the matched @@ -80,14 +86,12 @@ type Context interface { Tree() *Tree // Fox returns the Router instance. Fox() *Router - // Reset resets the Context to its initial state, attaching the provided Router, http.ResponseWriter, and *http.Request. - // Caution: You should always pass the original http.ResponseWriter to this method, not the ResponseWriter itself, to - // avoid wrapping the ResponseWriter within itself. Use wisely! - Reset(w http.ResponseWriter, r *http.Request) + // Reset resets the Context to its initial state, attaching the provided ResponseWriter and http.Request. + Reset(w ResponseWriter, r *http.Request) } -// context holds request-related information and allows interaction with the ResponseWriter. -type context struct { +// cTx holds request-related information and allows interaction with the ResponseWriter. +type cTx struct { w ResponseWriter req *http.Request params *Params @@ -102,62 +106,96 @@ type context struct { rec recorder } -// Reset resets the Context to its initial state, attaching the provided Router, http.ResponseWriter, and *http.Request. +// Reset resets the Context to its initial state, attaching the provided ResponseWriter and http.Request. +func (c *cTx) Reset(w ResponseWriter, r *http.Request) { + c.req = r + c.w = w + c.path = "" + c.cachedQuery = nil + *c.params = (*c.params)[:0] +} + +// reset resets the Context to its initial state, attaching the provided http.ResponseWriter and http.Request. // Caution: You should always pass the original http.ResponseWriter to this method, not the ResponseWriter itself, to // avoid wrapping the ResponseWriter within itself. Use wisely! -func (c *context) Reset(w http.ResponseWriter, r *http.Request) { +func (c *cTx) reset(w http.ResponseWriter, r *http.Request) { c.rec.reset(w) c.req = r c.w = &c.rec - c.fox = c.tree.fox c.path = "" c.cachedQuery = nil *c.params = (*c.params)[:0] } -func (c *context) resetNil() { +func (c *cTx) resetNil() { c.req = nil c.w = nil - c.fox = nil c.path = "" c.cachedQuery = nil *c.params = (*c.params)[:0] } // Request returns the *http.Request. -func (c *context) Request() *http.Request { +func (c *cTx) Request() *http.Request { return c.req } // SetRequest sets the *http.Request. -func (c *context) SetRequest(r *http.Request) { +func (c *cTx) SetRequest(r *http.Request) { c.req = r } // Writer returns the ResponseWriter. -func (c *context) Writer() ResponseWriter { +func (c *cTx) Writer() ResponseWriter { return c.w } // SetWriter sets the ResponseWriter. -func (c *context) SetWriter(w ResponseWriter) { +func (c *cTx) SetWriter(w ResponseWriter) { c.w = w } -// Ctx returns the context associated with the current request. -func (c *context) Ctx() netcontext.Context { - return c.req.Context() +// RemoteIP parses the IP from Request.RemoteAddr, normalizes it, and returns a *net.IPAddr. +// It never returns nil, even if parsing the IP fails. +func (c *cTx) RemoteIP() *net.IPAddr { + ipStr, _, _ := net.SplitHostPort(c.req.RemoteAddr) + + ip, zone := splitHostZone(ipStr) + ipAddr := &net.IPAddr{ + IP: net.ParseIP(ip), + Zone: zone, + } + + if ipAddr.IP == nil { + return &net.IPAddr{} + } + + return ipAddr +} + +// ClientIP returns the "real" client IP address based on the configured ClientIPStrategy. +// The strategy is set using the WithClientIPStrategy option. If no strategy is configured, +// the method returns the error ErrNoClientIPStrategy. +// +// The strategy used must be chosen and tuned for your network configuration. This should result +// in the strategy never returning an error -- i.e., never failing to find a candidate for the "real" IP. +// Consequently, getting an error result should be treated as an application error, perhaps even +// worthy of panicking. +// This api is EXPERIMENTAL and is likely to change in future release. +func (c *cTx) ClientIP() (*net.IPAddr, error) { + ipStrategy := c.Fox().ipStrategy + return ipStrategy.ClientIP(c) } // Params returns a Params slice containing the matched // wildcard parameters. -func (c *context) Params() Params { +func (c *cTx) Params() Params { return *c.params } // Param retrieve a matching wildcard segment by name. // It's a helper for c.Params.Get(name). -func (c *context) Param(name string) string { +func (c *cTx) Param(name string) string { for _, p := range c.Params() { if p.Key == name { return p.Value @@ -169,33 +207,33 @@ func (c *context) Param(name string) string { // QueryParams parses RawQuery and returns the corresponding values. // It's a helper for c.Request.URL.Query(). Note that the parsed // result is cached. -func (c *context) QueryParams() url.Values { +func (c *cTx) QueryParams() url.Values { return c.getQueries() } // QueryParam returns the first value associated with the given key. // It's a helper for c.QueryParams().Get(name). -func (c *context) QueryParam(name string) string { +func (c *cTx) QueryParam(name string) string { return c.getQueries().Get(name) } // SetHeader sets the response header for the given key to the specified value. -func (c *context) SetHeader(key, value string) { +func (c *cTx) SetHeader(key, value string) { c.w.Header().Set(key, value) } // Header retrieves the value of the request header for the given key. -func (c *context) Header(key string) string { +func (c *cTx) Header(key string) string { return c.req.Header.Get(key) } // Path returns the registered path for the handler. -func (c *context) Path() string { +func (c *cTx) Path() string { return c.path } // String sends a formatted string with the specified status code. -func (c *context) String(code int, format string, values ...any) (err error) { +func (c *cTx) String(code int, format string, values ...any) (err error) { if c.w.Header().Get(HeaderContentType) == "" { c.w.Header().Set(HeaderContentType, MIMETextPlainCharsetUTF8) } @@ -205,7 +243,7 @@ func (c *context) String(code int, format string, values ...any) (err error) { } // Blob sends a byte slice with the specified status code and content type. -func (c *context) Blob(code int, contentType string, buf []byte) (err error) { +func (c *cTx) Blob(code int, contentType string, buf []byte) (err error) { c.w.Header().Set(HeaderContentType, contentType) c.w.WriteHeader(code) _, err = c.w.Write(buf) @@ -213,7 +251,7 @@ func (c *context) Blob(code int, contentType string, buf []byte) (err error) { } // Stream sends data from an io.Reader with the specified status code and content type. -func (c *context) Stream(code int, contentType string, r io.Reader) (err error) { +func (c *cTx) Stream(code int, contentType string, r io.Reader) (err error) { c.w.Header().Set(HeaderContentType, contentType) c.w.WriteHeader(code) _, err = io.Copy(c.w, r) @@ -221,7 +259,7 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error) } // Redirect sends an HTTP redirect response with the given status code and URL. -func (c *context) Redirect(code int, url string) error { +func (c *cTx) Redirect(code int, url string) error { if code < http.StatusMultipleChoices || code > http.StatusPermanentRedirect { return ErrInvalidRedirectCode } @@ -230,19 +268,19 @@ func (c *context) Redirect(code int, url string) error { } // Tree is a local copy of the Tree in use to serve the request. -func (c *context) Tree() *Tree { +func (c *cTx) Tree() *Tree { return c.tree } // Fox returns the Router instance. -func (c *context) Fox() *Router { +func (c *cTx) Fox() *Router { return c.fox } // Clone returns a copy of the Context that is safe to use after the HandlerFunc returns. // Any attempt to write on the ResponseWriter will panic with the error ErrDiscardedResponseWriter. -func (c *context) Clone() Context { - cp := context{ +func (c *cTx) Clone() Context { + cp := cTx{ rec: c.rec, req: c.req.Clone(c.req.Context()), fox: c.fox, @@ -263,16 +301,15 @@ func (c *context) Clone() Context { // copy process. The returned ContextCloser must be closed once no longer needed. // This functionality is particularly beneficial for middlewares that need to wrap // their custom ResponseWriter while preserving the state of the original Context. -func (c *context) CloneWith(w ResponseWriter, r *http.Request) ContextCloser { - cp := c.tree.ctx.Get().(*context) +func (c *cTx) CloneWith(w ResponseWriter, r *http.Request) ContextCloser { + cp := c.tree.ctx.Get().(*cTx) cp.req = r cp.w = w cp.path = c.path - cp.fox = c.fox cp.cachedQuery = nil if len(*c.params) > len(*cp.params) { // Grow cp.params to a least cap(c.params) - *cp.params = grow(*cp.params, len(*c.params)-len(*cp.params)) + *cp.params = slices.Grow(*cp.params, len(*c.params)-len(*cp.params)) } // cap(cp.params) >= cap(c.params) // now constraint into len(c.params) & cap(c.params) @@ -282,7 +319,7 @@ func (c *context) CloneWith(w ResponseWriter, r *http.Request) ContextCloser { } // Close releases the context to be reused later. -func (c *context) Close() { +func (c *cTx) Close() { // Put back the context, if not extended more than max params or max depth, allowing // the slice to naturally grow within the constraint. if cap(*c.params) > int(c.tree.maxParams.Load()) || cap(*c.skipNds) > int(c.tree.maxDepth.Load()) { @@ -291,7 +328,7 @@ func (c *context) Close() { c.tree.ctx.Put(c) } -func (c *context) getQueries() url.Values { +func (c *cTx) getQueries() url.Values { if c.cachedQuery == nil { if c.req != nil { c.cachedQuery = c.req.URL.Query() @@ -307,7 +344,7 @@ func (c *context) getQueries() url.Values { func WrapF(f http.HandlerFunc) HandlerFunc { return func(c Context) { if len(c.Params()) > 0 { - ctx := netcontext.WithValue(c.Ctx(), paramsKey, c.Params().Clone()) + ctx := context.WithValue(c.Request().Context(), paramsKey, c.Params().Clone()) f.ServeHTTP(c.Writer(), c.Request().WithContext(ctx)) return } @@ -321,7 +358,7 @@ func WrapF(f http.HandlerFunc) HandlerFunc { func WrapH(h http.Handler) HandlerFunc { return func(c Context) { if len(c.Params()) > 0 { - ctx := netcontext.WithValue(c.Ctx(), paramsKey, c.Params().Clone()) + ctx := context.WithValue(c.Request().Context(), paramsKey, c.Params().Clone()) h.ServeHTTP(c.Writer(), c.Request().WithContext(ctx)) return } @@ -329,3 +366,16 @@ func WrapH(h http.Handler) HandlerFunc { h.ServeHTTP(c.Writer(), c.Request()) } } + +func splitHostZone(s string) (host, zone string) { + // This is copied from an unexported function in the Go stdlib: + // https://github.com/golang/go/blob/5c9b6e8e63e012513b1cb1a4a08ff23dec4137a1/src/net/ipsock.go#L219-L228 + + // The IPv6 scoped addressing zone identifier starts after the last percent sign. + if i := strings.LastIndexByte(s, '%'); i > 0 { + host, zone = s[:i], s[i+1:] + } else { + host = s + } + return +} diff --git a/context_test.go b/context_test.go index d89d2ea..2eecd8a 100644 --- a/context_test.go +++ b/context_test.go @@ -6,7 +6,6 @@ package fox import ( "bytes" - netcontext "context" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "net/http" @@ -124,17 +123,6 @@ func TestContext_CloneWith(t *testing.T) { assert.Nil(t, cc.cachedQuery) } -func TestContext_Ctx(t *testing.T) { - t.Parallel() - req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) - ctx, cancel := netcontext.WithCancel(netcontext.Background()) - cancel() - req = req.WithContext(ctx) - _, c := NewTestContext(httptest.NewRecorder(), req) - <-c.Ctx().Done() - require.ErrorIs(t, c.Request().Context().Err(), netcontext.Canceled) -} - func TestContext_Redirect(t *testing.T) { t.Parallel() w := httptest.NewRecorder() @@ -160,6 +148,29 @@ func TestContext_Blob(t *testing.T) { assert.True(t, c.Writer().Written()) } +func TestContext_RemoteIP(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + r.RemoteAddr = "192.0.2.1:8080" + _, c := NewTestContext(w, r) + assert.Equal(t, "192.0.2.1", c.RemoteIP().String()) + + r.RemoteAddr = "[::1]:80" + _, c = NewTestContext(w, r) + assert.Equal(t, "::1", c.RemoteIP().String()) +} + +func TestContext_ClientIP(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + r.RemoteAddr = "192.0.2.1:8080" + c := NewTestContextOnly(New(), w, r) + _, err := c.ClientIP() + assert.ErrorIs(t, err, ErrNoClientIPStrategy) +} + func TestContext_Stream(t *testing.T) { t.Parallel() w := httptest.NewRecorder() @@ -303,7 +314,7 @@ func TestWrapF(t *testing.T) { params := make(Params, 0) if tc.params != nil { params = tc.params.Clone() - c.(*context).params = ¶ms + c.(*cTx).params = ¶ms } WrapF(tc.handler(params))(c) @@ -363,7 +374,7 @@ func TestWrapH(t *testing.T) { params := make(Params, 0) if tc.params != nil { params = tc.params.Clone() - c.(*context).params = ¶ms + c.(*cTx).params = ¶ms } WrapH(tc.handler(params))(c) diff --git a/error.go b/error.go index 759bc24..75b4771 100644 --- a/error.go +++ b/error.go @@ -17,6 +17,7 @@ var ( ErrInvalidRoute = errors.New("invalid route") ErrDiscardedResponseWriter = errors.New("discarded response writer") ErrInvalidRedirectCode = errors.New("invalid redirect code") + ErrNoClientIPStrategy = errors.New("no client ip strategy") ) // RouteConflictError is a custom error type used to represent conflicts when diff --git a/fox.go b/fox.go index 5b34293..9e05d28 100644 --- a/fox.go +++ b/fox.go @@ -7,6 +7,7 @@ package fox import ( "errors" "fmt" + "net" "net/http" "path" "regexp" @@ -44,6 +45,28 @@ type HandlerFunc func(c Context) // be thread-safe, as they will be called concurrently. type MiddlewareFunc func(next HandlerFunc) HandlerFunc +// ClientIPStrategy define a strategy for obtaining the "real" client IP from HTTP requests. The strategy used must be +// chosen and tuned for your network configuration. This should result in the strategy never returning an error +// i.e., never failing to find a candidate for the "real" IP. Consequently, getting an error result should be treated as +// an application error, perhaps even worthy of panicking. Builtin best practices strategies can be found in the +// github.com/tigerwill90/fox/strategy package. See https://adam-p.ca/blog/2022/03/x-forwarded-for/ for more details on +// how to choose the right strategy for your use-case and network. +type ClientIPStrategy interface { + // ClientIP returns the "real" client IP according to the implemented strategy. It returns an error if no valid IP + // address can be derived using the strategy. This is typically considered a misconfiguration error, unless the strategy + // involves obtaining an untrustworthy or optional value. + ClientIP(c Context) (*net.IPAddr, error) +} + +// The ClientIPStrategyFunc type is an adapter to allow the use of ordinary functions as ClientIPStrategy. If f is a function +// with the appropriate signature, ClientIPStrategyFunc(f) is a ClientIPStrategyFunc that calls f. +type ClientIPStrategyFunc func(c Context) (*net.IPAddr, error) + +// ClientIP calls f(c). +func (f ClientIPStrategyFunc) ClientIP(c Context) (*net.IPAddr, error) { + return f(c) +} + // Router is a lightweight high performance HTTP request router that support mutation on its routing tree // while handling request concurrently. type Router struct { @@ -52,6 +75,7 @@ type Router struct { tsrRedirect HandlerFunc autoOptions HandlerFunc tree atomic.Pointer[Tree] + ipStrategy ClientIPStrategy mws []middleware handleMethodNotAllowed bool handleOptions bool @@ -73,6 +97,7 @@ func New(opts ...Option) *Router { r.noRoute = DefaultNotFoundHandler() r.noMethod = DefaultMethodNotAllowedHandler() r.autoOptions = DefaultOptionsHandler() + r.ipStrategy = noClientIPStrategy{} for _, opt := range opts { opt.apply(r) @@ -115,6 +140,13 @@ func (fox *Router) IgnoreTrailingSlashEnabled() bool { return fox.ignoreTrailingSlash } +// ClientIPStrategyEnabled returns whether the router is configured with a ClientIPStrategy. +// This api is EXPERIMENTAL and is likely to change in future release. +func (fox *Router) ClientIPStrategyEnabled() bool { + _, ok := fox.ipStrategy.(noClientIPStrategy) + return !ok +} + // NewTree returns a fresh routing Tree that inherits all registered router options. It's safe to create multiple Tree // concurrently. However, a Tree itself is not thread-safe and all its APIs that perform write operations should be run // serially. Note that a Tree give direct access to the underlying sync.Mutex. @@ -202,7 +234,7 @@ func (fox *Router) Remove(method, path string) error { // and a boolean indicating if a trailing slash action (e.g. redirect) is recommended (tsr). The ContextCloser should always // be closed if non-nil. // This API is EXPERIMENTAL and is likely to change in future release. -func (fox *Router) Lookup(w http.ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) { +func (fox *Router) Lookup(w ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) { tree := fox.tree.Load() return tree.Lookup(w, r) } @@ -300,8 +332,8 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { } tree := fox.tree.Load() - c := tree.ctx.Get().(*context) - c.Reset(w, r) + c := tree.ctx.Get().(*cTx) + c.reset(w, r) nds := *tree.nodes.Load() index := findRootNode(r.Method, nds) @@ -372,7 +404,7 @@ NoMethodFallback: if sb.Len() > 0 { sb.WriteString(", ") sb.WriteString(http.MethodOptions) - w.Header().Set("Allow", sb.String()) + w.Header().Set(HeaderAllow, sb.String()) fox.autoOptions(c) c.Close() return @@ -390,7 +422,7 @@ NoMethodFallback: } } if sb.Len() > 0 { - w.Header().Set("Allow", sb.String()) + w.Header().Set(HeaderAllow, sb.String()) fox.noMethod(c) c.Close() return @@ -452,13 +484,6 @@ type searchResult struct { depth uint32 } -func min(a, b int) int { - if a < b { - return a - } - return b -} - func commonPrefix(k1, k2 string) string { minLength := min(len(k1), len(k2)) for i := 0; i < minLength; i++ { @@ -638,20 +663,6 @@ func localRedirect(w http.ResponseWriter, r *http.Request, path string, code int } } -// grow increases the slice's capacity, if necessary, to guarantee space for -// another n elements. After Grow(n), at least n elements can be appended -// to the slice without another allocation. If n is negative or too large to -// allocate the memory, Grow panics. -func grow[S ~[]E, E any](s S, n int) S { - if n < 0 { - panic("cannot be negative") - } - if n -= cap(s) - len(s); n > 0 { - s = append(s[:cap(s)], make([]E, n)...)[:len(s)] - } - return s -} - func hexEscapeNonASCII(s string) string { newLen := 0 for i := 0; i < len(s); i++ { @@ -695,3 +706,9 @@ var htmlReplacer = strings.NewReplacer( func htmlEscape(s string) string { return htmlReplacer.Replace(s) } + +type noClientIPStrategy struct{} + +func (s noClientIPStrategy) ClientIP(_ Context) (*net.IPAddr, error) { + return nil, ErrNoClientIPStrategy +} diff --git a/fox_test.go b/fox_test.go index f9f0a42..f591584 100644 --- a/fox_test.go +++ b/fox_test.go @@ -8,6 +8,7 @@ import ( "fmt" "log" "math/rand" + "net" "net/http" "net/http/httptest" "reflect" @@ -1744,6 +1745,13 @@ func TestRouterWithIgnoreTrailingSlash(t *testing.T) { } } +func TestRouterWithClientIPStrategy(t *testing.T) { + f := New(WithClientIPStrategy(ClientIPStrategyFunc(func(c Context) (*net.IPAddr, error) { + return c.RemoteIP(), nil + }))) + require.True(t, f.ClientIPStrategyEnabled()) +} + func TestRedirectTrailingSlash(t *testing.T) { cases := []struct { @@ -2338,7 +2346,7 @@ func TestDefaultOptions(t *testing.T) { } }) r := New(WithMiddleware(m), DefaultOptions()) - assert.Equal(t, reflect.ValueOf(m).Pointer(), reflect.ValueOf(r.mws[1].m).Pointer()) + assert.Equal(t, reflect.ValueOf(m).Pointer(), reflect.ValueOf(r.mws[2].m).Pointer()) assert.True(t, r.handleOptions) } @@ -2389,7 +2397,7 @@ func TestRouter_Lookup(t *testing.T) { for _, rte := range githubAPI { req := httptest.NewRequest(rte.method, rte.path, nil) - handler, cc, _ := f.Lookup(mockResponseWriter{}, req) + handler, cc, _ := f.Lookup(newResponseWriter(mockResponseWriter{}), req) require.NotNil(t, cc) assert.NotNil(t, handler) @@ -2410,13 +2418,13 @@ func TestRouter_Lookup(t *testing.T) { // No method match req := httptest.NewRequest("ANY", "/bar", nil) - handler, cc, _ := f.Lookup(mockResponseWriter{}, req) + handler, cc, _ := f.Lookup(newResponseWriter(mockResponseWriter{}), req) assert.Nil(t, handler) assert.Nil(t, cc) // No path match req = httptest.NewRequest(http.MethodGet, "/bar", nil) - handler, cc, _ = f.Lookup(mockResponseWriter{}, req) + handler, cc, _ = f.Lookup(newResponseWriter(mockResponseWriter{}), req) assert.Nil(t, handler) assert.Nil(t, cc) } @@ -2852,10 +2860,10 @@ func atomicSync() (start func(), wait func()) { } // This example demonstrates how to create a simple router using the default options, -// which include the Recovery middleware. A basic route is defined, along with a +// which include the Recovery and Logger middleware. A basic route is defined, along with a // custom middleware to log the request metrics. func ExampleNew() { - // Create a new router with default options, which include the Recovery middleware + // Create a new router with default options, which include the Recovery and Logger middleware r := New(DefaultOptions()) // Define a custom middleware to measure the time taken for request processing and diff --git a/go.mod b/go.mod index d873c60..1067303 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/tigerwill90/fox -go 1.19 +go 1.21 require ( github.com/google/gofuzz v1.2.0 diff --git a/helpers.go b/helpers.go index 47a47e7..df46a14 100644 --- a/helpers.go +++ b/helpers.go @@ -21,25 +21,32 @@ func NewTestContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) Con return newTextContextOnly(fox, w, r) } -func newTextContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) *context { +func newTextContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) *cTx { c := fox.Tree().allocateContext() c.resetNil() - c.fox = fox c.req = r c.rec.reset(w) c.w = &c.rec return c } -func newTestContextTree(t *Tree) *context { +func newTestContextTree(t *Tree) *cTx { c := t.allocateContext() c.resetNil() return c } -func unwrapContext(t *testing.T, c Context) *context { +func newResponseWriter(w http.ResponseWriter) ResponseWriter { + return &recorder{ + ResponseWriter: w, + size: notWritten, + status: http.StatusOK, + } +} + +func unwrapContext(t *testing.T, c Context) *cTx { t.Helper() - cc, ok := c.(*context) + cc, ok := c.(*cTx) if !ok { t.Fatal("unable to unwrap context") } diff --git a/http_consts.go b/http_consts.go index ed2ef20..6bb0ccc 100644 --- a/http_consts.go +++ b/http_consts.go @@ -36,6 +36,7 @@ const ( // See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 HeaderAllow = "Allow" HeaderAuthorization = "Authorization" + HeaderProxyAuthorization = "Proxy-Authorization" HeaderContentDisposition = "Content-Disposition" HeaderContentEncoding = "Content-Encoding" HeaderContentLength = "Content-Length" @@ -50,12 +51,13 @@ const ( HeaderVary = "Vary" HeaderWWWAuthenticate = "WWW-Authenticate" HeaderXForwardedFor = "X-Forwarded-For" + HeaderForwarded = "Forwarded" HeaderXForwardedProto = "X-Forwarded-Proto" HeaderXForwardedProtocol = "X-Forwarded-Protocol" HeaderXForwardedSsl = "X-Forwarded-Ssl" + HeaderXRealIP = "X-Real-Ip" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" - HeaderXRealIP = "X-Real-Ip" HeaderXRequestID = "X-Request-Id" HeaderXCorrelationID = "X-Correlation-Id" HeaderXRequestedWith = "X-Requested-With" @@ -85,4 +87,16 @@ const ( // nolint:gosec HeaderXCSRFToken = "X-CSRF-Token" HeaderReferrerPolicy = "Referrer-Policy" + + // Platform Header for single IP + HeaderCFConnectionIP = "CF-Connecting-IP" + HeaderTrueClientIP = "True-Client-IP" + HeaderFastClientIP = "Fastly-Client-IP" + HeaderXAzureClientIP = "X-Azure-ClientIP" + HeaderXAzureSocketIP = "X-Azure-SocketIP" + HeaderXAppengineRemoteAddr = "X-Appengine-Remote-Addr" + HeaderFlyClientIP = "Fly-Client-IP" ) + +// nolint:gosec +var blacklistedHeader = []string{HeaderAuthorization, HeaderProxyAuthorization, "X-Vault-Token", HeaderCookie, HeaderSetCookie, HeaderXCSRFToken} diff --git a/internal/ansi/ansi.go b/internal/ansi/ansi.go new file mode 100644 index 0000000..ce6f744 --- /dev/null +++ b/internal/ansi/ansi.go @@ -0,0 +1,25 @@ +// Copyright 2023 GreyXor. All rights reserved. +// Mount of this source code is governed by a MIT license that can be found +// at https://gitlab.com/greyxor/slogor/-/blob/main/LICENSE?ref_type=heads. + +package ansi + +// ANSI codes for text styling and formatting. +const ( + Reset = "\033[0m" + Bold = "\033[1m" + Faint = "\033[2m" + NormalIntensity = "\033[22m" + // Foreground colors + FgRed = "\033[31m" + FgGreen = "\033[32m" + FgYellow = "\033[33m" + FgMagenta = "\033[35m" + FgCyan = "\033[36m" + + // Background colors + BgRed = "\033[41m" + BgYellow = "\033[43m" + BgBlue = "\033[44m" + BgMagenta = "\033[45m" +) diff --git a/internal/ansi/ansi_windows.go b/internal/ansi/ansi_windows.go new file mode 100644 index 0000000..7ffea6b --- /dev/null +++ b/internal/ansi/ansi_windows.go @@ -0,0 +1,35 @@ +// Copyright 2023 GreyXor. All rights reserved. +// Mount of this source code is governed by a MIT license that can be found +// at https://gitlab.com/greyxor/slogor/-/blob/main/LICENSE?ref_type=heads. + +package ansi + +import ( + "golang.org/x/sys/windows" + "os" +) + +// init initializes the Windows console mode to add colors support to it. +func init() { + // Get the file descriptor for the standard output (stdout). + stdout := windows.Handle(os.Stdout.Fd()) + + // Declare a variable to store the original console mode. + var originalMode uint32 + + // Retrieve the current console mode for the standard output. + // The retrieved mode will be stored in the originalMode variable. + windows.GetConsoleMode(stdout, &originalMode) + + // Calculate the new console mode by combining the original mode with various + // flags to enhance the terminal's capabilities for better logging. + // Here, ENABLE_PROCESSED_OUTPUT ensures that the output is processed before being written to the console. + // ENABLE_WRAP_AT_EOL_OUTPUT enables automatic wrapping at the end of the line. + // ENABLE_VIRTUAL_TERMINAL_PROCESSING enables processing of virtual terminal sequences for colors and formatting. + // More information about console mode flags can be found at: https://learn.microsoft.com/en-us/windows/console/setconsolemode + newConsoleMode := originalMode | windows.ENABLE_PROCESSED_OUTPUT | + windows.ENABLE_WRAP_AT_EOL_OUTPUT | windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING + + // Set the new console mode for the standard output. + windows.SetConsoleMode(stdout, newConsoleMode) +} diff --git a/internal/slogpretty/handler.go b/internal/slogpretty/handler.go new file mode 100644 index 0000000..f3af513 --- /dev/null +++ b/internal/slogpretty/handler.go @@ -0,0 +1,263 @@ +// The code in this package is derivative of https://gitlab.com/greyxor/slogor. +// Mount of this source code is governed by a MIT license that can be found +// at https://gitlab.com/greyxor/slogor/-/blob/main/LICENSE?ref_type=heads. + +package slogpretty + +import ( + "context" + "fmt" + "github.com/tigerwill90/fox/internal/ansi" + "io" + "log/slog" + "os" + "sync" + "time" +) + +const ( + maxBufferSize = 16 << 10 // 16384 + initialBufferSize = 1024 +) + +var _ slog.Handler = (*Handler)(nil) + +var logBufPool = sync.Pool{ + New: func() any { + b := make([]byte, 0, initialBufferSize) + return &b + }, +} + +var ( + DefaultHandler = &Handler{ + We: &lockedWriter{w: os.Stderr}, + Wo: &lockedWriter{w: os.Stdout}, + Lvl: slog.LevelDebug, + Goa: make([]GroupOrAttrs, 0), + } + timeFormat = fmt.Sprintf("%s %s", time.DateOnly, time.TimeOnly) +) + +func freeBuf(b *[]byte) { + if cap(*b) <= maxBufferSize { + *b = (*b)[:0] + logBufPool.Put(b) + } +} + +type GroupOrAttrs struct { + attr slog.Attr + group string +} + +type Handler struct { + We io.Writer + Wo io.Writer + Lvl slog.Leveler + Goa []GroupOrAttrs +} + +func (h *Handler) Enabled(_ context.Context, level slog.Level) bool { + return level >= h.Lvl.Level() +} + +func (h *Handler) Handle(_ context.Context, record slog.Record) error { + bufp := logBufPool.Get().(*[]byte) + buf := *bufp + + defer func() { + *bufp = buf + freeBuf(bufp) + }() + + buf = append(buf, "[FOX] "...) + + if !record.Time.IsZero() { + buf = append(buf, ansi.Faint...) + buf = append(buf, record.Time.Format(timeFormat)...) + buf = append(buf, ansi.NormalIntensity...) + buf = append(buf, " "...) + } + + // Write level with appropriate formatting and color. + // Also append right padding depending on the log level. + buf = append(buf, "| "...) + switch record.Level { + case slog.LevelInfo: + buf = append(buf, ansi.FgGreen...) + buf = append(buf, record.Level.String()...) + buf = append(buf, " "...) + case slog.LevelError: + buf = append(buf, ansi.FgRed...) + buf = append(buf, record.Level.String()...) + case slog.LevelWarn: + buf = append(buf, ansi.FgYellow...) + buf = append(buf, record.Level.String()...) + buf = append(buf, " "...) + case slog.LevelDebug: + buf = append(buf, ansi.FgMagenta...) + buf = append(buf, record.Level.String()...) + } + + buf = append(buf, ansi.Reset...) + buf = append(buf, " | "...) + // Write the log message. + if record.Message == "unknown" { + // special case if the ip cannot be found using the ClientIPStrategy. + buf = append(buf, ansi.FgRed...) + buf = append(buf, record.Message...) + buf = append(buf, ansi.Reset...) + } else { + buf = append(buf, record.Message...) + } + buf = append(buf, " | "...) + + lastGroup := "" + for _, goa := range h.Goa { + switch { + case goa.group != "": + lastGroup += goa.group + "." + default: + attr := goa.attr + if lastGroup != "" { + attr.Key = lastGroup + attr.Key + } + + buf = appendAttr(record.Level, buf, attr) + } + } + + // If there are additional attributes, append them to the log record. + if record.NumAttrs() > 0 { + record.Attrs(func(attr slog.Attr) bool { + if lastGroup != "" { + attr.Key = lastGroup + attr.Key + } + buf = appendAttr(record.Level, buf, attr) + + return true + }) + } + + // Replace the latest space by an EOL. + buf[len(buf)-1] = '\n' + + if record.Level >= slog.LevelError { + if _, err := h.We.Write(buf); err != nil { + return fmt.Errorf("failed to write buffer: %w", err) + } + } else { + if _, err := h.Wo.Write(buf); err != nil { + return fmt.Errorf("failed to write buffer: %w", err) + } + } + + return nil +} + +func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler { + newAttrs := make([]GroupOrAttrs, len(attrs)) + for i, attr := range attrs { + newAttrs[i] = GroupOrAttrs{attr: attr} + } + + return &Handler{ + We: h.We, + Wo: h.Wo, + Lvl: h.Lvl, + Goa: append(h.Goa, newAttrs...), + } +} + +func (h *Handler) WithGroup(name string) slog.Handler { + return &Handler{ + We: h.We, + Wo: h.Wo, + Lvl: h.Lvl, + Goa: append(h.Goa, GroupOrAttrs{group: name}), + } +} + +// appendAttr appends the attribute to the buffer. +func appendAttr(level slog.Level, buf []byte, attr slog.Attr) []byte { + // Resolve the Attr's value before doing anything else. + attr.Value = attr.Value.Resolve() + + // Ignore empty Attrs. + if attr.Equal(slog.Attr{}) { + return buf + } + + buf = append(buf, ansi.Faint...) + buf = append(buf, ansi.Bold...) + + buf = append(buf, attr.Key...) + buf = append(buf, "="...) + buf = append(buf, ansi.NormalIntensity...) + + var addWhitespace bool + if _, isErr := attr.Value.Any().(error); isErr { + buf = append(buf, ansi.FgRed...) + } else { + switch attr.Key { + case "method": + buf = append(buf, ansi.BgBlue...) + addWhitespace = true + case "status": + buf = append(buf, levelColor(level)...) + addWhitespace = true + case "location": + buf = append(buf, ansi.FgYellow...) + case "latency": + buf = append(buf, latencyColor(attr.Value.Duration())...) + default: + buf = append(buf, ansi.FgCyan...) + } + } + + if addWhitespace { + buf = append(buf, " "+attr.Value.String()+" "...) + } else { + buf = append(buf, attr.Value.String()...) + } + buf = append(buf, ansi.Reset...) + buf = append(buf, " "...) + + return buf +} + +type lockedWriter struct { + w io.Writer + sync.Mutex +} + +func (w *lockedWriter) Write(p []byte) (n int, err error) { + w.Lock() + n, err = w.w.Write(p) + w.Unlock() + return +} + +func levelColor(level slog.Level) string { + switch level { + case slog.LevelInfo: + return ansi.BgBlue + case slog.LevelWarn: + return ansi.BgYellow + case slog.LevelError: + return ansi.BgRed + default: + return ansi.BgMagenta + } +} + +func latencyColor(d time.Duration) string { + if d < 100*time.Millisecond { + return ansi.FgGreen + } + if d < 500*time.Millisecond { + return ansi.FgYellow + } + return ansi.FgRed +} diff --git a/internal/slogpretty/handler_test.go b/internal/slogpretty/handler_test.go new file mode 100644 index 0000000..0dc5eeb --- /dev/null +++ b/internal/slogpretty/handler_test.go @@ -0,0 +1,43 @@ +package slogpretty + +import ( + "bytes" + "context" + "github.com/stretchr/testify/require" + "log/slog" + "net/http" + "testing" + "time" +) + +func TestLogHandler_Handle(t *testing.T) { + bufWo := bytes.NewBuffer(nil) + bufWe := bytes.NewBuffer(nil) + + h := &Handler{ + We: &lockedWriter{w: bufWe}, + Wo: &lockedWriter{w: bufWo}, + Lvl: slog.LevelDebug, + Goa: make([]GroupOrAttrs, 0), + } + + record := slog.Record{ + Time: time.Date(2024, 06, 26, 0, 0, 0, 0, time.UTC), + Message: "::1", + Level: slog.LevelDebug, + } + record.Add("method", http.MethodGet) + record.Add("status", http.StatusOK) + record.Add("latency", 2*time.Second) + record.Add("location", "../foo") + record.Add(slog.Group("foo", slog.String("bar", "bar"))) + require.NoError(t, h.Handle(context.Background(), record)) + record.Level = slog.LevelInfo + require.NoError(t, h.Handle(context.Background(), record)) + record.Level = slog.LevelWarn + require.NoError(t, h.Handle(context.Background(), record)) + record.Level = slog.LevelError + require.NoError(t, h.Handle(context.Background(), record)) + record.Message = "unknown" + require.NoError(t, h.Handle(context.Background(), record)) +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..96afe8f --- /dev/null +++ b/logger.go @@ -0,0 +1,107 @@ +// Copyright 2022 Sylvain Müller. All rights reserved. +// Mount of this source code is governed by a Apache-2.0 license that can be found +// at https://github.com/tigerwill90/fox/blob/master/LICENSE.txt. + +package fox + +import ( + "errors" + "github.com/tigerwill90/fox/internal/slogpretty" + "log/slog" + "time" +) + +// LoggerWithHandler returns middleware that logs request information using the provided slog.Handler. +// It logs details such as the remote IP, HTTP method, request path, status code and latency. +func LoggerWithHandler(handler slog.Handler) MiddlewareFunc { + log := slog.New(handler) + return func(next HandlerFunc) HandlerFunc { + return func(c Context) { + start := time.Now() + next(c) + latency := time.Since(start) + + req := c.Request() + lvl := level(c.Writer().Status()) + var location string + if lvl.Level() == slog.LevelDebug { + location = c.Writer().Header().Get(HeaderLocation) + } + + var ipStr string + ip, err := c.ClientIP() + if err == nil { + ipStr = ip.String() + } else if errors.Is(err, ErrNoClientIPStrategy) { + ipStr = c.RemoteIP().String() + } else { + ipStr = "unknown" + } + + if location == "" { + log.LogAttrs( + req.Context(), + lvl, + ipStr, + slog.Int("status", c.Writer().Status()), + slog.String("method", req.Method), + slog.String("path", c.Request().URL.String()), + slog.Duration("latency", roundLatency(latency)), + ) + } else { + location = c.Writer().Header().Get(HeaderLocation) + log.LogAttrs( + req.Context(), + lvl, + ipStr, + slog.Int("status", c.Writer().Status()), + slog.String("method", req.Method), + slog.String("path", c.Request().URL.String()), + slog.Duration("latency", roundLatency(latency)), + slog.String("location", location), + ) + } + + } + } +} + +// Logger returns middleware that logs request information to os.Stdout and os.Stderr. +// It logs details such as the remote IP, HTTP method, request path, status code and latency. +func Logger() MiddlewareFunc { + return LoggerWithHandler(slogpretty.DefaultHandler) +} + +func level(status int) slog.Level { + switch { + case status >= 200 && status < 300: + return slog.LevelInfo + case status >= 300 && status < 400: + return slog.LevelDebug + case status >= 400 && status < 500: + return slog.LevelWarn + case status >= 500: + return slog.LevelError + default: + return slog.LevelInfo + } +} + +func roundLatency(d time.Duration) time.Duration { + switch { + case d < 1*time.Microsecond: + return d.Round(100 * time.Nanosecond) + case d < 1*time.Millisecond: + return d.Round(10 * time.Microsecond) + case d < 10*time.Millisecond: + return d.Round(100 * time.Microsecond) + case d < 100*time.Millisecond: + return d.Round(1 * time.Millisecond) + case d < 1*time.Second: + return d.Round(10 * time.Millisecond) + case d < 10*time.Second: + return d.Round(100 * time.Millisecond) + default: + return d.Round(1 * time.Second) + } +} diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 0000000..3c26451 --- /dev/null +++ b/logger_test.go @@ -0,0 +1,73 @@ +package fox + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "log/slog" + "net/http" + "net/http/httptest" + "testing" +) + +func TestLoggerWithHandler(t *testing.T) { + buf := bytes.NewBuffer(nil) + f := New( + WithRedirectTrailingSlash(true), + WithMiddleware(LoggerWithHandler(slog.NewTextHandler(buf, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == "time" { + return slog.String("time", "time") + } + if a.Key == "latency" { + return slog.String("latency", "latency") + } + return a + }, + }))), + ) + require.NoError(t, f.Handle(http.MethodGet, "/success", func(c Context) { + c.Writer().WriteHeader(http.StatusOK) + })) + require.NoError(t, f.Handle(http.MethodGet, "/failure", func(c Context) { + c.Writer().WriteHeader(http.StatusInternalServerError) + })) + + cases := []struct { + name string + req *http.Request + want string + }{ + { + name: "should log info level", + req: httptest.NewRequest(http.MethodGet, "/success", nil), + want: "time=time level=INFO msg=192.0.2.1 status=200 method=GET path=/success latency=latency\n", + }, + { + name: "should log error level", + req: httptest.NewRequest(http.MethodGet, "/failure", nil), + want: "time=time level=ERROR msg=192.0.2.1 status=500 method=GET path=/failure latency=latency\n", + }, + { + name: "should log warn level", + req: httptest.NewRequest(http.MethodGet, "/foobar", nil), + want: "time=time level=WARN msg=192.0.2.1 status=404 method=GET path=/foobar latency=latency\n", + }, + { + name: "should log debug level", + req: httptest.NewRequest(http.MethodGet, "/success/", nil), + want: "time=time level=DEBUG msg=192.0.2.1 status=301 method=GET path=/success/ latency=latency location=../success\n", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + buf.Reset() + w := httptest.NewRecorder() + f.ServeHTTP(w, tc.req) + assert.Equal(t, tc.want, buf.String()) + }) + } + +} diff --git a/options.go b/options.go index 071a6c3..b2d17ab 100644 --- a/options.go +++ b/options.go @@ -132,12 +132,29 @@ func WithIgnoreTrailingSlash(enable bool) Option { }) } -// DefaultOptions configure the router to use the Recovery middleware for the RouteHandlers scope and enable -// automatic OPTIONS response. Note that DefaultOptions push the Recovery middleware to the first position of the -// middleware chains. +// WithClientIPStrategy sets the strategy for obtaining the "real" client IP address from HTTP requests. +// This strategy is used by the Context.ClientIP method. The strategy must be chosen and tuned for your network +// configuration to ensure it never returns an error -- i.e., never fails to find a candidate for the "real" IP. +// Consequently, getting an error result should be treated as an application error, perhaps even worthy of panicking. +// There is no sane default, so if no strategy is configured, Context.ClientIP returns ErrNoClientIPStrategy. +// This API is EXPERIMENTAL and is likely to change in future releases. +func WithClientIPStrategy(strategy ClientIPStrategy) Option { + return optionFunc(func(r *Router) { + if strategy != nil { + r.ipStrategy = strategy + } + }) +} + +// DefaultOptions configure the router to use the Recovery middleware for the RouteHandlers scope, the Logger middleware +// for AllHandlers scope and enable automatic OPTIONS response. Note that DefaultOptions push the Recovery and Logger middleware +// respectively to the first and second position of the middleware chains. func DefaultOptions() Option { return optionFunc(func(r *Router) { - r.mws = append([]middleware{{Recovery(DefaultHandleRecovery), RouteHandlers}}, r.mws...) + r.mws = append([]middleware{ + {Recovery(), RouteHandlers}, + {Logger(), AllHandlers}, + }, r.mws...) r.handleOptions = true }) } diff --git a/params.go b/params.go index 17df796..7b80f84 100644 --- a/params.go +++ b/params.go @@ -4,7 +4,7 @@ package fox -import netcontext "context" +import "context" type ctxKey struct{} @@ -46,9 +46,9 @@ func (p Params) Clone() Params { return cloned } -// ParamsFromContext is a helper to retrieve params from context when a http.Handler +// ParamsFromContext is a helper to retrieve params from context.Context when a http.Handler // is registered using WrapF or WrapH. -func ParamsFromContext(ctx netcontext.Context) Params { +func ParamsFromContext(ctx context.Context) Params { p, _ := ctx.Value(paramsKey).(Params) return p } diff --git a/params_test.go b/params_test.go index 1f49628..135046c 100644 --- a/params_test.go +++ b/params_test.go @@ -5,7 +5,7 @@ package fox import ( - netcontext "context" + "context" "testing" "github.com/stretchr/testify/assert" @@ -67,17 +67,17 @@ func TestParamsFromContext(t *testing.T) { cases := []struct { name string - ctx netcontext.Context + ctx context.Context expectedParams Params }{ { name: "empty context", - ctx: netcontext.Background(), + ctx: context.Background(), expectedParams: nil, }, { name: "context with params", - ctx: func() netcontext.Context { + ctx: func() context.Context { params := make(Params, 0, 2) params = append(params, Param{ @@ -85,7 +85,7 @@ func TestParamsFromContext(t *testing.T) { Value: "bar", }, ) - return netcontext.WithValue(netcontext.Background(), paramsKey, params) + return context.WithValue(context.Background(), paramsKey, params) }(), expectedParams: func() Params { params := make(Params, 0, 2) diff --git a/path_test.go b/path_test.go index 748ceb9..0680495 100644 --- a/path_test.go +++ b/path_test.go @@ -1,7 +1,7 @@ // Copyright 2013 Julien Schmidt. All rights reserved. // Based on the path package, Copyright 2009 The Go Authors. -// Use of this source code is governed by a BSD-style license that can be found -// in the LICENSE file. +// Mount of this source code is governed by a BSD-style license that can be found +// at https://github.com/julienschmidt/httprouter/blob/master/LICENSE. package fox diff --git a/recovery.go b/recovery.go index 2c557f4..20b7e89 100644 --- a/recovery.go +++ b/recovery.go @@ -6,50 +6,96 @@ package fox import ( "errors" - "log" + "fmt" + "github.com/tigerwill90/fox/internal/slogpretty" + "log/slog" "net" "net/http" + "net/http/httputil" "os" - "runtime/debug" + "runtime" + "slices" "strings" ) -var stdErr = log.New(os.Stderr, "", log.LstdFlags) - // RecoveryFunc is a function type that defines how to handle panics that occur during the // handling of an HTTP request. type RecoveryFunc func(c Context, err any) -// Recovery is a middleware that captures panics and recovers from them. It takes a custom handle function -// that will be called with the Context and the value recovered from the panic. -// Note that the middleware check if the panic is caused by http.ErrAbortHandler and re-panic if true -// allowing the http server to handle it as an abort. -func Recovery(handle RecoveryFunc) MiddlewareFunc { +// CustomRecoveryWithLogHandler returns middleware for a given slog.Handler that recovers from any panics, +// logs the error, request details, and stack trace, and then calls the provided handle function to handle the recovery. +func CustomRecoveryWithLogHandler(handler slog.Handler, handle RecoveryFunc) MiddlewareFunc { + slogger := slog.New(handler) return func(next HandlerFunc) HandlerFunc { return func(c Context) { - defer recovery(c, handle) + defer recovery(slogger, c, handle) next(c) } } } +// CustomRecovery returns middleware that recovers from any panics, logs the error, request details, and stack trace, +// and then calls the provided handle function to handle the recovery. +func CustomRecovery(handle RecoveryFunc) MiddlewareFunc { + return CustomRecoveryWithLogHandler(slogpretty.DefaultHandler, handle) +} + +// Recovery returns middleware that recovers from any panics, logs the error, request details, and stack trace, +// and writes a 500 status code response if a panic occurs. +func Recovery() MiddlewareFunc { + return CustomRecovery(DefaultHandleRecovery) +} + // DefaultHandleRecovery is a default implementation of the RecoveryFunc. -// It logs the recovered panic error to stderr, including the stack trace. -// If the response has not been written yet and the error is not caused by a broken connection, -// it sets the status code to http.StatusInternalServerError and writes a generic error message. -func DefaultHandleRecovery(c Context, err any) { - stdErr.Printf("[PANIC] %q recovered\n%s", err, debug.Stack()) - if !c.Writer().Written() && !connIsBroken(err) { - http.Error(c.Writer(), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - } +// It responds with a status code 500 and writes a generic error message. +func DefaultHandleRecovery(c Context, _ any) { + http.Error(c.Writer(), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } -func recovery(c Context, handle RecoveryFunc) { +func recovery(logger *slog.Logger, c Context, handle RecoveryFunc) { if err := recover(); err != nil { if abortErr, ok := err.(error); ok && errors.Is(abortErr, http.ErrAbortHandler) { panic(abortErr) } - handle(c, err) + + var sb strings.Builder + + sb.WriteString("Recovered from PANIC\n") + sb.WriteString("Request Dump:\n") + + httpRequest, _ := httputil.DumpRequest(c.Request(), false) + headers := strings.Split(string(httpRequest), "\r\n") + sb.WriteString(headers[0]) + for i := 1; i < len(headers); i++ { + sb.WriteString("\r\n") + current := strings.Split(headers[i], ":") + if slices.Contains(blacklistedHeader, current[0]) { + sb.WriteString(current[0]) + sb.WriteString(": ") + continue + } + sb.WriteString(headers[i]) + } + + sb.WriteString("Stack:\n") + sb.WriteString(stacktrace(4, 6)) + + params := c.Params() + attrs := make([]any, 0, len(params)) + for _, param := range params { + attrs = append(attrs, slog.String(param.Key, param.Value)) + } + + logger.Error( + sb.String(), + slog.String("path", c.Path()), + slog.Group("param", attrs...), + slog.Any("error", err), + ) + + if !c.Writer().Written() && !connIsBroken(err) { + handle(c, err) + } } } @@ -64,3 +110,30 @@ func connIsBroken(err any) bool { } return false } + +func stacktrace(skip, nFrames int) string { + pcs := make([]uintptr, nFrames+1) + n := runtime.Callers(skip+1, pcs) + if n == 0 { + return "(no stack)" + } + frames := runtime.CallersFrames(pcs[:n]) + var b strings.Builder + i := 0 + for { + frame, more := frames.Next() + if i > 0 { + b.WriteByte('\n') + } + _, _ = fmt.Fprintf(&b, "called from %s %s:%d", frame.Function, frame.File, frame.Line) + if !more { + break + } + i++ + if i >= nFrames { + _, _ = fmt.Fprintf(&b, "\n(rest of stack elided)") + break + } + } + return b.String() +} diff --git a/recovery_test.go b/recovery_test.go index 27c886e..6791780 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -1,8 +1,11 @@ package fox import ( + "bytes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tigerwill90/fox/internal/slogpretty" + "log/slog" "net" "net/http" "net/http/httptest" @@ -12,7 +15,7 @@ import ( ) func TestAbortHandler(t *testing.T) { - m := Recovery(func(c Context, err any) { + m := CustomRecovery(func(c Context, err any) { c.Writer().WriteHeader(http.StatusInternalServerError) _, _ = c.Writer().Write([]byte(err.(error).Error())) }) @@ -24,8 +27,9 @@ func TestAbortHandler(t *testing.T) { _ = c.String(200, "foo") } - require.NoError(t, r.Tree().Handle(http.MethodPost, "/", h)) - req := httptest.NewRequest(http.MethodPost, "/", nil) + require.NoError(t, r.Tree().Handle(http.MethodPost, "/{foo}", h)) + req := httptest.NewRequest(http.MethodPost, "/foo", nil) + req.Header.Set(HeaderAuthorization, "foobar") w := httptest.NewRecorder() defer func() { @@ -39,7 +43,14 @@ func TestAbortHandler(t *testing.T) { } func TestRecoveryMiddleware(t *testing.T) { - m := Recovery(func(c Context, err any) { + woBuf := bytes.NewBuffer(nil) + weBuf := bytes.NewBuffer(nil) + + m := CustomRecoveryWithLogHandler(&slogpretty.Handler{ + We: weBuf, + Wo: woBuf, + Lvl: slog.LevelDebug, + }, func(c Context, err any) { c.Writer().WriteHeader(http.StatusInternalServerError) _, _ = c.Writer().Write([]byte(err.(string))) }) @@ -54,13 +65,19 @@ func TestRecoveryMiddleware(t *testing.T) { require.NoError(t, r.Tree().Handle(http.MethodPost, "/", h)) req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(HeaderAuthorization, "foobar") w := httptest.NewRecorder() r.ServeHTTP(w, req) require.Equal(t, http.StatusInternalServerError, w.Code) assert.Equal(t, errMsg, w.Body.String()) + assert.Equal(t, woBuf.Len(), 0) + assert.NotEqual(t, weBuf.Len(), 0) } func TestRecoveryMiddlewareWithBrokenPipe(t *testing.T) { + woBuf := bytes.NewBuffer(nil) + weBuf := bytes.NewBuffer(nil) + expectMsgs := map[syscall.Errno]string{ syscall.EPIPE: "broken pipe", syscall.ECONNRESET: "connection reset by peer", @@ -68,7 +85,11 @@ func TestRecoveryMiddlewareWithBrokenPipe(t *testing.T) { for errno, expectMsg := range expectMsgs { t.Run(expectMsg, func(t *testing.T) { - f := New(WithMiddleware(Recovery(func(c Context, err any) { + f := New(WithMiddleware(CustomRecoveryWithLogHandler(&slogpretty.Handler{ + We: weBuf, + Wo: woBuf, + Lvl: slog.LevelDebug, + }, func(c Context, err any) { if !connIsBroken(err) { http.Error(c.Writer(), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } @@ -81,8 +102,9 @@ func TestRecoveryMiddlewareWithBrokenPipe(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/foo", nil) w := httptest.NewRecorder() f.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, woBuf.Len(), 0) + assert.NotEqual(t, weBuf.Len(), 0) }) } } diff --git a/strategy/strategy.go b/strategy/strategy.go new file mode 100644 index 0000000..ee837e0 --- /dev/null +++ b/strategy/strategy.go @@ -0,0 +1,624 @@ +// The code in this package is derivative of https://github.com/realclientip/realclientip-go (all credit to Adam Pritchard). +// Mount of this source code is governed by a BSD Zero Clause License that can be found +// at https://github.com/realclientip/realclientip-go/blob/main/LICENSE. + +package strategy + +import ( + "errors" + "fmt" + "github.com/tigerwill90/fox" + "net" + "net/http" + "strings" +) + +const ( + xForwardedForHdr = "X-Forwarded-For" + forwardedHdr = "Forwarded" +) + +var ( + ErrInvalidIpAddress = errors.New("invalid ip address") + ErrUnspecifiedIpAddress = errors.New("unspecified ip address") + ErrRemoteAddress = errors.New("remote address strategy") + ErrSingleIPHeader = errors.New("single ip header strategy") + ErrLeftmostNonPrivate = errors.New("leftmost non private strategy") + ErrRightmostNonPrivate = errors.New("rightmost non private strategy") + ErrRightmostTrustedCount = errors.New("rightmost trusted count strategy") + ErrRightmostTrustedRange = errors.New("rightmost trusted range strategy") +) + +// Chain attempts to use the given strategies in order. If the first one returns an error, the second one is +// tried, and so on, until a good IP is found or the strategies are exhausted. A common use for this is if a server is +// both directly connected to the internet and expecting a header to check. It might be called like: +// +// NewChain(NewLeftmostNonPrivate("X-Forwarded-For")), NewRemoteAddr()) +type Chain struct { + strategies []fox.ClientIPStrategy +} + +// NewChain creates a Chain that attempts to use the given strategies to +// derive the client IP, stopping when the first one succeeds. +func NewChain(strategies ...fox.ClientIPStrategy) Chain { + return Chain{strategies: strategies} +} + +// ClientIP derives the client IP using this strategy. +// headers is expected to be like http.Request.Header. +// remoteAddr is expected to be like http.Request.RemoteAddr. +// The returned IP may contain a zone identifier. +// If all chained strategies fail to derive a valid IP, an empty string is returned. +func (s Chain) ClientIP(c fox.Context) (*net.IPAddr, error) { + var errs error + for _, sub := range s.strategies { + ipAddr, err := sub.ClientIP(c) + if err == nil { + return ipAddr, nil + } + errs = errors.Join(errs, err) + } + + return nil, errs +} + +// RemoteAddr returns the client socket IP, stripped of port. +// This strategy should be used if the server accept direct connections, rather than +// through a reverse proxy. +type RemoteAddr struct{} + +// NewRemoteAddr that uses request remote address to get the client IP. +func NewRemoteAddr() RemoteAddr { + return RemoteAddr{} +} + +// ClientIP derives the client IP using the RemoteAddr strategy. The returned net.IPAddr may contain a zone identifier. +// This should only happen if remoteAddr has been modified to something illegal, or if the server is accepting connections +// on a Unix domain socket (in which case RemoteAddr is "@"). If no valid IP can be derived, an error is returned. +func (s RemoteAddr) ClientIP(c fox.Context) (*net.IPAddr, error) { + ipAddr, err := ParseIPAddr(c.Request().RemoteAddr) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrRemoteAddress, err) + } + return ipAddr, nil +} + +// SingleIPHeader derives an IP address from a single-IP header. A non-exhaustive list of such single-IP headers +// is: X-Real-IP, CF-Connecting-IP, True-Client-IP, Fastly-Client-IP, X-Azure-ClientIP, X-Azure-SocketIP. This strategy +// should be used when the given header is added by a trusted reverse proxy. You must ensure that this header is not +// spoofable (as is possible with Akamai's use of True-Client-IP, Fastly's default use of Fastly-Client-IP, +// and Azure's X-Azure-ClientIP). +// See the single-IP wiki page for more info: https://github.com/realclientip/realclientip-go/wiki/Single-IP-Headers +type SingleIPHeader struct { + headerName string +} + +// NewSingleIPHeader creates a SingleIPHeader strategy that uses the headerName request header to get the client IP. +func NewSingleIPHeader(headerName string) SingleIPHeader { + if headerName == "" { + panic(errors.New("header must not be empty")) + } + + // We will be using the headerName for lookups in the http.Header map, which is keyed + // by canonicalized header name. We'll canonicalize here so we only have to do it once. + headerName = http.CanonicalHeaderKey(headerName) + + if headerName == xForwardedForHdr || headerName == forwardedHdr { + panic(fmt.Errorf("header must not be %s or %s", xForwardedForHdr, forwardedHdr)) + } + + return SingleIPHeader{headerName: headerName} +} + +// ClientIP derives the client IP using the SingleIPHeader. The returned net.IPAddr may contain a zone identifier. +// If no valid IP can be derived, an error is returned. +func (s SingleIPHeader) ClientIP(c fox.Context) (*net.IPAddr, error) { + // RFC 2616 does not allow multiple instances of single-IP headers (or any non-list header). + // It is debatable whether it is better to treat multiple such headers as an error + // (more correct) or simply pick one of them (more flexible). As we've already + // told the user tom make sure the header is not spoofable, we're going to use the + // last header instance if there are multiple. (Using the last is arbitrary, but + // in theory it should be the newest value.) + ipStr := lastHeader(c.Request().Header, s.headerName) + if ipStr == "" { + return nil, fmt.Errorf("%w: header %q not found", ErrSingleIPHeader, s.headerName) + } + + return ParseIPAddr(ipStr) +} + +// LeftmostNonPrivate derives the client IP from the leftmost valid and non-private IP address in the +// X-Fowarded-For or Forwarded header. This strategy should be used when a valid, non-private IP closest to the client is desired. +// Note that this MUST NOT BE USED FOR SECURITY PURPOSES. This IP can be TRIVIALLY SPOOFED. +type LeftmostNonPrivate struct { + headerName string +} + +// NewLeftmostNonPrivate creates a LeftmostNonPrivate strategy. headerName must be "X-Forwarded-For" or "Forwarded". +func NewLeftmostNonPrivate(headerName string) LeftmostNonPrivate { + if headerName == "" { + panic("header must not be empty") + } + + // We will be using the headerName for lookups in the http.Header map, which is keyed + // by canonicalized header name. We'll do that here so we only have to do it once. + headerName = http.CanonicalHeaderKey(headerName) + + if headerName != xForwardedForHdr && headerName != forwardedHdr { + panic(fmt.Errorf("header must be %s or %s", xForwardedForHdr, forwardedHdr)) + } + + return LeftmostNonPrivate{headerName: headerName} +} + +// ClientIP derives the client IP using the LeftmostNonPrivate. +// The returned net.IPAddr may contain a zone identifier. If no valid IP can be derived, an error returned. +func (s LeftmostNonPrivate) ClientIP(c fox.Context) (*net.IPAddr, error) { + ipAddrs := getIPAddrList(c.Request().Header, s.headerName) + for _, ip := range ipAddrs { + if ip != nil && !isPrivateOrLocal(ip.IP) { + // This is the leftmost valid, non-private IP + return ip, nil + } + } + + // We failed to find any valid, non-private IP + return nil, fmt.Errorf("%w: unable to find a valid or non-private IP", ErrLeftmostNonPrivate) +} + +// RightmostNonPrivate derives the client IP from the rightmost valid, non-private/non-internal IP address in +// the X-Fowarded-For or Forwarded header. This strategy should be used when all reverse proxies between the internet +// and the server have private-space IP addresses. +type RightmostNonPrivate struct { + headerName string +} + +// NewRightmostNonPrivate creates a RightmostNonPrivate strategy. headerName must be "X-Forwarded-For" or "Forwarded". +func NewRightmostNonPrivate(headerName string) RightmostNonPrivate { + if headerName == "" { + panic(errors.New("header must not be empty")) + } + + // We will be using the headerName for lookups in the http.Header map, which is keyed + // by canonicalized header name. We'll do that here so we only have to do it once. + headerName = http.CanonicalHeaderKey(headerName) + + if headerName != xForwardedForHdr && headerName != forwardedHdr { + panic(fmt.Errorf("header must be %s or %s", xForwardedForHdr, forwardedHdr)) + } + + return RightmostNonPrivate{headerName: headerName} +} + +// ClientIP derives the client IP using the RightmostNonPrivate. +// The returned net.IPAddr may contain a zone identifier. If no valid IP can be derived, an error returned. +func (s RightmostNonPrivate) ClientIP(c fox.Context) (*net.IPAddr, error) { + ipAddrs := getIPAddrList(c.Request().Header, s.headerName) + // Look backwards through the list of IP addresses + for i := len(ipAddrs) - 1; i >= 0; i-- { + if ipAddrs[i] != nil && !isPrivateOrLocal(ipAddrs[i].IP) { + // This is the rightmost non-private IP + return ipAddrs[i], nil + } + } + + // We failed to find any valid, non-private IP + return nil, fmt.Errorf("%w: unable to find a valid or non-private IP", ErrRightmostNonPrivate) +} + +// RightmostTrustedCount derives the client IP from the valid IP address added by the first trusted reverse +// proxy to the X-Forwarded-For or Forwarded header. This strategy should be used when there is a fixed number of +// trusted reverse proxies that are appending IP addresses to the header. +type RightmostTrustedCount struct { + headerName string + trustedCount int +} + +// NewRightmostTrustedCount creates a RightmostTrustedCount strategy. headerName must be "X-Forwarded-For" or "Forwarded". +// trustedCount is the number of trusted reverse proxies. The IP returned will be the (trustedCount-1)th from the right. For +// example, if there's only one trusted proxy, this strategy will return the last (rightmost) IP address. +func NewRightmostTrustedCount(headerName string, trustedCount int) RightmostTrustedCount { + if headerName == "" { + panic(errors.New("header must not be empty")) + } + + if trustedCount <= 0 { + panic(fmt.Errorf("count must be greater than zero")) + } + + // We will be using the headerName for lookups in the http.Header map, which is keyed + // by canonicalized header name. We'll do that here so we only have to do it once. + headerName = http.CanonicalHeaderKey(headerName) + + if headerName != xForwardedForHdr && headerName != forwardedHdr { + panic(fmt.Errorf("header must be %s or %s", xForwardedForHdr, forwardedHdr)) + } + + return RightmostTrustedCount{headerName: headerName, trustedCount: trustedCount} +} + +// ClientIP derives the client IP using the RightmostTrustedCount. +// The returned net.IPAddr may contain a zone identifier. If no valid IP can be derived, an error returned. +func (s RightmostTrustedCount) ClientIP(c fox.Context) (*net.IPAddr, error) { + ipAddrs := getIPAddrList(c.Request().Header, s.headerName) + + // We want the (N-1)th from the rightmost. For example, if there's only one + // trusted proxy, we want the last. + rightmostIndex := len(ipAddrs) - 1 + targetIndex := rightmostIndex - (s.trustedCount - 1) + + if targetIndex < 0 { + // This is a misconfiguration error. There were fewer IPs than we expected. + return nil, fmt.Errorf("%w: expected %d IP(s) but found %d", ErrRightmostTrustedCount, s.trustedCount, len(ipAddrs)) + } + + ipAddr := ipAddrs[targetIndex] + + if ipAddr == nil { + // This is a misconfiguration error. Our first trusted proxy didn't add a + // valid IP address to the header. + return nil, fmt.Errorf("%w: invalid IP address from the first trusted proxy", ErrRightmostTrustedCount) + } + + return ipAddr, nil +} + +// RightmostTrustedRange derives the client IP from the rightmost valid IP address in the X-Forwarded-For or Forwarded +// header which is not in a set of trusted IP ranges. This strategy should be used when the IP ranges of the reverse +// proxies between the internet and the server are known. If a third-party WAF, CDN, etc., is used, you SHOULD use a +// method of verifying its access to your origin that is stronger than checking its IP address (e.g., using authenticated pulls). +// Failure to do so can result in scenarios like: You use AWS CloudFront in front of a server you host elsewhere. An +// attacker creates a CF distribution that points at your origin server. The attacker uses Lambda@Edge to spoof the Host +// and X-Forwarded-For headers. Now your "trusted" reverse proxy is no longer trustworthy. +type RightmostTrustedRange struct { + headerName string + trustedRanges []net.IPNet +} + +// NewRightmostTrustedRange creates a RightmostTrustedRange strategy. headerName must be "X-Forwarded-For" +// or "Forwarded". trustedRanges must contain all trusted reverse proxies on the path to this server. trustedRanges can +// be private/internal or external (for example, if a third-party reverse proxy is used). +func NewRightmostTrustedRange(headerName string, trustedRanges []net.IPNet) RightmostTrustedRange { + if headerName == "" { + panic(errors.New("header must not be empty")) + } + + // We will be using the headerName for lookups in the http.Header map, which is keyed + // by canonicalized header name. We'll do that here so we only have to do it once. + headerName = http.CanonicalHeaderKey(headerName) + + if headerName != xForwardedForHdr && headerName != forwardedHdr { + panic(fmt.Errorf("header must be %s or %s", xForwardedForHdr, forwardedHdr)) + } + + return RightmostTrustedRange{headerName: headerName, trustedRanges: trustedRanges} +} + +// ClientIP derives the client IP using the RightmostTrustedRange. +// The returned net.IPAddr may contain a zone identifier. If no valid IP can be derived, an error is returned. +func (s RightmostTrustedRange) ClientIP(c fox.Context) (*net.IPAddr, error) { + ipAddrs := getIPAddrList(c.Request().Header, s.headerName) + // Look backwards through the list of IP addresses + for i := len(ipAddrs) - 1; i >= 0; i-- { + if ipAddrs[i] != nil && isIPContainedInRanges(ipAddrs[i].IP, s.trustedRanges) { + // This IP is trusted + continue + } + + // At this point we have found the first-from-the-rightmost untrusted IP + if ipAddrs[i] == nil { + return nil, fmt.Errorf("%w: unable to find a valid IP address", ErrRightmostTrustedRange) + } + + return ipAddrs[i], nil + } + + // Either there are no addresses or they are all in our trusted ranges + return nil, fmt.Errorf("%w: unable to find a valid IP address", ErrRightmostTrustedRange) +} + +// MustParseIPAddr panics if ParseIPAddr fails. +func MustParseIPAddr(ipStr string) *net.IPAddr { + ipAddr, err := ParseIPAddr(ipStr) + if err != nil { + panic(fmt.Sprintf("ParseIPAddr failed: %v", err)) + } + return ipAddr +} + +// ParseIPAddr safely parses the given string into a net.IPAddr. It also returns an error for unspecified (like "::") and zero-value +// addresses (like "0.0.0.0"). These are nominally valid IPs (net.ParseIP will accept them), but they are never valid "real" client IPs. +// +// The function returns the following errors: +// - ErrInvalidIpAddress: if the IP address cannot be parsed. +// - ErrUnspecifiedIpAddress: if the IP address is unspecified (e.g., "::" or "0.0.0.0"). +func ParseIPAddr(ip string) (*net.IPAddr, error) { + host, _, err := net.SplitHostPort(ip) + if err == nil { + ip = host + } + + // We continue even if net.SplitHostPort returned an error. This is because it may + // complain that there are "too many colons" in an IPv6 address that has no brackets + // and no port. net.ParseIP will be the final arbiter of validity. + + // Square brackets around IPv6 addresses may be used in the Forwarded header. + // net.ParseIP doesn't like them, so we'll trim them off. + ip = trimMatchedEnds(ip, "[]") + + ipStr, zone := splitHostZone(ip) + ipAddr := &net.IPAddr{ + IP: net.ParseIP(ipStr), + Zone: zone, + } + + if ipAddr.IP == nil { + return nil, ErrInvalidIpAddress + } + + if ipAddr.IP.IsUnspecified() { + return nil, ErrUnspecifiedIpAddress + } + + return ipAddr, nil +} + +// AddressesAndRangesToIPNets converts a slice of strings with IPv4 and IPv6 addresses and CIDR ranges (prefixes) to +// net.IPNet instances. If net.ParseCIDR or net.ParseIP fail, an error will be returned. Zones in addresses or ranges +// are not allowed and will result in an error. +func AddressesAndRangesToIPNets(ranges ...string) ([]net.IPNet, error) { + var result []net.IPNet + for _, r := range ranges { + if strings.Contains(r, "%") { + return nil, fmt.Errorf("zones are not allowed: %q", r) + } + + if strings.Contains(r, "/") { + // This is a CIDR/prefix + _, ipNet, err := net.ParseCIDR(r) + if err != nil { + return nil, fmt.Errorf("net.ParseCIDR failed for %q: %w", r, err) + } + result = append(result, *ipNet) + } else { + // This is a single IP; convert it to a range including only itself + ip := net.ParseIP(r) + if ip == nil { + return nil, fmt.Errorf("net.ParseIP failed for %q", r) + } + + // To use the right size IP and mask, we need to know if the address is IPv4 or v6. + // Attempt to convert it to IPv4 to find out. + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + } + + // Mask all the bits + mask := len(ip) * 8 + result = append(result, net.IPNet{ + IP: ip, + Mask: net.CIDRMask(mask, mask), + }) + } + } + + return result, nil +} + +func splitHostZone(s string) (host, zone string) { + // This is copied from an unexported function in the Go stdlib: + // https://github.com/golang/go/blob/5c9b6e8e63e012513b1cb1a4a08ff23dec4137a1/src/net/ipsock.go#L219-L228 + + // The IPv6 scoped addressing zone identifier starts after the last percent sign. + if i := strings.LastIndexByte(s, '%'); i > 0 { + host, zone = s[:i], s[i+1:] + } else { + host = s + } + return +} + +// trimMatchedEnds trims s if and only if the first and last bytes in s are in chars. +// If chars is a single character (like `"`), then the first and last bytes must match +// that single character. If chars is two characters (like `[]`), the first byte in s +// must match the first byte in chars, and the last bytes in s must match the last byte +// in chars. +// This helps us ensure that we only trim _matched_ quotes and brackets, +// which strings.Trim doesn't provide. +func trimMatchedEnds(s string, chars string) string { + if len(chars) != 1 && len(chars) != 2 { + panic("chars must be length 1 or 2") + } + + first, last := chars[0], chars[0] + if len(chars) > 1 { + last = chars[1] + } + + if len(s) < 2 { + return s + } + + if s[0] != first { + return s + } + + if s[len(s)-1] != last { + return s + } + + return s[1 : len(s)-1] +} + +// lastHeader returns the last header with the given name. It returns empty string if the +// header is not found or if the header has an empty value. No validation is done on the +// IP string. headerName must already be canonicalized. +// This should be used with single-IP headers, like X-Real-IP. Per RFC 2616, they should +// not have multiple headers, but if they do we can hope we're getting the newest/best by +// taking the last instance. +// This MUST NOT be used with list headers, like X-Forwarded-For and Forwarded. +func lastHeader(headers http.Header, headerName string) string { + // Note that Go's Header map uses canonicalized keys + matches, ok := headers[headerName] + if !ok || len(matches) == 0 { + // For our uses of this function, returning an empty string in this case is fine + return "" + } + + return matches[len(matches)-1] +} + +// getIPAddrList creates a single list of all the X-Forwarded-For or Forwarded header +// values, in order. Any invalid IPs will result in nil elements. headerName must already +// be canonicalized. +func getIPAddrList(headers http.Header, headerName string) []*net.IPAddr { + var result []*net.IPAddr + + // There may be multiple XFF headers present. We need to iterate through them all, + // in order, and collect all the IPs. + // Note that we're not joining all the headers into a single string and then + // splitting. Doing it that way would use more memory. + // Note that Go's Header map uses canonicalized keys. + for _, h := range headers[headerName] { + // We now have a string with comma-separated list items + for _, rawListItem := range strings.Split(h, ",") { + // The IPs are often comma-space separated, so we'll need to trim the string + rawListItem = strings.TrimSpace(rawListItem) + + var ipAddr *net.IPAddr + // If this is the XFF header, rawListItem is just an IP; + // if it's the Forwarded header, then there's more parsing to do. + if headerName == forwardedHdr { + ipAddr = parseForwardedListItem(rawListItem) + } else { // == XFF + ipAddr, _ = ParseIPAddr(rawListItem) + } + + // ipAddr is nil if not valid + result = append(result, ipAddr) + } + } + + // Possible performance improvements: + // Here we are parsing _all_ of the IPs in the XFF headers, but we don't need all of + // them. Instead, we could start from the left or the right (depending on strategy), + // parse as we go, and stop when we've come to the one we want. But that would make + // the various strategies somewhat more complex. + + return result +} + +// parseForwardedListItem parses a Forwarded header list item, and returns the "for" IP +// address. Nil is returned if the "for" IP is absent or invalid. +func parseForwardedListItem(fwd string) *net.IPAddr { + // The header list item can look like these kinds of thing: + // For="[2001:db8:cafe::17%zone]:4711" + // For="[2001:db8:cafe::17%zone]" + // for=192.0.2.60;proto=http; by=203.0.113.43 + // for=192.0.2.43 + + // First split up "for=", "by=", "host=", etc. + fwdParts := strings.Split(fwd, ";") + + // Find the "for=" part, since that has the IP we want (maybe) + var forPart string + for _, fp := range fwdParts { + // Whitespace is allowed around the semicolons + fp = strings.TrimSpace(fp) + + fpSplit := strings.Split(fp, "=") + if len(fpSplit) != 2 { + // There are too many or too few equal signs in this part + continue + } + + if strings.EqualFold(fpSplit[0], "for") { + // We found the "for=" part + forPart = fpSplit[1] + break + } + } + + // There shouldn't (per RFC 7239) be spaces around the semicolon or equal sign. It might + // be more correct to consider spaces an error, but we'll tolerate and trim them. + forPart = strings.TrimSpace(forPart) + + // Get rid of any quotes, such as surrounding IPv6 addresses. + // Note that doing this without checking if the quotes are present means that we are + // effectively accepting IPv6 addresses that don't strictly conform to RFC 7239, which + // requires quotes. https://www.rfc-editor.org/rfc/rfc7239#section-4 + // This behaviour is debatable. + // It also means that we will accept IPv4 addresses with quotes, which is correct. + forPart = trimMatchedEnds(forPart, `"`) + + if forPart == "" { + // We failed to find a "for=" part + return nil + } + + ipAddr, _ := ParseIPAddr(forPart) + if ipAddr == nil { + // The IP extracted from the "for=" part isn't valid + return nil + } + + return ipAddr +} + +// mustParseCIDR panics if net.ParseCIDR fails +func mustParseCIDR(s string) net.IPNet { + _, ipNet, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return *ipNet +} + +// privateAndLocalRanges net.IPNets that are loopback, private, link local, default unicast. +// Based on https://github.com/wader/filtertransport/blob/bdd9e61eee7804e94ceb927c896b59920345c6e4/filter.go#L36-L64 +// which is based on https://github.com/letsencrypt/boulder/blob/master/bdns/dns.go +var privateAndLocalRanges = []net.IPNet{ + mustParseCIDR("10.0.0.0/8"), // RFC1918 + mustParseCIDR("172.16.0.0/12"), // private + mustParseCIDR("192.168.0.0/16"), // private + mustParseCIDR("127.0.0.0/8"), // RFC5735 + mustParseCIDR("0.0.0.0/8"), // RFC1122 Section 3.2.1.3 + mustParseCIDR("169.254.0.0/16"), // RFC3927 + mustParseCIDR("192.0.0.0/24"), // RFC 5736 + mustParseCIDR("192.0.2.0/24"), // RFC 5737 + mustParseCIDR("198.51.100.0/24"), // Assigned as TEST-NET-2 + mustParseCIDR("203.0.113.0/24"), // Assigned as TEST-NET-3 + mustParseCIDR("192.88.99.0/24"), // RFC 3068 + mustParseCIDR("192.18.0.0/15"), // RFC 2544 + mustParseCIDR("224.0.0.0/4"), // RFC 3171 + mustParseCIDR("240.0.0.0/4"), // RFC 1112 + mustParseCIDR("255.255.255.255/32"), // RFC 919 Section 7 + mustParseCIDR("100.64.0.0/10"), // RFC 6598 + mustParseCIDR("::/128"), // RFC 4291: Unspecified Address + mustParseCIDR("::1/128"), // RFC 4291: Loopback Address + mustParseCIDR("100::/64"), // RFC 6666: Discard Address Block + mustParseCIDR("2001::/23"), // RFC 2928: IETF Protocol Assignments + mustParseCIDR("2001:2::/48"), // RFC 5180: Benchmarking + mustParseCIDR("2001:db8::/32"), // RFC 3849: Documentation + mustParseCIDR("2001::/32"), // RFC 4380: TEREDO + mustParseCIDR("fc00::/7"), // RFC 4193: Unique-Local + mustParseCIDR("fe80::/10"), // RFC 4291: Section 2.5.6 Link-Scoped Unicast + mustParseCIDR("ff00::/8"), // RFC 4291: Section 2.7 + mustParseCIDR("2002::/16"), // RFC 7526: 6to4 anycast prefix deprecated +} + +// isIPContainedInRanges returns true if the given IP is contained in at least one of the given ranges +func isIPContainedInRanges(ip net.IP, ranges []net.IPNet) bool { + for _, r := range ranges { + if r.Contains(ip) { + return true + } + } + return false +} + +// isPrivateOrLocal return true if the given IP address is private, local, or otherwise +// not suitable for an external client IP. +func isPrivateOrLocal(ip net.IP) bool { + return isIPContainedInRanges(ip, privateAndLocalRanges) +} diff --git a/strategy/strategy_test.go b/strategy/strategy_test.go new file mode 100644 index 0000000..f933299 --- /dev/null +++ b/strategy/strategy_test.go @@ -0,0 +1,2057 @@ +// The code in this package is derivative of https://github.com/realclientip/realclientip-go (all credit to Adam Pritchard). +// Mount of this source code is governed by a BSD Zero Clause License that can be found +// at https://github.com/realclientip/realclientip-go/blob/main/LICENSE. + +package strategy + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tigerwill90/fox" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRemoteAddrStrategy_ClientIP(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + req.Header.Add("X-Forwarded-For", "1.1.1.1, 2001:db8:cafe::99%eth0, 3.3.3.3, 192.168.1.1") + w := httptest.NewRecorder() + + c := fox.NewTestContextOnly(fox.New(), w, req) + + cases := []struct { + name string + remoteIP string + wantIP string + wantZone string + wantErr error + }{ + { + name: "should return an ipv4 address", + remoteIP: "192.0.2.1:56235", + wantIP: "192.0.2.1", + }, + { + name: "should return an ipv6 address", + remoteIP: "[fe80::1ff:fe23:4567:890a]:56235", + wantIP: "fe80::1ff:fe23:4567:890a", + }, + { + name: "should return an ipv6 address with zone", + remoteIP: "[fe80::1ff:fe23:4567:890a%eth0]:56235", + wantIP: "fe80::1ff:fe23:4567:890a", + wantZone: "eth0", + }, + { + name: "should return an an invalid ip address error", + remoteIP: "@", + wantErr: ErrInvalidIpAddress, + }, + { + // This is for coverage. It should not be possible for RemoteAddr. + name: "should return an an unspecified ip address error", + remoteIP: "0.0.0.0", + wantErr: ErrUnspecifiedIpAddress, + }, + } + + s := NewRemoteAddr() + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + c.Request().RemoteAddr = tc.remoteIP + ipAddr, err := s.ClientIP(c) + if tc.wantErr != nil { + assert.ErrorIs(t, err, tc.wantErr) + return + } + assert.Equal(t, tc.wantIP, ipAddr.IP.String()) + assert.Equal(t, tc.wantZone, ipAddr.Zone) + + }) + } + +} + +func TestSingleIPHeaderStrategy_ClientIP(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + req.Header.Add("X-Real-IP", "4.4.4.4") + req.Header.Add("X-Real-IP", "5.5.5.5") + w := httptest.NewRecorder() + + c := fox.NewTestContextOnly(fox.New(), w, req) + + s := NewSingleIPHeader("X-Real-IP") + ipAddr, err := s.ClientIP(c) + require.NoError(t, err) + assert.Equal(t, "5.5.5.5", ipAddr.String()) + + c.Request().Header.Del("X-Real-IP") + _, err = s.ClientIP(c) + assert.ErrorIs(t, err, ErrSingleIPHeader) +} + +func TestLeftmostNonPrivateStrategy_ClientIP(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + req.Header.Add("Forwarded", `For=fe80::abcd;By=fe80::1234, Proto=https;For=::ffff:188.0.2.128, For="[2001:db8:cafe::17]:4848", For=fc00::1`) + w := httptest.NewRecorder() + + c := fox.NewTestContextOnly(fox.New(), w, req) + + s := NewLeftmostNonPrivate("Forwarded") + ipAddr, err := s.ClientIP(c) + require.NoError(t, err) + assert.Equal(t, "188.0.2.128", ipAddr.String()) + + // Only private IP address + req.Header.Set("Forwarded", `for=192.168.1.1, for=10.0.0.1, for="[fd00::1]", for=172.16.0.1`) + _, err = s.ClientIP(c) + assert.ErrorIs(t, err, ErrLeftmostNonPrivate) +} + +func TestRightmostNonPrivateStrategy_ClientIP(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + req.Header.Add("X-Forwarded-For", "1.1.1.1, 2001:db8:cafe::99%eth0, 3.3.3.3, 192.168.1.1") + w := httptest.NewRecorder() + + c := fox.NewTestContextOnly(fox.New(), w, req) + s := NewRightmostNonPrivate("X-Forwarded-For") + ipAddr, err := s.ClientIP(c) + require.NoError(t, err) + assert.Equal(t, "3.3.3.3", ipAddr.String()) + + // With no whitespace + req.Header.Set("X-Forwarded-For", "1.1.1.1,2001:db8:cafe::99%eth0, 3.3.3.3,192.168.1.1") + ipAddr, err = s.ClientIP(c) + require.NoError(t, err) + assert.Equal(t, "3.3.3.3", ipAddr.String()) + + // Only private IP + req.Header.Set("X-Forwarded-For", "192.168.1.1, 10.0.0.1, [fd00::1], 172.16.0.1") + _, err = s.ClientIP(c) + assert.ErrorIs(t, err, ErrRightmostNonPrivate) +} + +func TestRightmostTrustedCountStrategy_ClientIP(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + req.Header.Add("Forwarded", `For=fe80::abcd;By=fe80::1234, Proto=https;For=::ffff:188.0.2.128, For="[2001:db8:cafe::17]:4848", For=fc00::1`) + w := httptest.NewRecorder() + + c := fox.NewTestContextOnly(fox.New(), w, req) + s := NewRightmostTrustedCount("Forwarded", 2) + ipAddr, err := s.ClientIP(c) + require.NoError(t, err) + assert.Equal(t, "2001:db8:cafe::17", ipAddr.String()) + + req.Header.Set("Forwarded", `For=fc00::1`) + _, err = s.ClientIP(c) + assert.ErrorIs(t, err, ErrRightmostTrustedCount) + assert.ErrorContains(t, err, "expected 2 IP(s) but found 1") + + req.Header.Set("Forwarded", `For=fe80::abcd;By=fe80::1234, Proto=https;For=::ffff:188.0.2.128, For="invalid", For=fc00::1`) + _, err = s.ClientIP(c) + assert.ErrorIs(t, err, ErrRightmostTrustedCount) + assert.ErrorContains(t, err, "invalid IP address from the first trusted proxy") +} + +func TestRightmostTrustedRangeStrategy_ClientIP(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + req.Header.Add("X-Forwarded-For", "1.1.1.1, 2001:db8:cafe::99%eth0, 3.3.3.3, 192.168.1.1") + w := httptest.NewRecorder() + + c := fox.NewTestContextOnly(fox.New(), w, req) + trustedRanges, _ := AddressesAndRangesToIPNets([]string{"192.168.0.0/16", "3.3.3.3"}...) + s := NewRightmostTrustedRange("X-Forwarded-For", trustedRanges) + ipAddr, err := s.ClientIP(c) + require.NoError(t, err) + assert.Equal(t, "2001:db8:cafe::99%eth0", ipAddr.String()) + + // Invalid IP + req.Header.Set("X-Forwarded-For", "1.1.1.1, invalid, 3.3.3.3, 192.168.1.1") + _, err = s.ClientIP(c) + assert.ErrorIs(t, err, ErrRightmostTrustedRange) + assert.ErrorContains(t, err, "unable to find a valid IP address") + + req.Header.Set("X-Forwarded-For", "192.168.1.2, 3.3.3.3, 192.168.1.1") + _, err = s.ClientIP(c) + assert.ErrorIs(t, err, ErrRightmostTrustedRange) + assert.ErrorContains(t, err, "unable to find a valid IP address") +} + +func TestChainStrategy_ClientIP(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + req.Header.Add("X-Real-IP", "4.4.4.4") + req.RemoteAddr = "192.0.2.1:8080" + w := httptest.NewRecorder() + + c := fox.NewTestContextOnly(fox.New(), w, req) + s := NewChain( + NewSingleIPHeader("Cf-Connecting-IP"), + NewRemoteAddr(), + ) + ipAddr, err := s.ClientIP(c) + require.NoError(t, err) + assert.Equal(t, "192.0.2.1", ipAddr.String()) + + // Invalid remote ip + req.RemoteAddr = " @" + _, err = s.ClientIP(c) + assert.ErrorIs(t, err, ErrSingleIPHeader) + assert.ErrorIs(t, err, ErrRemoteAddress) + assert.ErrorIs(t, err, ErrInvalidIpAddress) + assert.ErrorContains(t, err, "header \"Cf-Connecting-Ip\" not found") +} + +func TestAddressesAndRangesToIPNets(t *testing.T) { + tests := []struct { + name string + ranges []string + want []string + wantErr bool + }{ + { + name: "Empty input", + ranges: []string{}, + want: nil, + }, + { + name: "Single IPv4 address", + ranges: []string{"1.1.1.1"}, + want: []string{"1.1.1.1/32"}, + }, + { + name: "Single IPv6 address", + ranges: []string{"2607:f8b0:4004:83f::200e"}, + want: []string{"2607:f8b0:4004:83f::200e/128"}, + }, + { + name: "Single IPv4 range", + ranges: []string{"1.1.1.1/16"}, + want: []string{"1.1.0.0/16"}, + }, + { + name: "Single IPv6 range", + ranges: []string{"2607:f8b0:4004:83f::200e/48"}, + want: []string{"2607:f8b0:4004::/48"}, + }, + { + name: "Mixed input", + ranges: []string{ + "1.1.1.1", "2607:f8b0:4004:83f::200e", + "1.1.1.1/32", "2607:f8b0:4004:83f::200e/128", + "1.1.1.1/16", "2607:f8b0:4004:83f::200e/56", + "::ffff:188.0.2.128/112", "::ffff:bc15:0006/104", + "64:ff9b::188.0.2.128/112", + }, + want: []string{ + "1.1.1.1/32", "2607:f8b0:4004:83f::200e/128", + "1.1.1.1/32", "2607:f8b0:4004:83f::200e/128", + "1.1.0.0/16", "2607:f8b0:4004:800::/56", + "188.0.0.0/16", "188.0.0.0/8", + "64:ff9b::bc00:0/112", + }, + }, + { + name: "No input", + ranges: nil, + want: nil, + }, + { + name: "Error: garbage CIDR", + ranges: []string{"2607:f8b0:4004:83f::200e/nope"}, + wantErr: true, + }, + { + name: "Error: CIDR with zone", + ranges: []string{"fe80::abcd%nope/64"}, + wantErr: true, + }, + { + name: "Error: garbage IP", + ranges: []string{"1.1.1.nope"}, + wantErr: true, + }, + { + name: "Error: empty value", + ranges: []string{""}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := AddressesAndRangesToIPNets(tt.ranges...) + if tt.wantErr { + assert.Error(t, err) + return + } + + if err != nil { + // We can't continue + return + } + + require.Equal(t, len(tt.want), len(got)) + for i := 0; i < len(got); i++ { + if got[i].String() != tt.want[i] { + assert.Equal(t, tt.want[i], got[i].String()) + } + } + }) + } +} + +func TestMustParseIPAddr(t *testing.T) { + // We test the non-panic path elsewhere, but we need to specifically check the panic case + assert.Panics(t, func() { + MustParseIPAddr("nope") + }) +} + +func TestParseIPAddr(t *testing.T) { + tests := []struct { + name string + ipStr string + want net.IPAddr + wantErr bool + }{ + { + name: "Empty zone", + ipStr: "1.1.1.1%", + want: net.IPAddr{IP: net.ParseIP("1.1.1.1"), Zone: ""}, + }, + { + name: "No zone", + ipStr: "1.1.1.1", + want: net.IPAddr{IP: net.ParseIP("1.1.1.1"), Zone: ""}, + }, + { + name: "With zone", + ipStr: "fe80::abcd%zone", + want: net.IPAddr{IP: net.ParseIP("fe80::abcd"), Zone: "zone"}, + }, + { + name: "With zone and port", + ipStr: "[2607:f8b0:4004:83f::200e%zone]:4484", + want: net.IPAddr{IP: net.ParseIP("2607:f8b0:4004:83f::200e"), Zone: "zone"}, + }, + { + name: "With port", + ipStr: "1.1.1.1:48944", + want: net.IPAddr{IP: net.ParseIP("1.1.1.1"), Zone: ""}, + }, + { + name: "Bad port (is discarded)", + ipStr: "[fe80::abcd%eth0]:xyz", + want: net.IPAddr{IP: net.ParseIP("fe80::abcd"), Zone: "eth0"}, + }, + { + name: "Zero address", + ipStr: "0.0.0.0", + wantErr: true, + }, + { + name: "Unspecified address", + ipStr: "::", + wantErr: true, + }, + { + name: "Error: bad IP with zone", + ipStr: "nope%zone", + wantErr: true, + }, + { + name: "Error: bad IP", + ipStr: "nope!!", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseIPAddr(tt.ipStr) + if tt.wantErr { + assert.Error(t, err) + return + } + + if !ipAddrsEqual(*got, tt.want) { + t.Fatalf("ParseIPAddr() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_isPrivateOrLocal(t *testing.T) { + tests := []struct { + name string + ip string + want bool + }{ + { + name: "IPv4 loopback", + ip: `127.0.0.2`, + want: true, + }, + { + name: "IPv6 loopback", + ip: `::1`, + want: true, + }, + { + name: "IPv4 10.*", + ip: `10.0.0.1`, + want: true, + }, + { + name: "IPv4 192.168.*", + ip: `192.168.1.1`, + want: true, + }, + { + name: "IPv6 unique local address", + ip: `fd12:3456:789a:1::1`, + want: true, + }, + { + name: "IPv4 link-local", + ip: `169.254.1.1`, + want: true, + }, + { + name: "IPv6 link-local", + ip: `fe80::abcd`, + want: true, + }, + { + name: "Non-local IPv4", + ip: `1.1.1.1`, + want: false, + }, + { + name: "Non-local IPv4-mapped IPv6", + ip: `::ffff:188.0.2.128`, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + require.NotNil(t, ip) + got := isPrivateOrLocal(ip) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_mustParseCIDR(t *testing.T) { + // We test the non-panic path elsewhere, but we need to specifically check the panic case + assert.Panics(t, func() { + mustParseCIDR("nope") + }) +} + +func Test_trimMatchedEnds(t *testing.T) { + // We test the non-panic paths elsewhere, but we need to specifically check the panic case + assert.Panics(t, func() { + trimMatchedEnds("nope", "abcd") + }) +} + +func Test_parseForwardedListItem(t *testing.T) { + tests := []struct { + name string + fwd string + want *net.IPAddr + }{ + { + // This is the correct form for IPv6 wit port + name: "IPv6 with port and quotes", + fwd: `For="[2607:f8b0:4004:83f::200e]:4711"`, + want: MustParseIPAddr("2607:f8b0:4004:83f::200e"), + }, + { + // This is the correct form for IP with no port + name: "IPv6 with quotes, brackets and no port", + fwd: `fOR="[2607:f8b0:4004:83f::200e]"`, + want: MustParseIPAddr("2607:f8b0:4004:83f::200e"), + }, + { + // RFC deviation: missing brackets + name: "IPv6 with quotes, no brackets, and no port", + fwd: `for="2607:f8b0:4004:83f::200e"`, + want: MustParseIPAddr("2607:f8b0:4004:83f::200e"), + }, + { + // RFC deviation: missing quotes + name: "IPv6 with brackets, no quotes, and no port", + fwd: `FOR=[2607:f8b0:4004:83f::200e]`, + want: MustParseIPAddr("2607:f8b0:4004:83f::200e"), + }, + { + // RFC deviation: missing quotes + name: "IPv6 with port and no quotes", + fwd: `For=[2607:f8b0:4004:83f::200e]:4711`, + want: MustParseIPAddr("2607:f8b0:4004:83f::200e"), + }, + { + name: "IPv6 with port, quotes, and zone", + fwd: `For="[fe80::abcd%zone]:4711"`, + want: MustParseIPAddr("fe80::abcd%zone"), + }, + { + // RFC deviation: missing brackets + name: "IPv6 with zone, no quotes, no port", + fwd: `For="fe80::abcd%zone"`, + want: MustParseIPAddr("fe80::abcd%zone"), + }, + { + // RFC deviation: missing quotes + name: "IPv4 with port", + fwd: `FoR=192.0.2.60:4711`, + want: MustParseIPAddr("192.0.2.60"), + }, + { + name: "IPv4 with no port", + fwd: `for=192.0.2.60`, + want: MustParseIPAddr("192.0.2.60"), + }, + { + name: "IPv4 with quotes", + fwd: `for="192.0.2.60"`, + want: MustParseIPAddr("192.0.2.60"), + }, + { + name: "IPv4 with port and quotes", + fwd: `for="192.0.2.60:4823"`, + want: MustParseIPAddr("192.0.2.60"), + }, + { + name: "Error: invalid IPv4", + fwd: `for=192.0.2.999`, + want: nil, + }, + { + name: "Error: invalid IPv6", + fwd: `for="2607:f8b0:4004:83f::999999"`, + want: nil, + }, + { + name: "Error: non-IP identifier", + fwd: `for="_test"`, + want: nil, + }, + { + name: "Error: empty IP value", + fwd: `for=`, + want: nil, + }, + { + name: "Multiple IPv4 directives", + fwd: `by=1.1.1.1; for=2.2.2.2;host=myhost; proto=https`, + want: MustParseIPAddr("2.2.2.2"), + }, + { + // RFC deviation: missing quotes around IPv6 + name: "Multiple IPv6 directives", + fwd: `by=1::1;host=myhost;for=2::2;proto=https`, + want: MustParseIPAddr("2::2"), + }, + { + // RFC deviation: missing quotes around IPv6 + name: "Multiple mixed directives", + fwd: `by=1::1;host=myhost;proto=https;for=2.2.2.2`, + want: MustParseIPAddr("2.2.2.2"), + }, + { + name: "IPv4-mapped IPv6", + fwd: `for="[::ffff:188.0.2.128]"`, + want: MustParseIPAddr("188.0.2.128"), + }, + { + name: "IPv4-mapped IPv6 with port and quotes", + fwd: `for="[::ffff:188.0.2.128]:49428"`, + want: MustParseIPAddr("188.0.2.128"), + }, + { + name: "IPv4-mapped IPv6 in IPv6 form", + fwd: `for="[0:0:0:0:0:ffff:bc15:0006]"`, + want: MustParseIPAddr("188.21.0.6"), + }, + { + name: "NAT64 IPv4-mapped IPv6", + fwd: `for="[64:ff9b::188.0.2.128]"`, + want: MustParseIPAddr("64:ff9b::188.0.2.128"), + }, + { + name: "IPv4 loopback", + fwd: `for=127.0.0.1`, + want: MustParseIPAddr("127.0.0.1"), + }, + { + name: "IPv6 loopback", + fwd: `for="[::1]"`, + want: MustParseIPAddr("::1"), + }, + { + // RFC deviation: quotes must be matched + name: "Error: Unmatched quote", + fwd: `for="1.1.1.1`, + want: nil, + }, + { + // RFC deviation: brackets must be matched + name: "Error: IPv6 loopback", + fwd: `for="::1]"`, + want: nil, + }, + { + name: "Error: misplaced quote", + fwd: `for="[0:0:0:0:0:ffff:bc15:0006"]`, + want: nil, + }, + { + name: "Error: garbage", + fwd: "ads\x00jkl&#*(383fdljk", + want: nil, + }, + { + // Per RFC 7230 section 3.2.6, this should not be an error, but we don't have + // full syntax support yet. + name: "RFC deviation: quoted pair", + fwd: `for=1.1.1.\1`, + want: nil, + }, + { + // Per RFC 7239, this extraneous whitespace should be an error, but we don't + // have full syntax support yet. + name: "RFC deviation: Incorrect whitespace", + fwd: `for= 1.1.1.1`, + want: MustParseIPAddr("1.1.1.1"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseForwardedListItem(tt.fwd) + + if got == nil || tt.want == nil { + assert.Equal(t, tt.want, got) + return + } + + if !ipAddrsEqual(*got, *tt.want) { + t.Fatalf("parseForwardedListItem() = %v, want %v", got, tt.want) + } + }) + } +} + +// Demonstrate parsing deviations from Forwarded header syntax RFCs, particularly +// RFC 7239 (Forwarded header) and RFC 7230 (HTTP/1.1 syntax) section 3.2.6. +func Test_forwardedHeaderRFCDeviations(t *testing.T) { + type args struct { + headers http.Header + headerName string + } + tests := []struct { + name string + args args + want []*net.IPAddr + }{ + { + // The value in quotes should be a single value but we split by comma, so it's not. + // The first and third "For=" bits have one double-quote in them, so they are + // considered invalid by our parser. The second is still in the quoted-string, + // but doesn't have any quotes in it, so it parses okay. + name: "Comma in quotes", + args: args{ + headers: http.Header{"Forwarded": []string{`For="1.1.1.1, For=2.2.2.2, For=3.3.3.3", For="4.4.4.4"`}}, + headerName: "Forwarded", + }, + // There are really only two values, so we actually want: {nil, "4.4.4.4"} + want: []*net.IPAddr{nil, MustParseIPAddr("2.2.2.2"), nil, MustParseIPAddr("4.4.4.4")}, + }, + { + // Per 7239, the opening unmatched quote makes the whole rest of the header invalid. + // But that would mean that an attacker can invalidate the whole header with a + // quote character early on, even the trusted IPs added by our reverse proxies. + // Our actual behaviour is probably the best approach. + name: "Unmatched quote", + args: args{ + headers: http.Header{"Forwarded": []string{`For="1.1.1.1, For=2.2.2.2`}}, + headerName: "Forwarded", + }, + // There are really only two values, so the RFC would require: {nil} (or empty slice?) + want: []*net.IPAddr{nil, MustParseIPAddr("2.2.2.2")}, + }, + { + // The invalid non-For parameter should invalidate the whole item, but we're + // not checking anything but the "For=" part. + name: "Invalid characters", + args: args{ + headers: http.Header{"Forwarded": []string{`For=1.1.1.1;@!=😀, For=2.2.2.2`}}, + headerName: "Forwarded", + }, + // Only the last value is valid, so it should be: {nil, "2.2.2.2"} + want: []*net.IPAddr{MustParseIPAddr("1.1.1.1"), MustParseIPAddr("2.2.2.2")}, + }, + { + // The duplicate "For=" parameter should invalidate the whole item but we don't check for it + name: "Duplicate token", + args: args{ + headers: http.Header{"Forwarded": []string{`For=1.1.1.1;For=2.2.2.2, For=3.3.3.3`}}, + headerName: "Forwarded", + }, + // Only the last value is valid, so it should be: {nil, "3.3.3.3"} + want: []*net.IPAddr{MustParseIPAddr("1.1.1.1"), MustParseIPAddr("3.3.3.3")}, + }, + { + // An escaped character in quotes should be unescaped, but we're not doing it. + // (And if we do end up doing it, make sure that `\\` becomes `\` after escaping. + // And escaping is only allowed in quoted strings.) + // There is no good reason for any part of an IP address to be escaped anyway. + name: "Escaped character", + args: args{ + headers: http.Header{"Forwarded": []string{`For="3.3.3.\3"`}}, + headerName: "Forwarded", + }, + // The value is valid, so it should be: {nil, "3.3.3.3"} + want: []*net.IPAddr{nil}, + }, + { + // Spaces are not allowed around the equal signs, but due to the way we parse + // a space after the equal will pass but one before won't. + name: "Equal sign spaces", + args: args{ + headers: http.Header{"Forwarded": []string{`For =1.1.1.1, For= 3.3.3.3`}}, + headerName: "Forwarded", + }, + // Neither value is valid, so it should be: {nil, nil} + want: []*net.IPAddr{nil, MustParseIPAddr("3.3.3.3")}, + }, + { + // Disallowed characters are only allowed in quoted strings. This means + // that IPv6 addresses must be quoted. + name: "Disallowed characters in unquoted value (like colons and square brackets", + args: args{ + headers: http.Header{"Forwarded": []string{`For=[2607:f8b0:4004:83f::200e]`}}, + headerName: "Forwarded", + }, + // Value is invalid without quotes, so should be {nil} + want: []*net.IPAddr{MustParseIPAddr("2607:f8b0:4004:83f::200e")}, + }, + { + // IPv6 addresses are required to be contained in square brackets. We don't + // require this simply to be more flexible in what is accepted. + name: "IPv6 brackets", + args: args{ + headers: http.Header{"Forwarded": []string{`For="2607:f8b0:4004:83f::200e"`}}, + headerName: "Forwarded", + }, + // IPv6 is invalid without brackets, so should be {nil} + want: []*net.IPAddr{MustParseIPAddr("2607:f8b0:4004:83f::200e")}, + }, + { + // IPv4 addresses are _not_ supposed to be in square brackets, but we trim + // them unconditionally. + name: "IPv4 brackets", + args: args{ + headers: http.Header{"Forwarded": []string{`For="[1.1.1.1]"`}}, + headerName: "Forwarded", + }, + // IPv4 is invalid with brackets, so should be {nil} + want: []*net.IPAddr{MustParseIPAddr("1.1.1.1")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getIPAddrList(tt.args.headers, tt.args.headerName) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestRemoteAddrStrategy(t *testing.T) { + // Ensure the strategy interface is implemented + var _ fox.ClientIPStrategy = RemoteAddr{} + + type args struct { + headers http.Header + remoteAddr string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "IPv4 with port", + args: args{ + remoteAddr: "2.2.2.2:1234", + }, + want: "2.2.2.2", + }, + { + name: "IPv4 with no port", + args: args{ + remoteAddr: "2.2.2.2", + }, + want: "2.2.2.2", + }, + { + name: "IPv6 with port", + args: args{ + remoteAddr: "[2607:f8b0:4004:83f::18]:3838", + }, + want: "2607:f8b0:4004:83f::18", + }, + { + name: "IPv6 with no port", + args: args{ + remoteAddr: "2607:f8b0:4004:83f::18", + }, + want: "2607:f8b0:4004:83f::18", + }, + { + name: "IPv6 with zone and no port", + args: args{ + remoteAddr: `fe80::1111%eth0`, + }, + want: `fe80::1111%eth0`, + }, + { + name: "IPv6 with zone and port", + args: args{ + remoteAddr: `[fe80::2222%eth0]:4848`, + }, + want: `fe80::2222%eth0`, + }, + { + name: "IPv4-mapped IPv6", + args: args{ + remoteAddr: "[::ffff:172.21.0.6]:4747", + }, + // It is okay that this converts to the IPv4 format, since it's most important + // that the respresentation be consistent. It might also be good that it does, + // so that it will match the same plain IPv4 address. + // net/netip.ParseAddr gives a different form: "::ffff:172.21.0.6" + want: "172.21.0.6", + }, + { + name: "IPv4-mapped IPv6 in IPv6 form", + args: args{ + remoteAddr: "0:0:0:0:0:ffff:bc15:0006", + }, + // net/netip.ParseAddr gives a different form: "::ffff:188.21.0.6" + want: "188.21.0.6", + }, + { + name: "NAT64 IPv4-mapped IPv6", + args: args{ + remoteAddr: "[64:ff9b::188.0.2.128]:4747", + }, + // net.ParseIP and net/netip.ParseAddr convert to this. This is fine, as it is + // done consistently. + want: "64:ff9b::bc00:280", + }, + { + name: "6to4 IPv4-mapped IPv6", + args: args{ + remoteAddr: "[2002:c000:204::]:4747", + }, + want: "2002:c000:204::", + }, + { + name: "IPv4 loopback", + args: args{ + remoteAddr: "127.0.0.1", + }, + want: "127.0.0.1", + }, + { + name: "IPv6 loopback", + args: args{ + remoteAddr: "::1", + }, + want: "::1", + }, + { + name: "Garbage header (unused)", + args: args{ + headers: http.Header{"X-Forwarded-For": []string{"!!!"}}, + remoteAddr: "2.2.2.2:1234", + }, + want: "2.2.2.2", + }, + { + name: "Fail: empty RemoteAddr", + args: args{ + remoteAddr: "", + }, + want: "", + }, + { + name: "Fail: garbage RemoteAddr", + args: args{ + remoteAddr: "ohno", + }, + want: "", + }, + { + name: "Fail: zero RemoteAddr IP", + args: args{ + remoteAddr: "0.0.0.0", + }, + want: "", + }, + { + name: "Fail: unspecified RemoteAddr IP", + args: args{ + remoteAddr: "::", + }, + want: "", + }, + { + name: "Fail: Unix domain socket", + args: args{ + remoteAddr: "@", + }, + want: "", + }, + } + + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + w := httptest.NewRecorder() + c := fox.NewTestContextOnly(fox.New(), w, req) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := RemoteAddr{} + c.Request().Header = tt.args.headers + c.Request().RemoteAddr = tt.args.remoteAddr + ipAddr, err := s.ClientIP(c) + if tt.want == "" { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, ipAddr.String()) + }) + } +} + +func TestSingleIPHeaderStrategy(t *testing.T) { + // Ensure the strategy interface is implemented + var _ fox.ClientIPStrategy = SingleIPHeader{} + + type args struct { + headerName string + headers http.Header + remoteAddr string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "IPv4 with port", + args: args{ + headerName: "True-Client-IP", + headers: http.Header{ + "X-Real-Ip": []string{"1.1.1.1"}, + "True-Client-Ip": []string{"2.2.2.2:49489"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "2.2.2.2", + }, + { + name: "IPv4 with no port", + args: args{ + headerName: "X-Real-IP", + headers: http.Header{ + "X-Real-Ip": []string{"1.1.1.1"}, + "True-Client-Ip": []string{"2.2.2.2:49489"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "1.1.1.1", + }, + { + name: "IPv6 with port", + args: args{ + headerName: "X-Real-IP", + headers: http.Header{ + "X-Real-Ip": []string{"[2607:f8b0:4004:83f::18]:3838"}, + "True-Client-Ip": []string{"2.2.2.2:49489"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "2607:f8b0:4004:83f::18", + }, + { + name: "IPv6 with no port", + args: args{ + headerName: "X-Real-IP", + headers: http.Header{ + "X-Real-Ip": []string{"2607:f8b0:4004:83f::19"}, + "True-Client-Ip": []string{"2.2.2.2:49489"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "2607:f8b0:4004:83f::19", + }, + { + name: "IPv6 with zone and no port", + args: args{ + headerName: "a-b-c-d", + headers: http.Header{ + "X-Real-Ip": []string{"2607:f8b0:4004:83f::19"}, + "A-B-C-D": []string{"fe80::1111%zone"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "fe80::1111%zone", + }, + { + name: "IPv6 with zone and port", + args: args{ + headerName: "a-b-c-d", + headers: http.Header{ + "X-Real-Ip": []string{"2607:f8b0:4004:83f::19"}, + "A-B-C-D": []string{"[fe80::1111%zone]:4848"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "fe80::1111%zone", + }, + { + name: "IPv6 with brackets but no port", + args: args{ + headerName: "x-real-ip", + headers: http.Header{ + "X-Real-Ip": []string{"2607:f8b0:4004:83f::19"}, + "A-B-C-D": []string{"[fe80::1111%zone]:4848"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "2607:f8b0:4004:83f::19", + }, + { + name: "IP-mapped IPv6", + args: args{ + headerName: "x-real-ip", + headers: http.Header{ + "X-Real-Ip": []string{"::ffff:172.21.0.6"}, + "A-B-C-D": []string{"[fe80::1111%zone]:4848"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "172.21.0.6", + }, + { + name: "IPv4-mapped IPv6 in IPv6 form", + args: args{ + headerName: "x-real-ip", + headers: http.Header{ + "X-Real-Ip": []string{"[64:ff9b::188.0.2.128]:4747"}, + "A-B-C-D": []string{"[fe80::1111%zone]:4848"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "64:ff9b::bc00:280", + }, + { + name: "6to4 IPv4-mapped IPv6", + args: args{ + headerName: "x-real-ip", + headers: http.Header{ + "X-Real-Ip": []string{"2002:c000:204::"}, + "A-B-C-D": []string{"[fe80::1111%zone]:4848"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "2002:c000:204::", + }, + { + name: "IPv4 loopback", + args: args{ + headerName: "x-real-ip", + headers: http.Header{ + "X-Real-Ip": []string{"127.0.0.1"}, + "A-B-C-D": []string{"[fe80::1111%zone]:4848"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "127.0.0.1", + }, + { + name: "Fail: missing header", + args: args{ + headerName: "x-real-ip", + headers: http.Header{ + "A-B-C-D": []string{"[fe80::1111%zone]:4848"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "", + }, + { + name: "Fail: garbage IP", + args: args{ + headerName: "True-Client-Ip", + headers: http.Header{ + "X-Real-Ip": []string{"::1"}, + "True-Client-Ip": []string{"nope"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "", + }, + { + name: "Fail: zero IP", + args: args{ + headerName: "True-Client-Ip", + headers: http.Header{ + "X-Real-Ip": []string{"::1"}, + "True-Client-Ip": []string{"0.0.0.0"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + want: "", + }, + { + name: "Error: empty header name", + args: args{ + headerName: "", + headers: http.Header{ + "X-Real-Ip": []string{"::1"}, + "True-Client-Ip": []string{"2.2.2.2"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + wantErr: true, + }, + { + name: "Error: X-Forwarded-For header", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{"::1"}, + "True-Client-Ip": []string{"2.2.2.2"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + wantErr: true, + }, + } + + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + w := httptest.NewRecorder() + c := fox.NewTestContextOnly(fox.New(), w, req) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s fox.ClientIPStrategy + if tt.wantErr { + require.Panics(t, func() { + s = NewSingleIPHeader(tt.args.headerName) + }) + return + } + + s = NewSingleIPHeader(tt.args.headerName) + + c.Request().Header = tt.args.headers + c.Request().RemoteAddr = tt.args.remoteAddr + ipAddr, err := s.ClientIP(c) + if tt.want == "" { + require.Error(t, err) + return + } + assert.Equal(t, tt.want, ipAddr.String()) + }) + } +} + +func TestLeftmostNonPrivateStrategy(t *testing.T) { + // Ensure the strategy interface is implemented + var _ fox.ClientIPStrategy = LeftmostNonPrivate{} + + type args struct { + headerName string + headers http.Header + remoteAddr string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "IPv4 with port", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + }, + }, + want: "2.2.2.2", + }, + { + name: "IPv4 with no port", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`For=5.5.5.5`, `For=6.6.6.6`}, + }, + }, + want: "5.5.5.5", + }, + { + name: "IPv6 with port", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`[2607:f8b0:4004:83f::18]:3838, 3.3.3.3`, `4.4.4.4`}, + }, + }, + want: "2607:f8b0:4004:83f::18", + }, + { + name: "IPv6 with no port", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`Host=blah;For="2607:f8b0:4004:83f::18";Proto=https`}, + }, + }, + want: "2607:f8b0:4004:83f::18", + }, + { + name: "IPv6 with port and zone", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`For=[fe80::1111%zone], Host=blah;For="[2607:f8b0:4004:83f::18%zone]:9943";Proto=https`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "2607:f8b0:4004:83f::18%zone", + }, + { + name: "IPv6 with port and zone, no quotes", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`For=[fe80::1111%zone], Host=blah;For=[2607:f8b0:4004:83f::18%zone]:9943;Proto=https`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "2607:f8b0:4004:83f::18%zone", + }, + { + name: "IPv4-mapped IPv6", + args: args{ + headerName: "x-forwarded-for", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`::ffff:188.0.2.128, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`Host=blah;For="7.7.7.7";Proto=https`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "188.0.2.128", + }, + { + name: "IPv4-mapped IPv6 with port", + args: args{ + headerName: "x-forwarded-for", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`[::ffff:188.0.2.128]:48483, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`Host=blah;For="7.7.7.7";Proto=https`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "188.0.2.128", + }, + { + name: "IPv4-mapped IPv6 in IPv6 (hex) form", + args: args{ + headerName: "forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`[::ffff:188.0.2.128]:48483, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "188.21.0.6", + }, + { + name: "NAT64 IPv4-mapped IPv6", + args: args{ + headerName: "x-forwarded-for", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`64:ff9b::188.0.2.128, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "64:ff9b::bc00:280", + }, + { + name: "XFF: leftmost not desirable", + args: args{ + headerName: "x-forwarded-for", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`::1, nope`, `4.4.4.4, 5.5.5.5`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "4.4.4.4", + }, + { + name: "Forwarded: leftmost not desirable", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`::1, nope`, `4.4.4.4, 5.5.5.5`}, + "Forwarded": []string{`For="", For="::ffff:192.168.1.1"`, `host=what;for=:48485;proto=https,For="2607:f8b0:4004:83f::18"`}, + }, + }, + want: "2607:f8b0:4004:83f::18", + }, + { + name: "Fail: XFF: none acceptable", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`::1, nope, ::, 0.0.0.0`, `192.168.1.1, !?!`}, + "Forwarded": []string{`For="", For="::ffff:192.168.1.1"`, `host=what;for=:48485;proto=https,For="fe80::abcd%zone"`}, + }, + }, + want: "", + }, + { + name: "Fail: Forwarded: none acceptable", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`::1, nope`, `192.168.1.1, 2.2.2.2`}, + "Forwarded": []string{`For="", For="::ffff:192.168.1.1"`, `host=what;for=:48485;proto=https,For="::ffff:ac15:0006%zone",For="::",For=0.0.0.0`}, + }, + }, + want: "", + }, + { + name: "Fail: XFF: no header", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "Forwarded": []string{`For="", For="::ffff:192.168.1.1"`, `host=what;for=:48485;proto=https,For="::ffff:ac15:0006%zone"`}, + }, + }, + want: "", + }, + { + name: "Fail: Forwarded: no header", + args: args{ + headerName: "forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`64:ff9b::188.0.2.128, 3.3.3.3`, `4.4.4.4`}, + }, + }, + want: "", + }, + { + name: "Error: empty header name", + args: args{ + headerName: "", + headers: http.Header{ + "X-Real-Ip": []string{"::1"}, + "True-Client-Ip": []string{"2.2.2.2"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + wantErr: true, + }, + { + name: "Error: invalid header", + args: args{ + headerName: "X-Real-IP", + headers: http.Header{ + "X-Real-Ip": []string{"::1"}, + "True-Client-Ip": []string{"2.2.2.2"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + wantErr: true, + }, + } + + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + w := httptest.NewRecorder() + c := fox.NewTestContextOnly(fox.New(), w, req) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s fox.ClientIPStrategy + if tt.wantErr { + require.Panics(t, func() { + s = NewLeftmostNonPrivate(tt.args.headerName) + }) + return + } + + s = NewLeftmostNonPrivate(tt.args.headerName) + + c.Request().Header = tt.args.headers + c.Request().RemoteAddr = tt.args.remoteAddr + ipAddr, err := s.ClientIP(c) + if tt.want == "" { + require.Error(t, err) + return + } + assert.Equal(t, tt.want, ipAddr.String()) + }) + } +} + +func TestRightmostNonPrivateStrategy(t *testing.T) { + // Ensure the strategy interface is implemented + var _ fox.ClientIPStrategy = RightmostNonPrivate{} + + type args struct { + headerName string + headers http.Header + remoteAddr string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "IPv4 with port", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4:39333`}, + }, + }, + want: "4.4.4.4", + }, + { + name: "IPv4 with no port", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`For=5.5.5.5`, `For=6.6.6.6`}, + }, + }, + want: "6.6.6.6", + }, + { + name: "IPv6 with port", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`[2607:f8b0:4004:83f::18]:3838`}, + }, + }, + want: "2607:f8b0:4004:83f::18", + }, + { + name: "IPv6 with no port", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`host=what;for=6.6.6.6;proto=https`, `Host=blah;For="2607:f8b0:4004:83f::18";Proto=https`}, + }, + }, + want: "2607:f8b0:4004:83f::18", + }, + { + name: "IPv6 with port and zone", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`host=what;for=6.6.6.6;proto=https`, `For="[2607:f8b0:4004:83f::18%eth0]:3393";Proto=https`, `Host=blah;For="[fe80::1111%zone]:9943";Proto=https`}, + }, + }, + want: "2607:f8b0:4004:83f::18%eth0", + }, + { + name: "IPv6 with port and zone, no quotes", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`host=what;for=6.6.6.6;proto=https`, `For="[2607:f8b0:4004:83f::18%eth0]:3393";Proto=https`, `Host=blah;For=[fe80::1111%zone]:9943;Proto=https`}, + }, + }, + want: "2607:f8b0:4004:83f::18%eth0", + }, + { + name: "IPv4-mapped IPv6", + args: args{ + headerName: "x-forwarded-for", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`3.3.3.3`, `4.4.4.4, ::ffff:188.0.2.128`}, + "Forwarded": []string{`Host=blah;For="7.7.7.7";Proto=https`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "188.0.2.128", + }, + { + name: "IPv4-mapped IPv6 with port", + args: args{ + headerName: "x-forwarded-for", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`3.3.3.3`, `4.4.4.4,[::ffff:188.0.2.128]:48483`}, + "Forwarded": []string{`Host=blah;For="7.7.7.7";Proto=https`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "188.0.2.128", + }, + { + name: "IPv4-mapped IPv6 in IPv6 (hex) form", + args: args{ + headerName: "forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`[::ffff:188.0.2.128]:48483, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{`host=what;for=6.6.6.6;proto=https`, `For="::ffff:bc15:0006"`}, + }, + }, + want: "188.21.0.6", + }, + { + name: "NAT64 IPv4-mapped IPv6", + args: args{ + headerName: "x-forwarded-for", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`3.3.3.3`, `4.4.4.4, 64:ff9b::188.0.2.128`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "64:ff9b::bc00:280", + }, + { + name: "XFF: rightmost not desirable", + args: args{ + headerName: "x-forwarded-for", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`4.4.4.4, 5.5.5.5`, `::1, nope`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "5.5.5.5", + }, + { + name: "Forwarded: rightmost not desirable", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`::1, nope`, `4.4.4.4, 5.5.5.5`}, + "Forwarded": []string{`host=what;for=:48485;proto=https,For=2.2.2.2`, `For="", For="::ffff:192.168.1.1"`}, + }, + }, + want: "2.2.2.2", + }, + { + name: "Fail: XFF: none acceptable", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`::1, nope`, `192.168.1.1, !?!, ::, 0.0.0.0`}, + "Forwarded": []string{`For="", For="::ffff:192.168.1.1"`, `host=what;for=:48485;proto=https,For="fe80::abcd%zone"`}, + }, + }, + want: "", + }, + { + name: "Fail: Forwarded: none acceptable", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`::1, nope`, `192.168.1.1, 2.2.2.2`}, + "Forwarded": []string{`For="", For="::ffff:192.168.1.1"`, `host=what;for=:48485;proto=https,For="::ffff:ac15:0006%zone", For="::", For=0.0.0.0`}, + }, + }, + want: "", + }, + { + name: "Fail: XFF: no header", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "Forwarded": []string{`For="", For="::ffff:192.168.1.1"`, `host=what;for=:48485;proto=https,For="::ffff:ac15:0006%zone"`}, + }, + remoteAddr: "9.9.9.9", + }, + want: "", + }, + { + name: "Fail: Forwarded: no header", + args: args{ + headerName: "forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`64:ff9b::188.0.2.128, 3.3.3.3`, `4.4.4.4`}, + }, + }, + want: "", + }, + { + name: "Error: empty header name", + args: args{ + headerName: "", + headers: http.Header{ + "X-Real-Ip": []string{"::1"}, + "True-Client-Ip": []string{"2.2.2.2"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + wantErr: true, + }, + { + name: "Error: invalid header", + args: args{ + headerName: "X-Real-IP", + headers: http.Header{ + "X-Real-Ip": []string{"::1"}, + "True-Client-Ip": []string{"2.2.2.2"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + wantErr: true, + }, + } + + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + w := httptest.NewRecorder() + c := fox.NewTestContextOnly(fox.New(), w, req) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s fox.ClientIPStrategy + if tt.wantErr { + require.Panics(t, func() { + s = NewRightmostNonPrivate(tt.args.headerName) + }) + return + } + + s = NewRightmostNonPrivate(tt.args.headerName) + + c.Request().Header = tt.args.headers + c.Request().RemoteAddr = tt.args.remoteAddr + ipAddr, err := s.ClientIP(c) + if tt.want == "" { + require.Error(t, err) + return + } + assert.Equal(t, tt.want, ipAddr.String()) + }) + } +} + +func TestRightmostTrustedCountStrategy(t *testing.T) { + // Ensure the strategy interface is implemented + var _ fox.ClientIPStrategy = RightmostTrustedCount{} + + type args struct { + headerName string + trustedCount int + headers http.Header + remoteAddr string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "Count one", + args: args{ + headerName: "Forwarded", + trustedCount: 1, + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`4.4.4.4, 5.5.5.5`, `::1, fe80::382b:141b:fa4a:2a16%28`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "6.6.6.6", + }, + { + name: "Count five", + args: args{ + headerName: "X-Forwarded-For", + trustedCount: 5, + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`4.4.4.4, 5.5.5.5`, `::1, fe80::382b:141b:fa4a:2a16%28`, `7.7.7.7.7, 8.8.8.8, 9.9.9.9, 10.10.10.10,11.11.11.11, 12.12.12.12`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "8.8.8.8", + }, + { + name: "Fail: header too short/count too large", + args: args{ + headerName: "X-Forwarded-For", + trustedCount: 50, + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`4.4.4.4, 5.5.5.5`, `::1, fe80::382b:141b:fa4a:2a16%28`, `7.7.7.7.7, 8.8.8.8, 9.9.9.9, 10.10.10.10,11.11.11.11, 12.12.12.12`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "", + }, + { + name: "Fail: bad value at count index", + args: args{ + headerName: "Forwarded", + trustedCount: 2, + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`4.4.4.4, 5.5.5.5`, `::1, fe80::382b:141b:fa4a:2a16%28`, `7.7.7.7.7, 8.8.8.8, 9.9.9.9, 10.10.10.10,11.11.11.11, 12.12.12.12`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `For=nope`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "", + }, + { + name: "Fail: zero value at count index", + args: args{ + headerName: "Forwarded", + trustedCount: 2, + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`4.4.4.4, 5.5.5.5`, `::1, fe80::382b:141b:fa4a:2a16%28`, `7.7.7.7.7, 8.8.8.8, 9.9.9.9, 10.10.10.10,11.11.11.11, 12.12.12.12`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `For=0.0.0.0`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + want: "", + }, + { + name: "Fail: header missing", + args: args{ + headerName: "Forwarded", + trustedCount: 1, + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`4.4.4.4, 5.5.5.5`, `::1, fe80::382b:141b:fa4a:2a16%28`, `7.7.7.7.7, 8.8.8.8, 9.9.9.9, 10.10.10.10,11.11.11.11, 12.12.12.12`}, + }, + }, + want: "", + }, + { + name: "Error: empty header name", + args: args{ + headerName: "", + trustedCount: 1, + headers: http.Header{ + "X-Real-Ip": []string{"::1"}, + "True-Client-Ip": []string{"2.2.2.2"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + wantErr: true, + }, + { + name: "Error: invalid header", + args: args{ + headerName: "X-Real-IP", + trustedCount: 1, + headers: http.Header{ + "X-Real-Ip": []string{"::1"}, + "True-Client-Ip": []string{"2.2.2.2"}, + "X-Forwarded-For": []string{"3.3.3.3"}}, + }, + wantErr: true, + }, + { + name: "Error: zero trustedCount", + args: args{ + headerName: "x-forwarded-for", + trustedCount: 0, + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`4.4.4.4, 5.5.5.5`, `::1, nope`, `fe80::382b:141b:fa4a:2a16%28`}, + "Forwarded": []string{`For="::ffff:bc15:0006"`, `host=what;for=6.6.6.6;proto=https`}, + }, + }, + wantErr: true, + }, + { + name: "Error: negative trustedCount", + args: args{ + headerName: "X-Forwarded-For", + trustedCount: -999, + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4:39333`}, + }, + }, + wantErr: true, + }, + } + + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + w := httptest.NewRecorder() + c := fox.NewTestContextOnly(fox.New(), w, req) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s fox.ClientIPStrategy + if tt.wantErr { + require.Panics(t, func() { + s = NewRightmostTrustedCount(tt.args.headerName, tt.args.trustedCount) + }) + return + } + + s = NewRightmostTrustedCount(tt.args.headerName, tt.args.trustedCount) + + c.Request().Header = tt.args.headers + c.Request().RemoteAddr = tt.args.remoteAddr + ipAddr, err := s.ClientIP(c) + if tt.want == "" { + require.Error(t, err) + return + } + assert.Equal(t, tt.want, ipAddr.String()) + }) + } +} + +func TestRightmostTrustedRangeStrategy(t *testing.T) { + // Ensure the strategy interface is implemented + var _ fox.ClientIPStrategy = RightmostTrustedRange{} + + type args struct { + headerName string + headers http.Header + remoteAddr string + trustedRanges []string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "No ranges", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + }, + trustedRanges: nil, + }, + want: "4.4.4.4", + }, + { + name: "One range", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + }, + trustedRanges: []string{`4.4.4.0/24`}, + }, + want: "3.3.3.3", + }, + { + name: "One IP", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + }, + trustedRanges: []string{`4.4.4.4`}, + }, + want: "3.3.3.3", + }, + { + name: "Many kinds of ranges", + args: args{ + headerName: "Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + "Forwarded": []string{ + `For=99.99.99.99, For=4.4.4.8, For="[2607:f8b0:4004:83f::200e]:4747"`, + `For=2.2.2.2:8883, For=64:ff9b::188.0.2.200, For=3.3.5.5, For=2001:db7::abcd`, + }, + }, + trustedRanges: []string{ + `2.2.2.2/32`, `2607:f8b0:4004:83f::200e/128`, + `3.3.0.0/16`, `2001:db7::/64`, + `::ffff:4.4.4.4/124`, `64:ff9b::188.0.2.128/112`, + }, + }, + want: "99.99.99.99", + }, + { + name: "Cloudflare ranges", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`, `2400:cb00::1`}, + }, + trustedRanges: []string{ + "173.245.48.0/20", + "103.21.244.0/22", + "103.22.200.0/22", + "103.31.4.0/22", + "141.101.64.0/18", + "108.162.192.0/18", + "190.93.240.0/20", + "188.114.96.0/20", + "197.234.240.0/22", + "198.41.128.0/17", + "162.158.0.0/15", + "104.16.0.0/13", + "104.24.0.0/14", + "172.64.0.0/13", + "131.0.72.0/22", + "2400:cb00::/32", + "2606:4700::/32", + "2803:f800::/32", + "2405:b500::/32", + "2405:8100::/32", + "2a06:98c0::/29", + "2c0f:f248::/32", + }, + }, + want: "4.4.4.4", + }, + { + name: "Fail: no non-trusted IP", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 2.2.2.3`, `2.2.2.4`}, + }, + trustedRanges: []string{`2.2.2.0/24`}, + }, + want: "", + }, + { + name: "Fail: rightmost non-trusted IP invalid", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`nope, 2.2.2.2:3384, 2.2.2.3`, `2.2.2.4`}, + }, + trustedRanges: []string{`2.2.2.0/24`}, + }, + want: "", + }, + { + name: "Fail: rightmost non-trusted IP unspecified", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`::, 2.2.2.2:3384, 2.2.2.3`, `2.2.2.4`}, + }, + trustedRanges: []string{`2.2.2.0/24`}, + }, + want: "", + }, + { + name: "Fail: no values in header", + args: args{ + headerName: "X-Forwarded-For", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{}}, + trustedRanges: []string{`2.2.2.0/24`}, + }, + want: "", + }, + { + name: "Error: empty header nanme", + args: args{ + headerName: "", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + }, + trustedRanges: nil, + }, + wantErr: true, + }, + { + name: "Error: bad header nanme", + args: args{ + headerName: "Not-XFF-Or-Forwarded", + headers: http.Header{ + "X-Real-Ip": []string{`1.1.1.1`}, + "X-Forwarded-For": []string{`2.2.2.2:3384, 3.3.3.3`, `4.4.4.4`}, + }, + trustedRanges: nil, + }, + wantErr: true, + }, + } + + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + w := httptest.NewRecorder() + c := fox.NewTestContextOnly(fox.New(), w, req) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + ranges, err := AddressesAndRangesToIPNets(tt.args.trustedRanges...) + if err != nil { + // We're not testing AddressesAndRangesToIPNets here + t.Fatalf("AddressesAndRangesToIPNets failed") + } + + var s fox.ClientIPStrategy + if tt.wantErr { + require.Panics(t, func() { + s = NewRightmostTrustedRange(tt.args.headerName, ranges) + }) + return + } + + s = NewRightmostTrustedRange(tt.args.headerName, ranges) + + c.Request().Header = tt.args.headers + c.Request().RemoteAddr = tt.args.remoteAddr + ipAddr, err := s.ClientIP(c) + if tt.want == "" { + require.Error(t, err) + return + } + assert.Equal(t, tt.want, ipAddr.String()) + }) + } +} + +func ipAddrsEqual(a, b net.IPAddr) bool { + return a.IP.Equal(b.IP) && a.Zone == b.Zone +} diff --git a/tree.go b/tree.go index 3332930..93761a8 100644 --- a/tree.go +++ b/tree.go @@ -104,7 +104,7 @@ func (t *Tree) Has(method, path string) bool { return false } - c := t.ctx.Get().(*context) + c := t.ctx.Get().(*cTx) c.resetNil() n, tsr := t.lookup(nds[index], path, c.params, c.skipNds, true) c.Close() @@ -126,7 +126,7 @@ func (t *Tree) Match(method, path string) string { return "" } - c := t.ctx.Get().(*context) + c := t.ctx.Get().(*cTx) c.resetNil() n, tsr := t.lookup(nds[index], path, c.params, c.skipNds, true) c.Close() @@ -156,7 +156,7 @@ func (t *Tree) Methods(path string) []string { } } } else { - c := t.ctx.Get().(*context) + c := t.ctx.Get().(*cTx) c.resetNil() for i := range nds { n, tsr := t.lookup(nds[i], path, c.params, c.skipNds, true) @@ -177,12 +177,11 @@ func (t *Tree) Methods(path string) []string { // Lookup performs a manual route lookup for a given http.Request, returning the matched HandlerFunc along with a // ContextCloser, and a boolean indicating if the handler was matched by adding or removing a trailing slash // (trailing slash action is recommended). The ContextCloser should always be closed if non-nil. This method is primarily -// intended for integrating the fox router into custom routing solutions or middleware. It requires the use of the original -// http.ResponseWriter, typically obtained from ServeHTTP. This function is safe for concurrent use by multiple goroutine -// and while mutation on Tree are ongoing. If there is a direct match or a tsr is possible, Lookup always return a -// HandlerFunc and a ContextCloser. +// intended for integrating the fox router into custom routing solutions or middleware. This function is safe for concurrent +// use by multiple goroutine and while mutation on Tree are ongoing. If there is a direct match or a tsr is possible, +// Lookup always return a HandlerFunc and a ContextCloser. // This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) Lookup(w http.ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) { +func (t *Tree) Lookup(w ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) { nds := *t.nodes.Load() index := findRootNode(r.Method, nds) @@ -190,7 +189,7 @@ func (t *Tree) Lookup(w http.ResponseWriter, r *http.Request) (handler HandlerFu return } - c := t.ctx.Get().(*context) + c := t.ctx.Get().(*cTx) c.Reset(w, r) target := r.URL.Path @@ -804,15 +803,17 @@ STOP: } } -func (t *Tree) allocateContext() *context { +func (t *Tree) allocateContext() *cTx { params := make(Params, 0, t.maxParams.Load()) skipNds := make(skippedNodes, 0, t.maxDepth.Load()) - return &context{ + return &cTx{ params: ¶ms, skipNds: &skipNds, // This is a read only value, no reset, it's always the // owner of the pool. tree: t, + // This is a read only value, no reset. + fox: t.fox, } }