diff --git a/middleware/cors.go b/middleware/cors.go index c2f995cd2..bd6315644 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -107,13 +107,23 @@ type CORSConfig struct { // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age MaxAge int `yaml:"max_age"` + + // PreflightStatusCode determines the status code to be returned on a + // successful preflight request. + // + // Optional. Default value is http.StatusNoContent(204) + // + // See also: https://fetch.spec.whatwg.org/#ref-for-ok-status + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Methods/OPTIONS#preflighted_requests_in_cors + PreflightStatusCode int `yaml:"preflight_status_code"` } // DefaultCORSConfig is the default CORS middleware config. var DefaultCORSConfig = CORSConfig{ - Skipper: DefaultSkipper, - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, + Skipper: DefaultSkipper, + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, + PreflightStatusCode: http.StatusNoContent, } // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. @@ -147,6 +157,10 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { config.AllowMethods = DefaultCORSConfig.AllowMethods } + if config.PreflightStatusCode == 0 { + config.PreflightStatusCode = DefaultCORSConfig.PreflightStatusCode + } + allowOriginPatterns := make([]*regexp.Regexp, 0, len(config.AllowOrigins)) for _, origin := range config.AllowOrigins { if origin == "*" { @@ -214,7 +228,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { if !preflight { return next(c) } - return c.NoContent(http.StatusNoContent) + return c.NoContent(config.PreflightStatusCode) } if config.AllowOriginFunc != nil { @@ -264,7 +278,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { if !preflight { return echo.ErrUnauthorized } - return c.NoContent(http.StatusNoContent) + return c.NoContent(config.PreflightStatusCode) } res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) @@ -301,7 +315,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { if config.MaxAge != 0 { res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge) } - return c.NoContent(http.StatusNoContent) + return c.NoContent(config.PreflightStatusCode) } } } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index d77c194c5..c8583bb70 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -683,3 +683,40 @@ func Test_allowOriginFunc(t *testing.T) { } } } + +func TestCORSWithConfig_PreflightStatusCode(t *testing.T) { + tests := []struct { + name string + mw echo.MiddlewareFunc + expectedStatusCode int + }{ + { + name: "ok, preflight with default config returns http.StatusNoContent (204)", + mw: CORS(), + expectedStatusCode: http.StatusNoContent, + }, + { + name: "ok, preflight returning http.StatusOK (200)", + mw: CORSWithConfig(CORSConfig{ + PreflightStatusCode: http.StatusOK, + }), + expectedStatusCode: http.StatusOK, + }, + } + e := echo.New() + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + cors := tc.mw(echo.NotFoundHandler) + err := cors(c) + + assert.NoError(t, err) + assert.Equal(t, rec.Result().StatusCode, tc.expectedStatusCode) + }) + } +}