diff --git a/README.md b/README.md index 4d46907..6e3fa38 100644 --- a/README.md +++ b/README.md @@ -102,8 +102,9 @@ if errors.Is(err, fox.ErrRouteConflict) { ``` #### Named parameters -A route can be defined using placeholder (e.g `{name}`). The matching segment are recorder into the `fox.Params` slice accessible -via `fox.Context`. The `Param` and `Get` methods are helpers to retrieve the value using the placeholder name. +A route can be defined using placeholder (e.g `{name}`). The matching segment are recorder into `fox.Param` accessible +via `fox.Context`. `fox.Context.Params` provide an iterator to range over `fox.Param` and `fox.Context.Param` allow +to retrieve directly the value of a parameter using the placeholder name. ```` Pattern /avengers/{name} @@ -168,11 +169,10 @@ GET /fs/*{filepath} #3 => match /fs/avengers/ironman.txt #### Warning about context The `fox.Context` instance is freed once the request handler function returns to optimize resource allocation. -If you need to retain `fox.Context` or `fox.Params` beyond the scope of the handler, use the `Clone` methods. +If you need to retain `fox.Context` beyond the scope of the handler, use the `fox.Context.Clone` methods. ````go func Hello(c fox.Context) { cc := c.Clone() - // cp := c.Params().Clone() go func() { time.Sleep(2 * time.Second) log.Println(cc.Param("name")) // Safe @@ -455,7 +455,7 @@ f := fox.New( Finally, it's also possible to attaches middleware on a per-route basis. Note that route-specific middleware must be explicitly reapplied when updating a route. If not, any middleware will be removed, and the route will fall back to using only global middleware (if any). -```` +````go f := fox.New( fox.WithMiddleware(fox.Logger()), ) diff --git a/context.go b/context.go index 9303490..0e587b6 100644 --- a/context.go +++ b/context.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "io" + "iter" "net" "net/http" "net/url" @@ -15,7 +16,7 @@ import ( "strings" ) -// ContextCloser extends Context for manually created instances, adding a Close method +// ContextCloser extends [Context] for manually created instances, adding a Close method // to release resources after use. type ContextCloser interface { Context @@ -25,41 +26,42 @@ type ContextCloser interface { // Context represents the context of the current HTTP request. It provides methods to access request data and // to write a response. Be aware that the Context API is not thread-safe and its lifetime should be limited to the -// duration of the HandlerFunc execution, as the underlying implementation may be reused a soon as the handler return. -// (see Clone method). +// duration of the [HandlerFunc] execution, as the underlying implementation may be reused a soon as the handler return. +// (see [Context.Clone] method). type Context interface { - // Request returns the current *http.Request. + // Request returns the current [http.Request]. Request() *http.Request - // SetRequest sets the *http.Request. + // SetRequest sets the [*http.Request]. SetRequest(r *http.Request) - // Writer method returns a custom ResponseWriter implementation. + // Writer method returns a custom [ResponseWriter] implementation. Writer() ResponseWriter - // SetWriter sets the ResponseWriter. + // SetWriter sets the [ResponseWriter]. SetWriter(w ResponseWriter) - // RemoteIP parses the IP from Request.RemoteAddr, normalizes it, and returns an IP address. The returned *net.IPAddr + // RemoteIP parses the IP from [http.Request.RemoteAddr], normalizes it, and returns an IP address. The returned [net.IPAddr] // may contain a zone identifier. RemoteIP never returns nil, even if parsing the IP fails. RemoteIP() *net.IPAddr - // ClientIP returns the "real" client IP address based on the configured ClientIPStrategy. - // The strategy is set using the WithClientIPStrategy option. There is no sane default, so if no strategy is configured, - // the method returns ErrNoClientIPStrategy. + // ClientIP returns the "real" client IP address based on the configured [ClientIPStrategy]. + // The strategy is set using the [WithClientIPStrategy] option. There is no sane default, so if no strategy is configured, + // the method returns [ErrNoClientIPStrategy]. // // The strategy used must be chosen and tuned for your network configuration. This should result // in the strategy never returning an error -- i.e., never failing to find a candidate for the "real" IP. // Consequently, getting an error result should be treated as an application error, perhaps even // worthy of panicking. // - // The returned *net.IPAddr may contain a zone identifier. + // The returned [net.IPAddr] may contain a zone identifier. // // This api is EXPERIMENTAL and is likely to change in future release. ClientIP() (*net.IPAddr, error) - // Path returns the registered path for the handler. + // Path returns the registered path or an empty string if the handler is called in a scope other than [RouteHandler]. Path() string - // Params returns a Params slice containing the matched - // wildcard parameters. - Params() Params + // Route returns the registered route or nil if the handler is called in a scope other than [RouteHandler]. + Route() *Route + // Params returns a range iterator over the matched wildcard parameters for the current route. + Params() iter.Seq[Param] // Param retrieve a matching wildcard parameter by name. Param(name string) string - // QueryParams parses the Request RawQuery and returns the corresponding values. + // QueryParams parses the [http.Request] raw query and returns the corresponding values. QueryParams() url.Values // QueryParam returns the first query value associated with the given key. QueryParam(name string) string @@ -71,25 +73,32 @@ type Context interface { String(code int, format string, values ...any) error // Blob sends a byte slice with the specified status code and content type. Blob(code int, contentType string, buf []byte) error - // Stream sends data from an io.Reader with the specified status code and content type. + // Stream sends data from an [io.Reader] with the specified status code and content type. Stream(code int, contentType string, r io.Reader) error // Redirect sends an HTTP redirect response with the given status code and URL. Redirect(code int, url string) error - // Clone returns a copy of the Context that is safe to use after the HandlerFunc returns. + // Clone returns a copy of the [Context] that is safe to use after the [HandlerFunc] returns. Clone() Context - // CloneWith returns a copy of the current Context, substituting its ResponseWriter and - // http.Request with the provided ones. The method is designed for zero allocation during the - // copy process. The returned ContextCloser must be closed once no longer needed. - // This functionality is particularly beneficial for middlewares that need to wrap - // their custom ResponseWriter while preserving the state of the original Context. + // CloneWith returns a shallow copy of the current [Context], substituting its [ResponseWriter] and [http.Request] + // with the provided ones. The method is designed for zero allocation during the copy process. The returned + // [ContextCloser] must be closed once no longer needed. This functionality is particularly beneficial for + // middlewares that need to wrap their custom [ResponseWriter] while preserving the state of the original [Context]. CloneWith(w ResponseWriter, r *http.Request) ContextCloser - // Scope returns the HandlerScope associated with the current Context. - // This indicates the scope in which the handler is being executed, such as RouteHandler, NoRouteHandler, etc. + // Scope returns the [HandlerScope] associated with the current [Context]. + // This indicates the scope in which the handler is being executed, such as [RouteHandler], [NoRouteHandler], etc. Scope() HandlerScope - // Tree is a local copy of the Tree in use to serve the request. + // Tree is a local copy of the [Tree] in use to serve the request. Tree() *Tree - // Fox returns the Router instance. + // Fox returns the [Router] instance. Fox() *Router + // Rehydrate updates the current [Context] to serve the provided [Route], bypassing the need for a full tree lookup. + // It succeeds only if the [http.Request]'s URL path strictly matches the given [Route]. If successful, the internal state + // of the context is updated, allowing the context to serve the route directly, regardless of whether the route + // still exists in the routing tree. This provides a key advantage in concurrent scenarios where routes may be + // modified by other threads, as Rehydrate guarantees success if the path matches, without requiring serial execution + // or tree lookups. Note that the context's state is only mutated if the rehydration is successful. + // This api is EXPERIMENTAL and is likely to change in future release. + Rehydrate(route *Route) bool } // cTx holds request-related information and allows interaction with the ResponseWriter. @@ -122,6 +131,42 @@ func (c *cTx) Reset(w ResponseWriter, r *http.Request) { *c.params = (*c.params)[:0] } +// Rehydrate updates the current Context to serve the provided Route, bypassing the need for a full tree lookup. +// It succeeds only if the Request's URL path strictly matches the given Route. If successful, the internal state +// of the context is updated, allowing the context to serve the route directly, regardless of whether the route +// still exists in the routing tree. This provides a key advantage in concurrent scenarios where routes may be +// modified by other threads, as Rehydrate guarantees success if the path matches, without requiring serial execution +// or tree lookups. Note that the context's state is only mutated if the rehydration is successful. +// This api is EXPERIMENTAL and is likely to change in future release. +func (c *cTx) Rehydrate(route *Route) bool { + + target := c.req.URL.Path + if len(c.req.URL.RawPath) > 0 { + // Using RawPath to prevent unintended match (e.g. /search/a%2Fb/1) + target = c.req.URL.RawPath + } + + var params *Params + if c.tsr { + *c.params = (*c.params)[:0] + params = c.params + } else { + *c.tsrParams = (*c.tsrParams)[:0] + params = c.tsrParams + } + + if !route.hydrateParams(target, params) { + return false + } + + *c.params, *c.tsrParams = *c.tsrParams, *c.params + c.cachedQuery = nil + c.route = route + c.scope = RouteHandler + + return true +} + // reset resets the Context to its initial state, attaching the provided http.ResponseWriter and http.Request. // Caution: always pass the original http.ResponseWriter to this method, not the ResponseWriter itself, to // avoid wrapping the ResponseWriter within itself. Use wisely! @@ -199,19 +244,29 @@ func (c *cTx) ClientIP() (*net.IPAddr, error) { return c.route.ipStrategy.ClientIP(c) } -// Params returns a Params slice containing the matched -// wildcard parameters. -func (c *cTx) Params() Params { - if c.tsr { - return *c.tsrParams +// Params returns an iterator over the matched wildcard parameters for the current route. +func (c *cTx) Params() iter.Seq[Param] { + return func(yield func(Param) bool) { + if c.tsr { + for _, p := range *c.tsrParams { + if !yield(p) { + return + } + } + return + } + for _, p := range *c.params { + if !yield(p) { + return + } + } } - return *c.params } // Param retrieve a matching wildcard segment by name. // It's a helper for c.Params.Get(name). func (c *cTx) Param(name string) string { - for _, p := range c.Params() { + for p := range c.Params() { if p.Key == name { return p.Value } @@ -242,7 +297,7 @@ func (c *cTx) Header(key string) string { return c.req.Header.Get(key) } -// Path returns the registered path for the handler. +// Path returns the registered path or an empty string if the handler is called in a scope other than RouteHandler. func (c *cTx) Path() string { if c.route == nil { return "" @@ -250,6 +305,11 @@ func (c *cTx) Path() string { return c.route.path } +// Route returns the registered route or nil if the handler is called in a scope other than RouteHandler. +func (c *cTx) Route() *Route { + return c.route +} + // String sends a formatted string with the specified status code. func (c *cTx) String(code int, format string, values ...any) (err error) { if c.w.Header().Get(HeaderContentType) == "" { @@ -295,8 +355,8 @@ func (c *cTx) Fox() *Router { return c.fox } -// Clone returns a copy of the Context that is safe to use after the HandlerFunc returns. -// Any attempt to write on the ResponseWriter will panic with the error ErrDiscardedResponseWriter. +// Clone returns a deep copy of the [Context] that is safe to use after the [HandlerFunc] returns. +// Any attempt to write on the [ResponseWriter] will panic with the error [ErrDiscardedResponseWriter]. func (c *cTx) Clone() Context { cp := cTx{ rec: c.rec, @@ -305,22 +365,29 @@ func (c *cTx) Clone() Context { tree: c.tree, route: c.route, scope: c.scope, + tsr: c.tsr, } cp.rec.ResponseWriter = noopWriter{c.rec.Header().Clone()} cp.w = noUnwrap{&cp.rec} - params := make(Params, len(*c.params)) - copy(params, *c.params) - cp.params = ¶ms + if !c.tsr { + params := make(Params, len(*c.params)) + copy(params, *c.params) + cp.params = ¶ms + } else { + tsrParams := make(Params, len(*c.tsrParams)) + copy(tsrParams, *c.tsrParams) + cp.tsrParams = &tsrParams + } + cp.cachedQuery = nil return &cp } -// CloneWith returns a copy of the current Context, substituting its ResponseWriter and -// http.Request with the provided ones. The method is designed for zero allocation during the -// copy process. The returned ContextCloser must be closed once no longer needed. -// This functionality is particularly beneficial for middlewares that need to wrap -// their custom ResponseWriter while preserving the state of the original Context. +// CloneWith returns a shallow copy of the current [Context], substituting its [ResponseWriter] and [http.Request] with the +// provided ones. The method is designed for zero allocation during the copy process. The returned [ContextCloser] must +// be closed once no longer needed. This functionality is particularly beneficial for middlewares that need to wrap +// their custom [ResponseWriter] while preserving the state of the original [Context]. func (c *cTx) CloneWith(w ResponseWriter, r *http.Request) ContextCloser { cp := c.tree.ctx.Get().(*cTx) cp.req = r @@ -328,17 +395,28 @@ func (c *cTx) CloneWith(w ResponseWriter, r *http.Request) ContextCloser { cp.route = c.route cp.scope = c.scope cp.cachedQuery = nil - if cap(*c.params) > cap(*cp.params) { - // Grow cp.params to a least cap(c.params) - *cp.params = slices.Grow(*cp.params, cap(*c.params)) + cp.tsr = c.tsr + + if !c.tsr { + copyParams(c.params, cp.params) + } else { + copyParams(c.tsrParams, cp.tsrParams) } - // cap(cp.params) >= cap(c.params) - // now constraint into len(c.params) & cap(c.params) - *cp.params = (*cp.params)[:len(*c.params):cap(*c.params)] - copy(*cp.params, *c.params) + return cp } +func copyParams(src, dst *Params) { + if cap(*src) > cap(*dst) { + // Grow dst to a least cap(src) + *dst = slices.Grow(*dst, cap(*src)) + } + // cap(dst) >= cap(src) + // now constraint into len(src) & cap(src) + *dst = (*dst)[:len(*src):cap(*src)] + copy(*dst, *src) +} + // Scope returns the HandlerScope associated with the current Context. // This indicates the scope in which the handler is being executed, such as RouteHandler, NoRouteHandler, etc. func (c *cTx) Scope() HandlerScope { @@ -370,8 +448,9 @@ func (c *cTx) getQueries() url.Values { // The route parameters are being accessed by the wrapped handler through the context. func WrapF(f http.HandlerFunc) HandlerFunc { return func(c Context) { - if len(c.Params()) > 0 { - ctx := context.WithValue(c.Request().Context(), paramsKey, c.Params().Clone()) + var params Params = slices.Collect(c.Params()) + if len(params) > 0 { + ctx := context.WithValue(c.Request().Context(), paramsKey, params) f.ServeHTTP(c.Writer(), c.Request().WithContext(ctx)) return } @@ -384,8 +463,9 @@ func WrapF(f http.HandlerFunc) HandlerFunc { // The route parameters are being accessed by the wrapped handler through the context. func WrapH(h http.Handler) HandlerFunc { return func(c Context) { - if len(c.Params()) > 0 { - ctx := context.WithValue(c.Request().Context(), paramsKey, c.Params().Clone()) + var params Params = slices.Collect(c.Params()) + if len(params) > 0 { + ctx := context.WithValue(c.Request().Context(), paramsKey, params) h.ServeHTTP(c.Writer(), c.Request().WithContext(ctx)) return } diff --git a/context_test.go b/context_test.go index 7ec3cb6..a0e1aaa 100644 --- a/context_test.go +++ b/context_test.go @@ -8,12 +8,118 @@ import ( "bytes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "io" "net/http" "net/http/httptest" "net/url" + "slices" "testing" ) +func TestContext_Rehydrate(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://example.com/foo/bar/baz", nil) + w := httptest.NewRecorder() + + c := NewTestContextOnly(New(), w, req) + cTx := unwrapContext(t, c) + + cases := []struct { + name string + route *Route + tsr bool + want bool + wantParams Params + }{ + { + name: "succeed using tsr params", + route: &Route{ + path: "/foo/{$1}/{$2}", + }, + tsr: false, + want: true, + wantParams: Params{ + { + Key: "$1", + Value: "bar", + }, + { + Key: "$2", + Value: "baz", + }, + }, + }, + { + name: "succeed using params", + route: &Route{ + path: "/foo/{$1}/{$2}", + }, + tsr: true, + want: true, + wantParams: Params{ + { + Key: "$1", + Value: "bar", + }, + { + Key: "$2", + Value: "baz", + }, + }, + }, + { + name: "fail using tsr params", + route: &Route{ + path: "/foo/{$1}/bili", + }, + tsr: false, + want: false, + wantParams: Params{ + { + Key: "old", + Value: "params", + }, + }, + }, + { + name: "fail using params", + route: &Route{ + path: "/foo/{$1}/bili", + }, + tsr: true, + want: false, + wantParams: Params{ + { + Key: "old", + Value: "tsrParams", + }, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + *cTx.params = Params{{Key: "old", Value: "params"}} + *cTx.tsrParams = Params{{Key: "old", Value: "tsrParams"}} + cTx.tsr = tc.tsr + cTx.cachedQuery = url.Values{"old": []string{"old"}} + cTx.route = nil + cTx.scope = NoRouteHandler + got := c.Rehydrate(tc.route) + require.Equal(t, tc.want, got) + assert.Equal(t, tc.wantParams, Params(slices.Collect(c.Params()))) + if got { + assert.Equal(t, RouteHandler, c.Scope()) + assert.Equal(t, tc.route, c.Route()) + assert.Nil(t, cTx.cachedQuery) + } else { + assert.Equal(t, NoRouteHandler, c.Scope()) + assert.Nil(t, c.Route()) + assert.Equal(t, url.Values{"old": []string{"old"}}, cTx.cachedQuery) + } + }) + } +} + func TestContext_Writer_ReadFrom(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) w := httptest.NewRecorder() @@ -82,6 +188,35 @@ func TestContext_QueryParam(t *testing.T) { assert.Equal(t, wantValues, c.cachedQuery) } +func TestContext_Route(t *testing.T) { + t.Parallel() + f := New() + f.MustHandle(http.MethodGet, "/foo", func(c Context) { + require.NotNil(t, c.Route()) + _, _ = io.WriteString(c.Writer(), c.Route().Path()) + }) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) + f.ServeHTTP(w, r) + assert.Equal(t, "/foo", w.Body.String()) +} + +func TestContext_Annotations(t *testing.T) { + t.Parallel() + f := New() + f.MustHandle( + http.MethodGet, + "/foo", + emptyHandler, + WithAnnotations(Annotation{Key: "foo", Value: "bar"}, Annotation{Key: "foo", Value: "baz"}), + WithAnnotation("john", 1), + ) + rte := f.Tree().Route(http.MethodGet, "/foo") + require.NotNil(t, rte) + assert.Equal(t, []Annotation{{"foo", "bar"}, {"foo", "baz"}, {"john", 1}}, slices.Collect(rte.Annotations())) +} + func TestContext_Clone(t *testing.T) { t.Parallel() wantValues := url.Values{ @@ -92,6 +227,7 @@ func TestContext_Clone(t *testing.T) { req.URL.RawQuery = wantValues.Encode() c := newTextContextOnly(New(), httptest.NewRecorder(), req) + *c.params = Params{{Key: "foo", Value: "bar"}} buf := []byte("foo bar") _, err := c.w.Write(buf) @@ -99,11 +235,17 @@ func TestContext_Clone(t *testing.T) { cc := c.Clone() assert.Equal(t, http.StatusOK, cc.Writer().Status()) + assert.Equal(t, slices.Collect(c.Params()), slices.Collect(cc.Params())) assert.Equal(t, len(buf), cc.Writer().Size()) assert.Equal(t, wantValues, c.QueryParams()) assert.Panics(t, func() { _, _ = cc.Writer().Write([]byte("invalid")) }) + + c.tsr = true + *c.tsrParams = Params{{Key: "john", Value: "doe"}} + cc = c.Clone() + assert.Equal(t, slices.Collect(c.Params()), slices.Collect(cc.Params())) } func TestContext_CloneWith(t *testing.T) { @@ -111,16 +253,21 @@ func TestContext_CloneWith(t *testing.T) { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil) c := newTextContextOnly(New(), w, req) + *c.params = Params{{Key: "foo", Value: "bar"}} cp := c.CloneWith(c.Writer(), c.Request()) cc := unwrapContext(t, cp) - - assert.Equal(t, c.Params(), cp.Params()) + assert.Equal(t, slices.Collect(c.Params()), slices.Collect(cp.Params())) assert.Equal(t, c.Request(), cp.Request()) assert.Equal(t, c.Writer(), cp.Writer()) assert.Equal(t, c.Path(), cp.Path()) assert.Equal(t, c.Fox(), cp.Fox()) assert.Nil(t, cc.cachedQuery) + + c.tsr = true + *c.tsrParams = Params{{Key: "john", Value: "doe"}} + cp = c.CloneWith(c.Writer(), c.Request()) + assert.Equal(t, slices.Collect(c.Params()), slices.Collect(cp.Params())) } func TestContext_Redirect(t *testing.T) { @@ -382,7 +529,7 @@ func TestWrapF(t *testing.T) { params := make(Params, 0) if tc.params != nil { - params = tc.params.Clone() + params = tc.params.clone() c.(*cTx).params = ¶ms } @@ -442,7 +589,7 @@ func TestWrapH(t *testing.T) { params := make(Params, 0) if tc.params != nil { - params = tc.params.Clone() + params = tc.params.clone() c.(*cTx).params = ¶ms } @@ -452,3 +599,20 @@ func TestWrapH(t *testing.T) { }) } } + +func BenchmarkContext_Rehydrate(b *testing.B) { + req := httptest.NewRequest(http.MethodGet, "/foo/ab:1/baz/123/y/bo/lo", nil) + w := httptest.NewRecorder() + + f := New() + f.MustHandle(http.MethodGet, "/foo/ab:{bar}/baz/{x}/{y}/*{zo}", emptyHandler) + rte, c, _ := f.Lookup(&recorder{ResponseWriter: w}, req) + defer c.Close() + + b.ResetTimer() + b.ReportAllocs() + + for range b.N { + c.Rehydrate(rte) + } +} diff --git a/fox.go b/fox.go index f1f1f95..f9613da 100644 --- a/fox.go +++ b/fox.go @@ -85,57 +85,6 @@ const ( AllHandlers = RouteHandler | NoRouteHandler | NoMethodHandler | RedirectHandler | OptionsHandler ) -// Route represent a registered route in the route tree. -// Most of the Route API is EXPERIMENTAL and is likely to change in future release. -type Route struct { - ipStrategy ClientIPStrategy - hbase HandlerFunc - hself HandlerFunc - hall HandlerFunc - path string - mws []middleware - redirectTrailingSlash bool - ignoreTrailingSlash bool -} - -// Handle calls the handler with the provided Context. See also HandleMiddleware. -func (r *Route) Handle(c Context) { - r.hbase(c) -} - -// HandleMiddleware calls the handler with route-specific middleware applied, using the provided Context. -func (r *Route) HandleMiddleware(c Context, _ ...struct{}) { - // The variadic parameter is intentionally added to prevent this method from having the same signature as HandlerFunc. - // This avoids accidental use of HandleMiddleware where a HandlerFunc is required. - r.hself(c) -} - -// Path returns the route path. -func (r *Route) Path() string { - return r.path -} - -// RedirectTrailingSlashEnabled returns whether the route is configured to automatically -// redirect requests that include or omit a trailing slash. -// This api is EXPERIMENTAL and is likely to change in future release. -func (r *Route) RedirectTrailingSlashEnabled() bool { - return r.redirectTrailingSlash -} - -// IgnoreTrailingSlashEnabled returns whether the route is configured to ignore -// trailing slashes in requests when matching routes. -// This api is EXPERIMENTAL and is likely to change in future release. -func (r *Route) IgnoreTrailingSlashEnabled() bool { - return r.ignoreTrailingSlash -} - -// ClientIPStrategyEnabled returns whether the route is configured with a ClientIPStrategy. -// This api is EXPERIMENTAL and is likely to change in future release. -func (r *Route) ClientIPStrategyEnabled() bool { - _, ok := r.ipStrategy.(noClientIPStrategy) - return !ok -} - // Router is a lightweight high performance HTTP request router that support mutation on its routing tree // while handling request concurrently. type Router struct { diff --git a/fox_test.go b/fox_test.go index 6a8bc0e..351dea0 100644 --- a/fox_test.go +++ b/fox_test.go @@ -643,24 +643,6 @@ func TestStaticRouteMalloc(t *testing.T) { } } -func TestRoute_HandleMiddlewareMalloc(t *testing.T) { - f := New() - for _, rte := range githubAPI { - require.NoError(t, f.Tree().Handle(rte.method, rte.path, emptyHandler)) - } - - for _, rte := range githubAPI { - req := httptest.NewRequest(rte.method, rte.path, nil) - w := httptest.NewRecorder() - r, c, _ := f.Lookup(&recorder{ResponseWriter: w}, req) - allocs := testing.AllocsPerRun(100, func() { - r.HandleMiddleware(c) - }) - c.Close() - assert.Equal(t, float64(0), allocs) - } -} - func TestParamsRoute(t *testing.T) { rx := regexp.MustCompile("({|\\*{)[A-z]+[}]") r := New() @@ -810,7 +792,7 @@ func TestRouteParamEmptySegment(t *testing.T) { c := newTestContextTree(tree) n, tsr := tree.lookup(nds[0], tc.path, c, false) assert.Nil(t, n) - assert.Empty(t, c.Params()) + assert.Empty(t, slices.Collect(c.Params())) assert.False(t, tsr) }) } @@ -1204,6 +1186,53 @@ func TestOverlappingRoute(t *testing.T) { }, }, }, + { + name: "param at index 1 with 2 nodes", + path: "/foo/[barr]", + routes: []string{ + "/foo/{bar}", + "/foo/[bar]", + }, + wantMatch: "/foo/{bar}", + wantParams: Params{ + { + Key: "bar", + Value: "[barr]", + }, + }, + }, + { + name: "param at index 1 with 3 nodes", + path: "/foo/|barr|", + routes: []string{ + "/foo/{bar}", + "/foo/[bar]", + "/foo/|bar|", + }, + wantMatch: "/foo/{bar}", + wantParams: Params{ + { + Key: "bar", + Value: "|barr|", + }, + }, + }, + { + name: "param at index 0 with 3 nodes", + path: "/foo/~barr~", + routes: []string{ + "/foo/{bar}", + "/foo/~bar~", + "/foo/|bar|", + }, + wantMatch: "/foo/{bar}", + wantParams: Params{ + { + Key: "bar", + Value: "~barr~", + }, + }, + }, } for _, tc := range cases { @@ -1221,9 +1250,10 @@ func TestOverlappingRoute(t *testing.T) { assert.False(t, tsr) assert.Equal(t, tc.wantMatch, n.route.path) if len(tc.wantParams) == 0 { - assert.Empty(t, c.Params()) + assert.Empty(t, slices.Collect(c.Params())) } else { - assert.Equal(t, tc.wantParams, c.Params()) + var params Params = slices.Collect(c.Params()) + assert.Equal(t, tc.wantParams, params) } // Test with lazy @@ -1232,7 +1262,7 @@ func TestOverlappingRoute(t *testing.T) { require.NotNil(t, n) require.NotNil(t, n.route) assert.False(t, tsr) - assert.Empty(t, c.Params()) + assert.Empty(t, slices.Collect(c.Params())) assert.Equal(t, tc.wantMatch, n.route.path) }) } @@ -1573,6 +1603,12 @@ func TestParseRoute(t *testing.T) { wantErr: ErrInvalidRoute, wantN: -1, }, + { + name: "unexpected character in param", + path: "/foo/{*bar}", + wantErr: ErrInvalidRoute, + wantN: -1, + }, { name: "in flight catch-all after param in one route segment", path: "/foo/{bar}*{baz}", @@ -2112,7 +2148,7 @@ func TestRouterWithTsrParams(t *testing.T) { name: "current not a leaf, should empty params", routes: []string{"/{a}", "/foo", "/foo/x/", "/foo/y/"}, target: "/foo/", - wantParams: Params{}, + wantParams: Params(nil), wantPath: "/foo", wantTsr: true, }, @@ -2124,7 +2160,8 @@ func TestRouterWithTsrParams(t *testing.T) { for _, rte := range tc.routes { require.NoError(t, f.Handle(http.MethodGet, rte, func(c Context) { assert.Equal(t, tc.wantPath, c.Path()) - assert.Equal(t, tc.wantParams, c.Params()) + var params Params = slices.Collect(c.Params()) + assert.Equal(t, tc.wantParams, params) assert.Equal(t, tc.wantTsr, unwrapContext(t, c).tsr) })) } @@ -2548,7 +2585,7 @@ func TestRouterWithAutomaticOptionsAndIgnoreTsOptionDisable(t *testing.T) { func TestRouterWithOptionsHandler(t *testing.T) { f := New(WithOptionsHandler(func(c Context) { assert.Equal(t, "", c.Path()) - assert.Empty(t, c.Params()) + assert.Empty(t, slices.Collect(c.Params())) c.Writer().WriteHeader(http.StatusNoContent) })) @@ -2825,7 +2862,7 @@ func TestTree_Has(t *testing.T) { } } -func TestTree_Match(t *testing.T) { +func TestTree_Route(t *testing.T) { routes := []string{ "/foo/bar", "/welcome/{name}", @@ -2884,7 +2921,7 @@ func TestTree_Match(t *testing.T) { } } -func TestTree_MatchWithIgnoreTrailingSlashEnable(t *testing.T) { +func TestTree_RouteWithIgnoreTrailingSlashEnable(t *testing.T) { routes := []string{ "/foo/bar", "/welcome/{name}/", diff --git a/helpers.go b/helpers.go index df46a14..b8209d7 100644 --- a/helpers.go +++ b/helpers.go @@ -27,6 +27,7 @@ func newTextContextOnly(fox *Router, w http.ResponseWriter, r *http.Request) *cT c.req = r c.rec.reset(w) c.w = &c.rec + c.scope = AllHandlers return c } diff --git a/logger.go b/logger.go index 7b9bdff..a4b7399 100644 --- a/logger.go +++ b/logger.go @@ -60,7 +60,6 @@ func LoggerWithHandler(handler slog.Handler) MiddlewareFunc { slog.String("location", location), ) } - } } } diff --git a/options.go b/options.go index daa0c75..ada69f9 100644 --- a/options.go +++ b/options.go @@ -4,7 +4,9 @@ package fox -import "cmp" +import ( + "cmp" +) type Option interface { GlobalOption @@ -210,16 +212,34 @@ func WithIgnoreTrailingSlash(enable bool) Option { // - Setting the strategy to nil is equivalent to no strategy configured. func WithClientIPStrategy(strategy ClientIPStrategy) Option { return optionFunc(func(router *Router, route *Route) { - if router != nil { - router.ipStrategy = cmp.Or(strategy, ClientIPStrategy(noClientIPStrategy{})) + if router != nil && strategy != nil { + router.ipStrategy = strategy } if route != nil { + // Apply no strategy if nil provided. route.ipStrategy = cmp.Or(strategy, ClientIPStrategy(noClientIPStrategy{})) } }) } +// WithAnnotations attach arbitrary metadata to routes. Annotations are key-value pairs that allow middleware, handler or +// any other components to modify behavior based on the attached metadata. Unlike context-based metadata, which is tied to +// the request lifetime, annotations are bound to the route's lifetime and remain static across all requests for that route. +// Annotations must be explicitly reapplied when updating a route. +func WithAnnotations(annotations ...Annotation) PathOption { + return pathOptionFunc(func(route *Route) { + route.annots = append(route.annots, annotations...) + }) +} + +// WithAnnotation attaches a single key-value annotation to a route. See also [WithAnnotations] and [Annotation] for more details. +func WithAnnotation(key string, value any) PathOption { + return pathOptionFunc(func(route *Route) { + route.annots = append(route.annots, Annotation{key, value}) + }) +} + // DefaultOptions configure the router to use the Recovery middleware for the RouteHandler scope, the Logger middleware // for AllHandlers scope and enable automatic OPTIONS response. Note that DefaultOptions push the Recovery and Logger middleware // respectively to the first and second position of the middleware chains. diff --git a/params.go b/params.go index 7b80f84..4ded7b6 100644 --- a/params.go +++ b/params.go @@ -39,8 +39,8 @@ func (p Params) Has(name string) bool { return false } -// Clone make a copy of Params. -func (p Params) Clone() Params { +// clone make a copy of Params. +func (p Params) clone() Params { cloned := make(Params, len(p)) copy(cloned, p) return cloned diff --git a/params_test.go b/params_test.go index 135046c..7beb46b 100644 --- a/params_test.go +++ b/params_test.go @@ -39,7 +39,7 @@ func TestParams_Clone(t *testing.T) { Value: "doe", }, ) - assert.Equal(t, params, params.Clone()) + assert.Equal(t, params, params.clone()) } func TestParams_Has(t *testing.T) { diff --git a/recovery.go b/recovery.go index 3f21efa..f84013d 100644 --- a/recovery.go +++ b/recovery.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "github.com/tigerwill90/fox/internal/slogpretty" + "iter" "log/slog" "net" "net/http" @@ -80,16 +81,17 @@ func recovery(logger *slog.Logger, c Context, handle RecoveryFunc) { sb.WriteString("Stack:\n") sb.WriteString(stacktrace(3, 6)) - params := c.Params() - attrs := make([]any, 0, len(params)) - for _, param := range params { - attrs = append(attrs, slog.String(param.Key, param.Value)) + params := slices.Collect(mapParamsToAttr(c.Params())) + var annotations []any + if route := c.Route(); route != nil { + annotations = slices.Collect(mapAnnotationsToAttr(route.Annotations())) } logger.Error( sb.String(), slog.String("path", c.Path()), - slog.Group("param", attrs...), + slog.Group("params", params...), + slog.Group("annotations", annotations...), slog.Any("error", err), ) @@ -137,3 +139,23 @@ func stacktrace(skip, nFrames int) string { } return b.String() } + +func mapParamsToAttr(params iter.Seq[Param]) iter.Seq[any] { + return func(yield func(any) bool) { + for p := range params { + if !yield(slog.String(p.Key, p.Value)) { + break + } + } + } +} + +func mapAnnotationsToAttr(annotations iter.Seq[Annotation]) iter.Seq[any] { + return func(yield func(any) bool) { + for a := range annotations { + if !yield(slog.Any(a.Key, a.Value)) { + break + } + } + } +} diff --git a/route.go b/route.go new file mode 100644 index 0000000..2847340 --- /dev/null +++ b/route.go @@ -0,0 +1,154 @@ +package fox + +import ( + "iter" + "strings" +) + +// Annotations is a collection of Annotation key-value pairs that can be attached to routes. +type Annotations []Annotation + +// Annotation represents a single key-value pair that provides metadata for a route. +// Annotations are typically used to store information that can be leveraged by middleware, handlers, or external +// libraries to modify or customize route behavior. +type Annotation struct { + Key string + Value any +} + +// Route represent a registered route in the route tree. +// Most of the Route API is EXPERIMENTAL and is likely to change in future release. +type Route struct { + ipStrategy ClientIPStrategy + hbase HandlerFunc + hself HandlerFunc + hall HandlerFunc + path string + mws []middleware + annots Annotations + redirectTrailingSlash bool + ignoreTrailingSlash bool +} + +// Handle calls the handler with the provided [Context]. See also [HandleMiddleware]. +func (r *Route) Handle(c Context) { + r.hbase(c) +} + +// HandleMiddleware calls the handler with route-specific middleware applied, using the provided [Context]. +func (r *Route) HandleMiddleware(c Context, _ ...struct{}) { + // The variadic parameter is intentionally added to prevent this method from having the same signature as HandlerFunc. + // This avoids accidental use of HandleMiddleware where a HandlerFunc is required. + r.hself(c) +} + +// Path returns the route path. +func (r *Route) Path() string { + return r.path +} + +// Annotations returns a range iterator over annotations associated with the route. +func (r *Route) Annotations() iter.Seq[Annotation] { + return func(yield func(Annotation) bool) { + for _, a := range r.annots { + if !yield(a) { + return + } + } + } +} + +// RedirectTrailingSlashEnabled returns whether the route is configured to automatically +// redirect requests that include or omit a trailing slash. +// This api is EXPERIMENTAL and is likely to change in future release. +func (r *Route) RedirectTrailingSlashEnabled() bool { + return r.redirectTrailingSlash +} + +// IgnoreTrailingSlashEnabled returns whether the route is configured to ignore +// trailing slashes in requests when matching routes. +// This api is EXPERIMENTAL and is likely to change in future release. +func (r *Route) IgnoreTrailingSlashEnabled() bool { + return r.ignoreTrailingSlash +} + +// ClientIPStrategyEnabled returns whether the route is configured with a [ClientIPStrategy]. +// This api is EXPERIMENTAL and is likely to change in future release. +func (r *Route) ClientIPStrategyEnabled() bool { + _, ok := r.ipStrategy.(noClientIPStrategy) + return !ok +} + +func (r *Route) hydrateParams(path string, params *Params) bool { + rLen := len(r.path) + pLen := len(path) + var i, j int + state := stateDefault + + // Note that we assume that this is a valid route (validated with parseRoute). +OUTER: + for i < rLen && j < pLen { + switch state { + case stateParam: + startPath := j + idx := strings.IndexByte(path[j:], slashDelim) + if idx > 0 { + j += idx + } else if idx < 0 { + j += len(path[j:]) + } else { + // segment is empty + return false + } + + startRoute := i + idx = strings.IndexByte(r.path[i:], slashDelim) + if idx >= 0 { + i += idx + } else { + i += len(r.path[i:]) + } + + *params = append(*params, Param{ + Key: r.path[startRoute : i-1], + Value: path[startPath:j], + }) + + state = stateDefault + + default: + if r.path[i] == '{' { + i++ + state = stateParam + continue + } + + if r.path[i] == '*' { + state = stateCatchAll + break OUTER + } + + if r.path[i] == path[j] { + i++ + j++ + continue + } + + return false + } + } + + if state == stateCatchAll || (i < rLen && r.path[i] == '*') { + *params = append(*params, Param{ + Key: r.path[i+2 : rLen-1], + Value: path[j:], + }) + return true + } + + if i == rLen && j == pLen { + return true + } + + return false +} diff --git a/route_test.go b/route_test.go new file mode 100644 index 0000000..0cdb3f6 --- /dev/null +++ b/route_test.go @@ -0,0 +1,226 @@ +package fox + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "net/http/httptest" + "testing" +) + +func TestRoute_HydrateParams(t *testing.T) { + cases := []struct { + name string + path string + route *Route + wantParams Params + want bool + }{ + { + name: "static route match", + path: "/foo/bar", + route: &Route{path: "/foo/bar"}, + wantParams: Params{}, + want: true, + }, + { + name: "static route do not match", + path: "/foo/bar", + route: &Route{path: "/foo/ba"}, + wantParams: Params{}, + want: false, + }, + { + name: "static route do not match", + path: "/foo/bar", + route: &Route{path: "/foo/barr"}, + wantParams: Params{}, + want: false, + }, + { + name: "static route do not match", + path: "/foo/bar", + route: &Route{path: "/foo/bax"}, + wantParams: Params{}, + want: false, + }, + { + name: "strict trailing slash", + path: "/foo/bar", + route: &Route{path: "/foo/bar/"}, + wantParams: Params{}, + want: false, + }, + { + name: "strict trailing slash with param and", + path: "/foo/bar", + route: &Route{path: "/foo/{1}/"}, + wantParams: Params{ + { + Key: "1", + Value: "bar", + }, + }, + want: false, + }, + { + name: "strict trailing slash with param", + path: "/foo/bar/", + route: &Route{path: "/foo/{2}"}, + wantParams: Params{ + { + Key: "2", + Value: "bar", + }, + }, + want: false, + }, + { + name: "strict trailing slash", + path: "/foo/bar/", + route: &Route{path: "/foo/bar"}, + wantParams: Params{}, + want: false, + }, + { + name: "multi route params and catch all", + path: "/foo/ab:1/baz/123/y/bo/lo", + route: &Route{path: "/foo/ab:{bar}/baz/{x}/{y}/*{zo}"}, + wantParams: Params{ + { + Key: "bar", + Value: "1", + }, + { + Key: "x", + Value: "123", + }, + { + Key: "y", + Value: "y", + }, + { + Key: "zo", + Value: "bo/lo", + }, + }, + want: true, + }, + { + name: "path with wildcard should be parsed", + path: "/foo/ab:{bar}/baz/{x}/{y}/*{zo}", + route: &Route{path: "/foo/ab:{bar}/baz/{x}/{y}/*{zo}"}, + wantParams: Params{ + { + Key: "bar", + Value: "{bar}", + }, + { + Key: "x", + Value: "{x}", + }, + { + Key: "y", + Value: "{y}", + }, + { + Key: "zo", + Value: "*{zo}", + }, + }, + want: true, + }, + { + name: "empty param end range", + path: "/foo/", + route: &Route{path: "/foo/{bar}"}, + wantParams: Params{}, + want: false, + }, + { + name: "empty param mid range", + path: "/foo//baz", + route: &Route{path: "/foo/{bar}/baz"}, + wantParams: Params{}, + want: false, + }, + { + name: "multiple slash", + path: "/foo/bar///baz", + route: &Route{path: "/foo/{bar}/baz"}, + wantParams: Params{ + { + Key: "bar", + Value: "bar", + }, + }, + want: false, + }, + { + name: "param at end range", + path: "/foo/baz", + route: &Route{path: "/foo/{bar}"}, + wantParams: Params{ + { + Key: "bar", + Value: "baz", + }, + }, + want: true, + }, + { + name: "full path catch all", + path: "/foo/bar/baz", + route: &Route{path: "/*{args}"}, + wantParams: Params{ + { + Key: "args", + Value: "foo/bar/baz", + }, + }, + want: true, + }, + } + + params := make(Params, 0) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + params = params[:0] + got := tc.route.hydrateParams(tc.path, ¶ms) + assert.Equal(t, tc.want, got) + assert.Equal(t, tc.wantParams, params) + }) + } + +} + +func TestRoute_HandleMiddlewareMalloc(t *testing.T) { + f := New() + for _, rte := range githubAPI { + require.NoError(t, f.Tree().Handle(rte.method, rte.path, emptyHandler)) + } + + for _, rte := range githubAPI { + req := httptest.NewRequest(rte.method, rte.path, nil) + w := httptest.NewRecorder() + r, c, _ := f.Lookup(&recorder{ResponseWriter: w}, req) + allocs := testing.AllocsPerRun(100, func() { + r.HandleMiddleware(c) + }) + c.Close() + assert.Equal(t, float64(0), allocs) + } +} + +func TestRoute_HydrateParamsMalloc(t *testing.T) { + rte := &Route{ + path: "/foo/ab:{bar}/baz/{x}/{y}/*{zo}", + } + path := "/foo/ab:1/baz/123/y/bo/lo" + params := make(Params, 0, 4) + + allocs := testing.AllocsPerRun(100, func() { + rte.hydrateParams(path, ¶ms) + params = params[:0] + }) + assert.Equal(t, float64(0), allocs) +}