Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/method case insensitive #764

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/url"
"path"
"regexp"
"strings"
)

var (
Expand All @@ -30,7 +31,12 @@ var (

// NewRouter returns a new router instance.
func NewRouter() *Router {
return &Router{namedRoutes: make(map[string]*Route)}
return &Router{
namedRoutes: make(map[string]*Route),
routeConf: routeConf{
methodMatcher: methodDefaultMatcher{},
},
}
}

// Router registers routes to be matched and dispatches a handler.
Expand Down Expand Up @@ -107,6 +113,9 @@ type routeConf struct {
buildScheme string

buildVarsFunc BuildVarsFunc

// Holds the default method matcher
methodMatcher matcher
}

// returns an effective deep copy of `routeConf`
Expand Down Expand Up @@ -306,6 +315,24 @@ func (r *Router) UseEncodedPath() *Router {
return r
}

// MatchMethodCaseInsensitive defines the behaviour of ignoring casing for request methods.
func (r *Router) MatchMethodDefault() *Router {
r.methodMatcher = methodDefaultMatcher{}
return r
}

// MatchMethodCaseInsensitive defines the behaviour of ignoring casing for request methods.
func (r *Router) MatchMethodCaseInsensitive() *Router {
r.methodMatcher = methodCaseInsensitiveMatcher{}
return r
}

// MatchMethodExact defines the behaviour of matching exact request methods.
func (r *Router) MatchMethodExact() *Router {
r.methodMatcher = methodCaseExactMatcher{}
return r
}

// ----------------------------------------------------------------------------
// Route factories
// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -611,6 +638,13 @@ func matchInArray(arr []string, value string) bool {
return false
}

func sliceToUpper(slice []string) []string {
for k, v := range slice {
slice[k] = strings.ToUpper(v)
}
return slice
}

// matchMapWithString returns true if the given key/value pairs exist in a given map.
func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool {
for k, v := range toCheck {
Expand Down
141 changes: 141 additions & 0 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,147 @@ func TestNoMatchMethodErrorHandler(t *testing.T) {
}
}

func TestMethodMatchingCaseInsensitiveOnRoute(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}

r := NewRouter()
r.HandleFunc("/", func1).Methods("get")

req, _ := http.NewRequest("get", "http://localhost/", nil)
match := new(RouteMatch)
matched := r.Match(req, match)

if matched {
t.Error("Should not have matched route for methods")
}

if match.MatchErr != ErrMethodMismatch {
t.Error("Should get ErrMethodMismatch error")
}

resp := NewRecorder()
r.ServeHTTP(resp, req)
if resp.Code != http.StatusMethodNotAllowed {
t.Errorf("Expecting code %v", 405)
}

// Add matching route
r.HandleFunc("/", func1).MethodsCaseInsensitive("GET")

match = new(RouteMatch)
matched = r.Match(req, match)

if !matched {
t.Error("Should have matched route")
}

if match.MatchErr != nil {
t.Error("Should not have any matching error. Found:", match.MatchErr)
}
}

func TestMethodMatchingCaseInsensitiveOnRouter(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}

r := NewRouter()
r.HandleFunc("/", func1).Methods("get")

req, _ := http.NewRequest("get", "http://localhost/", nil)
match := new(RouteMatch)
matched := r.Match(req, match)

if matched {
t.Error("Should not have matched route for methods")
}

if match.MatchErr != ErrMethodMismatch {
t.Error("Should get ErrMethodMismatch error")
}

resp := NewRecorder()
r.ServeHTTP(resp, req)
if resp.Code != http.StatusMethodNotAllowed {
t.Errorf("Expecting code %v", 405)
}

r.MatchMethodCaseInsensitive()
r.HandleFunc("/a", func1).Methods("get").Name("t")
req, _ = http.NewRequest("get", "http://localhost/a", nil)

match = new(RouteMatch)
matched = r.Match(req, match)

if !matched {
t.Error("Should have matched route")
}

if match.MatchErr != nil {
t.Error("Should not have any matching error. Found:", match.MatchErr)
}
}

func TestMethodMatchingCaseExact(t *testing.T) {
func1 := func(w http.ResponseWriter, r *http.Request) {}

r := NewRouter()
r.HandleFunc("/a", func1).Methods("get")
r.HandleFunc("/b", func1).MethodsCaseExact("get")

req, _ := http.NewRequest("get", "http://localhost/a", nil)
match := new(RouteMatch)
matched := r.Match(req, match)

if matched {
t.Error("Should not have matched route for method")
}

if match.MatchErr != ErrMethodMismatch {
t.Error("Should get ErrMethodMismatch error")
}

resp := NewRecorder()
r.ServeHTTP(resp, req)
if resp.Code != http.StatusMethodNotAllowed {
t.Errorf("Expecting code %v", 405)
}

req, _ = http.NewRequest("GET", "http://localhost/b", nil)
match = new(RouteMatch)
matched = r.Match(req, match)

if matched {
t.Error("Should not have matched route for method")
}

if match.MatchErr != ErrMethodMismatch {
t.Error("Should get ErrMethodMismatch error")
}

resp = NewRecorder()
r.ServeHTTP(resp, req)
if resp.Code != http.StatusMethodNotAllowed {
t.Errorf("Expecting code %v", 405)
}

resp = NewRecorder()
r.ServeHTTP(resp, req)
if resp.Code != http.StatusMethodNotAllowed {
t.Errorf("Expecting code %v", 405)
}

req, _ = http.NewRequest("get", "http://localhost/b", nil)
match = new(RouteMatch)
matched = r.Match(req, match)

if !matched {
t.Error("Should have matched route")
}

if match.MatchErr != nil {
t.Error("Should not have any matching error. Found:", match.MatchErr)
}
}

func TestMultipleDefinitionOfSamePathWithDifferentMethods(t *testing.T) {
emptyHandler := func(w http.ResponseWriter, r *http.Request) {}

Expand Down
10 changes: 5 additions & 5 deletions old_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,29 +281,29 @@ var hostMatcherTests = []hostMatcherTest{
}

type methodMatcherTest struct {
matcher methodMatcher
matcher methodDefaultMatcher
method string
result bool
}

var methodMatcherTests = []methodMatcherTest{
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
matcher: ([]string{"GET", "POST", "PUT"}),
method: "GET",
result: true,
},
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
matcher: methodDefaultMatcher([]string{"GET", "POST", "PUT"}),
method: "POST",
result: true,
},
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
matcher: methodDefaultMatcher([]string{"GET", "POST", "PUT"}),
method: "PUT",
result: true,
},
{
matcher: methodMatcher([]string{"GET", "POST", "PUT"}),
matcher: methodDefaultMatcher([]string{"GET", "POST", "PUT"}),
method: "DELETE",
result: false,
},
Expand Down
87 changes: 79 additions & 8 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
// Match everything.
for _, m := range r.matchers {
if matched := m.Match(req, match); !matched {
if _, ok := m.(methodMatcher); ok {
if _, ok := m.(methodDefaultMatcher); ok {
matchErr = ErrMethodMismatch
continue
} else if _, ok := m.(methodCaseInsensitiveMatcher); ok {
matchErr = ErrMethodMismatch
continue
} else if _, ok := m.(methodCaseExactMatcher); ok {
matchErr = ErrMethodMismatch
continue
}
Expand Down Expand Up @@ -385,21 +391,64 @@ func (r *Route) MatcherFunc(f MatcherFunc) *Route {

// Methods --------------------------------------------------------------------

// methodMatcher matches the request against HTTP methods.
type methodMatcher []string
// methodDefaultMatcher matches the request against HTTP methods.
// The supplied methods will be transformed to uppercase. The request method not.
type methodDefaultMatcher []string

func (m methodDefaultMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.Method)
}

// methodMatcher matches the request against HTTP methods without case sensitivity.
// Both the supplied methods as well as the request method will be transformed to uppercase.
type methodCaseInsensitiveMatcher []string

func (m methodCaseInsensitiveMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, strings.ToUpper(r.Method))
}

// methodCaseExactMatcher matches the request against HTTP methods exactly.
// No transformation of supplied methods or the request method is applied.
type methodCaseExactMatcher []string

func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool {
func (m methodCaseExactMatcher) Match(r *http.Request, match *RouteMatch) bool {
return matchInArray(m, r.Method)
}

// Methods adds a matcher for HTTP methods.
// It accepts a sequence of one or more methods to be matched, e.g.:
// "GET", "POST", "PUT".
func (r *Route) Methods(methods ...string) *Route {
for k, v := range methods {
methods[k] = strings.ToUpper(v)
if _, ok := r.methodMatcher.(methodCaseInsensitiveMatcher); ok {
return r.MethodsCaseInsensitive(methods...)
} else if _, ok := r.methodMatcher.(methodCaseExactMatcher); ok {
return r.MethodsCaseExact(methods...)
} else {
return r.MethodsDefault(methods...)
}
return r.addMatcher(methodMatcher(methods))
}

// Methods adds a matcher for HTTP methods.
// It accepts a sequence of one or more methods to be matched, e.g.:
// "GET", "POST", "PUT".
func (r *Route) MethodsDefault(methods ...string) *Route {
return r.addMatcher(methodDefaultMatcher(sliceToUpper(methods)))
}

// MethodsCaseInsensitive adds a matcher for HTTP methods without case sensitivity.
// This will override the initial config on the router for 'matchMethodCaseInsensitive'
// It accepts a sequence of one or more methods to be matched, e.g.:
// "GET", "POST", "PUT".
func (r *Route) MethodsCaseInsensitive(methods ...string) *Route {
return r.addMatcher(methodCaseInsensitiveMatcher(sliceToUpper(methods)))
}

// MethodsCaseInsensitive adds a matcher for exact HTTP methods with no transformation.
// This will override the initial config on the router for 'matchMethodCaseInsensitive'
// It accepts a sequence of one or more methods to be matched, e.g.:
// "GET", "POST", "PUT".
func (r *Route) MethodsCaseExact(methods ...string) *Route {
return r.addMatcher(methodCaseExactMatcher(methods))
}

// Path -----------------------------------------------------------------------
Expand Down Expand Up @@ -769,7 +818,11 @@ func (r *Route) GetMethods() ([]string, error) {
return nil, r.err
}
for _, m := range r.matchers {
if methods, ok := m.(methodMatcher); ok {
if methods, ok := m.(methodDefaultMatcher); ok {
return []string(methods), nil
} else if methods, ok := m.(methodCaseInsensitiveMatcher); ok {
return []string(methods), nil
} else if methods, ok := m.(methodCaseExactMatcher); ok {
return []string(methods), nil
}
}
Expand Down Expand Up @@ -826,3 +879,21 @@ func (r *Route) buildVars(m map[string]string) map[string]string {
}
return m
}

// MatchMethodDefault defines the behaviour of matching request methods with the default matcher on this route.
func (r *Route) MatchMethodDefault() *Route {
r.methodMatcher = methodDefaultMatcher{}
return r
}

// MatchMethodCaseInsensitive defines the behaviour of ignoring casing for request methods on this route.
func (r *Route) MatchMethodCaseInsensitive() *Route {
r.methodMatcher = methodCaseInsensitiveMatcher{}
return r
}

// MatchMethodCaseExact defines the behaviour of matching exact request methods on this route.
func (r *Route) MatchMethodCaseExact(value bool) *Route {
r.methodMatcher = methodCaseExactMatcher{}
return r
}