Skip to content

Commit a7f1ca9

Browse files
committed
add http handler response interceptor
1 parent d1c1c2d commit a7f1ca9

File tree

4 files changed

+121
-24
lines changed

4 files changed

+121
-24
lines changed

middleware/recover.go

+17-19
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,30 @@
11
package middleware
22

33
import (
4+
"context"
45
"net/http"
56

67
"github.com/leopoldxx/go-utils/trace"
78
)
89

9-
// Recover middleware will response a default 500 error
10-
// when panic occurred inside of the handler process.
11-
// NOTIC: if there is no a trace handler out the recover handler,
12-
// prepend a default trace handler
13-
func Recover() Middleware {
14-
return func(next http.HandlerFunc) http.HandlerFunc {
15-
return func(w http.ResponseWriter, r *http.Request) {
16-
tracer := trace.GetTraceFromRequest(r)
17-
defer func() {
18-
if err := recover(); err != nil {
19-
tracer.Error("panic:", tracer.Stack())
20-
http.Error(w, "internal error, plz check log!", http.StatusInternalServerError)
21-
}
22-
}()
23-
next(w, r)
24-
}
25-
}
26-
}
27-
2810
// RecoverWithTrace middleware is a RecoverMiddleware wraps with a trace handler
2911
func RecoverWithTrace(name string) Middleware {
3012
return func(next http.HandlerFunc) http.HandlerFunc {
3113
return func(w http.ResponseWriter, r *http.Request) {
14+
var rw *responseWriter
15+
if defaultResponseInterceptor != nil {
16+
rw = &responseWriter{
17+
ResponseWriter: w,
18+
status: http.StatusOK,
19+
}
20+
}
3221
recoverHandler := func(w http.ResponseWriter, r *http.Request) {
3322
tracer := trace.GetTraceFromRequest(r)
23+
if rw, ok := w.(interface {
24+
Record(ctx context.Context, recorder Recorder)
25+
}); ok {
26+
defer rw.Record(r.Context(), defaultResponseInterceptor)
27+
}
3428
defer func() {
3529
if err := recover(); err != nil {
3630
tracer.Error("panic:", tracer.Stack())
@@ -39,6 +33,10 @@ func RecoverWithTrace(name string) Middleware {
3933
}()
4034
next(w, r)
4135
}
36+
if rw != nil {
37+
trace.HandleFunc(name, recoverHandler)(rw, r)
38+
return
39+
}
4240

4341
trace.HandleFunc(name, recoverHandler)(w, r)
4442
}

middleware/recover_test.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ func TestRecover(t *testing.T) {
1818
panic("+_+!")
1919
}
2020

21-
tmw := Trace("recover_test")
22-
rmw := Recover()
21+
rmw := RecoverWithTrace("rwt")
2322

24-
newhh := Chain(tmw, rmw).HandlerFunc(ph)
25-
newhh2 := Chain(tmw, rmw).HandlerFunc(ph2)
23+
SetDefaultResponseInterceptor(NewMultiRecorder(NewLogRecorder(), NewLogRecorder(), NewLogRecorder()))
24+
25+
newhh := Chain(rmw).HandlerFunc(ph)
26+
newhh2 := Chain(rmw).HandlerFunc(ph2)
2627

2728
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
2829
w := httptest.NewRecorder()
@@ -32,6 +33,6 @@ func TestRecover(t *testing.T) {
3233
if w.Code != 500 {
3334
t.Fatal("fail chain handler:", w)
3435
}
35-
t.Log(string(w.Body.Bytes()))
36+
t.Log(len(w.Body.Bytes()), string(w.Body.Bytes()))
3637

3738
}

middleware/response_interceptor.go

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"sync"
7+
8+
"github.com/leopoldxx/go-utils/trace"
9+
)
10+
11+
var defaultResponseInterceptor Recorder
12+
13+
// SetDefaultResponseInterceptor for http handler response
14+
func SetDefaultResponseInterceptor(r Recorder) {
15+
defaultResponseInterceptor = r
16+
}
17+
18+
type responseWriter struct {
19+
http.ResponseWriter
20+
sync.Mutex
21+
22+
status int
23+
size int
24+
}
25+
26+
func (rs *responseWriter) Header() http.Header {
27+
return rs.ResponseWriter.Header()
28+
}
29+
30+
func (rs *responseWriter) Write(data []byte) (int, error) {
31+
rs.Lock()
32+
rs.size += len(data)
33+
rs.Unlock()
34+
return rs.ResponseWriter.Write(data)
35+
}
36+
37+
func (rs *responseWriter) WriteHeader(status int) {
38+
rs.Lock()
39+
rs.status = status
40+
rs.Unlock()
41+
rs.ResponseWriter.WriteHeader(status)
42+
}
43+
44+
// Recorder for http handler response status & body size
45+
type Recorder interface {
46+
Record(ctx context.Context, statistics Statistics)
47+
}
48+
49+
// NewLogRecorder for log purpose
50+
func NewLogRecorder() Recorder {
51+
return &logRecorder{}
52+
}
53+
54+
type logRecorder struct{}
55+
56+
func (lr logRecorder) Record(ctx context.Context, statistics Statistics) {
57+
tracer := trace.GetTraceFromContext(ctx)
58+
tracer.Infof("%+v", statistics)
59+
}
60+
61+
// NewMultiRecorder will chain MultiRecorder
62+
func NewMultiRecorder(recorders ...Recorder) Recorder {
63+
return &multiRecorder{recorders: recorders}
64+
}
65+
66+
type multiRecorder struct {
67+
recorders []Recorder
68+
}
69+
70+
func (mr multiRecorder) Record(ctx context.Context, statistics Statistics) {
71+
var wg sync.WaitGroup
72+
for i := range mr.recorders {
73+
wg.Add(1)
74+
go func(r Recorder, ctx context.Context, statistics Statistics) {
75+
defer wg.Done()
76+
r.Record(ctx, statistics)
77+
}(mr.recorders[i], ctx, statistics)
78+
}
79+
wg.Wait()
80+
}
81+
82+
// Statistics for http handler response
83+
type Statistics struct {
84+
Status int
85+
BodySize int
86+
}
87+
88+
func (rs *responseWriter) Record(ctx context.Context, recorder Recorder) {
89+
var s Statistics
90+
rs.Lock()
91+
s.Status = rs.status
92+
s.BodySize = rs.size
93+
rs.Unlock()
94+
if recorder != nil {
95+
recorder.Record(ctx, s)
96+
}
97+
}

server/example/main.go

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
)
1313

1414
func main() {
15+
middleware.SetDefaultResponseInterceptor(middleware.NewLogRecorder())
1516
s := server.New(server.ListenAddr(":8001"), server.APIPrefix("/example"))
1617
s.Register(new(filesvr))
1718
s.ListenAndServe()

0 commit comments

Comments
 (0)