diff --git a/middleware/allow_content_type.go b/middleware/allow_content_type.go new file mode 100644 index 000000000..b3aef8491 --- /dev/null +++ b/middleware/allow_content_type.go @@ -0,0 +1,56 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/labstack/echo/v4" +) + +// generateAcceptHeaderString takes in a list of allowed content types and generates +// a string that can be used in the Accept part of an HTTP header +func generateAcceptHeaderString(allowedContentTypes map[string]struct{}) string { + acceptString := "" + i := 0 + for mimeType := range allowedContentTypes { + acceptString += mimeType + if i != len(allowedContentTypes)-1 { + acceptString += ", " + } + i += 1 + } + return acceptString +} + +// AllowContentType returns an AllowContentType middleware +// +// It requries at least one content type to be passed in as an argument +func AllowContentType(contentTypes ...string) echo.MiddlewareFunc { + if len(contentTypes) == 0 { + panic("echo: allow-content middleware requires at least one content type") + } + allowedContentTypes := make(map[string]struct{}, len(contentTypes)) + for _, ctype := range contentTypes { + allowedContentTypes[strings.TrimSpace(strings.ToLower(ctype))] = struct{}{} + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + // Add allowed types to Accept header to tell client what data types are allowed + c.Response().Header().Add("Accept", generateAcceptHeaderString(allowedContentTypes)) + + if c.Request().ContentLength == 0 { + // skip check for empty content body + return next(c) + } + s := strings.ToLower(strings.TrimSpace(c.Request().Header.Get("Content-Type"))) + if i := strings.Index(s, ";"); i > -1 { + s = s[0:i] + } + if _, ok := allowedContentTypes[s]; ok { + return next(c) + } + return echo.NewHTTPError(http.StatusUnsupportedMediaType) + } + } +} diff --git a/middleware/allow_content_type_test.go b/middleware/allow_content_type_test.go new file mode 100644 index 000000000..38d445583 --- /dev/null +++ b/middleware/allow_content_type_test.go @@ -0,0 +1,36 @@ +package middleware + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestAllowContentType(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", bytes.NewReader([]byte("Hello World!"))) + res := httptest.NewRecorder() + + h := AllowContentType("application/json", "text/plain")(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Test valid content type + req.Header.Add("Content-Type", "application/json") + + c := e.NewContext(req, res) + assert.NoError(t, h(c)) + + // Test invalid content type + req.Header.Add("Content-Type", "text/html") + c = e.NewContext(req, res) + assert.NoError(t, h(c)) + + // Test Accept header + accept := c.Response().Header().Get("Accept") + assert.Equal(t, "application/json, text/plain", accept) +}