From e29f18029a17afe0f036543dfbb189926971bcfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Corn=C3=A9=20de=20Jong?= <5366568-cornedejong@users.noreply.gitlab.com> Date: Sun, 26 May 2024 09:39:04 +0200 Subject: [PATCH 1/3] added simple initial implementation of router in request context --- mux.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mux.go b/mux.go index c9396e91..eb03d48d 100644 --- a/mux.go +++ b/mux.go @@ -94,6 +94,9 @@ type routeConf struct { // If true, the http.Request context will not contain the Route. omitRouteFromContext bool + // if true, the the http.Request context will not contain the router + omitRouterFromContext bool + // Manager for the variables from host and path. regexp routeRegexpGroup @@ -207,6 +210,10 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } else { req = requestWithRouteAndVars(req, match.Route, match.Vars) } + + if !r.omitRouterFromContext { + req = requestWithRouter(req, r) + } } } @@ -443,6 +450,7 @@ type contextKey int const ( varsKey contextKey = iota routeKey + routerKey ) // Vars returns the route variables for the current request, if any. @@ -464,6 +472,13 @@ func CurrentRoute(r *http.Request) *Route { return nil } +func CurrentRouter(r *http.Request) *Router { + if rv := r.Context().Value(routerKey); rv != nil { + return rv.(*Router) + } + return nil +} + // requestWithVars adds the matched vars to the request ctx. // It shortcuts the operation when the vars are empty. func requestWithVars(r *http.Request, vars map[string]string) *http.Request { @@ -486,6 +501,11 @@ func requestWithRouteAndVars(r *http.Request, route *Route, vars map[string]stri return r.WithContext(ctx) } +func requestWithRouter(r *http.Request, router *Router) *http.Request { + ctx := context.WithValue(r.Context(), routerKey, router) + return r.WithContext(ctx) +} + // ---------------------------------------------------------------------------- // Helpers // ---------------------------------------------------------------------------- From b392eeae18e5df21281640139b1f9f54554cf1e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Corn=C3=A9=20de=20Jong?= <5366568-cornedejong@users.noreply.gitlab.com> Date: Tue, 28 May 2024 21:36:01 +0200 Subject: [PATCH 2/3] added conf method and defined initial test func --- mux.go | 9 +++++++++ mux_test.go | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/mux.go b/mux.go index eb03d48d..eaff897a 100644 --- a/mux.go +++ b/mux.go @@ -286,6 +286,15 @@ func (r *Router) OmitRouteFromContext(value bool) *Router { return r } +// OmitRouterFromContext defines the behavior of omitting the Router from the +// http.Request context. +// +// RouterFromRequest will yield nil with this option. +func (r *Router) OmitRouterFromContext(value bool) *Router { + r.omitRouterFromContext = value + return r +} + // UseEncodedPath tells the router to match the encoded original path // to the routes. // For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to". diff --git a/mux_test.go b/mux_test.go index 0845d7f7..b078f5f8 100644 --- a/mux_test.go +++ b/mux_test.go @@ -1768,6 +1768,10 @@ func TestPanicOnCapturingGroups(t *testing.T) { NewRouter().NewRoute().Path("/{type:(promo|special)}/{promoId}.json") } +func TestRouterInContext(t *testing.T) { + // TODO Write tests for router in context +} + // ---------------------------------------------------------------------------- // Helpers // ---------------------------------------------------------------------------- From 40d0f3cb1622263a16a0f7729240606f3ae18d7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Corn=C3=A9=20de=20Jong?= <5366568-cornedejong@users.noreply.gitlab.com> Date: Fri, 31 May 2024 18:41:26 +0200 Subject: [PATCH 3/3] added test for router in context --- mux_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/mux_test.go b/mux_test.go index b078f5f8..bac758bc 100644 --- a/mux_test.go +++ b/mux_test.go @@ -1769,7 +1769,70 @@ func TestPanicOnCapturingGroups(t *testing.T) { } func TestRouterInContext(t *testing.T) { - // TODO Write tests for router in context + router := NewRouter() + router.HandleFunc("/r1", func(w http.ResponseWriter, r *http.Request) { + contextRouter := CurrentRouter(r) + if contextRouter == nil { + t.Fatal("Router not found in context") + return + } + + route := contextRouter.Get("r2") + if route == nil { + t.Fatal("Route with name not found") + return + } + + url, err := route.URL() + if err != nil { + t.Fatal("Error while getting url for r2: ", err) + return + } + + _, err = w.Write([]byte(url.String())) + if err != nil { + t.Fatalf("Failed writing HTTP response: %v", err) + } + }).Name("r1") + + noRouterMsg := []byte("no-router") + haveRouterMsg := []byte("have-router") + router.HandleFunc("/r2", func(w http.ResponseWriter, r *http.Request) { + var msg []byte + + contextRouter := CurrentRouter(r) + if contextRouter == nil { + msg = noRouterMsg + } else { + msg = haveRouterMsg + } + + _, err := w.Write(msg) + if err != nil { + t.Fatalf("Failed writing HTTP response: %v", err) + } + }).Name("r2") + + t.Run("router in request context get route by name", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/r1") + + router.ServeHTTP(rw, req) + if !bytes.Equal(rw.Body.Bytes(), []byte("/r2")) { + t.Fatalf("Expected output to be '/r1' but got '%s'", rw.Body.String()) + } + }) + + t.Run("omit router from request context", func(t *testing.T) { + rw := NewRecorder() + req := newRequest("GET", "/r2") + + router.OmitRouterFromContext(true) + router.ServeHTTP(rw, req) + if !bytes.Equal(rw.Body.Bytes(), noRouterMsg) { + t.Fatal("Router not omitted from context") + } + }) } // ----------------------------------------------------------------------------