From 07edc7e5b0187826553fc8ff6d01aebc1d08cb75 Mon Sep 17 00:00:00 2001 From: Sylvain Muller Date: Tue, 8 Oct 2024 23:17:58 +0200 Subject: [PATCH] Expose writer deadline as part of the ResponseWriter interface (#39) * feat(writer): add support for SetWriteDeadline and SetReadDeadline * feat(writer): add support for SetWriteDeadline and SetReadDeadline * feat(writer): improve code coverage --- helpers_test.go | 7 + response_writer.go | 33 +++++ response_writer_test.go | 287 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 327 insertions(+) create mode 100644 response_writer_test.go diff --git a/helpers_test.go b/helpers_test.go index f1c788d..52e6571 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -37,4 +38,10 @@ func TestNewTestContext(t *testing.T) { err = c.Writer().Push("foo", nil) assert.ErrorIs(t, err, http.ErrNotSupported) + + err = c.Writer().SetReadDeadline(time.Time{}) + assert.ErrorIs(t, err, http.ErrNotSupported) + + err = c.Writer().SetWriteDeadline(time.Time{}) + assert.ErrorIs(t, err, http.ErrNotSupported) } diff --git a/response_writer.go b/response_writer.go index 66841f9..2ae118b 100644 --- a/response_writer.go +++ b/response_writer.go @@ -21,6 +21,7 @@ import ( "runtime" "strings" "sync" + "time" ) var _ ResponseWriter = (*recorder)(nil) @@ -53,6 +54,16 @@ type ResponseWriter interface { // Push initiates an HTTP/2 server push. Push returns http.ErrNotSupported if the client has disabled push or if push // is not supported on the underlying connection. See http.Pusher for more details. Push(target string, opts *http.PushOptions) error + // SetReadDeadline sets the deadline for reading the entire request, including the body. Reads from the request + // body after the deadline has been exceeded will return an error. A zero value means no deadline. Setting the read + // deadline after it has been exceeded will not extend it. If SetReadDeadline is not supported, it returns + // an error matching http.ErrNotSupported. + SetReadDeadline(deadline time.Time) error + // SetWriteDeadline sets the deadline for writing the response. Writes to the response body after the deadline has + // been exceeded will not block, but may succeed if the data has been buffered. A zero value means no deadline. + // Setting the write deadline after it has been exceeded will not extend it. If SetWriteDeadline is not supported, + // it returns an error matching http.ErrNotSupported. + SetWriteDeadline(deadline time.Time) error } const notWritten = -1 @@ -184,6 +195,28 @@ func (r *recorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, ErrNotSupported() } +// SetReadDeadline sets the deadline for reading the entire request, including the body. Reads from the request +// body after the deadline has been exceeded will return an error. A zero value means no deadline. Setting the read +// deadline after it has been exceeded will not extend it. If SetReadDeadline is not supported, it returns +// an error matching http.ErrNotSupported. +func (r *recorder) SetReadDeadline(deadline time.Time) error { + if w, ok := r.ResponseWriter.(interface{ SetReadDeadline(time.Time) error }); ok { + return w.SetReadDeadline(deadline) + } + return ErrNotSupported() +} + +// SetWriteDeadline sets the deadline for writing the response. Writes to the response body after the deadline has +// been exceeded will not block, but may succeed if the data has been buffered. A zero value means no deadline. +// Setting the write deadline after it has been exceeded will not extend it. If SetWriteDeadline is not supported, +// it returns an error matching http.ErrNotSupported. +func (r *recorder) SetWriteDeadline(deadline time.Time) error { + if w, ok := r.ResponseWriter.(interface{ SetWriteDeadline(time.Time) error }); ok { + return w.SetWriteDeadline(deadline) + } + return ErrNotSupported() +} + type noUnwrap struct { ResponseWriter } diff --git a/response_writer_test.go b/response_writer_test.go new file mode 100644 index 0000000..330e7a2 --- /dev/null +++ b/response_writer_test.go @@ -0,0 +1,287 @@ +package fox + +import ( + "bufio" + "errors" + "github.com/stretchr/testify/assert" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +type flushErrorWriterFunc func() error + +func (f flushErrorWriterFunc) FlushError() error { + return f() +} + +type flushWriterFunc func() + +func (f flushWriterFunc) Flush() { + f() +} + +type hijackWriterFunc func() (net.Conn, *bufio.ReadWriter, error) + +func (f hijackWriterFunc) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return f() +} + +type pushWriterFunc func(target string, opts *http.PushOptions) error + +func (f pushWriterFunc) Push(target string, opts *http.PushOptions) error { + return f(target, opts) +} + +type deadlineWriterFunc func(deadline time.Time) error + +func (f deadlineWriterFunc) SetReadDeadline(deadline time.Time) error { + return f(deadline) +} + +func (f deadlineWriterFunc) SetWriteDeadline(deadline time.Time) error { + return f(deadline) +} + +func TestRecorder_FlushError(t *testing.T) { + type flushError interface { + FlushError() error + } + + cases := []struct { + name string + rec *recorder + assert func(t *testing.T, w ResponseWriter) + }{ + { + name: "implement FlushError and flush returns error", + rec: &recorder{ + ResponseWriter: struct { + http.ResponseWriter + flushError + }{ + ResponseWriter: httptest.NewRecorder(), + flushError: flushErrorWriterFunc(func() error { + return errors.New("error") + }), + }, + }, + assert: func(t *testing.T, w ResponseWriter) { + assert.Error(t, w.FlushError()) + }, + }, + { + name: "implement Flusher and flush return nil", + rec: &recorder{ + ResponseWriter: struct { + http.ResponseWriter + http.Flusher + }{ + ResponseWriter: httptest.NewRecorder(), + Flusher: flushWriterFunc(func() {}), + }, + }, + assert: func(t *testing.T, w ResponseWriter) { + assert.Nil(t, w.FlushError()) + }, + }, + { + name: "does not implement flusher and return http.ErrNotSupported", + rec: &recorder{ + ResponseWriter: struct { + http.ResponseWriter + }{ + ResponseWriter: httptest.NewRecorder(), + }, + }, + assert: func(t *testing.T, w ResponseWriter) { + assert.ErrorIs(t, w.FlushError(), http.ErrNotSupported) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.assert(t, tc.rec) + }) + } +} + +func TestRecorder_Hijack(t *testing.T) { + cases := []struct { + name string + rec *recorder + assert func(t *testing.T, w ResponseWriter) + }{ + { + name: "implements Hijacker and hijack returns no error", + rec: &recorder{ + ResponseWriter: struct { + http.ResponseWriter + http.Hijacker + }{ + ResponseWriter: httptest.NewRecorder(), + Hijacker: hijackWriterFunc(func() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, nil + }), + }, + }, + assert: func(t *testing.T, w ResponseWriter) { + _, _, err := w.Hijack() + assert.NoError(t, err) + }, + }, + { + name: "does not implement Hijacker and return http.ErrNotSupported", + rec: &recorder{ + ResponseWriter: httptest.NewRecorder(), + }, + assert: func(t *testing.T, w ResponseWriter) { + _, _, err := w.Hijack() + assert.ErrorIs(t, err, http.ErrNotSupported) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.assert(t, tc.rec) + }) + } +} + +func TestRecorder_Push(t *testing.T) { + cases := []struct { + name string + rec *recorder + assert func(t *testing.T, w ResponseWriter) + }{ + { + name: "implements Pusher and push returns no error", + rec: &recorder{ + ResponseWriter: struct { + http.ResponseWriter + http.Pusher + }{ + ResponseWriter: httptest.NewRecorder(), + Pusher: pushWriterFunc(func(target string, opts *http.PushOptions) error { + return nil + }), + }, + }, + assert: func(t *testing.T, w ResponseWriter) { + assert.NoError(t, w.Push("/path", nil)) + }, + }, + { + name: "does not implement Pusher and return http.ErrNotSupported", + rec: &recorder{ + ResponseWriter: httptest.NewRecorder(), + }, + assert: func(t *testing.T, w ResponseWriter) { + err := w.Push("/path", nil) + assert.ErrorIs(t, err, http.ErrNotSupported) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.assert(t, tc.rec) + }) + } +} + +func TestRecorder_SetReadDeadline(t *testing.T) { + type deadlineWriter interface { + SetReadDeadline(time.Time) error + } + + cases := []struct { + name string + rec *recorder + assert func(t *testing.T, w ResponseWriter) + }{ + { + name: "implements SetReadDeadline and returns no error", + rec: &recorder{ + ResponseWriter: struct { + http.ResponseWriter + deadlineWriter + }{ + ResponseWriter: httptest.NewRecorder(), + deadlineWriter: deadlineWriterFunc(func(deadline time.Time) error { + return nil + }), + }, + }, + assert: func(t *testing.T, w ResponseWriter) { + assert.NoError(t, w.SetReadDeadline(time.Now())) + }, + }, + { + name: "does not implement SetReadDeadline and returns http.ErrNotSupported", + rec: &recorder{ + ResponseWriter: httptest.NewRecorder(), + }, + assert: func(t *testing.T, w ResponseWriter) { + err := w.SetReadDeadline(time.Now()) + assert.ErrorIs(t, err, http.ErrNotSupported) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.assert(t, tc.rec) + }) + } +} + +func TestRecorder_SetWriteDeadline(t *testing.T) { + type deadlineWriter interface { + SetWriteDeadline(time.Time) error + } + + cases := []struct { + name string + rec *recorder + assert func(t *testing.T, w ResponseWriter) + }{ + { + name: "implements SetWriteDeadline and returns no error", + rec: &recorder{ + ResponseWriter: struct { + http.ResponseWriter + deadlineWriter + }{ + ResponseWriter: httptest.NewRecorder(), + deadlineWriter: deadlineWriterFunc(func(deadline time.Time) error { + return nil + }), + }, + }, + assert: func(t *testing.T, w ResponseWriter) { + assert.NoError(t, w.SetWriteDeadline(time.Now())) + }, + }, + { + name: "does not implement SetWriteDeadline and returns http.ErrNotSupported", + rec: &recorder{ + ResponseWriter: httptest.NewRecorder(), + }, + assert: func(t *testing.T, w ResponseWriter) { + err := w.SetWriteDeadline(time.Now()) + assert.ErrorIs(t, err, http.ErrNotSupported) + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.assert(t, tc.rec) + }) + } +}