From 3d67098841eb91f11a2f578218045103d51e79ce Mon Sep 17 00:00:00 2001 From: grqphical07 <95062977+grqphical07@users.noreply.github.com> Date: Sat, 13 Jan 2024 15:32:22 -0800 Subject: [PATCH 1/4] Added Allow Content Type middleware and tests --- middleware/allow_content_type.go | 35 +++++++++++++++++++++++++++ middleware/allow_content_type_test.go | 29 ++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 middleware/allow_content_type.go create mode 100644 middleware/allow_content_type_test.go diff --git a/middleware/allow_content_type.go b/middleware/allow_content_type.go new file mode 100644 index 000000000..564493a31 --- /dev/null +++ b/middleware/allow_content_type.go @@ -0,0 +1,35 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/labstack/echo/v4" +) + +func AllowContentType(contentTypes ...string) echo.MiddlewareFunc { + if len(contentTypes) == 0 { + panic("echo: allow-content middleware requires at least one content type") + } + allowedContentTypes := make(map[string]struct{}, len(contentTypes)) + for _, ctype := range contentTypes { + allowedContentTypes[strings.TrimSpace(strings.ToLower(ctype))] = struct{}{} + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if c.Request().ContentLength == 0 { + // skip check for empty content body + return next(c) + } + s := strings.ToLower(strings.TrimSpace(c.Request().Header.Get("Content-Type"))) + if i := strings.Index(s, ";"); i > -1 { + s = s[0:i] + } + if _, ok := allowedContentTypes[s]; ok { + return next(c) + } + return echo.NewHTTPError(http.StatusUnsupportedMediaType) + } + } +} diff --git a/middleware/allow_content_type_test.go b/middleware/allow_content_type_test.go new file mode 100644 index 000000000..c8dfd215d --- /dev/null +++ b/middleware/allow_content_type_test.go @@ -0,0 +1,29 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestAllowContentType(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) + + h := AllowContentType("application/json", "text/plain")(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Test valid content type + req.Header.Add("Content-Type", "application/json") + assert.NoError(t, h(c)) + + // Test invalid content type + req.Header.Add("Content-Type", "application/json") + assert.NoError(t, h(c)) +} From a8143921be6766aba789289c2b2b48f11112d693 Mon Sep 17 00:00:00 2001 From: grqphical07 <95062977+grqphical07@users.noreply.github.com> Date: Sat, 13 Jan 2024 15:36:21 -0800 Subject: [PATCH 2/4] added doc comments --- middleware/allow_content_type.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/middleware/allow_content_type.go b/middleware/allow_content_type.go index 564493a31..4c55deabb 100644 --- a/middleware/allow_content_type.go +++ b/middleware/allow_content_type.go @@ -7,6 +7,9 @@ import ( "github.com/labstack/echo/v4" ) +// AllowContentType returns an AllowContentType middleware +// +// It requries at least one content type to be passed in as an argument func AllowContentType(contentTypes ...string) echo.MiddlewareFunc { if len(contentTypes) == 0 { panic("echo: allow-content middleware requires at least one content type") From bb55333e98b3f3af735f7ecbecee15a397be9043 Mon Sep 17 00:00:00 2001 From: grqphical07 <95062977+grqphical07@users.noreply.github.com> Date: Sat, 13 Jan 2024 15:44:58 -0800 Subject: [PATCH 3/4] added body to request in test --- middleware/allow_content_type_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/middleware/allow_content_type_test.go b/middleware/allow_content_type_test.go index c8dfd215d..d2e1c289b 100644 --- a/middleware/allow_content_type_test.go +++ b/middleware/allow_content_type_test.go @@ -1,6 +1,7 @@ package middleware import ( + "bytes" "net/http" "net/http/httptest" "testing" @@ -11,9 +12,8 @@ import ( func TestAllowContentType(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", bytes.NewReader([]byte("Hello World!"))) res := httptest.NewRecorder() - c := e.NewContext(req, res) h := AllowContentType("application/json", "text/plain")(func(c echo.Context) error { return c.String(http.StatusOK, "test") @@ -21,9 +21,12 @@ func TestAllowContentType(t *testing.T) { // Test valid content type req.Header.Add("Content-Type", "application/json") + + c := e.NewContext(req, res) assert.NoError(t, h(c)) // Test invalid content type - req.Header.Add("Content-Type", "application/json") + req.Header.Add("Content-Type", "text/html") + c = e.NewContext(req, res) assert.NoError(t, h(c)) } From 61e77eecc4608587660f2604e5ada98b97a6967b Mon Sep 17 00:00:00 2001 From: grqphical07 <95062977+grqphical07@users.noreply.github.com> Date: Sat, 13 Jan 2024 15:56:56 -0800 Subject: [PATCH 4/4] added Accept header value modification --- middleware/allow_content_type.go | 18 ++++++++++++++++++ middleware/allow_content_type_test.go | 4 ++++ 2 files changed, 22 insertions(+) diff --git a/middleware/allow_content_type.go b/middleware/allow_content_type.go index 4c55deabb..b3aef8491 100644 --- a/middleware/allow_content_type.go +++ b/middleware/allow_content_type.go @@ -7,6 +7,21 @@ import ( "github.com/labstack/echo/v4" ) +// generateAcceptHeaderString takes in a list of allowed content types and generates +// a string that can be used in the Accept part of an HTTP header +func generateAcceptHeaderString(allowedContentTypes map[string]struct{}) string { + acceptString := "" + i := 0 + for mimeType := range allowedContentTypes { + acceptString += mimeType + if i != len(allowedContentTypes)-1 { + acceptString += ", " + } + i += 1 + } + return acceptString +} + // AllowContentType returns an AllowContentType middleware // // It requries at least one content type to be passed in as an argument @@ -21,6 +36,9 @@ func AllowContentType(contentTypes ...string) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + // Add allowed types to Accept header to tell client what data types are allowed + c.Response().Header().Add("Accept", generateAcceptHeaderString(allowedContentTypes)) + if c.Request().ContentLength == 0 { // skip check for empty content body return next(c) diff --git a/middleware/allow_content_type_test.go b/middleware/allow_content_type_test.go index d2e1c289b..38d445583 100644 --- a/middleware/allow_content_type_test.go +++ b/middleware/allow_content_type_test.go @@ -29,4 +29,8 @@ func TestAllowContentType(t *testing.T) { req.Header.Add("Content-Type", "text/html") c = e.NewContext(req, res) assert.NoError(t, h(c)) + + // Test Accept header + accept := c.Response().Header().Get("Accept") + assert.Equal(t, "application/json, text/plain", accept) }