diff --git a/middleware/allow_content_type.go b/middleware/allow_content_type.go new file mode 100644 index 000000000..630390c7b --- /dev/null +++ b/middleware/allow_content_type.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "mime" + "net/http" + "slices" + + "github.com/labstack/echo/v4" +) + +func AllowContentType(contentTypes ...string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + var acceptTypes = "" + for i, contentType := range contentTypes { + acceptTypes += contentType + + if i != len(contentTypes)-1 { + acceptTypes += "," + } + } + c.Response().Header().Set("Accept", acceptTypes) + + mediaType, _, err := mime.ParseMediaType(c.Request().Header.Get("Content-Type")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "invalid content-type value") + } + if slices.Contains(contentTypes, mediaType) { + return next(c) + } + return echo.NewHTTPError(http.StatusUnsupportedMediaType) + } + } +} diff --git a/middleware/allow_content_type_test.go b/middleware/allow_content_type_test.go new file mode 100644 index 000000000..8c84bb0f5 --- /dev/null +++ b/middleware/allow_content_type_test.go @@ -0,0 +1,32 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestAllowContentType(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + + h := AllowContentType("application/json")(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Test valid content type + req.Header.Add("Content-Type", "application/json") + c := e.NewContext(req, res) + assert.NoError(t, h(c)) + assert.Equal(t, "application/json", res.Header().Get("Accept")) + + // Test invalid content type + req.Header.Set("Content-Type", "text/plain") + c = e.NewContext(req, res) + he := h(c).(*echo.HTTPError) + assert.Equal(t, http.StatusUnsupportedMediaType, he.Code) +}