From 670e306a878e84dbd4034d9baa2712e2f5d271dd Mon Sep 17 00:00:00 2001 From: Sylvain Muller Date: Mon, 24 Jun 2024 19:28:47 +0200 Subject: [PATCH] Ignore trailing slash (#32) * test: improve tsr tests * feat: fix tsr edge case when exact match on a intermediary leaf node * feat: implement ignore trailing slash new feature * feat: enable ignore trailing slash for tree method (has, match and methods) & add tests * docs: fix inconsistent documentation for Tree.Methods * feat: fix Tree.Match * docs(readme): update road to v1 section * feat: enable ignore trailing slash for tree method (has, match and methods) also when redirect trailing slash. * feat: disable ignore trailing slash for Tree.Has * feat: attach fox instance to every tree * docs: update lookup methods documentation * feat: improve local redirect * docs: fix localRedirect docs --- README.md | 4 + context.go | 19 +- fox.go | 170 +++++++++---- fox_test.go | 668 +++++++++++++++++++++++++++++++++++++++++++++++----- go.mod | 2 +- go.sum | 4 +- node.go | 2 +- options.go | 14 +- tree.go | 134 ++++++++--- 9 files changed, 866 insertions(+), 151 deletions(-) diff --git a/README.md b/README.md index a2cd326..73981f5 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,9 @@ priority rules, ensuring that there are no unintended matches and maintaining hi **Redirect trailing slashes:** Inspired from [httprouter](https://github.com/julienschmidt/httprouter), the router automatically redirects the client, at no extra cost, if another route match with or without a trailing slash. +**Ignore trailing slashes:** In contrast to redirecting, this option allows the router to handle requests regardless of an extra +or missing trailing slash, at no extra cost. + **Automatic OPTIONS replies:** Inspired from [httprouter](https://github.com/julienschmidt/httprouter), the router has built-in native support for [OPTIONS requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS). @@ -620,6 +623,7 @@ BenchmarkPat_GithubAll 424 2899405 ns/op 1843501 - [x] [Update route syntax](https://github.com/tigerwill90/fox/pull/10#issue-1643728309) @v0.6.0 - [x] [Route overlapping](https://github.com/tigerwill90/fox/pull/9#issue-1642887919) @v0.7.0 - [x] [Route overlapping (catch-all and params)](https://github.com/tigerwill90/fox/pull/24#issue-1784686061) @v0.10.0 +- [x] [Ignore trailing slash](https://github.com/tigerwill90/fox/pull/32) @v0.14.0 - [ ] Improving performance and polishing ## Contributions diff --git a/context.go b/context.go index 1158494..e82ae30 100644 --- a/context.go +++ b/context.go @@ -78,11 +78,12 @@ type Context interface { CloneWith(w ResponseWriter, r *http.Request) ContextCloser // Tree is a local copy of the Tree in use to serve the request. Tree() *Tree - // Fox returns the Router in use to serve the request. + // Fox returns the Router instance. Fox() *Router - // Reset resets the Context to its initial state, attaching the provided Router, - // http.ResponseWriter, and *http.Request. - Reset(fox *Router, w http.ResponseWriter, r *http.Request) + // Reset resets the Context to its initial state, attaching the provided Router, http.ResponseWriter, and *http.Request. + // Caution: You should always pass the original http.ResponseWriter to this method, not the ResponseWriter itself, to + // avoid wrapping the ResponseWriter within itself. Use wisely! + Reset(w http.ResponseWriter, r *http.Request) } // context holds request-related information and allows interaction with the ResponseWriter. @@ -102,13 +103,13 @@ type context struct { } // Reset resets the Context to its initial state, attaching the provided Router, http.ResponseWriter, and *http.Request. -// Caution: You should pass the original http.ResponseWriter to this method, not the ResponseWriter itself, to avoid -// wrapping the ResponseWriter within itself. -func (c *context) Reset(fox *Router, w http.ResponseWriter, r *http.Request) { +// Caution: You should always pass the original http.ResponseWriter to this method, not the ResponseWriter itself, to +// avoid wrapping the ResponseWriter within itself. Use wisely! +func (c *context) Reset(w http.ResponseWriter, r *http.Request) { c.rec.reset(w) c.req = r c.w = &c.rec - c.fox = fox + c.fox = c.tree.fox c.path = "" c.cachedQuery = nil *c.params = (*c.params)[:0] @@ -233,7 +234,7 @@ func (c *context) Tree() *Tree { return c.tree } -// Fox returns the Router in use to serve the request. +// Fox returns the Router instance. func (c *context) Fox() *Router { return c.fox } diff --git a/fox.go b/fox.go index 3e89bba..5b34293 100644 --- a/fox.go +++ b/fox.go @@ -10,9 +10,11 @@ import ( "net/http" "path" "regexp" + "strconv" "strings" "sync" "sync/atomic" + "unicode/utf8" ) const verb = 4 @@ -54,6 +56,7 @@ type Router struct { handleMethodNotAllowed bool handleOptions bool redirectTrailingSlash bool + ignoreTrailingSlash bool } type middleware struct { @@ -84,6 +87,34 @@ func New(opts ...Option) *Router { return r } +// MethodNotAllowedEnabled returns whether the router is configured to handle +// requests with methods that are not allowed. +// This api is EXPERIMENTAL and is likely to change in future release. +func (fox *Router) MethodNotAllowedEnabled() bool { + return fox.handleMethodNotAllowed +} + +// AutoOptionsEnabled returns whether the router is configured to automatically +// respond to OPTIONS requests. +// This api is EXPERIMENTAL and is likely to change in future release. +func (fox *Router) AutoOptionsEnabled() bool { + return fox.handleOptions +} + +// RedirectTrailingSlashEnabled returns whether the router is configured to automatically +// redirect requests that include or omit a trailing slash. +// This api is EXPERIMENTAL and is likely to change in future release. +func (fox *Router) RedirectTrailingSlashEnabled() bool { + return fox.redirectTrailingSlash +} + +// IgnoreTrailingSlashEnabled returns whether the router is configured to ignore +// trailing slashes in requests when matching routes. +// This api is EXPERIMENTAL and is likely to change in future release. +func (fox *Router) IgnoreTrailingSlashEnabled() bool { + return fox.ignoreTrailingSlash +} + // NewTree returns a fresh routing Tree that inherits all registered router options. It's safe to create multiple Tree // concurrently. However, a Tree itself is not thread-safe and all its APIs that perform write operations should be run // serially. Note that a Tree give direct access to the underlying sync.Mutex. @@ -91,6 +122,7 @@ func New(opts ...Option) *Router { func (fox *Router) NewTree() *Tree { tree := new(Tree) tree.mws = fox.mws + tree.fox = fox // Pre instantiate nodes for common http verb nds := make([]*node, len(commonVerbs)) @@ -165,38 +197,14 @@ func (fox *Router) Remove(method, path string) error { return t.Remove(method, path) } -// Lookup performs a manual route lookup for a given http.Request, returning the matched HandlerFunc along with a ContextCloser, -// and a boolean indicating if a trailing slash redirect is recommended. The ContextCloser should always be closed if non-nil. -// This method is primarily intended for integrating the fox router into custom routing solutions. It requires the use of the -// original http.ResponseWriter, typically obtained from ServeHTTP. This function is safe for concurrent use by multiple goroutine -// and while mutation on Tree are ongoing. +// Lookup is a helper that calls Tree.Lookup. For more details, refer to Tree.Lookup. +// It performs a manual route lookup for a given http.Request, returning the matched HandlerFunc along with a ContextCloser, +// and a boolean indicating if a trailing slash action (e.g. redirect) is recommended (tsr). The ContextCloser should always +// be closed if non-nil. // This API is EXPERIMENTAL and is likely to change in future release. func (fox *Router) Lookup(w http.ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) { tree := fox.tree.Load() - - nds := *tree.nodes.Load() - index := findRootNode(r.Method, nds) - - if index < 0 { - return - } - - c := tree.ctx.Get().(*context) - c.Reset(fox, w, r) - - target := r.URL.Path - if len(r.URL.RawPath) > 0 { - // Using RawPath to prevent unintended match (e.g. /search/a%2Fb/1) - target = r.URL.RawPath - } - - n, tsr := tree.lookup(nds[index], target, c.params, c.skipNds, false) - if n != nil { - c.path = n.path - return n.handler, c, tsr - } - c.Close() - return nil, nil, tsr + return tree.Lookup(w, r) } // SkipMethod is used as a return value from WalkFunc to indicate that @@ -293,7 +301,7 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { tree := fox.tree.Load() c := tree.ctx.Get().(*context) - c.Reset(fox, w, r) + c.Reset(w, r) nds := *tree.nodes.Load() index := findRootNode(r.Method, nds) @@ -302,7 +310,7 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { } n, tsr = tree.lookup(nds[index], target, c.params, c.skipNds, false) - if n != nil { + if !tsr && n != nil { c.path = n.path n.handler(c) // Put back the context, if not extended more than max params or max depth, allowing @@ -313,15 +321,27 @@ func (fox *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Reset params as it may have recorded wildcard segment - *c.params = (*c.params)[:0] + if r.Method != http.MethodConnect && r.URL.Path != "/" && tsr { + if fox.ignoreTrailingSlash { + c.path = n.path + n.handler(c) + c.Close() + return + } - if r.Method != http.MethodConnect && r.URL.Path != "/" && tsr && fox.redirectTrailingSlash && target == CleanPath(target) { - fox.tsrRedirect(c) - c.Close() - return + if fox.redirectTrailingSlash && target == CleanPath(target) { + // Reset params as it may have recorded wildcard segment (the context may still be used in a middleware) + *c.params = (*c.params)[:0] + fox.tsrRedirect(c) + c.Close() + return + } } + // Reset params as it may have recorded wildcard segment (the context may still be used in no route, no method and + // automatic option handler or middleware) + *c.params = (*c.params)[:0] + NoMethodFallback: if r.Method == http.MethodOptions && fox.handleOptions { var sb strings.Builder @@ -338,7 +358,7 @@ NoMethodFallback: } } else { for i := 0; i < len(nds); i++ { - if n, _ := tree.lookup(nds[i], target, c.params, c.skipNds, true); n != nil { + if n, tsr := tree.lookup(nds[i], target, c.params, c.skipNds, true); n != nil && (!tsr || fox.ignoreTrailingSlash) { if sb.Len() > 0 { sb.WriteString(", ") } else { @@ -361,7 +381,7 @@ NoMethodFallback: var sb strings.Builder for i := 0; i < len(nds); i++ { if nds[i].key != r.Method { - if n, _ := tree.lookup(nds[i], target, c.params, c.skipNds, true); n != nil { + if n, tsr := tree.lookup(nds[i], target, c.params, c.skipNds, true); n != nil && (!tsr || fox.ignoreTrailingSlash) { if sb.Len() > 0 { sb.WriteString(", ") } @@ -590,14 +610,32 @@ func applyMiddleware(scope MiddlewareScope, mws []middleware, h HandlerFunc) Han return m } -// localRedirect redirect the client to the new path. -// It does not convert relative paths to absolute paths like Redirect does. -func localRedirect(w http.ResponseWriter, r *http.Request, newPath string, code int) { +// localRedirect redirect the client to the new path, but it does not convert relative paths to absolute paths +// like Redirect does. If the Content-Type header has not been set, localRedirect sets it to "text/html; charset=utf-8" +// and writes a small HTML body. Setting the Content-Type header to any value, including nil, disables that behavior. +func localRedirect(w http.ResponseWriter, r *http.Request, path string, code int) { if q := r.URL.RawQuery; q != "" { - newPath += "?" + q + path += "?" + q + } + + h := w.Header() + + // RFC 7231 notes that a short HTML body is usually included in + // the response because older user agents may not understand 301/307. + // Do it only if the request didn't already have a Content-Type header. + _, hadCT := h["Content-Type"] + + h.Set(HeaderLocation, hexEscapeNonASCII(path)) + if !hadCT && (r.Method == "GET" || r.Method == "HEAD") { + h.Set(HeaderContentType, MIMETextHTMLCharsetUTF8) } - w.Header().Set(HeaderLocation, newPath) w.WriteHeader(code) + + // Shouldn't send the body for POST or HEAD; that leaves GET. + if !hadCT && r.Method == "GET" { + body := "" + http.StatusText(code) + ".\n" + _, _ = fmt.Fprintln(w, body) + } } // grow increases the slice's capacity, if necessary, to guarantee space for @@ -613,3 +651,47 @@ func grow[S ~[]E, E any](s S, n int) S { } return s } + +func hexEscapeNonASCII(s string) string { + newLen := 0 + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + newLen += 3 + } else { + newLen++ + } + } + if newLen == len(s) { + return s + } + b := make([]byte, 0, newLen) + var pos int + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + if pos < i { + b = append(b, s[pos:i]...) + } + b = append(b, '%') + b = strconv.AppendInt(b, int64(s[i]), 16) + pos = i + 1 + } + } + if pos < len(s) { + b = append(b, s[pos:]...) + } + return string(b) +} + +var htmlReplacer = strings.NewReplacer( + "&", "&", + "<", "<", + ">", ">", + // """ is shorter than """. + `"`, """, + // "'" is shorter than "'" and apos was not in HTML until HTML5. + "'", "'", +) + +func htmlEscape(s string) string { + return htmlReplacer.Replace(s) +} diff --git a/fox_test.go b/fox_test.go index bafa77f..f9f0a42 100644 --- a/fox_test.go +++ b/fox_test.go @@ -714,8 +714,9 @@ func TestRouteWithParams(t *testing.T) { nds := *tree.nodes.Load() for _, rte := range routes { c := newTestContextTree(tree) - n, _ := tree.lookup(nds[0], rte, c.params, c.skipNds, false) + n, tsr := tree.lookup(nds[0], rte, c.params, c.skipNds, false) require.NotNil(t, n) + assert.False(t, tsr) assert.Equal(t, rte, n.path) } } @@ -1159,10 +1160,10 @@ func TestOverlappingRoute(t *testing.T) { nds := *tree.nodes.Load() c := newTestContextTree(tree) - n, _ := tree.lookup(nds[0], tc.path, c.params, c.skipNds, false) + n, tsr := tree.lookup(nds[0], tc.path, c.params, c.skipNds, false) require.NotNil(t, n) require.NotNil(t, n.handler) - + assert.False(t, tsr) assert.Equal(t, tc.wantMatch, n.path) if len(tc.wantParams) == 0 { assert.Empty(t, c.Params()) @@ -1172,9 +1173,10 @@ func TestOverlappingRoute(t *testing.T) { // Test with lazy c = newTestContextTree(tree) - n, _ = tree.lookup(nds[0], tc.path, c.params, c.skipNds, true) + n, tsr = tree.lookup(nds[0], tc.path, c.params, c.skipNds, true) require.NotNil(t, n) require.NotNil(t, n.handler) + assert.False(t, tsr) assert.Empty(t, c.Params()) assert.Equal(t, tc.wantMatch, n.path) }) @@ -1537,34 +1539,39 @@ func TestParseRoute(t *testing.T) { func TestTree_LookupTsr(t *testing.T) { cases := []struct { - name string - paths []string - key string - want bool + name string + paths []string + key string + want bool + wantPath string }{ { - name: "match mid edge", - paths: []string{"/foo/bar/"}, - key: "/foo/bar", - want: true, + name: "match mid edge", + paths: []string{"/foo/bar/"}, + key: "/foo/bar", + want: true, + wantPath: "/foo/bar/", }, { - name: "incomplete match end of edge", - paths: []string{"/foo/bar"}, - key: "/foo/bar/", - want: true, + name: "incomplete match end of edge", + paths: []string{"/foo/bar"}, + key: "/foo/bar/", + want: true, + wantPath: "/foo/bar", }, { - name: "match mid edge with child node", - paths: []string{"/users/", "/users/{id}"}, - key: "/users", - want: true, + name: "match mid edge with child node", + paths: []string{"/users/", "/users/{id}"}, + key: "/users", + want: true, + wantPath: "/users/", }, { - name: "match mid edge in child node", - paths: []string{"/users", "/users/{id}"}, - key: "/users/", - want: true, + name: "match mid edge in child node", + paths: []string{"/users", "/users/{id}"}, + key: "/users/", + want: true, + wantPath: "/users", }, { name: "match mid edge in child node with invalid remaining prefix", @@ -1606,8 +1613,133 @@ func TestTree_LookupTsr(t *testing.T) { } nds := *tree.nodes.Load() c := newTestContextTree(tree) - _, got := tree.lookup(nds[0], tc.key, c.params, c.skipNds, true) + n, got := tree.lookup(nds[0], tc.key, c.params, c.skipNds, true) assert.Equal(t, tc.want, got) + if tc.want { + require.NotNil(t, n) + assert.Equal(t, tc.wantPath, n.path) + } + }) + } +} + +func TestRouterWithIgnoreTrailingSlash(t *testing.T) { + cases := []struct { + name string + paths []string + req string + method string + wantCode int + wantPath string + }{ + { + name: "current not a leaf with extra ts", + paths: []string{"/foo", "/foo/x/", "/foo/z/"}, + req: "/foo/", + method: http.MethodGet, + wantCode: http.StatusOK, + wantPath: "/foo", + }, + { + name: "current not a leaf and path does not end with ts", + paths: []string{"/foo", "/foo/x/", "/foo/z/"}, + req: "/foo/c", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "current not a leaf and path end with extra char and ts", + paths: []string{"/foo", "/foo/x/", "/foo/z/"}, + req: "/foo/c/", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "current not a leaf and path end with ts but last is not a leaf", + paths: []string{"/foo/a/a", "/foo/a/b", "/foo/c/"}, + req: "/foo/a/", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "mid edge key with extra ts", + paths: []string{"/foo/bar/"}, + req: "/foo/bar", + method: http.MethodGet, + wantCode: http.StatusOK, + wantPath: "/foo/bar/", + }, + { + name: "mid edge key with without extra ts", + paths: []string{"/foo/bar/baz", "/foo/bar"}, + req: "/foo/bar/", + method: http.MethodGet, + wantCode: http.StatusOK, + wantPath: "/foo/bar", + }, + { + name: "mid edge key without extra ts", + paths: []string{"/foo/bar/baz", "/foo/bar"}, + req: "/foo/bar/", + method: http.MethodPost, + wantCode: http.StatusOK, + wantPath: "/foo/bar", + }, + { + name: "incomplete match end of edge", + paths: []string{"/foo/bar"}, + req: "/foo/bar/", + method: http.MethodGet, + wantCode: http.StatusOK, + wantPath: "/foo/bar", + }, + { + name: "match mid edge with ts and more char after", + paths: []string{"/foo/bar/buzz"}, + req: "/foo/bar", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "match mid edge with ts and more char before", + paths: []string{"/foo/barr/"}, + req: "/foo/bar", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "incomplete match end of edge with ts and more char after", + paths: []string{"/foo/bar"}, + req: "/foo/bar/buzz", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "incomplete match end of edge with ts and more char before", + paths: []string{"/foo/bar"}, + req: "/foo/barr/", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r := New(WithIgnoreTrailingSlash(true)) + require.True(t, r.IgnoreTrailingSlashEnabled()) + for _, path := range tc.paths { + require.NoError(t, r.Tree().Handle(tc.method, path, func(c Context) { + _ = c.String(http.StatusOK, c.Path()) + })) + } + + req := httptest.NewRequest(tc.method, tc.req, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, tc.wantCode, w.Code) + if tc.wantPath != "" { + assert.Equal(t, tc.wantPath, w.Body.String()) + } }) } } @@ -1616,39 +1748,92 @@ func TestRedirectTrailingSlash(t *testing.T) { cases := []struct { name string - path string + paths []string req string method string wantCode int wantLocation string }{ { - name: "mid edge key with get method and status moved permanently", - path: "/foo/bar/", + name: "current not a leaf get method and status moved permanently with extra ts", + paths: []string{"/foo", "/foo/x/", "/foo/z/"}, + req: "/foo/", + method: http.MethodGet, + wantCode: http.StatusMovedPermanently, + wantLocation: "../foo", + }, + { + name: "current not a leaf post method and status moved permanently with extra ts", + paths: []string{"/foo", "/foo/x/", "/foo/z/"}, + req: "/foo/", + method: http.MethodPost, + wantCode: http.StatusPermanentRedirect, + wantLocation: "../foo", + }, + { + name: "current not a leaf and path does not end with ts", + paths: []string{"/foo", "/foo/x/", "/foo/z/"}, + req: "/foo/c", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "current not a leaf and path end with extra char and ts", + paths: []string{"/foo", "/foo/x/", "/foo/z/"}, + req: "/foo/c/", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "current not a leaf and path end with ts but last is not a leaf", + paths: []string{"/foo/a/a", "/foo/a/b", "/foo/c/"}, + req: "/foo/a/", + method: http.MethodGet, + wantCode: http.StatusNotFound, + }, + { + name: "mid edge key with get method and status moved permanently with extra ts", + paths: []string{"/foo/bar/"}, req: "/foo/bar", method: http.MethodGet, wantCode: http.StatusMovedPermanently, wantLocation: "bar/", }, { - name: "mid edge key with post method and status permanent redirect", - path: "/foo/bar/", + name: "mid edge key with post method and status permanent redirect with extra ts", + paths: []string{"/foo/bar/"}, req: "/foo/bar", method: http.MethodPost, wantCode: http.StatusPermanentRedirect, wantLocation: "bar/", }, { - name: "incomplete match end of edge", - path: "/foo/bar", + name: "mid edge key with get method and status moved permanently without extra ts", + paths: []string{"/foo/bar/baz", "/foo/bar"}, req: "/foo/bar/", method: http.MethodGet, wantCode: http.StatusMovedPermanently, wantLocation: "../bar", }, { - name: "incomplete match end of edge", - path: "/foo/bar", + name: "mid edge key with post method and status permanent redirect without extra ts", + paths: []string{"/foo/bar/baz", "/foo/bar"}, + req: "/foo/bar/", + method: http.MethodPost, + wantCode: http.StatusPermanentRedirect, + wantLocation: "../bar", + }, + { + name: "incomplete match end of edge with get method", + paths: []string{"/foo/bar"}, + req: "/foo/bar/", + method: http.MethodGet, + wantCode: http.StatusMovedPermanently, + wantLocation: "../bar", + }, + { + name: "incomplete match end of edge with post method", + paths: []string{"/foo/bar"}, req: "/foo/bar/", method: http.MethodPost, wantCode: http.StatusPermanentRedirect, @@ -1656,28 +1841,28 @@ func TestRedirectTrailingSlash(t *testing.T) { }, { name: "match mid edge with ts and more char after", - path: "/foo/bar/buzz", + paths: []string{"/foo/bar/buzz"}, req: "/foo/bar", method: http.MethodGet, wantCode: http.StatusNotFound, }, { name: "match mid edge with ts and more char before", - path: "/foo/barr/", + paths: []string{"/foo/barr/"}, req: "/foo/bar", method: http.MethodGet, wantCode: http.StatusNotFound, }, { name: "incomplete match end of edge with ts and more char after", - path: "/foo/bar", + paths: []string{"/foo/bar"}, req: "/foo/bar/buzz", method: http.MethodGet, wantCode: http.StatusNotFound, }, { name: "incomplete match end of edge with ts and more char before", - path: "/foo/bar", + paths: []string{"/foo/bar"}, req: "/foo/barr/", method: http.MethodGet, wantCode: http.StatusNotFound, @@ -1687,14 +1872,21 @@ func TestRedirectTrailingSlash(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { r := New(WithRedirectTrailingSlash(true)) - require.NoError(t, r.Tree().Handle(tc.method, tc.path, emptyHandler)) + require.True(t, r.RedirectTrailingSlashEnabled()) + for _, path := range tc.paths { + require.NoError(t, r.Tree().Handle(tc.method, path, emptyHandler)) + } req := httptest.NewRequest(tc.method, tc.req, nil) w := httptest.NewRecorder() r.ServeHTTP(w, req) assert.Equal(t, tc.wantCode, w.Code) if w.Code == http.StatusPermanentRedirect || w.Code == http.StatusMovedPermanently { - assert.Equal(t, tc.wantLocation, w.Header().Get("Location")) + assert.Equal(t, tc.wantLocation, w.Header().Get(HeaderLocation)) + if tc.method == http.MethodGet { + assert.Equal(t, MIMETextHTMLCharsetUTF8, w.Header().Get(HeaderContentType)) + assert.Equal(t, ""+http.StatusText(w.Code)+".\n\n", w.Body.String()) + } } }) } @@ -1755,6 +1947,27 @@ func TestTree_Methods(t *testing.T) { methods = f.Tree().Methods("*") assert.Equal(t, []string{"DELETE", "GET", "POST", "PUT"}, methods) + + // Ignore trailing slash disable + methods = f.Tree().Methods("/gists/123/star/") + assert.Empty(t, methods) +} + +func TestTree_MethodsWithIgnoreTsEnable(t *testing.T) { + f := New(WithIgnoreTrailingSlash(true)) + for _, method := range []string{"DELETE", "GET", "PUT"} { + require.NoError(t, f.Handle(method, "/foo/bar", emptyHandler)) + require.NoError(t, f.Handle(method, "/john/doe/", emptyHandler)) + } + + methods := f.Tree().Methods("/foo/bar/") + assert.Equal(t, []string{"DELETE", "GET", "PUT"}, methods) + + methods = f.Tree().Methods("/john/doe") + assert.Equal(t, []string{"DELETE", "GET", "PUT"}, methods) + + methods = f.Tree().Methods("/foo/bar/baz") + assert.Empty(t, methods) } func TestRouterWithAllowedMethod(t *testing.T) { @@ -1790,6 +2003,7 @@ func TestRouterWithAllowedMethod(t *testing.T) { }, } + require.True(t, r.MethodNotAllowedEnabled()) for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { for _, method := range tc.methods { @@ -1804,6 +2018,91 @@ func TestRouterWithAllowedMethod(t *testing.T) { } } +func TestRouterWithAllowedMethodAndIgnoreTsEnable(t *testing.T) { + r := New(WithNoMethod(true), WithIgnoreTrailingSlash(true)) + + // Support for ignore Trailing slash + cases := []struct { + name string + target string + path string + req string + want string + methods []string + }{ + { + name: "all route except the last one", + methods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodConnect, http.MethodOptions, http.MethodHead}, + path: "/foo/bar/", + req: "/foo/bar", + target: http.MethodTrace, + want: "GET, POST, PUT, DELETE, PATCH, CONNECT, OPTIONS, HEAD", + }, + { + name: "all route except the first one", + methods: []string{http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodConnect, http.MethodOptions, http.MethodHead, http.MethodTrace}, + path: "/foo/baz", + req: "/foo/baz/", + target: http.MethodGet, + want: "POST, PUT, DELETE, PATCH, CONNECT, OPTIONS, HEAD, TRACE", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + for _, method := range tc.methods { + require.NoError(t, r.Tree().Handle(method, tc.path, emptyHandler)) + } + req := httptest.NewRequest(tc.target, tc.req, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) + assert.Equal(t, tc.want, w.Header().Get("Allow")) + }) + } +} + +func TestRouterWithAllowedMethodAndIgnoreTsDisable(t *testing.T) { + r := New(WithNoMethod(true)) + + // Support for ignore Trailing slash + cases := []struct { + name string + target string + path string + req string + want int + methods []string + }{ + { + name: "all route except the last one", + methods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodConnect, http.MethodOptions, http.MethodHead}, + path: "/foo/bar/", + req: "/foo/bar", + target: http.MethodTrace, + }, + { + name: "all route except the first one", + methods: []string{http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch, http.MethodConnect, http.MethodOptions, http.MethodHead, http.MethodTrace}, + path: "/foo/baz", + req: "/foo/baz/", + target: http.MethodGet, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + for _, method := range tc.methods { + require.NoError(t, r.Tree().Handle(method, tc.path, emptyHandler)) + } + req := httptest.NewRequest(tc.target, tc.req, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) + }) + } +} + func TestRouterWithMethodNotAllowedHandler(t *testing.T) { f := New(WithNoMethodHandler(func(c Context) { c.SetHeader("FOO", "BAR") @@ -1876,6 +2175,83 @@ func TestRouterWithAutomaticOptions(t *testing.T) { }, } + require.True(t, f.AutoOptionsEnabled()) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + for _, method := range tc.methods { + require.NoError(t, f.Tree().Handle(method, tc.path, func(c Context) { + c.SetHeader("Allow", strings.Join(c.Tree().Methods(c.Request().URL.Path), ", ")) + c.Writer().WriteHeader(http.StatusNoContent) + })) + } + req := httptest.NewRequest(http.MethodOptions, tc.target, nil) + w := httptest.NewRecorder() + f.ServeHTTP(w, req) + assert.Equal(t, tc.wantCode, w.Code) + assert.Equal(t, tc.want, w.Header().Get("Allow")) + // Reset + f.Swap(f.NewTree()) + }) + } +} + +func TestRouterWithAutomaticOptionsAndIgnoreTsOptionEnable(t *testing.T) { + f := New(WithAutoOptions(true), WithIgnoreTrailingSlash(true)) + + cases := []struct { + name string + target string + path string + want string + wantCode int + methods []string + }{ + { + name: "system-wide requests", + target: "*", + path: "/foo", + methods: []string{"GET", "TRACE", "PUT"}, + want: "GET, PUT, TRACE, OPTIONS", + wantCode: http.StatusOK, + }, + { + name: "system-wide with custom options registered", + target: "*", + path: "/foo", + methods: []string{"GET", "TRACE", "PUT", "OPTIONS"}, + want: "GET, PUT, TRACE, OPTIONS", + wantCode: http.StatusOK, + }, + { + name: "system-wide requests with empty router", + target: "*", + wantCode: http.StatusNotFound, + }, + { + name: "regular option request and ignore ts", + target: "/foo/", + path: "/foo", + methods: []string{"GET", "TRACE", "PUT"}, + want: "GET, PUT, TRACE, OPTIONS", + wantCode: http.StatusOK, + }, + { + name: "regular option request with handler priority and ignore ts", + target: "/foo", + path: "/foo/", + methods: []string{"GET", "TRACE", "PUT", "OPTIONS"}, + want: "GET, OPTIONS, PUT, TRACE", + wantCode: http.StatusNoContent, + }, + { + name: "regular option request with no matching route", + target: "/bar", + path: "/foo", + methods: []string{"GET", "TRACE", "PUT"}, + wantCode: http.StatusNotFound, + }, + } + for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { for _, method := range tc.methods { @@ -1895,6 +2271,50 @@ func TestRouterWithAutomaticOptions(t *testing.T) { } } +func TestRouterWithAutomaticOptionsAndIgnoreTsOptionDisable(t *testing.T) { + f := New(WithAutoOptions(true)) + + cases := []struct { + name string + target string + path string + wantCode int + methods []string + }{ + { + name: "regular option request and ignore ts", + target: "/foo/", + path: "/foo", + methods: []string{"GET", "TRACE", "PUT"}, + wantCode: http.StatusNotFound, + }, + { + name: "regular option request with handler priority and ignore ts", + target: "/foo", + path: "/foo/", + methods: []string{"GET", "TRACE", "PUT", "OPTIONS"}, + wantCode: http.StatusNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + for _, method := range tc.methods { + require.NoError(t, f.Tree().Handle(method, tc.path, func(c Context) { + c.SetHeader("Allow", strings.Join(c.Tree().Methods(c.Request().URL.Path), ", ")) + c.Writer().WriteHeader(http.StatusNoContent) + })) + } + req := httptest.NewRequest(http.MethodOptions, tc.target, nil) + w := httptest.NewRecorder() + f.ServeHTTP(w, req) + assert.Equal(t, tc.wantCode, w.Code) + // Reset + f.Swap(f.NewTree()) + }) + } +} + func TestRouterWithOptionsHandler(t *testing.T) { f := New(WithOptionsHandler(func(c Context) { assert.Equal(t, "/foo/bar", c.Path()) @@ -2006,6 +2426,7 @@ func TestTree_Has(t *testing.T) { "/foo/bar", "/welcome/{name}", "/users/uid_{id}", + "/john/doe/", } r := New() @@ -2024,9 +2445,18 @@ func TestTree_Has(t *testing.T) { want: true, }, { - name: "no match static route", + name: "strict match static route", + path: "/john/doe/", + want: true, + }, + { + name: "no match static route (tsr)", path: "/foo/bar/", }, + { + name: "no match static route (tsr)", + path: "/john/doe", + }, { name: "strict match route params", path: "/welcome/{name}", @@ -2077,6 +2507,10 @@ func TestTree_Match(t *testing.T) { path: "/foo/bar", want: "/foo/bar", }, + { + name: "reverse static route with tsr disable", + path: "/foo/bar/", + }, { name: "reverse params route", path: "/welcome/fox", @@ -2100,6 +2534,66 @@ func TestTree_Match(t *testing.T) { } } +func TestTree_MatchWithIgnoreTrailingSlashEnable(t *testing.T) { + routes := []string{ + "/foo/bar", + "/welcome/{name}/", + "/users/uid_{id}", + } + + r := New(WithIgnoreTrailingSlash(true)) + for _, rte := range routes { + require.NoError(t, r.Handle(http.MethodGet, rte, emptyHandler)) + } + + cases := []struct { + name string + path string + want string + }{ + { + name: "reverse static route", + path: "/foo/bar", + want: "/foo/bar", + }, + { + name: "reverse static route with tsr", + path: "/foo/bar/", + want: "/foo/bar", + }, + { + name: "reverse params route", + path: "/welcome/fox/", + want: "/welcome/{name}/", + }, + { + name: "reverse params route with tsr", + path: "/welcome/fox", + want: "/welcome/{name}/", + }, + { + name: "reverse mid params route", + path: "/users/uid_123", + want: "/users/uid_{id}", + }, + { + name: "reverse mid params route with tsr", + path: "/users/uid_123/", + want: "/users/uid_{id}", + }, + { + name: "reverse no match", + path: "/users/fox", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, r.Tree().Match(http.MethodGet, tc.path)) + }) + } +} + func TestEncodedPath(t *testing.T) { encodedPath := "run/cmd/S123L%2FA" req := httptest.NewRequest(http.MethodGet, "/"+encodedPath, nil) @@ -2142,8 +2636,9 @@ func TestFuzzInsertLookupParam(t *testing.T) { nds := *tree.nodes.Load() c := newTestContextTree(tree) - n, _ := tree.lookup(nds[0], fmt.Sprintf(reqFormat, s1, "xxxx", s2, "xxxx", "xxxx"), c.params, c.skipNds, false) + n, tsr := tree.lookup(nds[0], fmt.Sprintf(reqFormat, s1, "xxxx", s2, "xxxx", "xxxx"), c.params, c.skipNds, false) require.NotNil(t, n) + assert.False(t, tsr) assert.Equal(t, fmt.Sprintf(routeFormat, s1, e1, s2, e2, e3), n.path) assert.Equal(t, "xxxx", c.Param(e1)) assert.Equal(t, "xxxx", c.Param(e2)) @@ -2201,8 +2696,9 @@ func TestFuzzInsertLookupUpdateAndDelete(t *testing.T) { for rte := range routes { nds := *tree.nodes.Load() c := newTestContextTree(tree) - n, _ := tree.lookup(nds[0], "/"+rte, c.params, c.skipNds, true) + n, tsr := tree.lookup(nds[0], "/"+rte, c.params, c.skipNds, true) require.NotNilf(t, n, "route /%s", rte) + require.Falsef(t, tsr, "tsr: %t", tsr) require.Truef(t, n.isLeaf(), "route /%s", rte) require.Equal(t, "/"+rte, n.path) require.NoError(t, tree.update(http.MethodGet, "/"+rte, "", emptyHandler)) @@ -2359,12 +2855,11 @@ func atomicSync() (start func(), wait func()) { // which include the Recovery middleware. A basic route is defined, along with a // custom middleware to log the request metrics. func ExampleNew() { - // Create a new router with default options, which include the Recovery middleware r := New(DefaultOptions()) // Define a custom middleware to measure the time taken for request processing and - // log the URL, route, time elapsed, and status code + // log the URL, route, time elapsed, and status code. metrics := func(next HandlerFunc) HandlerFunc { return func(c Context) { start := time.Now() @@ -2374,7 +2869,7 @@ func ExampleNew() { } // Define a route with the path "/hello/{name}", apply the custom "metrics" middleware, - // and set a simple handler that greets the user by their name + // and set a simple handler that greets the user by their name. r.MustHandle(http.MethodGet, "/hello/{name}", metrics(func(c Context) { _ = c.String(200, "Hello %s\n", c.Param("name")) })) @@ -2385,11 +2880,9 @@ func ExampleNew() { // This example demonstrates how to register a global middleware that will be // applied to all routes. - func ExampleWithMiddleware() { - // Define a custom middleware to measure the time taken for request processing and - // log the URL, route, time elapsed, and status code + // log the URL, route, time elapsed, and status code. metrics := func(next HandlerFunc) HandlerFunc { return func(c Context) { start := time.Now() @@ -2404,9 +2897,9 @@ func ExampleWithMiddleware() { } } - r := New(WithMiddleware(metrics)) + f := New(WithMiddleware(metrics)) - r.MustHandle(http.MethodGet, "/hello/{name}", func(c Context) { + f.MustHandle(http.MethodGet, "/hello/{name}", func(c Context) { _ = c.String(200, "Hello %s\n", c.Param("name")) }) } @@ -2447,31 +2940,82 @@ func ExampleRouter_Tree() { } // This example demonstrates how to create a custom middleware that cleans the request path and performs a manual -// lookup on the tree. If the cleaned path matches a registered route, the client is redirected with a 301 status -// code (Moved Permanently). -func ExampleTree_Match() { +// lookup on the tree. If the cleaned path matches a registered route, the client is redirected to the valid path. +func ExampleRouter_Lookup() { redirectFixedPath := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { return func(c Context) { req := c.Request() + target := req.URL.Path + cleanedPath := CleanPath(target) - cleanedPath := CleanPath(req.URL.Path) - if match := c.Tree().Match(req.Method, cleanedPath); match != "" { - // 301 redirect and returns. - req.URL.Path = cleanedPath - http.Redirect(c.Writer(), req, req.URL.String(), http.StatusMovedPermanently) + // Nothing to clean, call next handler or middleware. + if cleanedPath == target { + next(c) return } + req.URL.Path = cleanedPath + handler, cc, tsr := c.Fox().Lookup(c.Writer(), req) + if handler != nil { + defer cc.Close() + + code := http.StatusMovedPermanently + if req.Method != http.MethodGet { + code = http.StatusPermanentRedirect + } + + // Redirect the client if direct match or indirect match. + if !tsr || c.Fox().IgnoreTrailingSlashEnabled() { + if err := c.Redirect(code, cleanedPath); err != nil { + // Only if not in the range 300..308, so not possible here! + panic(err) + } + return + } + + // Add or remove an extra trailing slash and redirect the client. + if c.Fox().RedirectTrailingSlashEnabled() { + if err := c.Redirect(code, fixTrailingSlash(cleanedPath)); err != nil { + // Only if not in the range 300..308, so not possible here + panic(err) + } + return + } + } + + // rollback to the original path before calling the + // next handler or middleware. + req.URL.Path = target next(c) } }) f := New( // Register the middleware for the NoRouteHandler scope. - WithMiddlewareFor(NoRouteHandler, redirectFixedPath), + WithMiddlewareFor(NoRouteHandler|NoMethodHandler, redirectFixedPath), ) - f.MustHandle(http.MethodGet, "/foo/bar", func(c Context) { - _ = c.String(http.StatusOK, "foo bar") + f.MustHandle(http.MethodGet, "/hello/{name}", func(c Context) { + _ = c.String(200, "Hello %s\n", c.Param("name")) }) } + +// This example demonstrates how to do a reverse lookup on the tree. +func ExampleTree_Match() { + f := New() + f.MustHandle(http.MethodGet, "/hello/{name}", emptyHandler) + + tree := f.Tree() + matched := tree.Match(http.MethodGet, "/hello/fox") + fmt.Println(matched) // /hello/{name} +} + +// This example demonstrates how to check if a given route is registered in the tree. +func ExampleTree_Has() { + f := New() + f.MustHandle(http.MethodGet, "/hello/{name}", emptyHandler) + + tree := f.Tree() + exist := tree.Match(http.MethodGet, "/hello/{name}") + fmt.Println(exist) // true +} diff --git a/go.mod b/go.mod index c98a1aa..d873c60 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/google/gofuzz v1.2.0 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 ) require ( diff --git a/go.sum b/go.sum index eb76b96..3034cdf 100644 --- a/go.sum +++ b/go.sum @@ -17,8 +17,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/node.go b/node.go index d7acd58..7b52568 100644 --- a/node.go +++ b/node.go @@ -214,7 +214,7 @@ func (n *skippedNodes) pop() skippedNode { } type skippedNode struct { - node *node + parent *node pathIndex int paramCnt uint32 seen bool diff --git a/options.go b/options.go index 9bd4da5..071a6c3 100644 --- a/options.go +++ b/options.go @@ -113,13 +113,25 @@ func WithAutoOptions(enable bool) Option { // WithRedirectTrailingSlash enable automatic redirection fallback when the current request does not match but // another handler is found with/without an additional trailing slash. E.g. /foo/bar/ request does not match // but /foo/bar would match. The client is redirected with a http status code 301 for GET requests and 308 for -// all other methods. +// all other methods. Note that this option is mutually exclusive with WithIgnoreTrailingSlash, and if both are +// enabled, WithIgnoreTrailingSlash takes precedence. func WithRedirectTrailingSlash(enable bool) Option { return optionFunc(func(r *Router) { r.redirectTrailingSlash = enable }) } +// WithIgnoreTrailingSlash allows the router to match routes regardless of whether a trailing slash is present or not. +// E.g. /foo/bar/ and /foo/bar would both match the same handler. This option prevents the router from issuing +// a redirect and instead matches the request directly. Note that this option is mutually exclusive with +// WithRedirectTrailingSlash, and if both are enabled, WithIgnoreTrailingSlash takes precedence. +// This api is EXPERIMENTAL and is likely to change in future release. +func WithIgnoreTrailingSlash(enable bool) Option { + return optionFunc(func(r *Router) { + r.ignoreTrailingSlash = enable + }) +} + // DefaultOptions configure the router to use the Recovery middleware for the RouteHandlers scope and enable // automatic OPTIONS response. Note that DefaultOptions push the Recovery middleware to the first position of the // middleware chains. diff --git a/tree.go b/tree.go index e8c8f35..3332930 100644 --- a/tree.go +++ b/tree.go @@ -6,6 +6,7 @@ package fox import ( "fmt" + "net/http" "sort" "strings" "sync" @@ -21,16 +22,17 @@ import ( // the wrong Tree. // // Good: -// t := r.Tree() +// t := fox.Tree() // t.Lock() // defer t.Unlock() // // Dramatically bad, may cause deadlock -// r.Tree().Lock() -// defer r.Tree().Unlock() +// fox.Tree().Lock() +// defer fox.Tree().Unlock() type Tree struct { ctx sync.Pool nodes atomic.Pointer[[]*node] + fox *Router mws []middleware sync.Mutex maxParams atomic.Uint32 @@ -57,7 +59,7 @@ func (t *Tree) Handle(method, path string, handler HandlerFunc) error { // Update override an existing handler for the given method and path. If the route does not exist, // the function return an ErrRouteNotFound. It's perfectly safe to update a handler while the tree is in use for // serving requests. However, this function is NOT thread-safe and should be run serially, along with all other -// Tree APIs that perform write operations. To add new handler, use Handle method. +// Tree APIs that perform write operations. To add a new handler, use Handle method. func (t *Tree) Update(method, path string, handler HandlerFunc) error { if method == "" { return fmt.Errorf("%w: missing http method", ErrInvalidRoute) @@ -92,8 +94,8 @@ func (t *Tree) Remove(method, path string) error { return nil } -// Has allows to check if the given method and path exactly match a registered route. This function is safe for -// concurrent use by multiple goroutine and while mutation on Tree are ongoing. +// Has allows to check if the given method and path exactly match a registered route. This function is safe for concurrent +// use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. func (t *Tree) Has(method, path string) bool { nds := *t.nodes.Load() @@ -104,13 +106,18 @@ func (t *Tree) Has(method, path string) bool { c := t.ctx.Get().(*context) c.resetNil() - n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) + n, tsr := t.lookup(nds[index], path, c.params, c.skipNds, true) c.Close() - return n != nil && n.path == path + if n != nil && !tsr { + return n.path == path + } + return false } -// Match perform a lookup on the tree for the given method and path and return the matching registered route if any. -// This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. +// Match perform a reverse lookup on the tree for the given method and path and return the matching registered route if any. When +// WithIgnoreTrailingSlash or WithRedirectTrailingSlash are enabled, Match will match a registered route regardless of an +// extra or missing trailing slash. This function is safe for concurrent use by multiple goroutine and while mutation on +// Tree are ongoing. See also Tree.Lookup as an alternative. // This API is EXPERIMENTAL and is likely to change in future release. func (t *Tree) Match(method, path string) string { nds := *t.nodes.Load() @@ -121,18 +128,19 @@ func (t *Tree) Match(method, path string) string { c := t.ctx.Get().(*context) c.resetNil() - n, _ := t.lookup(nds[index], path, c.params, c.skipNds, true) + n, tsr := t.lookup(nds[index], path, c.params, c.skipNds, true) c.Close() - if n == nil { - return "" + if n != nil && (!tsr || t.fox.redirectTrailingSlash || t.fox.ignoreTrailingSlash) { + return n.path } - return n.path + return "" } // Methods returns a sorted list of HTTP methods associated with a given path in the routing tree. If the path is "*", -// it returns all HTTP methods that have at least one route registered in the tree. For a specific path, it returns the methods -// that can route requests to that path. -// This function is safe for concurrent use by multiple goroutine and while mutation on Tree are ongoing. +// it returns all HTTP methods that have at least one route registered in the tree. For a specific path, it returns the +// methods that can route requests to that path. When WithIgnoreTrailingSlash or WithRedirectTrailingSlash are enabled, +// Methods will match a registered route regardless of an extra or missing trailing slash. This function is safe for +// concurrent use by multiple goroutine and while mutation on Tree are ongoing. // This API is EXPERIMENTAL and is likely to change in future release. func (t *Tree) Methods(path string) []string { var methods []string @@ -151,8 +159,8 @@ func (t *Tree) Methods(path string) []string { c := t.ctx.Get().(*context) c.resetNil() for i := range nds { - n, _ := t.lookup(nds[i], path, c.params, c.skipNds, true) - if n != nil { + n, tsr := t.lookup(nds[i], path, c.params, c.skipNds, true) + if n != nil && (!tsr || t.fox.redirectTrailingSlash || t.fox.ignoreTrailingSlash) { if methods == nil { methods = make([]string, 0) } @@ -166,6 +174,40 @@ func (t *Tree) Methods(path string) []string { return methods } +// Lookup performs a manual route lookup for a given http.Request, returning the matched HandlerFunc along with a +// ContextCloser, and a boolean indicating if the handler was matched by adding or removing a trailing slash +// (trailing slash action is recommended). The ContextCloser should always be closed if non-nil. This method is primarily +// intended for integrating the fox router into custom routing solutions or middleware. It requires the use of the original +// http.ResponseWriter, typically obtained from ServeHTTP. This function is safe for concurrent use by multiple goroutine +// and while mutation on Tree are ongoing. If there is a direct match or a tsr is possible, Lookup always return a +// HandlerFunc and a ContextCloser. +// This API is EXPERIMENTAL and is likely to change in future release. +func (t *Tree) Lookup(w http.ResponseWriter, r *http.Request) (handler HandlerFunc, cc ContextCloser, tsr bool) { + nds := *t.nodes.Load() + index := findRootNode(r.Method, nds) + + if index < 0 { + return + } + + c := t.ctx.Get().(*context) + c.Reset(w, r) + + target := r.URL.Path + if len(r.URL.RawPath) > 0 { + // Using RawPath to prevent unintended match (e.g. /search/a%2Fb/1) + target = r.URL.RawPath + } + + n, tsr := t.lookup(nds[index], target, c.params, c.skipNds, false) + if n != nil { + c.path = n.path + return n.handler, c, tsr + } + c.Close() + return nil, nil, tsr +} + // Insert is not safe for concurrent use. The path must start by '/' and it's not validated. Use // parseRoute before. func (t *Tree) insert(method, path, catchAllKey string, paramsN uint32, handler HandlerFunc) error { @@ -494,6 +536,7 @@ func (t *Tree) lookup(rootNode *node, path string, params *Params, skipNds *skip paramCnt uint32 ) + var parent *node current := rootNode.children[0].Load() *skipNds = (*skipNds)[:0] @@ -569,6 +612,7 @@ Walk: } idx = current.paramChildIndex + parent = current current = current.children[idx].Load() continue } @@ -577,6 +621,7 @@ Walk: if current.paramChildIndex >= 0 || current.catchAllKey != "" { *skipNds = append(*skipNds, skippedNode{current, charsMatched, paramCnt, false}) } + parent = current current = current.children[idx].Load() } } @@ -585,13 +630,30 @@ Walk: hasSkpNds := len(*skipNds) > 0 if !current.isLeaf() { + + if !tsr { + // Tsr recommendation: remove the extra trailing slash (got an exact match) + // If match the completely /foo/, we end up in an intermediary node which is not a leaf. + // /foo [leaf=/foo] + // / + // b/ [leaf=/foo/b/] + // x/ [leaf=/foo/x/] + // But the parent (/foo) could be a leaf. This is only valid if we have an exact match with + // the intermediary node (charsMatched == len(path)). + if strings.HasSuffix(path, "/") && parent != nil && parent.isLeaf() && charsMatched == len(path) { + tsr = true + n = parent + } + } + if hasSkpNds { goto Backtrack } - return nil, false + return n, tsr } + // From here we are always in a leaf if charsMatched == len(path) { if charsMatchedInNodeFound == len(current.key) { // Exact match, note that if we match a catch-all node @@ -608,11 +670,17 @@ Walk: if strings.HasSuffix(path, "/") { // Tsr recommendation: remove the extra trailing slash (got an exact match) remainingPrefix := current.key[:charsMatchedInNodeFound] - tsr = len(remainingPrefix) == 1 && remainingPrefix[0] == slashDelim + if len(remainingPrefix) == 1 && remainingPrefix[0] == slashDelim { + tsr = true + n = parent + } } else { // Tsr recommendation: add an extra trailing slash (got an exact match) remainingSuffix := current.key[charsMatchedInNodeFound:] - tsr = len(remainingSuffix) == 1 && remainingSuffix[0] == slashDelim + if len(remainingSuffix) == 1 && remainingSuffix[0] == slashDelim { + tsr = true + n = current + } } } @@ -620,7 +688,7 @@ Walk: goto Backtrack } - return nil, tsr + return n, tsr } } @@ -638,23 +706,26 @@ Walk: // Tsr recommendation: remove the extra trailing slash (got an exact match) if !tsr { remainingKeySuffix := path[charsMatched:] - tsr = len(remainingKeySuffix) == 1 && remainingKeySuffix[0] == slashDelim + if len(remainingKeySuffix) == 1 && remainingKeySuffix[0] == slashDelim { + tsr = true + n = current + } } if hasSkpNds { goto Backtrack } - return nil, tsr + return n, tsr } // Finally incomplete match to middle of edge Backtrack: if hasSkpNds { skipped := skipNds.pop() - if skipped.node.paramChildIndex < 0 || skipped.seen { + if skipped.parent.paramChildIndex < 0 || skipped.seen { // skipped is catch all - current = skipped.node + current = skipped.parent *params = (*params)[:skipped.paramCnt] if !lazy { @@ -670,18 +741,19 @@ Backtrack: // /foo/*{any} // /foo/{bar} // In this case we evaluate first the child param node and fall back to the catch-all. - if skipped.node.catchAllKey != "" { - *skipNds = append(*skipNds, skippedNode{skipped.node, skipped.pathIndex, skipped.paramCnt, true}) + if skipped.parent.catchAllKey != "" { + *skipNds = append(*skipNds, skippedNode{skipped.parent, skipped.pathIndex, skipped.paramCnt, true}) } - current = skipped.node.children[skipped.node.paramChildIndex].Load() + parent = skipped.parent + current = skipped.parent.children[skipped.parent.paramChildIndex].Load() *params = (*params)[:skipped.paramCnt] charsMatched = skipped.pathIndex goto Walk } - return nil, tsr + return n, tsr } func (t *Tree) search(rootNode *node, path string) searchResult {