diff --git a/context.go b/context.go index dcb8619..9303490 100644 --- a/context.go +++ b/context.go @@ -19,6 +19,7 @@ import ( // to release resources after use. type ContextCloser interface { Context + // Close releases the context to be reused later. Close() } @@ -82,12 +83,13 @@ type Context interface { // This functionality is particularly beneficial for middlewares that need to wrap // their custom ResponseWriter while preserving the state of the original Context. CloneWith(w ResponseWriter, r *http.Request) ContextCloser + // Scope returns the HandlerScope associated with the current Context. + // This indicates the scope in which the handler is being executed, such as RouteHandler, NoRouteHandler, etc. + Scope() HandlerScope // Tree is a local copy of the Tree in use to serve the request. Tree() *Tree // Fox returns the Router instance. Fox() *Router - // Reset resets the Context to its initial state, attaching the provided ResponseWriter and http.Request. - Reset(w ResponseWriter, r *http.Request) } // cTx holds request-related information and allows interaction with the ResponseWriter. @@ -105,6 +107,7 @@ type cTx struct { fox *Router cachedQuery url.Values rec recorder + scope HandlerScope tsr bool } @@ -115,11 +118,12 @@ func (c *cTx) Reset(w ResponseWriter, r *http.Request) { c.tsr = false c.cachedQuery = nil c.route = nil + c.scope = RouteHandler *c.params = (*c.params)[:0] } // reset resets the Context to its initial state, attaching the provided http.ResponseWriter and http.Request. -// Caution: You should always pass the original http.ResponseWriter to this method, not the ResponseWriter itself, to +// Caution: always pass the original http.ResponseWriter to this method, not the ResponseWriter itself, to // avoid wrapping the ResponseWriter within itself. Use wisely! func (c *cTx) reset(w http.ResponseWriter, r *http.Request) { c.rec.reset(w) @@ -127,6 +131,7 @@ func (c *cTx) reset(w http.ResponseWriter, r *http.Request) { c.w = &c.rec c.cachedQuery = nil c.route = nil + c.scope = RouteHandler *c.params = (*c.params)[:0] } @@ -299,6 +304,7 @@ func (c *cTx) Clone() Context { fox: c.fox, tree: c.tree, route: c.route, + scope: c.scope, } cp.rec.ResponseWriter = noopWriter{c.rec.Header().Clone()} @@ -320,6 +326,7 @@ func (c *cTx) CloneWith(w ResponseWriter, r *http.Request) ContextCloser { cp.req = r cp.w = w cp.route = c.route + cp.scope = c.scope cp.cachedQuery = nil if cap(*c.params) > cap(*cp.params) { // Grow cp.params to a least cap(c.params) @@ -332,6 +339,12 @@ func (c *cTx) CloneWith(w ResponseWriter, r *http.Request) ContextCloser { return cp } +// Scope returns the HandlerScope associated with the current Context. +// This indicates the scope in which the handler is being executed, such as RouteHandler, NoRouteHandler, etc. +func (c *cTx) Scope() HandlerScope { + return c.scope +} + // Close releases the context to be reused later. func (c *cTx) Close() { // Put back the context, if not extended more than max params or max depth, allowing diff --git a/context_test.go b/context_test.go index 2eecd8a..7ec3cb6 100644 --- a/context_test.go +++ b/context_test.go @@ -265,6 +265,75 @@ func TestContext_Tree(t *testing.T) { f.ServeHTTP(w, req) } +func TestContext_Scope(t *testing.T) { + t.Parallel() + + f := New( + WithRedirectTrailingSlash(true), + WithMiddlewareFor(RedirectHandler, func(next HandlerFunc) HandlerFunc { + return func(c Context) { + assert.Equal(t, RedirectHandler, c.Scope()) + next(c) + } + }), + WithNoRouteHandler(func(c Context) { + assert.Equal(t, NoRouteHandler, c.Scope()) + }), + WithNoMethodHandler(func(c Context) { + assert.Equal(t, NoMethodHandler, c.Scope()) + }), + WithOptionsHandler(func(c Context) { + assert.Equal(t, OptionsHandler, c.Scope()) + }), + ) + require.NoError(t, f.Handle(http.MethodGet, "/foo", func(c Context) { + assert.Equal(t, RouteHandler, c.Scope()) + })) + + cases := []struct { + name string + req *http.Request + w http.ResponseWriter + }{ + { + name: "route handler scope", + req: httptest.NewRequest(http.MethodGet, "/foo", nil), + w: httptest.NewRecorder(), + }, + { + name: "redirect handler scope", + req: httptest.NewRequest(http.MethodGet, "/foo/", nil), + w: httptest.NewRecorder(), + }, + { + name: "no method handler scope", + req: httptest.NewRequest(http.MethodPost, "/foo", nil), + w: httptest.NewRecorder(), + }, + { + name: "options handler scope", + req: httptest.NewRequest(http.MethodOptions, "/foo", nil), + w: httptest.NewRecorder(), + }, + { + name: "options handler scope", + req: httptest.NewRequest(http.MethodOptions, "/foo", nil), + w: httptest.NewRecorder(), + }, + { + name: "no route handler scope", + req: httptest.NewRequest(http.MethodOptions, "/bar", nil), + w: httptest.NewRecorder(), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + f.ServeHTTP(tc.w, tc.req) + }) + } +} + func TestWrapF(t *testing.T) { t.Parallel() diff --git a/fox.go b/fox.go index f9617b1..f1f1f95 100644 --- a/fox.go +++ b/fox.go @@ -66,21 +66,48 @@ func (f ClientIPStrategyFunc) ClientIP(c Context) (*net.IPAddr, error) { return f(c) } +// HandlerScope represents different scopes where a handler may be called. It also allows for fine-grained control +// over where middleware is applied. +type HandlerScope uint8 + +const ( + // RouteHandler scope applies to regular routes registered in the router. + RouteHandler HandlerScope = 1 << (8 - 1 - iota) + // NoRouteHandler scope applies to the NoRoute handler, which is invoked when no route matches the request. + NoRouteHandler + // NoMethodHandler scope applies to the NoMethod handler, which is invoked when a route exists, but the method is not allowed. + NoMethodHandler + // RedirectHandler scope applies to the internal redirect handler, used for handling requests with trailing slashes. + RedirectHandler + // OptionsHandler scope applies to the automatic OPTIONS handler, which handles pre-flight or cross-origin requests. + OptionsHandler + // AllHandlers is a combination of all the above scopes, which can be used to apply middlewares to all types of handlers. + AllHandlers = RouteHandler | NoRouteHandler | NoMethodHandler | RedirectHandler | OptionsHandler +) + // Route represent a registered route in the route tree. // Most of the Route API is EXPERIMENTAL and is likely to change in future release. type Route struct { ipStrategy ClientIPStrategy - base HandlerFunc - handler HandlerFunc + hbase HandlerFunc + hself HandlerFunc + hall HandlerFunc path string mws []middleware redirectTrailingSlash bool ignoreTrailingSlash bool } -// Handle calls the base handler with the provided Context. +// Handle calls the handler with the provided Context. See also HandleMiddleware. func (r *Route) Handle(c Context) { - r.base(c) + r.hbase(c) +} + +// HandleMiddleware calls the handler with route-specific middleware applied, using the provided Context. +func (r *Route) HandleMiddleware(c Context, _ ...struct{}) { + // The variadic parameter is intentionally added to prevent this method from having the same signature as HandlerFunc. + // This avoids accidental use of HandleMiddleware where a HandlerFunc is required. + r.hself(c) } // Path returns the route path. @@ -127,7 +154,8 @@ type Router struct { type middleware struct { m MiddlewareFunc - scope MiddlewareScope + scope HandlerScope + g bool } var _ http.Handler = (*Router)(nil) @@ -363,7 +391,7 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !tsr && n != nil { c.route = n.route c.tsr = tsr - n.route.handler(c) + n.route.hall(c) // Put back the context, if not extended more than max params or max depth, allowing // the slice to naturally grow within the constraint. if cap(*c.params) <= int(tree.maxParams.Load()) && cap(*c.skipNds) <= int(tree.maxDepth.Load()) { @@ -376,7 +404,7 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { if n.route.ignoreTrailingSlash { c.route = n.route c.tsr = tsr - n.route.handler(c) + n.route.hall(c) c.Close() return } @@ -385,6 +413,7 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Reset params as it may have recorded wildcard segment (the context may still be used in a middleware) *c.params = (*c.params)[:0] c.tsr = false + c.scope = RedirectHandler fox.tsrRedirect(c) c.Close() return @@ -425,6 +454,7 @@ NoMethodFallback: sb.WriteString(", ") sb.WriteString(http.MethodOptions) w.Header().Set(HeaderAllow, sb.String()) + c.scope = OptionsHandler fox.autoOptions(c) c.Close() return @@ -443,12 +473,14 @@ NoMethodFallback: } if sb.Len() > 0 { w.Header().Set(HeaderAllow, sb.String()) + c.scope = NoMethodHandler fox.noMethod(c) c.Close() return } } + c.scope = NoRouteHandler fox.noRoute(c) c.Close() } @@ -645,7 +677,7 @@ func isRemovable(method string) bool { return true } -func applyMiddleware(scope MiddlewareScope, mws []middleware, h HandlerFunc) HandlerFunc { +func applyMiddleware(scope HandlerScope, mws []middleware, h HandlerFunc) HandlerFunc { m := h for i := len(mws) - 1; i >= 0; i-- { if mws[i].scope&scope != 0 { @@ -655,6 +687,21 @@ func applyMiddleware(scope MiddlewareScope, mws []middleware, h HandlerFunc) Han return m } +func applyRouteMiddleware(mws []middleware, base HandlerFunc) (HandlerFunc, HandlerFunc) { + rte := base + all := base + for i := len(mws) - 1; i >= 0; i-- { + if mws[i].scope&RouteHandler != 0 { + all = mws[i].m(all) + // route specific only + if !mws[i].g { + rte = mws[i].m(rte) + } + } + } + return rte, all +} + // localRedirect redirect the client to the new path, but it does not convert relative paths to absolute paths // like Redirect does. If the Content-Type header has not been set, localRedirect sets it to "text/html; charset=utf-8" // and writes a small HTML body. Setting the Content-Type header to any value, including nil, disables that behavior. diff --git a/fox_test.go b/fox_test.go index cfe2bdc..6a8bc0e 100644 --- a/fox_test.go +++ b/fox_test.go @@ -6,6 +6,7 @@ package fox import ( "fmt" + "github.com/tigerwill90/fox/internal/iterutil" "log" "math/rand" "net" @@ -642,6 +643,24 @@ func TestStaticRouteMalloc(t *testing.T) { } } +func TestRoute_HandleMiddlewareMalloc(t *testing.T) { + f := New() + for _, rte := range githubAPI { + require.NoError(t, f.Tree().Handle(rte.method, rte.path, emptyHandler)) + } + + for _, rte := range githubAPI { + req := httptest.NewRequest(rte.method, rte.path, nil) + w := httptest.NewRecorder() + r, c, _ := f.Lookup(&recorder{ResponseWriter: w}, req) + allocs := testing.AllocsPerRun(100, func() { + r.HandleMiddleware(c) + }) + c.Close() + assert.Equal(t, float64(0), allocs) + } +} + func TestParamsRoute(t *testing.T) { rx := regexp.MustCompile("({|\\*{)[A-z]+[}]") r := New() @@ -1814,7 +1833,7 @@ func TestRouterWithClientIPStrategy(t *testing.T) { require.NotNil(t, rte) assert.True(t, rte.ClientIPStrategyEnabled()) - require.NoError(t, f.Update(http.MethodGet, "/foo", emptyHandler, WithClientIPStrategy(noClientIPStrategy{}))) + require.NoError(t, f.Update(http.MethodGet, "/foo", emptyHandler, WithClientIPStrategy(nil))) rte = f.Tree().Route(http.MethodGet, "/foo") require.NotNil(t, rte) assert.False(t, rte.ClientIPStrategyEnabled()) @@ -2134,7 +2153,7 @@ func TestTree_Remove(t *testing.T) { } it := tree.Iter() - cnt := len(slices.Collect(right(it.All()))) + cnt := len(slices.Collect(iterutil.Right(it.All()))) assert.Equal(t, 0, cnt) assert.Equal(t, 4, len(*tree.nodes.Load())) @@ -2153,14 +2172,14 @@ func TestTree_Methods(t *testing.T) { require.NoError(t, f.Handle(rte.method, rte.path, emptyHandler)) } - methods := slices.Sorted(left(f.Iter().Reverse(f.Iter().Methods(), "/gists/123/star"))) + methods := slices.Sorted(iterutil.Left(f.Iter().Reverse(f.Iter().Methods(), "/gists/123/star"))) assert.Equal(t, []string{"DELETE", "GET", "PUT"}, methods) methods = slices.Sorted(f.Iter().Methods()) assert.Equal(t, []string{"DELETE", "GET", "POST", "PUT"}, methods) // Ignore trailing slash disable - methods = slices.Sorted(left(f.Iter().Reverse(f.Iter().Methods(), "/gists/123/star/"))) + methods = slices.Sorted(iterutil.Left(f.Iter().Reverse(f.Iter().Methods(), "/gists/123/star/"))) assert.Empty(t, methods) } @@ -2171,13 +2190,13 @@ func TestTree_MethodsWithIgnoreTsEnable(t *testing.T) { require.NoError(t, f.Handle(method, "/john/doe/", emptyHandler)) } - methods := slices.Sorted(left(f.Iter().Reverse(f.Iter().Methods(), "/foo/bar/"))) + methods := slices.Sorted(iterutil.Left(f.Iter().Reverse(f.Iter().Methods(), "/foo/bar/"))) assert.Equal(t, []string{"DELETE", "GET", "PUT"}, methods) - methods = slices.Sorted(left(f.Iter().Reverse(f.Iter().Methods(), "/john/doe"))) + methods = slices.Sorted(iterutil.Left(f.Iter().Reverse(f.Iter().Methods(), "/john/doe"))) assert.Equal(t, []string{"DELETE", "GET", "PUT"}, methods) - methods = slices.Sorted(left(f.Iter().Reverse(f.Iter().Methods(), "/foo/bar/baz"))) + methods = slices.Sorted(iterutil.Left(f.Iter().Reverse(f.Iter().Methods(), "/foo/bar/baz"))) assert.Empty(t, methods) } @@ -2391,7 +2410,7 @@ func TestRouterWithAutomaticOptions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, method := range tc.methods { require.NoError(t, f.Tree().Handle(method, tc.path, func(c Context) { - c.SetHeader("Allow", strings.Join(slices.Sorted(left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().URL.Path))), ", ")) + c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().URL.Path))), ", ")) c.Writer().WriteHeader(http.StatusNoContent) })) } @@ -2467,7 +2486,7 @@ func TestRouterWithAutomaticOptionsAndIgnoreTsOptionEnable(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, method := range tc.methods { require.NoError(t, f.Tree().Handle(method, tc.path, func(c Context) { - c.SetHeader("Allow", strings.Join(slices.Sorted(left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().URL.Path))), ", ")) + c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().URL.Path))), ", ")) c.Writer().WriteHeader(http.StatusNoContent) })) } @@ -2512,7 +2531,7 @@ func TestRouterWithAutomaticOptionsAndIgnoreTsOptionDisable(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, method := range tc.methods { require.NoError(t, f.Tree().Handle(method, tc.path, func(c Context) { - c.SetHeader("Allow", strings.Join(slices.Sorted(left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().URL.Path))), ", ")) + c.SetHeader("Allow", strings.Join(slices.Sorted(iterutil.Left(c.Tree().Iter().Reverse(c.Tree().Iter().Methods(), c.Request().URL.Path))), ", ")) c.Writer().WriteHeader(http.StatusNoContent) })) } @@ -2584,7 +2603,7 @@ func TestUpdateWithMiddleware(t *testing.T) { next(c) } }) - f := New() + f := New(WithMiddleware(Recovery())) f.MustHandle(http.MethodGet, "/foo", emptyHandler) req := httptest.NewRequest(http.MethodGet, "/foo", nil) w := httptest.NewRecorder() @@ -2595,10 +2614,29 @@ func TestUpdateWithMiddleware(t *testing.T) { assert.True(t, called) called = false + rte := f.Tree().Route(http.MethodGet, "/foo") + rte.Handle(newTestContextTree(f.Tree())) + assert.False(t, called) + called = false + + rte.HandleMiddleware(newTestContextTree(f.Tree())) + assert.True(t, called) + called = false + // Remove middleware require.NoError(t, f.Update(http.MethodGet, "/foo", emptyHandler)) f.ServeHTTP(w, req) assert.False(t, called) + called = false + + rte = f.Tree().Route(http.MethodGet, "/foo") + rte.Handle(newTestContextTree(f.Tree())) + assert.False(t, called) + called = false + + rte = f.Tree().Route(http.MethodGet, "/foo") + rte.HandleMiddleware(newTestContextTree(f.Tree())) + assert.False(t, called) } func TestRouteMiddleware(t *testing.T) { @@ -2641,6 +2679,28 @@ func TestRouteMiddleware(t *testing.T) { assert.True(t, c0) assert.False(t, c1) assert.True(t, c2) + + c0, c1, c2 = false, false, false + rte1 := f.Tree().Route(http.MethodGet, "/1") + require.NotNil(t, rte1) + rte1.Handle(newTestContextTree(f.Tree())) + assert.False(t, c0) + assert.False(t, c1) + assert.False(t, c2) + c0, c1, c2 = false, false, false + + rte1.HandleMiddleware(newTestContextTree(f.Tree())) + assert.False(t, c0) + assert.True(t, c1) + assert.False(t, c2) + c0, c1, c2 = false, false, false + + rte2 := f.Tree().Route(http.MethodGet, "/2") + require.NotNil(t, rte2) + rte2.HandleMiddleware(newTestContextTree(f.Tree())) + assert.False(t, c0) + assert.False(t, c1) + assert.True(t, c2) } func TestWithNotFoundHandler(t *testing.T) { @@ -3005,7 +3065,7 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { } it := tree.Iter() - countPath := len(slices.Collect(right(it.All()))) + countPath := len(slices.Collect(iterutil.Right(it.All()))) assert.Equal(t, len(routes), countPath) for rte := range routes { @@ -3027,7 +3087,7 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { } it = tree.Iter() - countPath = len(slices.Collect(right(it.All()))) + countPath = len(slices.Collect(iterutil.Right(it.All()))) assert.Equal(t, 0, countPath) } diff --git a/internal/iterutil/iterutil.go b/internal/iterutil/iterutil.go new file mode 100644 index 0000000..e6da8f8 --- /dev/null +++ b/internal/iterutil/iterutil.go @@ -0,0 +1,33 @@ +package iterutil + +import "iter" + +func Left[K, V any](seq iter.Seq2[K, V]) iter.Seq[K] { + return func(yield func(K) bool) { + for k, _ := range seq { + if !yield(k) { + return + } + } + } +} + +func Right[K, V any](seq iter.Seq2[K, V]) iter.Seq[V] { + return func(yield func(V) bool) { + for _, v := range seq { + if !yield(v) { + return + } + } + } +} + +func SeqOf[E any](elems ...E) iter.Seq[E] { + return func(yield func(E) bool) { + for _, e := range elems { + if !yield(e) { + return + } + } + } +} diff --git a/iter.go b/iter.go index 9d6eeba..50b50be 100644 --- a/iter.go +++ b/iter.go @@ -210,33 +210,3 @@ func (it Iter) All() iter.Seq2[string, *Route] { } } } - -func left[K, V any](seq iter.Seq2[K, V]) iter.Seq[K] { - return func(yield func(K) bool) { - for k := range seq { - if !yield(k) { - return - } - } - } -} - -func right[K, V any](seq iter.Seq2[K, V]) iter.Seq[V] { - return func(yield func(V) bool) { - for _, v := range seq { - if !yield(v) { - return - } - } - } -} - -func seqOf[E any](elems ...E) iter.Seq[E] { - return func(yield func(E) bool) { - for _, e := range elems { - if !yield(e) { - return - } - } - } -} diff --git a/iter_test.go b/iter_test.go index aa046a1..95e1a8a 100644 --- a/iter_test.go +++ b/iter_test.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tigerwill90/fox/internal/iterutil" "net/http" "slices" "testing" @@ -134,7 +135,7 @@ func TestIter_RootPrefixOneMethod(t *testing.T) { results := make(map[string][]string) it := tree.Iter() - for method, route := range it.Prefix(seqOf(http.MethodHead), "/") { + for method, route := range it.Prefix(iterutil.SeqOf(http.MethodHead), "/") { assert.NotNil(t, route) results[method] = append(results[method], route.Path()) } @@ -169,12 +170,12 @@ func TestIter_EdgeCase(t *testing.T) { tree := New().Tree() it := tree.Iter() - assert.Empty(t, slices.Collect(left(it.Prefix(seqOf("GET"), "/")))) - assert.Empty(t, slices.Collect(left(it.Prefix(seqOf("CONNECT"), "/")))) - assert.Empty(t, slices.Collect(left(it.Reverse(seqOf("GET"), "/")))) - assert.Empty(t, slices.Collect(left(it.Reverse(seqOf("CONNECT"), "/")))) - assert.Empty(t, slices.Collect(left(it.Routes(seqOf("GET"), "/")))) - assert.Empty(t, slices.Collect(left(it.Routes(seqOf("CONNECT"), "/")))) + assert.Empty(t, slices.Collect(iterutil.Left(it.Prefix(iterutil.SeqOf("GET"), "/")))) + assert.Empty(t, slices.Collect(iterutil.Left(it.Prefix(iterutil.SeqOf("CONNECT"), "/")))) + assert.Empty(t, slices.Collect(iterutil.Left(it.Reverse(iterutil.SeqOf("GET"), "/")))) + assert.Empty(t, slices.Collect(iterutil.Left(it.Reverse(iterutil.SeqOf("CONNECT"), "/")))) + assert.Empty(t, slices.Collect(iterutil.Left(it.Routes(iterutil.SeqOf("GET"), "/")))) + assert.Empty(t, slices.Collect(iterutil.Left(it.Routes(iterutil.SeqOf("CONNECT"), "/")))) } func TestIter_PrefixWithMethod(t *testing.T) { @@ -189,7 +190,7 @@ func TestIter_PrefixWithMethod(t *testing.T) { results := make(map[string][]string) it := tree.Iter() - for method, route := range it.Prefix(seqOf(http.MethodHead), "/foo") { + for method, route := range it.Prefix(iterutil.SeqOf(http.MethodHead), "/foo") { assert.NotNil(t, route) results[method] = append(results[method], route.Path()) } diff --git a/options.go b/options.go index 62a6b45..daa0c75 100644 --- a/options.go +++ b/options.go @@ -4,23 +4,7 @@ package fox -// MiddlewareScope is a type that represents different scopes for applying middleware. -type MiddlewareScope uint8 - -const ( - // RouteHandlers scope applies middleware only to regular routes registered in the router. - RouteHandlers MiddlewareScope = 1 << (8 - 1 - iota) - // NoRouteHandler scope applies middleware to the NoRoute handler. - NoRouteHandler - // NoMethodHandler scope applies middleware to the NoMethod handler. - NoMethodHandler - // RedirectHandler scope applies middleware to the internal redirect trailing slash handler. - RedirectHandler - // OptionsHandler scope applies middleware to the automatic OPTIONS handler. - OptionsHandler - // AllHandlers is a combination of all the above scopes, which means the middleware will be applied to all types of handlers. - AllHandlers = RouteHandlers | NoRouteHandler | NoMethodHandler | RedirectHandler | OptionsHandler -) +import "cmp" type Option interface { GlobalOption @@ -109,12 +93,12 @@ func WithMiddleware(m ...MiddlewareFunc) Option { return optionFunc(func(router *Router, route *Route) { if router != nil { for i := range m { - router.mws = append(router.mws, middleware{m[i], AllHandlers}) + router.mws = append(router.mws, middleware{m[i], AllHandlers, true}) } } if route != nil { for i := range m { - route.mws = append(route.mws, middleware{m[i], RouteHandlers}) + route.mws = append(route.mws, middleware{m[i], RouteHandler, false}) } } }) @@ -122,12 +106,12 @@ func WithMiddleware(m ...MiddlewareFunc) Option { // WithMiddlewareFor attaches middleware to the router for a specified scope. Middlewares provided will be chained // in the order they were added. The scope parameter determines which types of handlers the middleware will be applied to. -// Possible scopes include RouteHandlers (regular routes), NoRouteHandler, NoMethodHandler, RedirectHandler, OptionsHandler, +// Possible scopes include RouteHandler (regular routes), NoRouteHandler, NoMethodHandler, RedirectHandler, OptionsHandler, // and any combination of these. Use this option when you need fine-grained control over where the middleware is applied. -func WithMiddlewareFor(scope MiddlewareScope, m ...MiddlewareFunc) GlobalOption { +func WithMiddlewareFor(scope HandlerScope, m ...MiddlewareFunc) GlobalOption { return globOptionFunc(func(r *Router) { for i := range m { - r.mws = append(r.mws, middleware{m[i], scope}) + r.mws = append(r.mws, middleware{m[i], scope, true}) } }) } @@ -223,27 +207,27 @@ func WithIgnoreTrailingSlash(enable bool) Option { // - If applied to a specific route, it will override the global setting for that route. // - The option must be explicitly reapplied when updating a route. If not, the route will fall back // to the global client IP strategy (if one is configured). +// - Setting the strategy to nil is equivalent to no strategy configured. func WithClientIPStrategy(strategy ClientIPStrategy) Option { return optionFunc(func(router *Router, route *Route) { - if strategy != nil { - if router != nil { - router.ipStrategy = strategy - } - if route != nil { - route.ipStrategy = strategy - } + if router != nil { + router.ipStrategy = cmp.Or(strategy, ClientIPStrategy(noClientIPStrategy{})) + } + + if route != nil { + route.ipStrategy = cmp.Or(strategy, ClientIPStrategy(noClientIPStrategy{})) } }) } -// DefaultOptions configure the router to use the Recovery middleware for the RouteHandlers scope, the Logger middleware +// DefaultOptions configure the router to use the Recovery middleware for the RouteHandler scope, the Logger middleware // for AllHandlers scope and enable automatic OPTIONS response. Note that DefaultOptions push the Recovery and Logger middleware // respectively to the first and second position of the middleware chains. func DefaultOptions() GlobalOption { return globOptionFunc(func(r *Router) { r.mws = append([]middleware{ - {Recovery(), RouteHandlers}, - {Logger(), AllHandlers}, + {Recovery(), RouteHandler, true}, + {Logger(), AllHandlers, true}, }, r.mws...) r.handleOptions = true }) diff --git a/tree.go b/tree.go index 74d8d3f..fd62253 100644 --- a/tree.go +++ b/tree.go @@ -911,7 +911,7 @@ func (t *Tree) updateMaxDepth(max uint32) { func (t *Tree) newRoute(path string, handler HandlerFunc, opts ...PathOption) *Route { rte := &Route{ ipStrategy: t.fox.ipStrategy, - base: handler, + hbase: handler, path: path, mws: t.mws, redirectTrailingSlash: t.fox.redirectTrailingSlash, @@ -921,7 +921,7 @@ func (t *Tree) newRoute(path string, handler HandlerFunc, opts ...PathOption) *R for _, opt := range opts { opt.applyPath(rte) } - rte.handler = applyMiddleware(RouteHandlers, rte.mws, handler) + rte.hself, rte.hall = applyRouteMiddleware(rte.mws, handler) return rte }