-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add context timeout middleware (#2380)
Add context timeout middleware Co-authored-by: Erhan Akpınar <[email protected]> Co-authored-by: @erhanakp
- Loading branch information
1 parent
08093a4
commit 82a964c
Showing
2 changed files
with
298 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package middleware | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"time" | ||
|
||
"github.com/labstack/echo/v4" | ||
) | ||
|
||
// ContextTimeoutConfig defines the config for ContextTimeout middleware. | ||
type ContextTimeoutConfig struct { | ||
// Skipper defines a function to skip middleware. | ||
Skipper Skipper | ||
|
||
// ErrorHandler is a function when error aries in middeware execution. | ||
ErrorHandler func(err error, c echo.Context) error | ||
|
||
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout | ||
Timeout time.Duration | ||
} | ||
|
||
// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client | ||
// when underlying method returns context.DeadlineExceeded error. | ||
func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc { | ||
return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout}) | ||
} | ||
|
||
// ContextTimeoutWithConfig returns a Timeout middleware with config. | ||
func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc { | ||
mw, err := config.ToMiddleware() | ||
if err != nil { | ||
panic(err) | ||
} | ||
return mw | ||
} | ||
|
||
// ToMiddleware converts Config to middleware. | ||
func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { | ||
if config.Timeout == 0 { | ||
return nil, errors.New("timeout must be set") | ||
} | ||
if config.Skipper == nil { | ||
config.Skipper = DefaultSkipper | ||
} | ||
if config.ErrorHandler == nil { | ||
config.ErrorHandler = func(err error, c echo.Context) error { | ||
if err != nil && errors.Is(err, context.DeadlineExceeded) { | ||
return echo.ErrServiceUnavailable.WithInternal(err) | ||
} | ||
return err | ||
} | ||
} | ||
|
||
return func(next echo.HandlerFunc) echo.HandlerFunc { | ||
return func(c echo.Context) error { | ||
if config.Skipper(c) { | ||
return next(c) | ||
} | ||
|
||
timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) | ||
defer cancel() | ||
|
||
c.SetRequest(c.Request().WithContext(timeoutContext)) | ||
|
||
if err := next(c); err != nil { | ||
return config.ErrorHandler(err, c) | ||
} | ||
return nil | ||
} | ||
}, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
package middleware | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/labstack/echo/v4" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestContextTimeoutSkipper(t *testing.T) { | ||
t.Parallel() | ||
m := ContextTimeoutWithConfig(ContextTimeoutConfig{ | ||
Skipper: func(context echo.Context) bool { | ||
return true | ||
}, | ||
Timeout: 10 * time.Millisecond, | ||
}) | ||
|
||
req := httptest.NewRequest(http.MethodGet, "/", nil) | ||
rec := httptest.NewRecorder() | ||
|
||
e := echo.New() | ||
c := e.NewContext(req, rec) | ||
|
||
err := m(func(c echo.Context) error { | ||
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { | ||
return err | ||
} | ||
|
||
return errors.New("response from handler") | ||
})(c) | ||
|
||
// if not skipped we would have not returned error due context timeout logic | ||
assert.EqualError(t, err, "response from handler") | ||
} | ||
|
||
func TestContextTimeoutWithTimeout0(t *testing.T) { | ||
t.Parallel() | ||
assert.Panics(t, func() { | ||
ContextTimeout(time.Duration(0)) | ||
}) | ||
} | ||
|
||
func TestContextTimeoutErrorOutInHandler(t *testing.T) { | ||
t.Parallel() | ||
m := ContextTimeoutWithConfig(ContextTimeoutConfig{ | ||
// Timeout has to be defined or the whole flow for timeout middleware will be skipped | ||
Timeout: 10 * time.Millisecond, | ||
}) | ||
|
||
req := httptest.NewRequest(http.MethodGet, "/", nil) | ||
rec := httptest.NewRecorder() | ||
|
||
e := echo.New() | ||
c := e.NewContext(req, rec) | ||
|
||
rec.Code = 1 // we want to be sure that even 200 will not be sent | ||
err := m(func(c echo.Context) error { | ||
// this error must not be written to the client response. Middlewares upstream of timeout middleware must be able | ||
// to handle returned error and this can be done only then handler has not yet committed (written status code) | ||
// the response. | ||
return echo.NewHTTPError(http.StatusTeapot, "err") | ||
})(c) | ||
|
||
assert.Error(t, err) | ||
assert.EqualError(t, err, "code=418, message=err") | ||
assert.Equal(t, 1, rec.Code) | ||
assert.Equal(t, "", rec.Body.String()) | ||
} | ||
|
||
func TestContextTimeoutSuccessfulRequest(t *testing.T) { | ||
t.Parallel() | ||
m := ContextTimeoutWithConfig(ContextTimeoutConfig{ | ||
// Timeout has to be defined or the whole flow for timeout middleware will be skipped | ||
Timeout: 10 * time.Millisecond, | ||
}) | ||
|
||
req := httptest.NewRequest(http.MethodGet, "/", nil) | ||
rec := httptest.NewRecorder() | ||
|
||
e := echo.New() | ||
c := e.NewContext(req, rec) | ||
|
||
err := m(func(c echo.Context) error { | ||
return c.JSON(http.StatusCreated, map[string]string{"data": "ok"}) | ||
})(c) | ||
|
||
assert.NoError(t, err) | ||
assert.Equal(t, http.StatusCreated, rec.Code) | ||
assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String()) | ||
} | ||
|
||
func TestContextTimeoutTestRequestClone(t *testing.T) { | ||
t.Parallel() | ||
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode())) | ||
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"}) | ||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||
rec := httptest.NewRecorder() | ||
|
||
m := ContextTimeoutWithConfig(ContextTimeoutConfig{ | ||
// Timeout has to be defined or the whole flow for timeout middleware will be skipped | ||
Timeout: 1 * time.Second, | ||
}) | ||
|
||
e := echo.New() | ||
c := e.NewContext(req, rec) | ||
|
||
err := m(func(c echo.Context) error { | ||
// Cookie test | ||
cookie, err := c.Request().Cookie("cookie") | ||
if assert.NoError(t, err) { | ||
assert.EqualValues(t, "cookie", cookie.Name) | ||
assert.EqualValues(t, "value", cookie.Value) | ||
} | ||
|
||
// Form values | ||
if assert.NoError(t, c.Request().ParseForm()) { | ||
assert.EqualValues(t, "value", c.Request().FormValue("form")) | ||
} | ||
|
||
// Query string | ||
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0]) | ||
return nil | ||
})(c) | ||
|
||
assert.NoError(t, err) | ||
} | ||
|
||
func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) { | ||
t.Parallel() | ||
|
||
timeout := 10 * time.Millisecond | ||
m := ContextTimeoutWithConfig(ContextTimeoutConfig{ | ||
Timeout: timeout, | ||
}) | ||
|
||
req := httptest.NewRequest(http.MethodGet, "/", nil) | ||
rec := httptest.NewRecorder() | ||
|
||
e := echo.New() | ||
c := e.NewContext(req, rec) | ||
|
||
err := m(func(c echo.Context) error { | ||
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { | ||
return err | ||
} | ||
return c.String(http.StatusOK, "Hello, World!") | ||
})(c) | ||
|
||
assert.IsType(t, &echo.HTTPError{}, err) | ||
assert.Error(t, err) | ||
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) | ||
assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message) | ||
} | ||
|
||
func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { | ||
t.Parallel() | ||
|
||
timeoutErrorHandler := func(err error, c echo.Context) error { | ||
if err != nil { | ||
if errors.Is(err, context.DeadlineExceeded) { | ||
return &echo.HTTPError{ | ||
Code: http.StatusServiceUnavailable, | ||
Message: "Timeout! change me", | ||
} | ||
} | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
timeout := 10 * time.Millisecond | ||
m := ContextTimeoutWithConfig(ContextTimeoutConfig{ | ||
Timeout: timeout, | ||
ErrorHandler: timeoutErrorHandler, | ||
}) | ||
|
||
req := httptest.NewRequest(http.MethodGet, "/", nil) | ||
rec := httptest.NewRecorder() | ||
|
||
e := echo.New() | ||
c := e.NewContext(req, rec) | ||
|
||
err := m(func(c echo.Context) error { | ||
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) | ||
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output | ||
// difference over 500microseconds (0.5millisecond) response seems to be reliable | ||
|
||
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { | ||
return err | ||
} | ||
|
||
// The Request Context should have a Deadline set by http.ContextTimeoutHandler | ||
if _, ok := c.Request().Context().Deadline(); !ok { | ||
assert.Fail(t, "No timeout set on Request Context") | ||
} | ||
return c.String(http.StatusOK, "Hello, World!") | ||
})(c) | ||
|
||
assert.IsType(t, &echo.HTTPError{}, err) | ||
assert.Error(t, err) | ||
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) | ||
assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message) | ||
} | ||
|
||
func sleepWithContext(ctx context.Context, d time.Duration) error { | ||
timer := time.NewTimer(d) | ||
|
||
defer func() { | ||
_ = timer.Stop() | ||
}() | ||
|
||
select { | ||
case <-ctx.Done(): | ||
return context.DeadlineExceeded | ||
case <-timer.C: | ||
return nil | ||
} | ||
} |