diff --git a/mux.go b/mux.go index eaff897a..2975a9fa 100644 --- a/mux.go +++ b/mux.go @@ -12,6 +12,7 @@ import ( "net/url" "path" "regexp" + "strings" ) var ( @@ -30,7 +31,12 @@ var ( // NewRouter returns a new router instance. func NewRouter() *Router { - return &Router{namedRoutes: make(map[string]*Route)} + return &Router{ + namedRoutes: make(map[string]*Route), + routeConf: routeConf{ + methodMatcher: methodDefaultMatcher{}, + }, + } } // Router registers routes to be matched and dispatches a handler. @@ -107,6 +113,9 @@ type routeConf struct { buildScheme string buildVarsFunc BuildVarsFunc + + // Holds the default method matcher + methodMatcher matcher } // returns an effective deep copy of `routeConf` @@ -306,6 +315,24 @@ func (r *Router) UseEncodedPath() *Router { return r } +// MatchMethodCaseInsensitive defines the behaviour of ignoring casing for request methods. +func (r *Router) MatchMethodDefault() *Router { + r.methodMatcher = methodDefaultMatcher{} + return r +} + +// MatchMethodCaseInsensitive defines the behaviour of ignoring casing for request methods. +func (r *Router) MatchMethodCaseInsensitive() *Router { + r.methodMatcher = methodCaseInsensitiveMatcher{} + return r +} + +// MatchMethodExact defines the behaviour of matching exact request methods. +func (r *Router) MatchMethodExact() *Router { + r.methodMatcher = methodCaseExactMatcher{} + return r +} + // ---------------------------------------------------------------------------- // Route factories // ---------------------------------------------------------------------------- @@ -611,6 +638,13 @@ func matchInArray(arr []string, value string) bool { return false } +func sliceToUpper(slice []string) []string { + for k, v := range slice { + slice[k] = strings.ToUpper(v) + } + return slice +} + // matchMapWithString returns true if the given key/value pairs exist in a given map. func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool { for k, v := range toCheck { diff --git a/mux_test.go b/mux_test.go index bac758bc..f1e872b7 100644 --- a/mux_test.go +++ b/mux_test.go @@ -2136,6 +2136,147 @@ func TestNoMatchMethodErrorHandler(t *testing.T) { } } +func TestMethodMatchingCaseInsensitiveOnRoute(t *testing.T) { + func1 := func(w http.ResponseWriter, r *http.Request) {} + + r := NewRouter() + r.HandleFunc("/", func1).Methods("get") + + req, _ := http.NewRequest("get", "http://localhost/", nil) + match := new(RouteMatch) + matched := r.Match(req, match) + + if matched { + t.Error("Should not have matched route for methods") + } + + if match.MatchErr != ErrMethodMismatch { + t.Error("Should get ErrMethodMismatch error") + } + + resp := NewRecorder() + r.ServeHTTP(resp, req) + if resp.Code != http.StatusMethodNotAllowed { + t.Errorf("Expecting code %v", 405) + } + + // Add matching route + r.HandleFunc("/", func1).MethodsCaseInsensitive("GET") + + match = new(RouteMatch) + matched = r.Match(req, match) + + if !matched { + t.Error("Should have matched route") + } + + if match.MatchErr != nil { + t.Error("Should not have any matching error. Found:", match.MatchErr) + } +} + +func TestMethodMatchingCaseInsensitiveOnRouter(t *testing.T) { + func1 := func(w http.ResponseWriter, r *http.Request) {} + + r := NewRouter() + r.HandleFunc("/", func1).Methods("get") + + req, _ := http.NewRequest("get", "http://localhost/", nil) + match := new(RouteMatch) + matched := r.Match(req, match) + + if matched { + t.Error("Should not have matched route for methods") + } + + if match.MatchErr != ErrMethodMismatch { + t.Error("Should get ErrMethodMismatch error") + } + + resp := NewRecorder() + r.ServeHTTP(resp, req) + if resp.Code != http.StatusMethodNotAllowed { + t.Errorf("Expecting code %v", 405) + } + + r.MatchMethodCaseInsensitive() + r.HandleFunc("/a", func1).Methods("get").Name("t") + req, _ = http.NewRequest("get", "http://localhost/a", nil) + + match = new(RouteMatch) + matched = r.Match(req, match) + + if !matched { + t.Error("Should have matched route") + } + + if match.MatchErr != nil { + t.Error("Should not have any matching error. Found:", match.MatchErr) + } +} + +func TestMethodMatchingCaseExact(t *testing.T) { + func1 := func(w http.ResponseWriter, r *http.Request) {} + + r := NewRouter() + r.HandleFunc("/a", func1).Methods("get") + r.HandleFunc("/b", func1).MethodsCaseExact("get") + + req, _ := http.NewRequest("get", "http://localhost/a", nil) + match := new(RouteMatch) + matched := r.Match(req, match) + + if matched { + t.Error("Should not have matched route for method") + } + + if match.MatchErr != ErrMethodMismatch { + t.Error("Should get ErrMethodMismatch error") + } + + resp := NewRecorder() + r.ServeHTTP(resp, req) + if resp.Code != http.StatusMethodNotAllowed { + t.Errorf("Expecting code %v", 405) + } + + req, _ = http.NewRequest("GET", "http://localhost/b", nil) + match = new(RouteMatch) + matched = r.Match(req, match) + + if matched { + t.Error("Should not have matched route for method") + } + + if match.MatchErr != ErrMethodMismatch { + t.Error("Should get ErrMethodMismatch error") + } + + resp = NewRecorder() + r.ServeHTTP(resp, req) + if resp.Code != http.StatusMethodNotAllowed { + t.Errorf("Expecting code %v", 405) + } + + resp = NewRecorder() + r.ServeHTTP(resp, req) + if resp.Code != http.StatusMethodNotAllowed { + t.Errorf("Expecting code %v", 405) + } + + req, _ = http.NewRequest("get", "http://localhost/b", nil) + match = new(RouteMatch) + matched = r.Match(req, match) + + if !matched { + t.Error("Should have matched route") + } + + if match.MatchErr != nil { + t.Error("Should not have any matching error. Found:", match.MatchErr) + } +} + func TestMultipleDefinitionOfSamePathWithDifferentMethods(t *testing.T) { emptyHandler := func(w http.ResponseWriter, r *http.Request) {} diff --git a/old_test.go b/old_test.go index 96dbe337..ff843c53 100644 --- a/old_test.go +++ b/old_test.go @@ -281,29 +281,29 @@ var hostMatcherTests = []hostMatcherTest{ } type methodMatcherTest struct { - matcher methodMatcher + matcher methodDefaultMatcher method string result bool } var methodMatcherTests = []methodMatcherTest{ { - matcher: methodMatcher([]string{"GET", "POST", "PUT"}), + matcher: ([]string{"GET", "POST", "PUT"}), method: "GET", result: true, }, { - matcher: methodMatcher([]string{"GET", "POST", "PUT"}), + matcher: methodDefaultMatcher([]string{"GET", "POST", "PUT"}), method: "POST", result: true, }, { - matcher: methodMatcher([]string{"GET", "POST", "PUT"}), + matcher: methodDefaultMatcher([]string{"GET", "POST", "PUT"}), method: "PUT", result: true, }, { - matcher: methodMatcher([]string{"GET", "POST", "PUT"}), + matcher: methodDefaultMatcher([]string{"GET", "POST", "PUT"}), method: "DELETE", result: false, }, diff --git a/route.go b/route.go index d10401e9..6964fd22 100644 --- a/route.go +++ b/route.go @@ -54,7 +54,13 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool { // Match everything. for _, m := range r.matchers { if matched := m.Match(req, match); !matched { - if _, ok := m.(methodMatcher); ok { + if _, ok := m.(methodDefaultMatcher); ok { + matchErr = ErrMethodMismatch + continue + } else if _, ok := m.(methodCaseInsensitiveMatcher); ok { + matchErr = ErrMethodMismatch + continue + } else if _, ok := m.(methodCaseExactMatcher); ok { matchErr = ErrMethodMismatch continue } @@ -385,10 +391,27 @@ func (r *Route) MatcherFunc(f MatcherFunc) *Route { // Methods -------------------------------------------------------------------- -// methodMatcher matches the request against HTTP methods. -type methodMatcher []string +// methodDefaultMatcher matches the request against HTTP methods. +// The supplied methods will be transformed to uppercase. The request method not. +type methodDefaultMatcher []string + +func (m methodDefaultMatcher) Match(r *http.Request, match *RouteMatch) bool { + return matchInArray(m, r.Method) +} + +// methodMatcher matches the request against HTTP methods without case sensitivity. +// Both the supplied methods as well as the request method will be transformed to uppercase. +type methodCaseInsensitiveMatcher []string + +func (m methodCaseInsensitiveMatcher) Match(r *http.Request, match *RouteMatch) bool { + return matchInArray(m, strings.ToUpper(r.Method)) +} + +// methodCaseExactMatcher matches the request against HTTP methods exactly. +// No transformation of supplied methods or the request method is applied. +type methodCaseExactMatcher []string -func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool { +func (m methodCaseExactMatcher) Match(r *http.Request, match *RouteMatch) bool { return matchInArray(m, r.Method) } @@ -396,10 +419,36 @@ func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool { // It accepts a sequence of one or more methods to be matched, e.g.: // "GET", "POST", "PUT". func (r *Route) Methods(methods ...string) *Route { - for k, v := range methods { - methods[k] = strings.ToUpper(v) + if _, ok := r.methodMatcher.(methodCaseInsensitiveMatcher); ok { + return r.MethodsCaseInsensitive(methods...) + } else if _, ok := r.methodMatcher.(methodCaseExactMatcher); ok { + return r.MethodsCaseExact(methods...) + } else { + return r.MethodsDefault(methods...) } - return r.addMatcher(methodMatcher(methods)) +} + +// Methods adds a matcher for HTTP methods. +// It accepts a sequence of one or more methods to be matched, e.g.: +// "GET", "POST", "PUT". +func (r *Route) MethodsDefault(methods ...string) *Route { + return r.addMatcher(methodDefaultMatcher(sliceToUpper(methods))) +} + +// MethodsCaseInsensitive adds a matcher for HTTP methods without case sensitivity. +// This will override the initial config on the router for 'matchMethodCaseInsensitive' +// It accepts a sequence of one or more methods to be matched, e.g.: +// "GET", "POST", "PUT". +func (r *Route) MethodsCaseInsensitive(methods ...string) *Route { + return r.addMatcher(methodCaseInsensitiveMatcher(sliceToUpper(methods))) +} + +// MethodsCaseInsensitive adds a matcher for exact HTTP methods with no transformation. +// This will override the initial config on the router for 'matchMethodCaseInsensitive' +// It accepts a sequence of one or more methods to be matched, e.g.: +// "GET", "POST", "PUT". +func (r *Route) MethodsCaseExact(methods ...string) *Route { + return r.addMatcher(methodCaseExactMatcher(methods)) } // Path ----------------------------------------------------------------------- @@ -769,7 +818,11 @@ func (r *Route) GetMethods() ([]string, error) { return nil, r.err } for _, m := range r.matchers { - if methods, ok := m.(methodMatcher); ok { + if methods, ok := m.(methodDefaultMatcher); ok { + return []string(methods), nil + } else if methods, ok := m.(methodCaseInsensitiveMatcher); ok { + return []string(methods), nil + } else if methods, ok := m.(methodCaseExactMatcher); ok { return []string(methods), nil } } @@ -826,3 +879,21 @@ func (r *Route) buildVars(m map[string]string) map[string]string { } return m } + +// MatchMethodDefault defines the behaviour of matching request methods with the default matcher on this route. +func (r *Route) MatchMethodDefault() *Route { + r.methodMatcher = methodDefaultMatcher{} + return r +} + +// MatchMethodCaseInsensitive defines the behaviour of ignoring casing for request methods on this route. +func (r *Route) MatchMethodCaseInsensitive() *Route { + r.methodMatcher = methodCaseInsensitiveMatcher{} + return r +} + +// MatchMethodCaseExact defines the behaviour of matching exact request methods on this route. +func (r *Route) MatchMethodCaseExact(value bool) *Route { + r.methodMatcher = methodCaseExactMatcher{} + return r +}