Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support file server in rest #4244

Merged
merged 2 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions rest/internal/fileserver/filehandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package fileserver

import (
"net/http"
"strings"
)

func Middleware(path, dir string) func(http.HandlerFunc) http.HandlerFunc {
fileServer := http.FileServer(http.Dir(dir))
pathWithTrailSlash := ensureTrailingSlash(path)
pathWithoutTrailSlash := ensureNoTrailingSlash(path)

return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, pathWithTrailSlash) {
r.URL.Path = strings.TrimPrefix(r.URL.Path, pathWithoutTrailSlash)
fileServer.ServeHTTP(w, r)
} else {
next(w, r)
}
}
}
}

func ensureTrailingSlash(path string) string {
if strings.HasSuffix(path, "/") {
return path
}

return path + "/"
}

func ensureNoTrailingSlash(path string) string {
if strings.HasSuffix(path, "/") {
return path[:len(path)-1]
}

return path
}
99 changes: 99 additions & 0 deletions rest/internal/fileserver/filehandler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package fileserver

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestMiddleware(t *testing.T) {
tests := []struct {
name string
path string
dir string
requestPath string
expectedStatus int
expectedContent string
}{
{
name: "Serve static file",
path: "/static/",
dir: "./testdata",
requestPath: "/static/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "1",
},
{
name: "Pass through non-matching path",
path: "/static/",
dir: "./testdata",
requestPath: "/other/path",
expectedStatus: http.StatusNotFound,
},
{
name: "Directory with trailing slash",
path: "/assets",
dir: "testdata",
requestPath: "/assets/sample.txt",
expectedStatus: http.StatusOK,
expectedContent: "2",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := Middleware(tt.path, tt.dir)
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})

handlerToTest := middleware(nextHandler)

req := httptest.NewRequest("GET", tt.requestPath, nil)
rr := httptest.NewRecorder()

handlerToTest.ServeHTTP(rr, req)

assert.Equal(t, tt.expectedStatus, rr.Code)
if len(tt.expectedContent) > 0 {
assert.Equal(t, tt.expectedContent, rr.Body.String())
}
})
}
}

func TestEnsureTrailingSlash(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"path", "path/"},
{"path/", "path/"},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := ensureTrailingSlash(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}

func TestEnsureNoTrailingSlash(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"path", "path"},
{"path/", "path"},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := ensureNoTrailingSlash(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
1 change: 1 addition & 0 deletions rest/internal/fileserver/testdata/example.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1
1 change: 1 addition & 0 deletions rest/internal/fileserver/testdata/sample.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
2
24 changes: 24 additions & 0 deletions rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/internal"
"github.com/zeromicro/go-zero/rest/internal/cors"
"github.com/zeromicro/go-zero/rest/internal/fileserver"
"github.com/zeromicro/go-zero/rest/router"
)

Expand Down Expand Up @@ -170,6 +171,13 @@ func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(htt
}
}

// WithFileServer returns a RunOption to serve files from given dir with given path.
func WithFileServer(path, dir string) RunOption {
return func(server *Server) {
server.router = newFileServingRouter(server.router, path, dir)
}
}

// WithJwt returns a func to enable jwt authentication in given route.
func WithJwt(secret string) RouteOption {
return func(r *featuredRoutes) {
Expand Down Expand Up @@ -337,3 +345,19 @@ func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...s
func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.middleware(c.Router.ServeHTTP)(w, r)
}

type fileServingRouter struct {
httpx.Router
middleware Middleware
}

func newFileServingRouter(router httpx.Router, path, dir string) httpx.Router {
return &fileServingRouter{
Router: router,
middleware: fileserver.Middleware(path, dir),
}
}

func (f *fileServingRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
f.middleware(f.Router.ServeHTTP)(w, r)
}
50 changes: 50 additions & 0 deletions rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,56 @@ func TestWithMiddleware(t *testing.T) {
}, m)
}

func TestWithFileServerMiddleware(t *testing.T) {
tests := []struct {
name string
path string
dir string
requestPath string
expectedStatus int
expectedContent string
}{
{
name: "Serve static file",
path: "/assets/",
dir: "./testdata",
requestPath: "/assets/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "example content",
},
{
name: "Pass through non-matching path",
path: "/static/",
dir: "./testdata",
requestPath: "/other/path",
expectedStatus: http.StatusNotFound,
},
{
name: "Directory with trailing slash",
path: "/static",
dir: "testdata",
requestPath: "/static/sample.txt",
expectedStatus: http.StatusOK,
expectedContent: "sample content",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := MustNewServer(RestConf{}, WithFileServer(tt.path, tt.dir))
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
rr := httptest.NewRecorder()

server.ServeHTTP(rr, req)

assert.Equal(t, tt.expectedStatus, rr.Code)
if len(tt.expectedContent) > 0 {
assert.Equal(t, tt.expectedContent, rr.Body.String())
}
})
}
}

func TestMultiMiddlewares(t *testing.T) {
m := make(map[string]string)
rt := router.NewRouter()
Expand Down
1 change: 1 addition & 0 deletions rest/testdata/example.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
example content
1 change: 1 addition & 0 deletions rest/testdata/sample.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sample content