From 9172015c050d07cbc19aadb4d196e62efc6d6d1c Mon Sep 17 00:00:00 2001 From: Sylvain Muller Date: Mon, 7 Oct 2024 20:06:20 +0200 Subject: [PATCH] Introducing per route options (#37) * feat: wip per route options * feat: wip per route options * feat: wip per route options * Revert "feat: wip per route options" This reverts commit fb7f2e133292da1df0c37e2865e309ba4f7a234f. * feat: make sure that tree from different router cannot be swapped * feat: wip per route options * feat: wip per route options * feat: improve test coverage * feat: improve test coverage * feat: better handle stack frame skipping with recovery handler * feat: improve test coverage * feat: improve test coverage * feat: disable some lint rules * feat: replace sort.Slice by slices.SortFunc to improve sort performance * docs: fix comment for options that can be applied on a route basis. * feat: remove Route.HandleWithMiddleware API as it's not clear for now how and when to use it. * feat: make FixTrailingSlash public * feat: rework on the Match method --- context.go | 32 ++++-- fox.go | 126 ++++++++++++++------- fox_test.go | 258 ++++++++++++++++++++++++++++++++++--------- iter.go | 6 +- node.go | 30 +++-- options.go | 162 +++++++++++++++++++++------ path.go | 5 +- recovery.go | 2 +- response_writer.go | 7 +- strategy/strategy.go | 2 +- tree.go | 127 ++++++++++++--------- 11 files changed, 540 insertions(+), 217 deletions(-) diff --git a/context.go b/context.go index 4d74322..dcb8619 100644 --- a/context.go +++ b/context.go @@ -97,13 +97,13 @@ type cTx struct { params *Params tsrParams *Params skipNds *skippedNodes + route *Route // tree at allocation (read-only, no reset) tree *Tree // router at allocation (read-only, no reset) fox *Router cachedQuery url.Values - path string rec recorder tsr bool } @@ -112,9 +112,9 @@ type cTx struct { func (c *cTx) Reset(w ResponseWriter, r *http.Request) { c.req = r c.w = w - c.path = "" c.tsr = false c.cachedQuery = nil + c.route = nil *c.params = (*c.params)[:0] } @@ -125,16 +125,16 @@ func (c *cTx) reset(w http.ResponseWriter, r *http.Request) { c.rec.reset(w) c.req = r c.w = &c.rec - c.path = "" c.cachedQuery = nil + c.route = nil *c.params = (*c.params)[:0] } func (c *cTx) resetNil() { c.req = nil c.w = nil - c.path = "" c.cachedQuery = nil + c.route = nil *c.params = (*c.params)[:0] } @@ -186,8 +186,12 @@ func (c *cTx) RemoteIP() *net.IPAddr { // worthy of panicking. // This api is EXPERIMENTAL and is likely to change in future release. func (c *cTx) ClientIP() (*net.IPAddr, error) { - ipStrategy := c.Fox().ipStrategy - return ipStrategy.ClientIP(c) + // We may be in a handler which does not match a route like NotFound handler. + if c.route == nil { + ipStrategy := c.fox.ipStrategy + return ipStrategy.ClientIP(c) + } + return c.route.ipStrategy.ClientIP(c) } // Params returns a Params slice containing the matched @@ -235,7 +239,10 @@ func (c *cTx) Header(key string) string { // Path returns the registered path for the handler. func (c *cTx) Path() string { - return c.path + if c.route == nil { + return "" + } + return c.route.path } // String sends a formatted string with the specified status code. @@ -287,10 +294,11 @@ func (c *cTx) Fox() *Router { // Any attempt to write on the ResponseWriter will panic with the error ErrDiscardedResponseWriter. func (c *cTx) Clone() Context { cp := cTx{ - rec: c.rec, - req: c.req.Clone(c.req.Context()), - fox: c.fox, - tree: c.tree, + rec: c.rec, + req: c.req.Clone(c.req.Context()), + fox: c.fox, + tree: c.tree, + route: c.route, } cp.rec.ResponseWriter = noopWriter{c.rec.Header().Clone()} @@ -311,7 +319,7 @@ func (c *cTx) CloneWith(w ResponseWriter, r *http.Request) ContextCloser { cp := c.tree.ctx.Get().(*cTx) cp.req = r cp.w = w - cp.path = c.path + cp.route = c.route cp.cachedQuery = nil if cap(*c.params) > cap(*cp.params) { // Grow cp.params to a least cap(c.params) diff --git a/fox.go b/fox.go index a0a0f19..01f3bc0 100644 --- a/fox.go +++ b/fox.go @@ -67,6 +67,49 @@ func (f ClientIPStrategyFunc) ClientIP(c Context) (*net.IPAddr, error) { return f(c) } +// Route represent a registered route in the route tree. +// Most of the Route API is EXPERIMENTAL and is likely to change in future release. +type Route struct { + ipStrategy ClientIPStrategy + base HandlerFunc + handler HandlerFunc + path string + mws []middleware + redirectTrailingSlash bool + ignoreTrailingSlash bool +} + +// Handle calls the base handler with the provided Context. +func (r *Route) Handle(c Context) { + r.base(c) +} + +// Path returns the route path. +func (r *Route) Path() string { + return r.path +} + +// RedirectTrailingSlashEnabled returns whether the route is configured to automatically +// redirect requests that include or omit a trailing slash. +// This api is EXPERIMENTAL and is likely to change in future release. +func (r *Route) RedirectTrailingSlashEnabled() bool { + return r.redirectTrailingSlash +} + +// IgnoreTrailingSlashEnabled returns whether the route is configured to ignore +// trailing slashes in requests when matching routes. +// This api is EXPERIMENTAL and is likely to change in future release. +func (r *Route) IgnoreTrailingSlashEnabled() bool { + return r.ignoreTrailingSlash +} + +// ClientIPStrategyEnabled returns whether the route is configured with a ClientIPStrategy. +// This api is EXPERIMENTAL and is likely to change in future release. +func (r *Route) ClientIPStrategyEnabled() bool { + _, ok := r.ipStrategy.(noClientIPStrategy) + return !ok +} + // Router is a lightweight high performance HTTP request router that support mutation on its routing tree // while handling request concurrently. type Router struct { @@ -91,16 +134,16 @@ type middleware struct { var _ http.Handler = (*Router)(nil) // New returns a ready to use instance of Fox router. -func New(opts ...Option) *Router { +func New(opts ...GlobalOption) *Router { r := new(Router) - r.noRoute = DefaultNotFoundHandler() - r.noMethod = DefaultMethodNotAllowedHandler() - r.autoOptions = DefaultOptionsHandler() + r.noRoute = DefaultNotFoundHandler + r.noMethod = DefaultMethodNotAllowedHandler + r.autoOptions = DefaultOptionsHandler r.ipStrategy = noClientIPStrategy{} for _, opt := range opts { - opt.apply(r) + opt.applyGlob(r) } r.noRoute = applyMiddleware(NoRouteHandler, r.mws, r.noRoute) @@ -181,8 +224,13 @@ func (fox *Router) Tree() *Tree { } // Swap atomically replaces the currently in-use routing tree with the provided new tree, and returns the previous tree. -// This API is EXPERIMENTAL and is likely to change in future release. +// Note that the swap will panic if the current tree belongs to a different instance of the router, preventing accidental +// replacement of trees from different routers. func (fox *Router) Swap(new *Tree) (old *Tree) { + current := fox.tree.Load() + if current.fox != new.fox { + panic("swap failed: current and new routing trees belong to different router instances") + } return fox.tree.Swap(new) } @@ -190,11 +238,11 @@ func (fox *Router) Swap(new *Tree) (old *Tree) { // is already registered or conflict with another. It's perfectly safe to add a new handler while the tree is in use // for serving requests. This function is safe for concurrent use by multiple goroutine. // To override an existing route, use Update. -func (fox *Router) Handle(method, path string, handler HandlerFunc) error { +func (fox *Router) Handle(method, path string, handler HandlerFunc, opts ...PathOption) error { t := fox.Tree() t.Lock() defer t.Unlock() - return t.Handle(method, path, handler) + return t.Handle(method, path, handler, opts...) } // MustHandle registers a new handler for the given method and path. This function is a convenience @@ -202,8 +250,8 @@ func (fox *Router) Handle(method, path string, handler HandlerFunc) error { // with another route. It's perfectly safe to add a new handler while the tree is in use for serving // requests. This function is safe for concurrent use by multiple goroutines. // To override an existing route, use Update. -func (fox *Router) MustHandle(method, path string, handler HandlerFunc) { - if err := fox.Handle(method, path, handler); err != nil { +func (fox *Router) MustHandle(method, path string, handler HandlerFunc, opts ...PathOption) { + if err := fox.Handle(method, path, handler, opts...); err != nil { panic(err) } } @@ -212,11 +260,11 @@ func (fox *Router) MustHandle(method, path string, handler HandlerFunc) { // the function return an ErrRouteNotFound. It's perfectly safe to update a handler while the tree is in use for // serving requests. This function is safe for concurrent use by multiple goroutine. // To add new handler, use Handle method. -func (fox *Router) Update(method, path string, handler HandlerFunc) error { +func (fox *Router) Update(method, path string, handler HandlerFunc, opts ...PathOption) error { t := fox.Tree() t.Lock() defer t.Unlock() - return t.Update(method, path, handler) + return t.Update(method, path, handler, opts...) } // Remove delete an existing handler for the given method and path. If the route does not exist, the function @@ -230,11 +278,11 @@ func (fox *Router) Remove(method, path string) error { } // Lookup is a helper that calls Tree.Lookup. For more details, refer to Tree.Lookup. -// It performs a manual route lookup for a given http.Request, returning the matched HandlerFunc along with a ContextCloser, +// It performs a manual route lookup for a given http.Request, returning the matched Route along with a ContextCloser, // and a boolean indicating if a trailing slash action (e.g. redirect) is recommended (tsr). The ContextCloser should always // be closed if non-nil. // This API is EXPERIMENTAL and is likely to change in future release. -func (fox *Router) Lookup(w ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) { +func (fox *Router) Lookup(w ResponseWriter, r *http.Request) (route *Route, cc ContextCloser, tsr bool) { tree := fox.tree.Load() return tree.Lookup(w, r) } @@ -257,7 +305,7 @@ Next: method := nds[i].key it := newRawIterator(nds[i]) for it.hasNext() { - err := fn(method, it.path, it.current.handler) + err := fn(method, it.path, it.current.route.handler) if err != nil { if errors.Is(err, SkipMethod) { continue Next @@ -270,27 +318,21 @@ Next: return nil } -// DefaultNotFoundHandler returns a simple HandlerFunc that replies to each request +// DefaultNotFoundHandler is a simple HandlerFunc that replies to each request // with a “404 page not found” reply. -func DefaultNotFoundHandler() HandlerFunc { - return func(c Context) { - http.Error(c.Writer(), "404 page not found", http.StatusNotFound) - } +func DefaultNotFoundHandler(c Context) { + http.Error(c.Writer(), "404 page not found", http.StatusNotFound) } -// DefaultMethodNotAllowedHandler returns a simple HandlerFunc that replies to each request +// DefaultMethodNotAllowedHandler is a simple HandlerFunc that replies to each request // with a “405 Method Not Allowed” reply. -func DefaultMethodNotAllowedHandler() HandlerFunc { - return func(c Context) { - http.Error(c.Writer(), http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - } +func DefaultMethodNotAllowedHandler(c Context) { + http.Error(c.Writer(), http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) } -// DefaultOptionsHandler returns a simple HandlerFunc that replies to each request with a "200 OK" reply. -func DefaultOptionsHandler() HandlerFunc { - return func(c Context) { - c.Writer().WriteHeader(http.StatusOK) - } +// DefaultOptionsHandler is a simple HandlerFunc that replies to each request with a "200 OK" reply. +func DefaultOptionsHandler(c Context) { + c.Writer().WriteHeader(http.StatusOK) } func defaultRedirectTrailingSlashHandler(c Context) { @@ -304,9 +346,9 @@ func defaultRedirectTrailingSlashHandler(c Context) { var url string if len(req.URL.RawPath) > 0 { - url = fixTrailingSlash(req.URL.RawPath) + url = FixTrailingSlash(req.URL.RawPath) } else { - url = fixTrailingSlash(req.URL.Path) + url = FixTrailingSlash(req.URL.Path) } if url[len(url)-1] == '/' { @@ -343,9 +385,9 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { n, tsr = tree.lookup(nds[index], target, c, false) if !tsr && n != nil { - c.path = n.path + c.route = n.route c.tsr = tsr - n.handler(c) + n.route.handler(c) // Put back the context, if not extended more than max params or max depth, allowing // the slice to naturally grow within the constraint. if cap(*c.params) <= int(tree.maxParams.Load()) && cap(*c.skipNds) <= int(tree.maxDepth.Load()) { @@ -355,15 +397,15 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if r.Method != http.MethodConnect && r.URL.Path != "/" && tsr { - if fox.ignoreTrailingSlash { - c.path = n.path + if n.route.ignoreTrailingSlash { + c.route = n.route c.tsr = tsr - n.handler(c) + n.route.handler(c) c.Close() return } - if fox.redirectTrailingSlash && target == CleanPath(target) { + if n.route.redirectTrailingSlash && target == CleanPath(target) { // Reset params as it may have recorded wildcard segment (the context may still be used in a middleware) *c.params = (*c.params)[:0] c.tsr = false @@ -395,7 +437,7 @@ NoMethodFallback: } else { // Since different method and route may match (e.g. GET /foo/bar & POST /foo/{name}), we cannot set the path and params. for i := 0; i < len(nds); i++ { - if n, tsr := tree.lookup(nds[i], target, c, true); n != nil && (!tsr || fox.ignoreTrailingSlash) { + if n, tsr := tree.lookup(nds[i], target, c, true); n != nil && (!tsr || n.route.ignoreTrailingSlash) { if sb.Len() > 0 { sb.WriteString(", ") } @@ -415,7 +457,7 @@ NoMethodFallback: var sb strings.Builder for i := 0; i < len(nds); i++ { if nds[i].key != r.Method { - if n, tsr := tree.lookup(nds[i], target, c, true); n != nil && (!tsr || fox.ignoreTrailingSlash) { + if n, tsr := tree.lookup(nds[i], target, c, true); n != nil && (!tsr || n.route.ignoreTrailingSlash) { if sb.Len() > 0 { sb.WriteString(", ") } @@ -604,7 +646,7 @@ func getRouteConflict(n *node) []string { routes := make([]string, 0) if n.isCatchAll() { - routes = append(routes, n.path) + routes = append(routes, n.route.path) return routes } @@ -613,7 +655,7 @@ func getRouteConflict(n *node) []string { } it := newRawIterator(n) for it.hasNext() { - routes = append(routes, it.current.path) + routes = append(routes, it.current.route.path) } return routes } diff --git a/fox_test.go b/fox_test.go index f129171..b79eb99 100644 --- a/fox_test.go +++ b/fox_test.go @@ -750,8 +750,9 @@ func TestRouteWithParams(t *testing.T) { c := newTestContextTree(tree) n, tsr := tree.lookup(nds[0], rte, c, false) require.NotNil(t, n) + require.NotNil(t, n.route) assert.False(t, tsr) - assert.Equal(t, rte, n.path) + assert.Equal(t, rte, n.route.path) } } @@ -1196,9 +1197,9 @@ func TestOverlappingRoute(t *testing.T) { c := newTestContextTree(tree) n, tsr := tree.lookup(nds[0], tc.path, c, false) require.NotNil(t, n) - require.NotNil(t, n.handler) + require.NotNil(t, n.route) assert.False(t, tsr) - assert.Equal(t, tc.wantMatch, n.path) + assert.Equal(t, tc.wantMatch, n.route.path) if len(tc.wantParams) == 0 { assert.Empty(t, c.Params()) } else { @@ -1209,10 +1210,10 @@ func TestOverlappingRoute(t *testing.T) { c = newTestContextTree(tree) n, tsr = tree.lookup(nds[0], tc.path, c, true) require.NotNil(t, n) - require.NotNil(t, n.handler) + require.NotNil(t, n.route) assert.False(t, tsr) assert.Empty(t, c.Params()) - assert.Equal(t, tc.wantMatch, n.path) + assert.Equal(t, tc.wantMatch, n.route.path) }) } } @@ -1378,6 +1379,18 @@ func TestUpdateConflict(t *testing.T) { } } +func TestInvalidRoute(t *testing.T) { + f := New() + // Invalid route on insert + assert.ErrorIs(t, f.Handle("get", "/foo", emptyHandler), ErrInvalidRoute) + assert.ErrorIs(t, f.Handle("", "/foo", emptyHandler), ErrInvalidRoute) + assert.ErrorIs(t, f.Handle(http.MethodGet, "/foo", nil), ErrInvalidRoute) + + // Invalid route on update + assert.ErrorIs(t, f.Update("", "/foo", emptyHandler), ErrInvalidRoute) + assert.ErrorIs(t, f.Update(http.MethodGet, "/foo", nil), ErrInvalidRoute) +} + func TestUpdateRoute(t *testing.T) { cases := []struct { name string @@ -1643,7 +1656,7 @@ func TestTree_LookupTsr(t *testing.T) { t.Run(tc.name, func(t *testing.T) { tree := New().Tree() for _, path := range tc.paths { - require.NoError(t, tree.insert(http.MethodGet, path, "", 0, emptyHandler)) + require.NoError(t, tree.insert(http.MethodGet, path, "", 0, tree.newRoute(path, emptyHandler))) } nds := *tree.nodes.Load() c := newTestContextTree(tree) @@ -1651,7 +1664,8 @@ func TestTree_LookupTsr(t *testing.T) { assert.Equal(t, tc.want, got) if tc.want { require.NotNil(t, n) - assert.Equal(t, tc.wantPath, n.path) + require.NotNil(t, n.route) + assert.Equal(t, tc.wantPath, n.route.path) } }) } @@ -1765,6 +1779,9 @@ func TestRouterWithIgnoreTrailingSlash(t *testing.T) { require.NoError(t, r.Tree().Handle(tc.method, path, func(c Context) { _ = c.String(http.StatusOK, c.Path()) })) + rte := r.Tree().Route(tc.method, path) + require.NotNil(t, rte) + assert.True(t, rte.IgnoreTrailingSlashEnabled()) } req := httptest.NewRequest(tc.method, tc.req, nil) @@ -1779,10 +1796,33 @@ func TestRouterWithIgnoreTrailingSlash(t *testing.T) { } func TestRouterWithClientIPStrategy(t *testing.T) { - f := New(WithClientIPStrategy(ClientIPStrategyFunc(func(c Context) (*net.IPAddr, error) { + c1 := ClientIPStrategyFunc(func(c Context) (*net.IPAddr, error) { return c.RemoteIP(), nil - }))) - require.True(t, f.ClientIPStrategyEnabled()) + }) + f := New(WithClientIPStrategy(c1), WithNoRouteHandler(func(c Context) { + assert.Empty(t, c.Path()) + ip, err := c.ClientIP() + assert.NoError(t, err) + assert.NotNil(t, ip) + DefaultNotFoundHandler(c) + })) + f.MustHandle(http.MethodGet, "/foo", emptyHandler) + assert.True(t, f.ClientIPStrategyEnabled()) + + rte := f.Tree().Route(http.MethodGet, "/foo") + require.NotNil(t, rte) + assert.True(t, rte.ClientIPStrategyEnabled()) + + require.NoError(t, f.Update(http.MethodGet, "/foo", emptyHandler, WithClientIPStrategy(noClientIPStrategy{}))) + rte = f.Tree().Route(http.MethodGet, "/foo") + require.NotNil(t, rte) + assert.False(t, rte.ClientIPStrategyEnabled()) + + // On not found handler, fallback to global ip strategy + req := httptest.NewRequest(http.MethodGet, "/bar", nil) + w := httptest.NewRecorder() + f.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) } func TestRedirectTrailingSlash(t *testing.T) { @@ -1916,6 +1956,9 @@ func TestRedirectTrailingSlash(t *testing.T) { require.True(t, r.RedirectTrailingSlashEnabled()) for _, path := range tc.paths { require.NoError(t, r.Tree().Handle(tc.method, path, emptyHandler)) + rte := r.Tree().Route(tc.method, path) + require.NotNil(t, rte) + assert.True(t, rte.RedirectTrailingSlashEnabled()) } req := httptest.NewRequest(tc.method, tc.req, nil) @@ -2060,7 +2103,6 @@ func TestRouterWithTsrParams(t *testing.T) { f := New(WithIgnoreTrailingSlash(true)) for _, rte := range tc.routes { require.NoError(t, f.Handle(http.MethodGet, rte, func(c Context) { - fmt.Println(c.Path(), c.Params()) assert.Equal(t, tc.wantPath, c.Path()) assert.Equal(t, tc.wantParams, c.Params()) assert.Equal(t, tc.wantTsr, unwrapContext(t, c).tsr) @@ -2536,6 +2578,73 @@ func TestWithScopedMiddleware(t *testing.T) { assert.True(t, called) } +func TestUpdateWithMiddleware(t *testing.T) { + called := false + m := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { + return func(c Context) { + called = true + next(c) + } + }) + f := New() + f.MustHandle(http.MethodGet, "/foo", emptyHandler) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + w := httptest.NewRecorder() + + // Add middleware + require.NoError(t, f.Update(http.MethodGet, "/foo", emptyHandler, WithMiddleware(m))) + f.ServeHTTP(w, req) + assert.True(t, called) + called = false + + // Remove middleware + require.NoError(t, f.Update(http.MethodGet, "/foo", emptyHandler)) + f.ServeHTTP(w, req) + assert.False(t, called) +} + +func TestRouteMiddleware(t *testing.T) { + var c0, c1, c2 bool + m0 := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { + return func(c Context) { + c0 = true + next(c) + } + }) + + m1 := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { + return func(c Context) { + c1 = true + next(c) + } + }) + + m2 := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { + return func(c Context) { + c2 = true + next(c) + } + }) + f := New(WithMiddleware(m0)) + f.MustHandle(http.MethodGet, "/1", emptyHandler, WithMiddleware(m1)) + f.MustHandle(http.MethodGet, "/2", emptyHandler, WithMiddleware(m2)) + + req := httptest.NewRequest(http.MethodGet, "/1", nil) + w := httptest.NewRecorder() + + f.ServeHTTP(w, req) + assert.True(t, c0) + assert.True(t, c1) + assert.False(t, c2) + c0, c1, c2 = false, false, false + + req.URL.Path = "/2" + f.ServeHTTP(w, req) + assert.True(t, c0) + assert.False(t, c1) + assert.True(t, c2) +} + func TestWithNotFoundHandler(t *testing.T) { notFound := func(c Context) { _ = c.String(http.StatusNotFound, "NOT FOUND\n") @@ -2561,9 +2670,10 @@ func TestRouter_Lookup(t *testing.T) { for _, rte := range githubAPI { req := httptest.NewRequest(rte.method, rte.path, nil) - handler, cc, _ := f.Lookup(newResponseWriter(mockResponseWriter{}), req) + route, cc, _ := f.Lookup(newResponseWriter(mockResponseWriter{}), req) require.NotNil(t, cc) - assert.NotNil(t, handler) + require.NotNil(t, route) + assert.Equal(t, rte.path, route.Path()) matches := rx.FindAllString(rte.path, -1) for _, match := range matches { @@ -2582,14 +2692,14 @@ func TestRouter_Lookup(t *testing.T) { // No method match req := httptest.NewRequest("ANY", "/bar", nil) - handler, cc, _ := f.Lookup(newResponseWriter(mockResponseWriter{}), req) - assert.Nil(t, handler) + route, cc, _ := f.Lookup(newResponseWriter(mockResponseWriter{}), req) + assert.Nil(t, route) assert.Nil(t, cc) // No path match req = httptest.NewRequest(http.MethodGet, "/bar", nil) - handler, cc, _ = f.Lookup(newResponseWriter(mockResponseWriter{}), req) - assert.Nil(t, handler) + route, cc, _ = f.Lookup(newResponseWriter(mockResponseWriter{}), req) + assert.Nil(t, route) assert.Nil(t, cc) } @@ -2670,9 +2780,10 @@ func TestTree_Match(t *testing.T) { } cases := []struct { - name string - path string - want string + name string + path string + want string + wantTsr bool }{ { name: "reverse static route", @@ -2680,8 +2791,10 @@ func TestTree_Match(t *testing.T) { want: "/foo/bar", }, { - name: "reverse static route with tsr disable", - path: "/foo/bar/", + name: "reverse static route with tsr disable", + path: "/foo/bar/", + want: "/foo/bar", + wantTsr: true, }, { name: "reverse params route", @@ -2701,7 +2814,14 @@ func TestTree_Match(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, r.Tree().Match(http.MethodGet, tc.path)) + route, tsr := r.Tree().Match(http.MethodGet, tc.path) + if tc.want != "" { + require.NotNil(t, route) + assert.Equal(t, tc.want, route.Path()) + assert.Equal(t, tc.wantTsr, tsr) + return + } + assert.Nil(t, route) }) } } @@ -2719,9 +2839,10 @@ func TestTree_MatchWithIgnoreTrailingSlashEnable(t *testing.T) { } cases := []struct { - name string - path string - want string + name string + path string + want string + wantTsr bool }{ { name: "reverse static route", @@ -2729,9 +2850,10 @@ func TestTree_MatchWithIgnoreTrailingSlashEnable(t *testing.T) { want: "/foo/bar", }, { - name: "reverse static route with tsr", - path: "/foo/bar/", - want: "/foo/bar", + name: "reverse static route with tsr", + path: "/foo/bar/", + want: "/foo/bar", + wantTsr: true, }, { name: "reverse params route", @@ -2739,9 +2861,10 @@ func TestTree_MatchWithIgnoreTrailingSlashEnable(t *testing.T) { want: "/welcome/{name}/", }, { - name: "reverse params route with tsr", - path: "/welcome/fox", - want: "/welcome/{name}/", + name: "reverse params route with tsr", + path: "/welcome/fox", + want: "/welcome/{name}/", + wantTsr: true, }, { name: "reverse mid params route", @@ -2749,9 +2872,10 @@ func TestTree_MatchWithIgnoreTrailingSlashEnable(t *testing.T) { want: "/users/uid_{id}", }, { - name: "reverse mid params route with tsr", - path: "/users/uid_123/", - want: "/users/uid_{id}", + name: "reverse mid params route with tsr", + path: "/users/uid_123/", + want: "/users/uid_{id}", + wantTsr: true, }, { name: "reverse no match", @@ -2761,7 +2885,14 @@ func TestTree_MatchWithIgnoreTrailingSlashEnable(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, r.Tree().Match(http.MethodGet, tc.path)) + route, tsr := r.Tree().Match(http.MethodGet, tc.path) + if tc.want != "" { + require.NotNil(t, route) + assert.Equal(t, tc.want, route.Path()) + assert.Equal(t, tc.wantTsr, tsr) + return + } + assert.Nil(t, route) }) } } @@ -2780,6 +2911,20 @@ func TestEncodedPath(t *testing.T) { assert.Equal(t, encodedPath, w.Body.String()) } +func TestTreeSwap(t *testing.T) { + f := New() + tree := f.NewTree() + assert.NotPanics(t, func() { + f.Swap(tree) + }) + assert.Equal(t, tree, f.Tree()) + + f2 := New() + assert.Panics(t, func() { + f2.Swap(tree) + }) +} + func TestFuzzInsertLookupParam(t *testing.T) { // no '*', '{}' and '/' and invalid escape char unicodeRanges := fuzz.UnicodeRanges{ @@ -2804,14 +2949,16 @@ func TestFuzzInsertLookupParam(t *testing.T) { if s1 == "" || s2 == "" || e1 == "" || e2 == "" || e3 == "" { continue } - if err := tree.insert(http.MethodGet, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), "", 3, emptyHandler); err == nil { + path := fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3) + if err := tree.insert(http.MethodGet, path, "", 3, tree.newRoute(path, emptyHandler)); err == nil { nds := *tree.nodes.Load() c := newTestContextTree(tree) n, tsr := tree.lookup(nds[0], fmt.Sprintf(reqFormat, s1, "xxxx", s2, "xxxx", "xxxx"), c, false) require.NotNil(t, n) + require.NotNil(t, n.route) assert.False(t, tsr) - assert.Equal(t, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), n.path) + assert.Equal(t, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), n.route.path) assert.Equal(t, "xxxx", c.Param(e1)) assert.Equal(t, "xxxx", c.Param(e2)) assert.Equal(t, "xxxx", c.Param(e3)) @@ -2833,7 +2980,7 @@ func TestFuzzInsertNoPanics(t *testing.T) { continue } require.NotPanicsf(t, func() { - _ = tree.insert(http.MethodGet, rte, catchAllKey, 0, emptyHandler) + _ = tree.insert(http.MethodGet, rte, catchAllKey, 0, tree.newRoute(appendCatchAll(rte, catchAllKey), emptyHandler)) }, fmt.Sprintf("rte: %s, catch all: %s", rte, catchAllKey)) } } @@ -2854,7 +3001,8 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { f.Fuzz(&routes) for rte := range routes { - err := tree.insert(http.MethodGet, "/"+rte, "", 0, emptyHandler) + path := "/" + rte + err := tree.insert(http.MethodGet, path, "", 0, tree.newRoute(path, emptyHandler)) require.NoError(t, err) } @@ -2870,10 +3018,12 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { c := newTestContextTree(tree) n, tsr := tree.lookup(nds[0], "/"+rte, c, true) require.NotNilf(t, n, "route /%s", rte) + require.NotNilf(t, n.route, "route /%s", rte) require.Falsef(t, tsr, "tsr: %t", tsr) require.Truef(t, n.isLeaf(), "route /%s", rte) - require.Equal(t, "/"+rte, n.path) - require.NoError(t, tree.update(http.MethodGet, "/"+rte, "", emptyHandler)) + require.Equal(t, "/"+rte, n.route.path) + path := "/" + rte + require.NoError(t, tree.update(http.MethodGet, path, "", tree.newRoute(path, emptyHandler))) } for rte := range routes { @@ -3034,6 +3184,12 @@ func TestNode_String(t *testing.T) { assert.Equal(t, want, strings.TrimSuffix(nds[0].String(), "\n")) } +func TestFixTrailingSlash(t *testing.T) { + assert.Equal(t, "/foo/", FixTrailingSlash("/foo")) + assert.Equal(t, "/foo", FixTrailingSlash("/foo/")) + assert.Equal(t, "/", FixTrailingSlash("")) +} + // This example demonstrates how to create a simple router using the default options, // which include the Recovery and Logger middleware. A basic route is defined, along with a // custom middleware to log the request metrics. @@ -3131,15 +3287,15 @@ func ExampleRouter_Lookup() { target := req.URL.Path cleanedPath := CleanPath(target) - // Nothing to clean, call next handler or middleware. + // Nothing to clean, call next handler. if cleanedPath == target { next(c) return } req.URL.Path = cleanedPath - handler, cc, tsr := c.Fox().Lookup(c.Writer(), req) - if handler != nil { + route, cc, tsr := c.Fox().Lookup(c.Writer(), req) + if route != nil { defer cc.Close() code := http.StatusMovedPermanently @@ -3148,7 +3304,7 @@ func ExampleRouter_Lookup() { } // Redirect the client if direct match or indirect match. - if !tsr || c.Fox().IgnoreTrailingSlashEnabled() { + if !tsr || route.IgnoreTrailingSlashEnabled() { if err := c.Redirect(code, cleanedPath); err != nil { // Only if not in the range 300..308, so not possible here! panic(err) @@ -3157,8 +3313,8 @@ func ExampleRouter_Lookup() { } // Add or remove an extra trailing slash and redirect the client. - if c.Fox().RedirectTrailingSlashEnabled() { - if err := c.Redirect(code, fixTrailingSlash(cleanedPath)); err != nil { + if route.RedirectTrailingSlashEnabled() { + if err := c.Redirect(code, FixTrailingSlash(cleanedPath)); err != nil { // Only if not in the range 300..308, so not possible here panic(err) } @@ -3189,8 +3345,8 @@ func ExampleTree_Match() { f.MustHandle(http.MethodGet, "/hello/{name}", emptyHandler) tree := f.Tree() - matched := tree.Match(http.MethodGet, "/hello/fox") - fmt.Println(matched) // /hello/{name} + route, _ := tree.Match(http.MethodGet, "/hello/fox") + fmt.Println(route.Path()) // /hello/{name} } // This example demonstrates how to check if a given route is registered in the tree. @@ -3199,6 +3355,6 @@ func ExampleTree_Has() { f.MustHandle(http.MethodGet, "/hello/{name}", emptyHandler) tree := f.Tree() - exist := tree.Match(http.MethodGet, "/hello/{name}") + exist := tree.Has(http.MethodGet, "/hello/{name}") fmt.Println(exist) // true } diff --git a/iter.go b/iter.go index 9344deb..bfe092b 100644 --- a/iter.go +++ b/iter.go @@ -167,7 +167,7 @@ func (it *Iterator) Next() { // Path returns the registered path for the current route. func (it *Iterator) Path() string { if it.current != nil { - return it.current.path + return it.current.route.path } return "" } @@ -180,7 +180,7 @@ func (it *Iterator) Method() string { // Handler return the registered handler for the current route. func (it *Iterator) Handler() HandlerFunc { if it.current != nil { - return it.current.handler + return it.current.route.handler } return nil } @@ -221,7 +221,7 @@ func (it *rawIterator) hasNext() bool { it.current = elem if it.current.isLeaf() { - it.path = elem.path + it.path = elem.route.Path() return true } } diff --git a/node.go b/node.go index da8b570..4d8e1bf 100644 --- a/node.go +++ b/node.go @@ -5,16 +5,17 @@ package fox import ( - "sort" + "cmp" + "slices" "strconv" "strings" "sync/atomic" ) type node struct { - // The registered handler matching the full path. Nil if the node is not a leaf. - // Once assigned, handler is immutable. - handler HandlerFunc + // The registered route matching the full path. Nil if the node is not a leaf. + // Once assigned, route is immutable. + route *Route // key represent a segment of a route which share a common prefix with it parent. key string @@ -23,9 +24,6 @@ type node struct { // Once assigned, catchAllKey is immutable. catchAllKey string - // The full path when it's a leaf node - path string - // First char of each outgoing edges from this node sorted in ascending order. // Once assigned, this is a read only slice. It allows to lazily traverse the // tree without the extra cost of atomic load operation. @@ -42,9 +40,9 @@ type node struct { paramChildIndex int } -func newNode(key string, handler HandlerFunc, children []*node, catchAllKey string, path string) *node { - sort.Slice(children, func(i, j int) bool { - return children[i].key < children[j].key +func newNode(key string, route *Route, children []*node, catchAllKey string) *node { + slices.SortFunc(children, func(a, b *node) int { + return cmp.Compare(a.key, b.key) }) nds := make([]atomic.Pointer[node], len(children)) childKeys := make([]byte, len(children)) @@ -58,24 +56,23 @@ func newNode(key string, handler HandlerFunc, children []*node, catchAllKey stri } } - return newNodeFromRef(key, handler, nds, childKeys, catchAllKey, childIndex, path) + return newNodeFromRef(key, route, nds, childKeys, catchAllKey, childIndex) } -func newNodeFromRef(key string, handler HandlerFunc, children []atomic.Pointer[node], childKeys []byte, catchAllKey string, childIndex int, path string) *node { +func newNodeFromRef(key string, route *Route, children []atomic.Pointer[node], childKeys []byte, catchAllKey string, childIndex int) *node { return &node{ key: key, childKeys: childKeys, children: children, - handler: handler, + route: route, catchAllKey: catchAllKey, - path: appendCatchAll(path, catchAllKey), paramChildIndex: childIndex, params: parseWildcard(key), } } func (n *node) isLeaf() bool { - return n.handler != nil + return n.route != nil } func (n *node) isCatchAll() bool { @@ -134,6 +131,7 @@ func linearSearch(keys []byte, s byte) int { func binarySearch(keys []byte, s byte) int { low, high := 0, len(keys)-1 for low <= high { + // nolint:gosec mid := int(uint(low+high) >> 1) // avoid overflow cmp := compare(keys[mid], s) if cmp < 0 { @@ -211,7 +209,7 @@ func (n *node) string(space int) string { } if n.isLeaf() { sb.WriteString(" [leaf=") - sb.WriteString(n.path) + sb.WriteString(n.route.path) sb.WriteString("]") } if n.hasWildcard() { diff --git a/options.go b/options.go index b2d17ab..62a6b45 100644 --- a/options.go +++ b/options.go @@ -23,19 +23,46 @@ const ( ) type Option interface { - apply(*Router) + GlobalOption + PathOption } -type optionFunc func(*Router) +type GlobalOption interface { + applyGlob(*Router) +} + +type PathOption interface { + applyPath(*Route) +} + +type globOptionFunc func(*Router) + +func (o globOptionFunc) applyGlob(r *Router) { + o(r) +} + +// nolint:unused +type pathOptionFunc func(*Route) -func (o optionFunc) apply(r *Router) { +// nolint:unused +func (o pathOptionFunc) applyPath(r *Route) { o(r) } +type optionFunc func(*Router, *Route) + +func (o optionFunc) applyGlob(r *Router) { + o(r, nil) +} + +func (o optionFunc) applyPath(r *Route) { + o(nil, r) +} + // WithNoRouteHandler register an HandlerFunc which is called when no matching route is found. // By default, the DefaultNotFoundHandler is used. -func WithNoRouteHandler(handler HandlerFunc) Option { - return optionFunc(func(r *Router) { +func WithNoRouteHandler(handler HandlerFunc) GlobalOption { + return globOptionFunc(func(r *Router) { if handler != nil { r.noRoute = handler } @@ -46,8 +73,8 @@ func WithNoRouteHandler(handler HandlerFunc) Option { // but the same route exist for other methods. The "Allow" header it automatically set before calling the // handler. By default, the DefaultMethodNotAllowedHandler is used. Note that this option automatically // enable WithNoMethod. -func WithNoMethodHandler(handler HandlerFunc) Option { - return optionFunc(func(r *Router) { +func WithNoMethodHandler(handler HandlerFunc) GlobalOption { + return globOptionFunc(func(r *Router) { if handler != nil { r.noMethod = handler r.handleMethodNotAllowed = true @@ -59,9 +86,8 @@ func WithNoMethodHandler(handler HandlerFunc) Option { // respond with a 200 OK status code. The "Allow" header it automatically set before calling the handler. Note that custom OPTIONS // handler take priority over automatic replies. By default, DefaultOptionsHandler is used. Note that this option // automatically enable WithAutoOptions. -// This api is EXPERIMENTAL and is likely to change in future release. -func WithOptionsHandler(handler HandlerFunc) Option { - return optionFunc(func(r *Router) { +func WithOptionsHandler(handler HandlerFunc) GlobalOption { + return globOptionFunc(func(r *Router) { if handler != nil { r.autoOptions = handler r.handleOptions = true @@ -69,20 +95,37 @@ func WithOptionsHandler(handler HandlerFunc) Option { }) } -// WithMiddleware attaches a global middleware to the router. Middlewares provided will be chained -// in the order they were added. Note that this option apply middleware to all handler, including NotFound, -// MethodNotAllowed and the internal redirect handler. +// WithMiddleware attaches middleware to the router or to a specific route. The middlewares are executed +// in the order they are added. When applied globally, the middleware affects all handlers, including special handlers +// such as NotFound, MethodNotAllowed, AutoOption, and the internal redirect handler. +// +// This option can be applied on a per-route basis or globally: +// - If applied globally, the middleware will be applied to all routes and handlers by default. +// - If applied to a specific route, the middleware will only apply to that route and will be chained after any global middleware. +// +// Route-specific middleware must be explicitly reapplied when updating a route. If not, any middleware will be removed, +// and the route will fall back to using only global middleware (if any). func WithMiddleware(m ...MiddlewareFunc) Option { - return WithMiddlewareFor(AllHandlers, m...) + return optionFunc(func(router *Router, route *Route) { + if router != nil { + for i := range m { + router.mws = append(router.mws, middleware{m[i], AllHandlers}) + } + } + if route != nil { + for i := range m { + route.mws = append(route.mws, middleware{m[i], RouteHandlers}) + } + } + }) } // WithMiddlewareFor attaches middleware to the router for a specified scope. Middlewares provided will be chained // in the order they were added. The scope parameter determines which types of handlers the middleware will be applied to. // Possible scopes include RouteHandlers (regular routes), NoRouteHandler, NoMethodHandler, RedirectHandler, OptionsHandler, // and any combination of these. Use this option when you need fine-grained control over where the middleware is applied. -// This api is EXPERIMENTAL and is likely to change in future release. -func WithMiddlewareFor(scope MiddlewareScope, m ...MiddlewareFunc) Option { - return optionFunc(func(r *Router) { +func WithMiddlewareFor(scope MiddlewareScope, m ...MiddlewareFunc) GlobalOption { + return globOptionFunc(func(r *Router) { for i := range m { r.mws = append(r.mws, middleware{m[i], scope}) } @@ -93,8 +136,8 @@ func WithMiddlewareFor(scope MiddlewareScope, m ...MiddlewareFunc) Option { // when the route exist for another http verb. The "Allow" header it automatically set before calling the // handler. Note that this option is automatically enabled when providing a custom handler with the // option WithNoMethodHandler. -func WithNoMethod(enable bool) Option { - return optionFunc(func(r *Router) { +func WithNoMethod(enable bool) GlobalOption { + return globOptionFunc(func(r *Router) { r.handleMethodNotAllowed = enable }) } @@ -103,9 +146,9 @@ func WithNoMethod(enable bool) Option { // Use the WithOptionsHandler option to customize the response. When this option is enabled, the router automatically // determines the "Allow" header value based on the methods registered for the given route. Note that custom OPTIONS // handler take priority over automatic replies. This option is automatically enabled when providing a custom handler with -// the option WithOptionsHandler. This api is EXPERIMENTAL and is likely to change in future release. -func WithAutoOptions(enable bool) Option { - return optionFunc(func(r *Router) { +// the option WithOptionsHandler. +func WithAutoOptions(enable bool) GlobalOption { + return globOptionFunc(func(r *Router) { r.handleOptions = enable }) } @@ -113,22 +156,59 @@ func WithAutoOptions(enable bool) Option { // WithRedirectTrailingSlash enable automatic redirection fallback when the current request does not match but // another handler is found with/without an additional trailing slash. E.g. /foo/bar/ request does not match // but /foo/bar would match. The client is redirected with a http status code 301 for GET requests and 308 for -// all other methods. Note that this option is mutually exclusive with WithIgnoreTrailingSlash, and if both are -// enabled, WithIgnoreTrailingSlash takes precedence. +// all other methods. +// +// This option can be applied on a per-route basis or globally: +// - If applied globally, it affects all routes by default. +// - If applied to a specific route, it will override the global setting for that route. +// - The option must be explicitly reapplied when updating a route. If not, the route will fall back +// to the global configuration for trailing slash behavior. +// +// Note that this option is mutually exclusive with WithIgnoreTrailingSlash, and if enabled will +// automatically deactivate WithIgnoreTrailingSlash. func WithRedirectTrailingSlash(enable bool) Option { - return optionFunc(func(r *Router) { - r.redirectTrailingSlash = enable + return optionFunc(func(router *Router, route *Route) { + if router != nil { + router.redirectTrailingSlash = enable + if enable { + router.ignoreTrailingSlash = false + } + } + if route != nil { + route.redirectTrailingSlash = enable + if enable { + route.ignoreTrailingSlash = false + } + } }) } // WithIgnoreTrailingSlash allows the router to match routes regardless of whether a trailing slash is present or not. // E.g. /foo/bar/ and /foo/bar would both match the same handler. This option prevents the router from issuing -// a redirect and instead matches the request directly. Note that this option is mutually exclusive with -// WithRedirectTrailingSlash, and if both are enabled, WithIgnoreTrailingSlash takes precedence. -// This api is EXPERIMENTAL and is likely to change in future release. +// a redirect and instead matches the request directly. +// +// This option can be applied on a per-route basis or globally: +// - If applied globally, it affects all routes by default. +// - If applied to a specific route, it will override the global setting for that route. +// - The option must be explicitly reapplied when updating a route. If not, the route will fall back +// to the global configuration for trailing slash behavior. +// +// Note that this option is mutually exclusive with +// WithRedirectTrailingSlash, and if enabled will automatically deactivate WithRedirectTrailingSlash. func WithIgnoreTrailingSlash(enable bool) Option { - return optionFunc(func(r *Router) { - r.ignoreTrailingSlash = enable + return optionFunc(func(router *Router, route *Route) { + if router != nil { + router.ignoreTrailingSlash = enable + if enable { + router.redirectTrailingSlash = false + } + } + if route != nil { + route.ignoreTrailingSlash = enable + if enable { + route.redirectTrailingSlash = false + } + } }) } @@ -137,11 +217,21 @@ func WithIgnoreTrailingSlash(enable bool) Option { // configuration to ensure it never returns an error -- i.e., never fails to find a candidate for the "real" IP. // Consequently, getting an error result should be treated as an application error, perhaps even worthy of panicking. // There is no sane default, so if no strategy is configured, Context.ClientIP returns ErrNoClientIPStrategy. -// This API is EXPERIMENTAL and is likely to change in future releases. +// +// This option can be applied on a per-route basis or globally: +// - If applied globally, it affects all routes by default. +// - If applied to a specific route, it will override the global setting for that route. +// - The option must be explicitly reapplied when updating a route. If not, the route will fall back +// to the global client IP strategy (if one is configured). func WithClientIPStrategy(strategy ClientIPStrategy) Option { - return optionFunc(func(r *Router) { + return optionFunc(func(router *Router, route *Route) { if strategy != nil { - r.ipStrategy = strategy + if router != nil { + router.ipStrategy = strategy + } + if route != nil { + route.ipStrategy = strategy + } } }) } @@ -149,8 +239,8 @@ func WithClientIPStrategy(strategy ClientIPStrategy) Option { // DefaultOptions configure the router to use the Recovery middleware for the RouteHandlers scope, the Logger middleware // for AllHandlers scope and enable automatic OPTIONS response. Note that DefaultOptions push the Recovery and Logger middleware // respectively to the first and second position of the middleware chains. -func DefaultOptions() Option { - return optionFunc(func(r *Router) { +func DefaultOptions() GlobalOption { + return globOptionFunc(func(r *Router) { r.mws = append([]middleware{ {Recovery(), RouteHandlers}, {Logger(), AllHandlers}, diff --git a/path.go b/path.go index 571227e..c67984e 100644 --- a/path.go +++ b/path.go @@ -149,7 +149,10 @@ func bufApp(buf *[]byte, s string, w int, c byte) { b[w] = c } -func fixTrailingSlash(path string) string { +// FixTrailingSlash ensures a consistent trailing slash handling for a given path. +// If the path has more than one character and ends with a slash, it removes the trailing slash. +// Otherwise, it adds a trailing slash to the path. +func FixTrailingSlash(path string) string { if len(path) > 1 && path[len(path)-1] == '/' { return path[:len(path)-1] } diff --git a/recovery.go b/recovery.go index ca095ae..3f21efa 100644 --- a/recovery.go +++ b/recovery.go @@ -78,7 +78,7 @@ func recovery(logger *slog.Logger, c Context, handle RecoveryFunc) { } sb.WriteString("Stack:\n") - sb.WriteString(stacktrace(4, 6)) + sb.WriteString(stacktrace(3, 6)) params := c.Params() attrs := make([]any, 0, len(params)) diff --git a/response_writer.go b/response_writer.go index aa4ff9f..66841f9 100644 --- a/response_writer.go +++ b/response_writer.go @@ -162,7 +162,7 @@ func (r *recorder) FlushError() error { flusher.Flush() return nil default: - return errNotSupported() + return ErrNotSupported() } } @@ -181,7 +181,7 @@ func (r *recorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hijacker, ok := r.ResponseWriter.(http.Hijacker); ok { return hijacker.Hijack() } - return nil, nil, errNotSupported() + return nil, nil, ErrNotSupported() } type noUnwrap struct { @@ -225,6 +225,7 @@ func relevantCaller() runtime.Frame { return frame } -func errNotSupported() error { +// ErrNotSupported returns an error that Is ErrNotSupported, but is not == to it. +func ErrNotSupported() error { return fmt.Errorf("%w", http.ErrNotSupported) } diff --git a/strategy/strategy.go b/strategy/strategy.go index 2e302cd..91e8c58 100644 --- a/strategy/strategy.go +++ b/strategy/strategy.go @@ -292,8 +292,8 @@ func (s RightmostTrustedCount) ClientIP(c fox.Context) (*net.IPAddr, error) { // attacker creates a CF distribution that points at your origin server. The attacker uses Lambda@Edge to spoof the Host // and X-Forwarded-For headers. Now your "trusted" reverse proxy is no longer trustworthy. type RightmostTrustedRange struct { - headerName string resolver TrustedIPRange + headerName string } // NewRightmostTrustedRange creates a RightmostTrustedRange strategy. headerName must be "X-Forwarded-For" diff --git a/tree.go b/tree.go index 8c7b4aa..7bd437c 100644 --- a/tree.go +++ b/tree.go @@ -44,7 +44,10 @@ type Tree struct { // is already registered or conflict with another. It's perfectly safe to add a new handler while the tree is in use // for serving requests. However, this function is NOT thread-safe and should be run serially, along with all other // Tree APIs that perform write operations. To override an existing route, use Update. -func (t *Tree) Handle(method, path string, handler HandlerFunc) error { +func (t *Tree) Handle(method, path string, handler HandlerFunc, opts ...PathOption) error { + if handler == nil { + return fmt.Errorf("%w: nil handler", ErrInvalidRoute) + } if matched := regEnLetter.MatchString(method); !matched { return fmt.Errorf("%w: missing or invalid http method", ErrInvalidRoute) } @@ -54,14 +57,18 @@ func (t *Tree) Handle(method, path string, handler HandlerFunc) error { return err } - return t.insert(method, p, catchAllKey, uint32(n), applyMiddleware(RouteHandlers, t.mws, handler)) + // nolint:gosec + return t.insert(method, p, catchAllKey, uint32(n), t.newRoute(path, handler, opts...)) } // Update override an existing handler for the given method and path. If the route does not exist, // the function return an ErrRouteNotFound. It's perfectly safe to update a handler while the tree is in use for // serving requests. However, this function is NOT thread-safe and should be run serially, along with all other // Tree APIs that perform write operations. To add a new handler, use Handle method. -func (t *Tree) Update(method, path string, handler HandlerFunc) error { +func (t *Tree) Update(method, path string, handler HandlerFunc, opts ...PathOption) error { + if handler == nil { + return fmt.Errorf("%w: nil handler", ErrInvalidRoute) + } if method == "" { return fmt.Errorf("%w: missing http method", ErrInvalidRoute) } @@ -71,7 +78,7 @@ func (t *Tree) Update(method, path string, handler HandlerFunc) error { return err } - return t.update(method, p, catchAllKey, applyMiddleware(RouteHandlers, t.mws, handler)) + return t.update(method, p, catchAllKey, t.newRoute(path, handler, opts...)) } // Remove delete an existing handler for the given method and path. If the route does not exist, the function @@ -95,46 +102,54 @@ func (t *Tree) Remove(method, path string) error { return nil } -// Has allows to check if the given method and path exactly match a registered route. This function is safe for concurrent -// use by multiple goroutine and while mutation on Tree are ongoing. +// Has allows to check if the given method and path exactly match a registered route. This function is safe for +// concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. func (t *Tree) Has(method, path string) bool { + return t.Route(method, path) != nil +} + +// Route performs a lookup for a registered route matching the given method and path. It returns the route if a +// match is found or nil otherwise. This function is safe for concurrent use by multiple goroutine and while +// mutation on Tree are ongoing. +// This API is EXPERIMENTAL and is likely to change in future release. +func (t *Tree) Route(method, path string) *Route { nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { - return false + return nil } c := t.ctx.Get().(*cTx) c.resetNil() n, tsr := t.lookup(nds[index], path, c, true) c.Close() - if n != nil && !tsr { - return n.path == path + if n != nil && !tsr && n.route.path == path { + return n.route } - return false + return nil } -// Match perform a reverse lookup on the tree for the given method and path and return the matching registered route if any. When -// WithIgnoreTrailingSlash or WithRedirectTrailingSlash are enabled, Match will match a registered route regardless of an -// extra or missing trailing slash. This function is safe for concurrent use by multiple goroutine and while mutation on -// Tree are ongoing. See also Tree.Lookup as an alternative. +// Match perform a reverse lookup on the tree for the given method and path and return the matching registered route +// (if any) along with a boolean indicating if the route was matched by adding or removing a trailing slash +// (trailing slash action is recommended). This function is safe for concurrent use by multiple goroutine and while +// mutation on Tree are ongoing. See also Tree.Lookup as an alternative. // This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) Match(method, path string) string { +func (t *Tree) Match(method, path string) (route *Route, tsr bool) { nds := *t.nodes.Load() index := findRootNode(method, nds) if index < 0 { - return "" + return nil, false } c := t.ctx.Get().(*cTx) c.resetNil() n, tsr := t.lookup(nds[index], path, c, true) c.Close() - if n != nil && (!tsr || t.fox.redirectTrailingSlash || t.fox.ignoreTrailingSlash) { - return n.path + if n != nil { + return n.route, tsr } - return "" + return nil, false } // Methods returns a sorted list of HTTP methods associated with a given path in the routing tree. If the path is "*", @@ -161,7 +176,7 @@ func (t *Tree) Methods(path string) []string { c.resetNil() for i := range nds { n, tsr := t.lookup(nds[i], path, c, true) - if n != nil && (!tsr || t.fox.redirectTrailingSlash || t.fox.ignoreTrailingSlash) { + if n != nil && (!tsr || n.route.redirectTrailingSlash || n.route.ignoreTrailingSlash) { if methods == nil { methods = make([]string, 0) } @@ -175,14 +190,14 @@ func (t *Tree) Methods(path string) []string { return methods } -// Lookup performs a manual route lookup for a given http.Request, returning the matched HandlerFunc along with a -// ContextCloser, and a boolean indicating if the handler was matched by adding or removing a trailing slash +// Lookup performs a manual route lookup for a given http.Request, returning the matched Route along with a +// ContextCloser, and a boolean indicating if the route was matched by adding or removing a trailing slash // (trailing slash action is recommended). The ContextCloser should always be closed if non-nil. This method is primarily // intended for integrating the fox router into custom routing solutions or middleware. This function is safe for concurrent // use by multiple goroutine and while mutation on Tree are ongoing. If there is a direct match or a tsr is possible, -// Lookup always return a HandlerFunc and a ContextCloser. +// Lookup always return a Route and a ContextCloser. // This API is EXPERIMENTAL and is likely to change in future release. -func (t *Tree) Lookup(w ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) { +func (t *Tree) Lookup(w ResponseWriter, r *http.Request) (route *Route, cc ContextCloser, tsr bool) { nds := *t.nodes.Load() index := findRootNode(r.Method, nds) @@ -201,9 +216,9 @@ func (t *Tree) Lookup(w ResponseWriter, r *http.Request) (handler HandlerFunc, c n, tsr := t.lookup(nds[index], target, c, false) if n != nil { - c.path = n.path + c.route = n.route c.tsr = tsr - return n.handler, c, tsr + return n.route, c, tsr } c.Close() return nil, nil, tsr @@ -211,7 +226,7 @@ func (t *Tree) Lookup(w ResponseWriter, r *http.Request) (handler HandlerFunc, c // Insert is not safe for concurrent use. The path must start by '/' and it's not validated. Use // parseRoute before. -func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler HandlerFunc) error { +func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, route *Route) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. var rootNode *node nds := *t.nodes.Load() @@ -237,12 +252,12 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler if result.matched.isCatchAll() && isCatchAll { return newConflictErr(method, path, catchAllKey, getRouteConflict(result.matched)) } - return fmt.Errorf("%w: new route %s %s conflict with %s", ErrRouteExist, method, appendCatchAll(path, catchAllKey), result.matched.path) + return fmt.Errorf("%w: new route %s %s conflict with %s", ErrRouteExist, method, route.path, result.matched.route.path) } // We are updating an existing node. We only need to create a new node from // the matched one with the updated/added value (handler and wildcard). - n := newNodeFromRef(result.matched.key, handler, result.matched.children, result.matched.childKeys, catchAllKey, result.matched.paramChildIndex, path) + n := newNodeFromRef(result.matched.key, route, result.matched.children, result.matched.childKeys, catchAllKey, result.matched.paramChildIndex) t.updateMaxParams(paramsN) result.p.updateEdge(n) @@ -269,20 +284,18 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler child := newNodeFromRef( suffixFromExistingEdge, - result.matched.handler, + result.matched.route, result.matched.children, result.matched.childKeys, result.matched.catchAllKey, result.matched.paramChildIndex, - result.matched.path, ) parent := newNode( cPrefix, - handler, + route, []*node{child}, catchAllKey, - path, ) t.updateMaxParams(paramsN) @@ -306,15 +319,14 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler keySuffix := path[result.charsMatched:] // No children, so no paramChild - child := newNode(keySuffix, handler, nil, catchAllKey, path) + child := newNode(keySuffix, route, nil, catchAllKey) edges := result.matched.getEdgesShallowCopy() edges = append(edges, child) n := newNode( result.matched.key, - result.matched.handler, + result.matched.route, edges, result.matched.catchAllKey, - result.matched.path, ) t.updateMaxDepth(result.depth + 1) @@ -364,19 +376,18 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler keySuffix := path[result.charsMatched:] // No children, so no paramChild - n1 := newNodeFromRef(keySuffix, handler, nil, nil, catchAllKey, -1, path) // inserted node + n1 := newNodeFromRef(keySuffix, route, nil, nil, catchAllKey, -1) // inserted node n2 := newNodeFromRef( suffixFromExistingEdge, - result.matched.handler, + result.matched.route, result.matched.children, result.matched.childKeys, result.matched.catchAllKey, result.matched.paramChildIndex, - result.matched.path, ) // previous matched node // n3 children never start with a param - n3 := newNode(cPrefix, nil, []*node{n1, n2}, "", "") // intermediary node + n3 := newNode(cPrefix, nil, []*node{n1, n2}, "") // intermediary node t.updateMaxDepth(result.depth + 1) t.updateMaxParams(paramsN) @@ -389,7 +400,7 @@ func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler } // update is not safe for concurrent use. -func (t *Tree) update(method string, path, catchAllKey string, handler HandlerFunc) error { +func (t *Tree) update(method string, path, catchAllKey string, route *Route) error { // Note that we need a consistent view of the tree during the patching so search must imperatively be locked. nds := *t.nodes.Load() index := findRootNode(method, nds) @@ -403,7 +414,7 @@ func (t *Tree) update(method string, path, catchAllKey string, handler HandlerFu } if catchAllKey != result.matched.catchAllKey { - err := newConflictErr(method, path, catchAllKey, []string{result.matched.path}) + err := newConflictErr(method, path, catchAllKey, []string{result.matched.route.path}) err.isUpdate = true return err } @@ -412,12 +423,11 @@ func (t *Tree) update(method string, path, catchAllKey string, handler HandlerFu // the matched one with the updated/added value (handler and wildcard). n := newNodeFromRef( result.matched.key, - handler, + route, result.matched.children, result.matched.childKeys, catchAllKey, result.matched.paramChildIndex, - path, ) result.p.updateEdge(n) return nil @@ -450,7 +460,6 @@ func (t *Tree) remove(method, path string) bool { result.matched.childKeys, "", result.matched.paramChildIndex, - "", ) result.p.updateEdge(n) return true @@ -461,12 +470,11 @@ func (t *Tree) remove(method, path string) bool { mergedPath := fmt.Sprintf("%s%s", result.matched.key, child.key) n := newNodeFromRef( mergedPath, - child.handler, + child.route, child.children, child.childKeys, child.catchAllKey, child.paramChildIndex, - child.path, ) result.p.updateEdge(n) return true @@ -490,20 +498,18 @@ func (t *Tree) remove(method, path string) bool { mergedPath := fmt.Sprintf("%s%s", result.p.key, child.key) parent = newNodeFromRef( mergedPath, - child.handler, + child.route, child.children, child.childKeys, child.catchAllKey, child.paramChildIndex, - child.path, ) } else { parent = newNode( result.p.key, - result.p.handler, + result.p.route, parentEdges, result.p.catchAllKey, - result.p.path, ) } @@ -932,3 +938,22 @@ func (t *Tree) updateMaxDepth(max uint32) { t.maxDepth.Store(max) } } + +// newRoute create a new route, apply path options and apply middleware on the handler. +func (t *Tree) newRoute(path string, handler HandlerFunc, opts ...PathOption) *Route { + rte := &Route{ + ipStrategy: t.fox.ipStrategy, + base: handler, + path: path, + mws: t.mws, + redirectTrailingSlash: t.fox.redirectTrailingSlash, + ignoreTrailingSlash: t.fox.ignoreTrailingSlash, + } + + for _, opt := range opts { + opt.applyPath(rte) + } + rte.handler = applyMiddleware(RouteHandlers, rte.mws, handler) + + return rte +}