diff --git a/gateway/config.go b/gateway/config.go index 3d8e4dfd7abd..29e5f254dd1d 100644 --- a/gateway/config.go +++ b/gateway/config.go @@ -12,14 +12,22 @@ type ( Upstreams []Upstream } + // HttpClientConf is the configuration for an HTTP client. + HttpClientConf struct { + Target string + Prefix string `json:",optional"` + Timeout int64 `json:",default=3000"` + } + // RouteMapping is a mapping between a gateway route and an upstream rpc method. RouteMapping struct { // Method is the HTTP method, like GET, POST, PUT, DELETE. Method string // Path is the HTTP path. Path string - // RpcPath is the gRPC rpc method, with format of package.service/method - RpcPath string + // RpcPath is the gRPC rpc method, with format of package.service/method, optional. + // If the mapping is for HTTP, it's not necessary. + RpcPath string `json:",optional"` } // Upstream is the configuration for an upstream. @@ -27,12 +35,14 @@ type ( // Name is the name of the upstream. Name string `json:",optional"` // Grpc is the target of the upstream. - Grpc zrpc.RpcClientConf + Grpc *zrpc.RpcClientConf `json:",optional"` + // Http is the target of the upstream. + Http *HttpClientConf `json:",optional=!grpc"` // ProtoSets is the file list of proto set, like [hello.pb]. // if your proto file import another proto file, you need to write multi-file slice, // like [hello.pb, common.pb]. ProtoSets []string `json:",optional"` - // Mappings is the mapping between gateway routes and Upstream rpc methods. + // Mappings is the mapping between gateway routes and Upstream methods. // Keep it blank if annotations are added in rpc methods. Mappings []RouteMapping `json:",optional"` } diff --git a/gateway/server.go b/gateway/server.go index 51c11a352f16..f2b9ac86e045 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -3,22 +3,29 @@ package gateway import ( "context" "fmt" + "io" "net/http" + "net/url" "strings" + "time" "github.com/fullstorydev/grpcurl" "github.com/golang/protobuf/jsonpb" "github.com/jhump/protoreflect/grpcreflect" + "github.com/zeromicro/go-zero/core/logc" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/mr" "github.com/zeromicro/go-zero/core/threading" "github.com/zeromicro/go-zero/gateway/internal" "github.com/zeromicro/go-zero/rest" + "github.com/zeromicro/go-zero/rest/httpc" "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/zrpc" "google.golang.org/grpc/codes" ) +const defaultHttpScheme = "http" + type ( // Server is a gateway server. Server struct { @@ -83,52 +90,11 @@ func (s *Server) build() error { source <- up } }, func(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) { - var cli zrpc.Client - if s.dialer != nil { - cli = s.dialer(up.Grpc) - } else { - cli = zrpc.MustNewClient(up.Grpc) - } - s.conns = append(s.conns, cli) - - source, err := s.createDescriptorSource(cli, up) - if err != nil { - cancel(fmt.Errorf("%s: %w", up.Name, err)) - return - } - - methods, err := internal.GetMethods(source) - if err != nil { - cancel(fmt.Errorf("%s: %w", up.Name, err)) - return - } - - resolver := grpcurl.AnyResolverFromDescriptorSource(source) - for _, m := range methods { - if len(m.HttpMethod) > 0 && len(m.HttpPath) > 0 { - writer.Write(rest.Route{ - Method: m.HttpMethod, - Path: m.HttpPath, - Handler: s.buildHandler(source, resolver, cli, m.RpcPath), - }) - } - } - - methodSet := make(map[string]struct{}) - for _, m := range methods { - methodSet[m.RpcPath] = struct{}{} - } - for _, m := range up.Mappings { - if _, ok := methodSet[m.RpcPath]; !ok { - cancel(fmt.Errorf("%s: rpc method %s not found", up.Name, m.RpcPath)) - return - } - - writer.Write(rest.Route{ - Method: strings.ToUpper(m.Method), - Path: m.Path, - Handler: s.buildHandler(source, resolver, cli, m.RpcPath), - }) + // up.Grpc and up.Http are exclusive + if up.Grpc != nil { + s.buildGrpcRoute(up, writer, cancel) + } else if up.Http != nil { + s.buildHttpRoute(up, writer) } }, func(pipe <-chan rest.Route, cancel func(error)) { for route := range pipe { @@ -137,7 +103,7 @@ func (s *Server) build() error { }) } -func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver, +func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver, cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { parser, err := internal.NewRequestParser(r, resolver) @@ -160,31 +126,119 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A } } -func (s *Server) createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) { - var source grpcurl.DescriptorSource - var err error +func (s *Server) buildGrpcRoute(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) { + var cli zrpc.Client + if s.dialer != nil { + cli = s.dialer(*up.Grpc) + } else { + cli = zrpc.MustNewClient(*up.Grpc) + } + s.conns = append(s.conns, cli) - if len(up.ProtoSets) > 0 { - source, err = grpcurl.DescriptorSourceFromProtoSets(up.ProtoSets...) + source, err := createDescriptorSource(cli, up) + if err != nil { + cancel(fmt.Errorf("%s: %w", up.Name, err)) + return + } + + methods, err := internal.GetMethods(source) + if err != nil { + cancel(fmt.Errorf("%s: %w", up.Name, err)) + return + } + + resolver := grpcurl.AnyResolverFromDescriptorSource(source) + for _, m := range methods { + if len(m.HttpMethod) > 0 && len(m.HttpPath) > 0 { + writer.Write(rest.Route{ + Method: m.HttpMethod, + Path: m.HttpPath, + Handler: s.buildGrpcHandler(source, resolver, cli, m.RpcPath), + }) + } + } + + methodSet := make(map[string]struct{}) + for _, m := range methods { + methodSet[m.RpcPath] = struct{}{} + } + for _, m := range up.Mappings { + if _, ok := methodSet[m.RpcPath]; !ok { + cancel(fmt.Errorf("%s: rpc method %s not found", up.Name, m.RpcPath)) + return + } + + writer.Write(rest.Route{ + Method: strings.ToUpper(m.Method), + Path: m.Path, + Handler: s.buildGrpcHandler(source, resolver, cli, m.RpcPath), + }) + } +} + +func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(httpx.ContentType, httpx.JsonContentType) + req, err := buildRequestWithNewTarget(r, target) if err != nil { - return nil, err + httpx.ErrorCtx(r.Context(), w, err) + return + } + + if target.Timeout > 0 { + timeout := time.Duration(target.Timeout) * time.Millisecond + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + req = req.WithContext(ctx) + } + + resp, err := httpc.DoRequest(req) + if err != nil { + httpx.ErrorCtx(r.Context(), w, err) + return + } + defer resp.Body.Close() + + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + w.WriteHeader(resp.StatusCode) + if _, err = io.Copy(w, resp.Body); err != nil { + // log the error with original request info + logc.Error(r.Context(), err) } - } else { - client := grpcreflect.NewClientAuto(context.Background(), cli.Conn()) - source = grpcurl.DescriptorSourceFromServer(context.Background(), client) } +} - return source, nil +func (s *Server) buildHttpRoute(up Upstream, writer mr.Writer[rest.Route]) { + for _, m := range up.Mappings { + writer.Write(rest.Route{ + Method: strings.ToUpper(m.Method), + Path: m.Path, + Handler: s.buildHttpHandler(up.Http), + }) + } } func (s *Server) ensureUpstreamNames() error { for i := 0; i < len(s.upstreams); i++ { - target, err := s.upstreams[i].Grpc.BuildTarget() - if err != nil { - return err + if len(s.upstreams[i].Name) > 0 { + continue } - s.upstreams[i].Name = target + if s.upstreams[i].Grpc != nil { + target, err := s.upstreams[i].Grpc.BuildTarget() + if err != nil { + return err + } + + s.upstreams[i].Name = target + } else if s.upstreams[i].Http != nil { + s.upstreams[i].Name = s.upstreams[i].Http.Target + } } return nil @@ -207,6 +261,50 @@ func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server) } } +func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.Request, error) { + u := *r.URL + u.Host = target.Target + if len(u.Scheme) == 0 { + u.Scheme = defaultHttpScheme + } + + if len(target.Prefix) > 0 { + var err error + u.Path, err = url.JoinPath(target.Prefix, u.Path) + if err != nil { + return nil, err + } + } + + return &http.Request{ + Method: r.Method, + URL: &u, + Header: r.Header.Clone(), + Proto: r.Proto, + ProtoMajor: r.ProtoMajor, + ProtoMinor: r.ProtoMinor, + ContentLength: r.ContentLength, + Body: io.NopCloser(r.Body), + }, nil +} + +func createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) { + var source grpcurl.DescriptorSource + var err error + + if len(up.ProtoSets) > 0 { + source, err = grpcurl.DescriptorSourceFromProtoSets(up.ProtoSets...) + if err != nil { + return nil, err + } + } else { + client := grpcreflect.NewClientAuto(context.Background(), cli.Conn()) + source = grpcurl.DescriptorSourceFromServer(context.Background(), client) + } + + return source, nil +} + // withDialer sets a dialer to create a gRPC client. func withDialer(dialer func(conf zrpc.RpcClientConf) zrpc.Client) func(*Server) { return func(s *Server) { diff --git a/gateway/server_test.go b/gateway/server_test.go index 68b56ade80af..3655fab31764 100644 --- a/gateway/server_test.go +++ b/gateway/server_test.go @@ -2,9 +2,12 @@ package gateway import ( "context" + "errors" + "io" "log" "net" "net/http" + "net/http/httptest" "testing" "time" @@ -65,7 +68,7 @@ func TestMustNewServer(t *testing.T) { RpcPath: "mock.DepositService/Deposit", }, }, - Grpc: zrpc.RpcClientConf{ + Grpc: &zrpc.RpcClientConf{ Endpoints: []string{"foo"}, Timeout: 1000, Middlewares: zrpc.ClientMiddlewaresConf{ @@ -98,7 +101,7 @@ func TestServer_ensureUpstreamNames(t *testing.T) { var s = Server{ upstreams: []Upstream{ { - Grpc: zrpc.RpcClientConf{ + Grpc: &zrpc.RpcClientConf{ Target: "target", }, }, @@ -113,7 +116,7 @@ func TestServer_ensureUpstreamNames_badEtcd(t *testing.T) { var s = Server{ upstreams: []Upstream{ { - Grpc: zrpc.RpcClientConf{ + Grpc: &zrpc.RpcClientConf{ Etcd: discov.EtcdConf{}, }, }, @@ -125,3 +128,193 @@ func TestServer_ensureUpstreamNames_badEtcd(t *testing.T) { s.Start() }) } + +func TestHttpToHttp(t *testing.T) { + server := startTestServer(t) + defer server.Close() + + var c GatewayConf + assert.NoError(t, conf.FillDefault(&c)) + c.DevServer.Host = "localhost" + c.Host = "localhost" + c.Port = 18882 + + s := MustNewServer(c) + s.upstreams = []Upstream{ + { + Name: "test", + Mappings: []RouteMapping{ + { + Method: "get", + Path: "/api/ping", + }, + }, + Http: &HttpClientConf{ + Target: "localhost:45678", + Timeout: 3000, + }, + }, + { + Mappings: []RouteMapping{ + { + Method: "get", + Path: "/ping", + }, + }, + Http: &HttpClientConf{ + Target: "localhost:45678", + Prefix: "/api", + }, + }, + } + + go s.Start() + defer s.Stop() + + time.Sleep(time.Millisecond * 200) + + t.Run("/api/ping", func(t *testing.T) { + resp, err := httpc.Do(context.Background(), http.MethodGet, + "http://localhost:18882/api/ping", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + if assert.NoError(t, err) { + assert.Equal(t, "pong", string(body)) + } + }) + + t.Run("/ping", func(t *testing.T) { + resp, err := httpc.Do(context.Background(), http.MethodGet, + "http://localhost:18882/ping", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + if assert.NoError(t, err) { + assert.Equal(t, "pong", string(body)) + } + }) + + t.Run("no upstream", func(t *testing.T) { + resp, err := httpc.Do(context.Background(), http.MethodGet, + "http://localhost:18882/ping/bad", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) +} + +func TestHttpToHttpBadUpstream(t *testing.T) { + var c GatewayConf + assert.NoError(t, conf.FillDefault(&c)) + c.DevServer.Host = "localhost" + c.Host = "localhost" + c.Port = 18883 + + s := MustNewServer(c) + s.upstreams = []Upstream{ + { + Mappings: []RouteMapping{ + { + Method: "get", + Path: "/api/ping", + }, + }, + Http: &HttpClientConf{ + Target: "localhost:45678", + Prefix: "\x7f/api", + }, + }, + } + + go s.Start() + defer s.Stop() + + time.Sleep(time.Millisecond * 200) + + t.Run("/api/ping", func(t *testing.T) { + resp, err := httpc.Do(context.Background(), http.MethodGet, + "http://localhost:18883/api/ping", nil) + assert.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) +} + +func TestHttpToHttpBadWriter(t *testing.T) { + t.Run("bad url", func(t *testing.T) { + handler := new(Server).buildHttpHandler(&HttpClientConf{ + Target: "http://example.com", + Timeout: 3000, + }) + w := httptest.NewRecorder() + handler.ServeHTTP(&badResponseWriter{w}, + httptest.NewRequest(http.MethodGet, "http://localhost:18884", nil)) + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("bad url", func(t *testing.T) { + var c GatewayConf + assert.NoError(t, conf.FillDefault(&c)) + c.DevServer.Host = "localhost" + c.Host = "localhost" + c.Port = 18884 + + s := MustNewServer(c) + s.upstreams = []Upstream{ + { + Mappings: []RouteMapping{ + { + Method: "get", + Path: "/api/ping", + }, + }, + Http: &HttpClientConf{ + Target: "localhost:45678", + Prefix: "\x7f/api", + }, + }, + } + + go s.Start() + defer s.Stop() + + handler := new(Server).buildHttpHandler(&HttpClientConf{ + Target: "localhost:18884", + Timeout: 3000, + }) + w := httptest.NewRecorder() + handler.ServeHTTP(&badResponseWriter{w}, + httptest.NewRequest(http.MethodGet, "http://localhost:18884/api/ping", nil)) + assert.Equal(t, http.StatusBadRequest, w.Code) + }) +} + +// Handler function for the root route +func pingHandler(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("pong")) +} + +func startTestServer(t *testing.T) *http.Server { + http.HandleFunc("/api/ping", pingHandler) + + server := &http.Server{ + Addr: ":45678", + Handler: http.DefaultServeMux, + } + + go func() { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Errorf("failed to start server: %v", err) + } + }() + + return server +} + +type badResponseWriter struct { + http.ResponseWriter +} + +func (w *badResponseWriter) Write([]byte) (int, error) { + return 0, errors.New("bad writer") +} diff --git a/rest/httpc/internal/metricsinterceptor_test.go b/rest/httpc/internal/metricsinterceptor_test.go index 413619f99a51..07872f19d937 100644 --- a/rest/httpc/internal/metricsinterceptor_test.go +++ b/rest/httpc/internal/metricsinterceptor_test.go @@ -6,15 +6,11 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/logx" ) func TestMetricsInterceptor(t *testing.T) { - c := gomock.NewController(t) - defer c.Finish() - logx.Disable() svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/rest/httpc/requests.go b/rest/httpc/requests.go index 181b53e87b47..bc1251596671 100644 --- a/rest/httpc/requests.go +++ b/rest/httpc/requests.go @@ -183,7 +183,6 @@ func request(r *http.Request, cli client) (*http.Response, error) { for i := len(respHandlers) - 1; i >= 0; i-- { respHandlers[i](resp, err) } - if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) diff --git a/tools/goctl/pkg/parser/api/scanner/scanner.go b/tools/goctl/pkg/parser/api/scanner/scanner.go index 992566f152c3..f906b9e3605f 100644 --- a/tools/goctl/pkg/parser/api/scanner/scanner.go +++ b/tools/goctl/pkg/parser/api/scanner/scanner.go @@ -617,7 +617,7 @@ func NewScanner(filename string, src interface{}) (*Scanner, error) { } if len(data) == 0 { - return nil, fmt.Errorf("filename: %s,missing input", filename) + return nil, fmt.Errorf("filename: %s, missing input", filename) } var runeList []rune