forked from zeromicro/go-zero
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: rest.WithChain to replace builtin middlewares (zeromicro#2033)
* feat: rest.WithChain to replace builtin middlewares * chore: add comments * chore: refine code
- Loading branch information
Showing
6 changed files
with
322 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
package chain | ||
|
||
// This is a modified version of https://github.com/justinas/alice | ||
// The original code is licensed under the MIT license. | ||
// It's modified for couple reasons: | ||
// - Added the Chain interface | ||
// - Added support for the Chain.Prepend(...) method | ||
|
||
import "net/http" | ||
|
||
type ( | ||
// Chain defines a chain of middleware. | ||
Chain interface { | ||
Append(middlewares ...Middleware) Chain | ||
Prepend(middlewares ...Middleware) Chain | ||
Then(h http.Handler) http.Handler | ||
ThenFunc(fn http.HandlerFunc) http.Handler | ||
} | ||
|
||
// Middleware is an HTTP middleware. | ||
Middleware func(http.Handler) http.Handler | ||
|
||
// chain acts as a list of http.Handler middlewares. | ||
// chain is effectively immutable: | ||
// once created, it will always hold | ||
// the same set of middlewares in the same order. | ||
chain struct { | ||
middlewares []Middleware | ||
} | ||
) | ||
|
||
// New creates a new Chain, memorizing the given list of middleware middlewares. | ||
// New serves no other function, middlewares are only called upon a call to Then() or ThenFunc(). | ||
func New(middlewares ...Middleware) Chain { | ||
return chain{middlewares: append(([]Middleware)(nil), middlewares...)} | ||
} | ||
|
||
// Append extends a chain, adding the specified middlewares as the last ones in the request flow. | ||
// | ||
// c := chain.New(m1, m2) | ||
// c.Append(m3, m4) | ||
// // requests in c go m1 -> m2 -> m3 -> m4 | ||
func (c chain) Append(middlewares ...Middleware) Chain { | ||
return chain{middlewares: join(c.middlewares, middlewares)} | ||
} | ||
|
||
// Prepend extends a chain by adding the specified chain as the first one in the request flow. | ||
// | ||
// c := chain.New(m3, m4) | ||
// c1 := chain.New(m1, m2) | ||
// c.Prepend(c1) | ||
// // requests in c go m1 -> m2 -> m3 -> m4 | ||
func (c chain) Prepend(middlewares ...Middleware) Chain { | ||
return chain{middlewares: join(middlewares, c.middlewares)} | ||
} | ||
|
||
// Then chains the middleware and returns the final http.Handler. | ||
// New(m1, m2, m3).Then(h) | ||
// is equivalent to: | ||
// m1(m2(m3(h))) | ||
// When the request comes in, it will be passed to m1, then m2, then m3 | ||
// and finally, the given handler | ||
// (assuming every middleware calls the following one). | ||
// | ||
// A chain can be safely reused by calling Then() several times. | ||
// stdStack := chain.New(ratelimitHandler, csrfHandler) | ||
// indexPipe = stdStack.Then(indexHandler) | ||
// authPipe = stdStack.Then(authHandler) | ||
// Note that middlewares are called on every call to Then() or ThenFunc() | ||
// and thus several instances of the same middleware will be created | ||
// when a chain is reused in this way. | ||
// For proper middleware, this should cause no problems. | ||
// | ||
// Then() treats nil as http.DefaultServeMux. | ||
func (c chain) Then(h http.Handler) http.Handler { | ||
if h == nil { | ||
h = http.DefaultServeMux | ||
} | ||
|
||
for i := range c.middlewares { | ||
h = c.middlewares[len(c.middlewares)-1-i](h) | ||
} | ||
|
||
return h | ||
} | ||
|
||
// ThenFunc works identically to Then, but takes | ||
// a HandlerFunc instead of a Handler. | ||
// | ||
// The following two statements are equivalent: | ||
// c.Then(http.HandlerFunc(fn)) | ||
// c.ThenFunc(fn) | ||
// | ||
// ThenFunc provides all the guarantees of Then. | ||
func (c chain) ThenFunc(fn http.HandlerFunc) http.Handler { | ||
// This nil check cannot be removed due to the "nil is not nil" common mistake in Go. | ||
// Required due to: https://stackoverflow.com/questions/33426977/how-to-golang-check-a-variable-is-nil | ||
if fn == nil { | ||
return c.Then(nil) | ||
} | ||
return c.Then(fn) | ||
} | ||
|
||
func join(a, b []Middleware) []Middleware { | ||
mids := make([]Middleware, 0, len(a)+len(b)) | ||
mids = append(mids, a...) | ||
mids = append(mids, b...) | ||
return mids | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package chain | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"reflect" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
// A constructor for middleware | ||
// that writes its own "tag" into the RW and does nothing else. | ||
// Useful in checking if a chain is behaving in the right order. | ||
func tagMiddleware(tag string) Middleware { | ||
return func(h http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.Write([]byte(tag)) | ||
h.ServeHTTP(w, r) | ||
}) | ||
} | ||
} | ||
|
||
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer), | ||
// but the best we can do. | ||
func funcsEqual(f1, f2 interface{}) bool { | ||
val1 := reflect.ValueOf(f1) | ||
val2 := reflect.ValueOf(f2) | ||
return val1.Pointer() == val2.Pointer() | ||
} | ||
|
||
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.Write([]byte("app\n")) | ||
}) | ||
|
||
func TestNew(t *testing.T) { | ||
c1 := func(h http.Handler) http.Handler { | ||
return nil | ||
} | ||
|
||
c2 := func(h http.Handler) http.Handler { | ||
return http.StripPrefix("potato", nil) | ||
} | ||
|
||
slice := []Middleware{c1, c2} | ||
c := New(slice...) | ||
for k := range slice { | ||
assert.True(t, funcsEqual(c.(chain).middlewares[k], slice[k]), | ||
"New does not add constructors correctly") | ||
} | ||
} | ||
|
||
func TestThenWorksWithNoMiddleware(t *testing.T) { | ||
assert.True(t, funcsEqual(New().Then(testApp), testApp), | ||
"Then does not work with no middleware") | ||
} | ||
|
||
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) { | ||
assert.Equal(t, http.DefaultServeMux, New().Then(nil), | ||
"Then does not treat nil as DefaultServeMux") | ||
} | ||
|
||
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) { | ||
assert.Equal(t, http.DefaultServeMux, New().ThenFunc(nil), | ||
"ThenFunc does not treat nil as DefaultServeMux") | ||
} | ||
|
||
func TestThenFuncConstructsHandlerFunc(t *testing.T) { | ||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(200) | ||
}) | ||
chained := New().ThenFunc(fn) | ||
rec := httptest.NewRecorder() | ||
|
||
chained.ServeHTTP(rec, (*http.Request)(nil)) | ||
|
||
assert.Equal(t, reflect.TypeOf((http.HandlerFunc)(nil)), reflect.TypeOf(chained), | ||
"ThenFunc does not construct HandlerFunc") | ||
} | ||
|
||
func TestThenOrdersHandlersCorrectly(t *testing.T) { | ||
t1 := tagMiddleware("t1\n") | ||
t2 := tagMiddleware("t2\n") | ||
t3 := tagMiddleware("t3\n") | ||
|
||
chained := New(t1, t2, t3).Then(testApp) | ||
|
||
w := httptest.NewRecorder() | ||
r, err := http.NewRequest("GET", "/", nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
chained.ServeHTTP(w, r) | ||
|
||
assert.Equal(t, "t1\nt2\nt3\napp\n", w.Body.String(), | ||
"Then does not order handlers correctly") | ||
} | ||
|
||
func TestAppendAddsHandlersCorrectly(t *testing.T) { | ||
c := New(tagMiddleware("t1\n"), tagMiddleware("t2\n")) | ||
c = c.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n")) | ||
h := c.Then(testApp) | ||
|
||
w := httptest.NewRecorder() | ||
r, err := http.NewRequest("GET", "/", nil) | ||
assert.Nil(t, err) | ||
|
||
h.ServeHTTP(w, r) | ||
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(), | ||
"Append does not add handlers correctly") | ||
} | ||
|
||
func TestExtendAddsHandlersCorrectly(t *testing.T) { | ||
c := New(tagMiddleware("t3\n"), tagMiddleware("t4\n")) | ||
c = c.Prepend(tagMiddleware("t1\n"), tagMiddleware("t2\n")) | ||
h := c.Then(testApp) | ||
|
||
w := httptest.NewRecorder() | ||
r, err := http.NewRequest("GET", "/", nil) | ||
assert.Nil(t, err) | ||
|
||
h.ServeHTTP(w, r) | ||
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", w.Body.String(), | ||
"Extend does not add handlers in correctly") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.