Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor zrpc server interceptor builder #4300

Merged
merged 5 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions zrpc/internal/rpcpubserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const (
)

// NewRpcPubServer returns a Server.
func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, middlewares ServerMiddlewaresConf,
func NewRpcPubServer(etcd discov.EtcdConf, listenOn string,
opts ...ServerOption) (Server, error) {
registerEtcd := func() error {
pubListenOn := figureOutListenOn(listenOn)
Expand All @@ -34,7 +34,7 @@ func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, middlewares ServerMi
}
server := keepAliveServer{
registerEtcd: registerEtcd,
Server: NewRpcServer(listenOn, middlewares, opts...),
Server: NewRpcServer(listenOn, opts...),
}

return server, nil
Expand Down
2 changes: 1 addition & 1 deletion zrpc/internal/rpcpubserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestNewRpcPubServer(t *testing.T) {
User: "user",
Pass: "pass",
ID: 10,
}, "", ServerMiddlewaresConf{})
}, "")
assert.NoError(t, err)
assert.NotPanics(t, func() {
s.Start(nil)
Expand Down
63 changes: 4 additions & 59 deletions zrpc/internal/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import (
"net"

"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/internal/health"
"github.com/zeromicro/go-zero/zrpc/internal/serverinterceptors"
"google.golang.org/grpc"
"google.golang.org/grpc/health/grpc_health_v1"
)
Expand All @@ -19,38 +17,31 @@ type (
ServerOption func(options *rpcServerOptions)

rpcServerOptions struct {
metrics *stat.Metrics
health bool
health bool
}

rpcServer struct {
*baseRpcServer
name string
middlewares ServerMiddlewaresConf
healthManager health.Probe
}
)

// NewRpcServer returns a Server.
func NewRpcServer(addr string, middlewares ServerMiddlewaresConf, opts ...ServerOption) Server {
func NewRpcServer(addr string, opts ...ServerOption) Server {
var options rpcServerOptions
for _, opt := range opts {
opt(&options)
}
if options.metrics == nil {
options.metrics = stat.NewMetrics(addr)
}

return &rpcServer{
baseRpcServer: newBaseRpcServer(addr, &options),
middlewares: middlewares,
healthManager: health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, addr)),
}
}

func (s *rpcServer) SetName(name string) {
s.name = name
s.baseRpcServer.SetName(name)
}

func (s *rpcServer) Start(register RegisterFn) error {
Expand All @@ -59,8 +50,8 @@ func (s *rpcServer) Start(register RegisterFn) error {
return err
}

unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.buildUnaryInterceptors()...)
streamInterceptorOption := grpc.ChainStreamInterceptor(s.buildStreamInterceptors()...)
unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.unaryInterceptors...)
streamInterceptorOption := grpc.ChainStreamInterceptor(s.streamInterceptors...)

options := append(s.options, unaryInterceptorOption, streamInterceptorOption)
server := grpc.NewServer(options...)
Expand All @@ -87,52 +78,6 @@ func (s *rpcServer) Start(register RegisterFn) error {
return server.Serve(lis)
}

func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor {
var interceptors []grpc.StreamServerInterceptor

if s.middlewares.Trace {
interceptors = append(interceptors, serverinterceptors.StreamTracingInterceptor)
}
if s.middlewares.Recover {
interceptors = append(interceptors, serverinterceptors.StreamRecoverInterceptor)
}
if s.middlewares.Breaker {
interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor)
}

return append(interceptors, s.streamInterceptors...)
}

func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
var interceptors []grpc.UnaryServerInterceptor

if s.middlewares.Trace {
interceptors = append(interceptors, serverinterceptors.UnaryTracingInterceptor)
}
if s.middlewares.Recover {
interceptors = append(interceptors, serverinterceptors.UnaryRecoverInterceptor)
}
if s.middlewares.Stat {
interceptors = append(interceptors,
serverinterceptors.UnaryStatInterceptor(s.metrics, s.middlewares.StatConf))
}
if s.middlewares.Prometheus {
interceptors = append(interceptors, serverinterceptors.UnaryPrometheusInterceptor)
}
if s.middlewares.Breaker {
interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
}

return append(interceptors, s.unaryInterceptors...)
}

// WithMetrics returns a func that sets metrics to a Server.
func WithMetrics(metrics *stat.Metrics) ServerOption {
return func(options *rpcServerOptions) {
options.metrics = metrics
}
}

// WithRpcHealth returns a func that sets rpc health switch to a Server.
func WithRpcHealth(health bool) ServerOption {
return func(options *rpcServerOptions) {
Expand Down
131 changes: 2 additions & 129 deletions zrpc/internal/rpcserver_test.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
package internal

import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/internal/mock"
"google.golang.org/grpc"
)

func TestRpcServer(t *testing.T) {
metrics := stat.NewMetrics("foo")
server := NewRpcServer("localhost:54321", ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
}, WithMetrics(metrics), WithRpcHealth(true))
server := NewRpcServer("localhost:54321", WithRpcHealth(true))
server.SetName("mock")
var wg, wgDone sync.WaitGroup
var grpcServer *grpc.Server
Expand Down Expand Up @@ -52,13 +43,7 @@ func TestRpcServer(t *testing.T) {
}

func TestRpcServer_WithBadAddress(t *testing.T) {
server := NewRpcServer("localhost:111111", ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
}, WithRpcHealth(true))
server := NewRpcServer("localhost:111111", WithRpcHealth(true))
server.SetName("mock")
err := server.Start(func(server *grpc.Server) {
mock.RegisterDepositServiceServer(server, new(mock.DepositServer))
Expand All @@ -67,115 +52,3 @@ func TestRpcServer_WithBadAddress(t *testing.T) {

proc.WrapUp()
}

func TestRpcServer_buildUnaryInterceptor(t *testing.T) {
tests := []struct {
name string
r *rpcServer
len int
}{
{
name: "empty",
r: &rpcServer{
baseRpcServer: &baseRpcServer{},
},
len: 0,
},
{
name: "custom",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
unaryInterceptors: []grpc.UnaryServerInterceptor{
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
return nil, nil
},
},
},
},
len: 1,
},
{
name: "middleware",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
unaryInterceptors: []grpc.UnaryServerInterceptor{
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
return nil, nil
},
},
},
middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Stat: true,
Prometheus: true,
Breaker: true,
},
},
len: 6,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.len, len(test.r.buildUnaryInterceptors()))
})
}
}

func TestRpcServer_buildStreamInterceptor(t *testing.T) {
tests := []struct {
name string
r *rpcServer
len int
}{
{
name: "empty",
r: &rpcServer{
baseRpcServer: &baseRpcServer{},
},
len: 0,
},
{
name: "custom",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
streamInterceptors: []grpc.StreamServerInterceptor{
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
return nil
},
},
},
},
len: 1,
},
{
name: "middleware",
r: &rpcServer{
baseRpcServer: &baseRpcServer{
streamInterceptors: []grpc.StreamServerInterceptor{
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
return nil
},
},
},
middlewares: ServerMiddlewaresConf{
Trace: true,
Recover: true,
Breaker: true,
},
},
len: 4,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.len, len(test.r.buildStreamInterceptors()))
})
}
}
7 changes: 0 additions & 7 deletions zrpc/internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package internal
import (
"time"

"github.com/zeromicro/go-zero/core/stat"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
"google.golang.org/grpc/keepalive"
Expand All @@ -27,7 +26,6 @@ type (
baseRpcServer struct {
address string
health *health.Server
metrics *stat.Metrics
options []grpc.ServerOption
streamInterceptors []grpc.StreamServerInterceptor
unaryInterceptors []grpc.UnaryServerInterceptor
Expand All @@ -42,7 +40,6 @@ func newBaseRpcServer(address string, rpcServerOpts *rpcServerOptions) *baseRpcS
return &baseRpcServer{
address: address,
health: h,
metrics: rpcServerOpts.metrics,
options: []grpc.ServerOption{grpc.KeepaliveParams(keepalive.ServerParameters{
MaxConnectionIdle: defaultConnectionIdleDuration,
})},
Expand All @@ -60,7 +57,3 @@ func (s *baseRpcServer) AddStreamInterceptors(interceptors ...grpc.StreamServerI
func (s *baseRpcServer) AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) {
s.unaryInterceptors = append(s.unaryInterceptors, interceptors...)
}

func (s *baseRpcServer) SetName(name string) {
s.metrics.SetName(name)
}
13 changes: 3 additions & 10 deletions zrpc/internal/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,18 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stat"
"google.golang.org/grpc"
)

func TestBaseRpcServer_AddOptions(t *testing.T) {
metrics := stat.NewMetrics("foo")
server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics})
server.SetName("bar")
server := newBaseRpcServer("foo", &rpcServerOptions{})
var opt grpc.EmptyServerOption
server.AddOptions(opt)
assert.Contains(t, server.options, opt)
}

func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) {
metrics := stat.NewMetrics("foo")
server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics})
server.SetName("bar")
server := newBaseRpcServer("foo", &rpcServerOptions{})
var vals []int
f := func(_ any, _ grpc.ServerStream, _ *grpc.StreamServerInfo, _ grpc.StreamHandler) error {
vals = append(vals, 1)
Expand All @@ -35,9 +30,7 @@ func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) {
}

func TestBaseRpcServer_AddUnaryInterceptors(t *testing.T) {
metrics := stat.NewMetrics("foo")
server := newBaseRpcServer("foo", &rpcServerOptions{metrics: metrics})
server.SetName("bar")
server := newBaseRpcServer("foo", &rpcServerOptions{})
var vals []int
f := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
resp any, err error) {
Expand Down
Loading