From 85f83c51075eaf05abb8efe350d4caf91a2e470a Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Wed, 4 Sep 2024 11:22:28 +0530 Subject: [PATCH] Generic health check via producer --- balancer/balancer.go | 16 ++ balancer/base/balancer_test.go | 1 + balancer/endpointsharding/endpointsharding.go | 8 +- balancer/pickfirstleaf/pickfirstleaf.go | 216 +++++++++++++----- balancer_wrapper.go | 54 ++++- health/producer.go | 144 ++++++++++++ health/producer/generic_producer.go | 105 +++++++++ internal/internal.go | 3 +- internal/testutils/balancer.go | 6 +- test/healthcheck_test.go | 116 +++++++--- 10 files changed, 565 insertions(+), 104 deletions(-) create mode 100644 health/producer.go create mode 100644 health/producer/generic_producer.go diff --git a/balancer/balancer.go b/balancer/balancer.go index b181f386a1ba..df5cc3246b25 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -152,6 +152,8 @@ type SubConn interface { // indicate the shutdown operation. This may be delivered before // in-progress RPCs are complete and the actual connection is closed. Shutdown() + RegisterConnectivityListner(StateListener) + UnregisterConnectivityListner(StateListener) } // NewSubConnOptions contains options to create new SubConn. @@ -261,6 +263,8 @@ type BuildOptions struct { // metrics. Balancer implementations which do not register metrics on // metrics registry and record on them can ignore this field. MetricsRecorder estats.MetricsRecorder + + HealthCheckOptions HealthCheckOptions } // Builder creates a balancer. @@ -462,3 +466,15 @@ type ProducerBuilder interface { // other methods to provide additional functionality, e.g. configuration or // subscription registration. type Producer any + +type StateListener interface { + OnStateChange(SubConnState) +} + +// HealthCheckOptions are the options to configure the health check producer. +type HealthCheckOptions struct { + DisableHealthCheckDialOpt bool + ServiceName func() string + HealthCheckFunc internal.HealthChecker + EnableHealthCheck bool +} diff --git a/balancer/base/balancer_test.go b/balancer/base/balancer_test.go index 8a97b4220a5c..2ba5011586fa 100644 --- a/balancer/base/balancer_test.go +++ b/balancer/base/balancer_test.go @@ -41,6 +41,7 @@ func (c *testClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewS func (c *testClientConn) UpdateState(balancer.State) {} type testSubConn struct { + balancer.SubConn updateState func(balancer.SubConnState) } diff --git a/balancer/endpointsharding/endpointsharding.go b/balancer/endpointsharding/endpointsharding.go index df5b5abe73cc..149ef6ed86b0 100644 --- a/balancer/endpointsharding/endpointsharding.go +++ b/balancer/endpointsharding/endpointsharding.go @@ -116,7 +116,13 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState bal = child.(*balancerWrapper) } else { bal = &balancerWrapper{ - childState: ChildState{Endpoint: endpoint}, + childState: ChildState{ + Endpoint: endpoint, + State: balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable), + }, + }, ClientConn: es.cc, es: es, } diff --git a/balancer/pickfirstleaf/pickfirstleaf.go b/balancer/pickfirstleaf/pickfirstleaf.go index 3b72c6d7e304..ce835c4684e8 100644 --- a/balancer/pickfirstleaf/pickfirstleaf.go +++ b/balancer/pickfirstleaf/pickfirstleaf.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/health/producer" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/envconfig" internalgrpclog "google.golang.org/grpc/internal/grpclog" @@ -61,15 +62,17 @@ const logPrefix = "[pick-first-leaf-lb %p] " type pickfirstBuilder struct{} -func (pickfirstBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { +func (pickfirstBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { ctx, cancel := context.WithCancel(context.Background()) b := &pickfirstBalancer{ - cc: cc, - addressList: addressList{}, - subConns: resolver.NewAddressMap(), - serializer: grpcsync.NewCallbackSerializer(ctx), - serializerCancel: cancel, - state: connectivity.Connecting, + cc: cc, + addressList: addressList{}, + subConns: resolver.NewAddressMap(), + serializer: grpcsync.NewCallbackSerializer(ctx), + serializerCancel: cancel, + concludedState: connectivity.Connecting, + rawConnectivityState: connectivity.Connecting, + healthCheckOpts: opts.HealthCheckOptions, } b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b)) return b @@ -105,20 +108,24 @@ type scData struct { // The following fields should only be accessed from a serializer callback // to ensure synchronization. - state connectivity.State - lastErr error + rawConnectivityState connectivity.State + healthState balancer.SubConnState + lastErr error + unregisterHealthListener func() + closeHealthServiceProducer func() } func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) { sd := &scData{ - state: connectivity.Idle, - addr: addr, + rawConnectivityState: connectivity.Idle, + healthState: balancer.SubConnState{ConnectivityState: connectivity.Idle}, + addr: addr, } sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{ StateListener: func(state balancer.SubConnState) { // Store the state and delegate. b.serializer.TrySchedule(func(_ context.Context) { - sd.state = state.ConnectivityState + sd.rawConnectivityState = state.ConnectivityState b.updateSubConnState(sd, state) }) }, @@ -126,10 +133,26 @@ func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) { if err != nil { return nil, err } + // Start the health check service if its configured. + if internal.EnableHealthCheckViaProducer != nil { + sd.closeHealthServiceProducer = internal.EnableHealthCheckViaProducer.(func(balancer.HealthCheckOptions, balancer.SubConn) func())(b.healthCheckOpts, sc) + } sd.subConn = sc return sd, nil } +func (sd *scData) cleanup() { + if sd.unregisterHealthListener != nil { + sd.unregisterHealthListener() + sd.unregisterHealthListener = nil + } + if sd.closeHealthServiceProducer != nil { + sd.closeHealthServiceProducer() + sd.closeHealthServiceProducer = nil + } + sd.subConn.Shutdown() +} + type pickfirstBalancer struct { // The following fields are initialized at build time and read-only after // that and therefore do not need to be guarded by a mutex. @@ -143,12 +166,14 @@ type pickfirstBalancer struct { // The serializer is used to ensure synchronization of updates triggered // from the idle picker and the already serialized resolver, // subconn state updates. - serializer *grpcsync.CallbackSerializer - serializerCancel func() - state connectivity.State - subConns *resolver.AddressMap // scData for active subonns mapped by address. - addressList addressList - firstPass bool + serializer *grpcsync.CallbackSerializer + serializerCancel func() + concludedState connectivity.State + rawConnectivityState connectivity.State + subConns *resolver.AddressMap // scData for active subonns mapped by address. + addressList addressList + firstPass bool + healthCheckOpts balancer.HealthCheckOptions } func (b *pickfirstBalancer) ResolverError(err error) { @@ -162,12 +187,12 @@ func (b *pickfirstBalancer) resolverError(err error) { if b.logger.V(2) { b.logger.Infof("Received error from the name resolver: %v", err) } - if b.state == connectivity.Shutdown { + if b.rawConnectivityState == connectivity.Shutdown { return } // The picker will not change since the balancer does not currently // report an error. - if b.state != connectivity.TransientFailure { + if b.rawConnectivityState != connectivity.TransientFailure { if b.logger.V(2) { b.logger.Infof("Ignoring resolver error because balancer is using a previous good update.") } @@ -176,7 +201,8 @@ func (b *pickfirstBalancer) resolverError(err error) { b.closeSubConns() b.addressList.updateEndpointList(nil) - b.cc.UpdateState(balancer.State{ + b.rawConnectivityState = connectivity.TransientFailure + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.TransientFailure, Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)}, }) @@ -196,13 +222,13 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState // updateClientConnState handles clientConn state changes. // Only executed in the context of a serializer callback. func (b *pickfirstBalancer) updateClientConnState(state balancer.ClientConnState) error { - if b.state == connectivity.Shutdown { + if b.rawConnectivityState == connectivity.Shutdown { return errBalancerClosed } if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 { // Cleanup state pertaining to the previous resolver state. // Treat an empty address list like an error by calling b.ResolverError. - b.state = connectivity.TransientFailure + b.rawConnectivityState = connectivity.TransientFailure b.resolverError(errors.New("produced zero addresses")) return balancer.ErrBadResolverState } @@ -249,7 +275,7 @@ func (b *pickfirstBalancer) updateClientConnState(state balancer.ClientConnState prevAddr := b.addressList.currentAddress() prevAddrsCount := b.addressList.size() b.addressList.updateEndpointList(newEndpoints) - if b.state == connectivity.Ready && b.addressList.seekTo(prevAddr) { + if b.rawConnectivityState == connectivity.Ready && b.addressList.seekTo(prevAddr) { return nil } @@ -261,15 +287,15 @@ func (b *pickfirstBalancer) updateClientConnState(state balancer.ClientConnState // we should still enter CONNECTING because the sticky TF behaviour mentioned // in A62 applies only when the TRANSIENT_FAILURE is reported due to connectivity // failures. - if b.state == connectivity.Ready || b.state == connectivity.Connecting || prevAddrsCount == 0 { + if b.rawConnectivityState == connectivity.Ready || b.rawConnectivityState == connectivity.Connecting || prevAddrsCount == 0 { // Start connection attempt at first address. - b.state = connectivity.Connecting - b.cc.UpdateState(balancer.State{ + b.rawConnectivityState = connectivity.Connecting + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable}, }) b.requestConnection() - } else if b.state == connectivity.TransientFailure { + } else if b.rawConnectivityState == connectivity.TransientFailure { // If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until // we're READY. See A62. b.requestConnection() @@ -286,7 +312,7 @@ func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state b func (b *pickfirstBalancer) Close() { b.serializer.TrySchedule(func(_ context.Context) { b.closeSubConns() - b.state = connectivity.Shutdown + b.rawConnectivityState = connectivity.Shutdown }) b.serializerCancel() <-b.serializer.Done() @@ -296,7 +322,7 @@ func (b *pickfirstBalancer) Close() { // by the idlePicker and clientConn so access to variables should be synchronized. func (b *pickfirstBalancer) ExitIdle() { b.serializer.TrySchedule(func(_ context.Context) { - if b.state == connectivity.Idle { + if b.rawConnectivityState == connectivity.Idle { b.requestConnection() } }) @@ -305,7 +331,7 @@ func (b *pickfirstBalancer) ExitIdle() { // Only executed in the context of a serializer callback. func (b *pickfirstBalancer) closeSubConns() { for _, sd := range b.subConns.Values() { - sd.(*scData).subConn.Shutdown() + sd.(*scData).cleanup() } b.subConns = resolver.NewAddressMap() } @@ -356,7 +382,7 @@ func (b *pickfirstBalancer) reconcileSubConns(newEndpoints []resolver.Endpoint) continue } val, _ := b.subConns.Get(oldAddr) - val.(*scData).subConn.Shutdown() + val.(*scData).cleanup() b.subConns.Delete(oldAddr) } } @@ -367,9 +393,10 @@ func (b *pickfirstBalancer) reconcileSubConns(newEndpoints []resolver.Endpoint) func (b *pickfirstBalancer) shutdownRemaining(selected *scData) { for _, v := range b.subConns.Values() { sd := v.(*scData) - if sd.subConn != selected.subConn { - sd.subConn.Shutdown() + if sd.subConn == selected.subConn { + continue } + sd.cleanup() } for _, k := range b.subConns.Keys() { b.subConns.Delete(k) @@ -383,7 +410,7 @@ func (b *pickfirstBalancer) shutdownRemaining(selected *scData) { // attempted. // Only executed in the context of a serializer callback. func (b *pickfirstBalancer) requestConnection() { - if !b.addressList.isValid() || b.state == connectivity.Shutdown { + if !b.addressList.isValid() || b.rawConnectivityState == connectivity.Shutdown { return } curAddr := b.addressList.currentAddress() @@ -399,9 +426,9 @@ func (b *pickfirstBalancer) requestConnection() { b.logger.Warningf("Failed to create a subConn for address %v: %v", curAddr.String(), err) // The LB policy remains in TRANSIENT_FAILURE until a new resolver // update is received. - b.state = connectivity.TransientFailure b.addressList.reset() - b.cc.UpdateState(balancer.State{ + b.rawConnectivityState = connectivity.TransientFailure + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.TransientFailure, Picker: &picker{err: fmt.Errorf("failed to create a new subConn: %v", err)}, }) @@ -411,7 +438,7 @@ func (b *pickfirstBalancer) requestConnection() { } scd := sd.(*scData) - switch scd.state { + switch scd.rawConnectivityState { case connectivity.Idle: scd.subConn.Connect() case connectivity.TransientFailure: @@ -420,9 +447,12 @@ func (b *pickfirstBalancer) requestConnection() { return } b.requestConnection() - case connectivity.Ready: - // Should never happen. - b.logger.Errorf("Requesting a connection even though we have a READY subconn") + default: + // Wait for the current subconn to change state. It could be in READY if + // we're waiting for the health update to arrive. + if b.logger.V(2) { + b.logger.Infof("Waiting for subconn with connectivity state %q and health state %q to transition.", scd.rawConnectivityState, scd.healthState) + } } } @@ -447,22 +477,30 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses) return } - b.state = connectivity.Ready - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.Ready, - Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}}, - }) + b.rawConnectivityState = connectivity.Ready + hl := &healthListener{ + scData: sd, + pb: b, + } + sd.unregisterHealthListener = producer.RegisterListener(hl, sd.subConn) return } // If the LB policy is READY, and it receives a subchannel state change, // it means that the READY subchannel has failed. - if b.state == connectivity.Ready && state.ConnectivityState != connectivity.Ready { + if b.rawConnectivityState == connectivity.Ready && state.ConnectivityState != connectivity.Ready { // Once a transport fails, the balancer enters IDLE and starts from // the first address when the picker is used. - b.state = connectivity.Idle + sd.healthState = balancer.SubConnState{ + ConnectivityState: connectivity.Idle, + } + if sd.unregisterHealthListener != nil { + sd.unregisterHealthListener() + sd.unregisterHealthListener = nil + } b.addressList.reset() - b.cc.UpdateState(balancer.State{ + b.rawConnectivityState = connectivity.Idle + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.Idle, Picker: &idlePicker{exitIdle: b.ExitIdle}, }) @@ -476,8 +514,9 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon // If it's in TRANSIENT_FAILURE, stay in TRANSIENT_FAILURE until // it's READY. See A62. // If the balancer is already in CONNECTING, no update is needed. - if b.state == connectivity.Idle { - b.cc.UpdateState(balancer.State{ + if b.rawConnectivityState == connectivity.Idle { + b.rawConnectivityState = connectivity.Connecting + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable}, }) @@ -505,9 +544,9 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon if curAddr := b.addressList.currentAddress(); !equalAddressIgnoringBalAttributes(&sd.addr, &curAddr) { return } - b.state = connectivity.Idle b.addressList.reset() - b.cc.UpdateState(balancer.State{ + b.rawConnectivityState = connectivity.Idle + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.Idle, Picker: &idlePicker{exitIdle: b.ExitIdle}, }) @@ -519,7 +558,8 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon switch state.ConnectivityState { case connectivity.TransientFailure: sd.lastErr = state.ConnectionError - b.cc.UpdateState(balancer.State{ + b.rawConnectivityState = connectivity.TransientFailure + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.TransientFailure, Picker: &picker{err: state.ConnectionError}, }) @@ -531,19 +571,67 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon } } +// Only executed in the context of a serializer callback. +func (b *pickfirstBalancer) updateSubConnHealthState(sd *scData) { + state := sd.healthState + // If the raw connectivity state is not READY, we ignore the health updates. + if state.ConnectivityState == connectivity.Shutdown || sd.rawConnectivityState != connectivity.Ready { + return + } + // Previously relevant subconns can still callback with state updates. + // To prevent pickers from returning these obsolete subconns, this logic + // is included to check if the current list of active subconns includes this + // subconn. + if activeSD, found := b.subConns.Get(sd.addr); !found || activeSD != sd { + return + } + switch state.ConnectivityState { + case connectivity.Ready: + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Ready, + Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}}, + }) + case connectivity.TransientFailure: + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: &picker{err: fmt.Errorf("health check failure: %v", state.ConnectionError)}, + }) + default: + // If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until + // we're READY. See A62. + if b.concludedState == connectivity.TransientFailure { + return + } + // The health check will report CONNECTING once the raw connectivity state + // changes to READY. We can avoid sending a new picker since the balancer + // would already be in CONNECTING. + if state.ConnectivityState == connectivity.Connecting && b.concludedState == connectivity.Connecting { + return + } + b.updateBalancerState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: &picker{err: balancer.ErrNoSubConnAvailable}, + }) + } +} + +func (b *pickfirstBalancer) updateBalancerState(newState balancer.State) { + b.concludedState = newState.ConnectivityState + b.cc.UpdateState(newState) +} + // Only executed in the context of a serializer callback. func (b *pickfirstBalancer) endFirstPass(lastErr error) { b.firstPass = false - b.state = connectivity.TransientFailure - - b.cc.UpdateState(balancer.State{ + b.rawConnectivityState = connectivity.TransientFailure + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.TransientFailure, Picker: &picker{err: lastErr}, }) // Start re-connecting all the subconns that are already in IDLE. for _, v := range b.subConns.Values() { sd := v.(*scData) - if sd.state == connectivity.Idle { + if sd.rawConnectivityState == connectivity.Idle { sd.subConn.Connect() } } @@ -640,3 +728,15 @@ func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool { a.Attributes.Equal(b.Attributes) && a.Metadata == b.Metadata } + +type healthListener struct { + pb *pickfirstBalancer + scData *scData +} + +func (hl *healthListener) OnStateChange(state balancer.SubConnState) { + hl.pb.serializer.TrySchedule(func(context.Context) { + hl.scData.healthState = state + hl.pb.updateSubConnHealthState(hl.scData) + }) +} diff --git a/balancer_wrapper.go b/balancer_wrapper.go index 5877b71533bc..60084067b4e4 100644 --- a/balancer_wrapper.go +++ b/balancer_wrapper.go @@ -83,6 +83,17 @@ func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper { ChannelzParent: cc.channelz, Target: cc.parsedTarget, MetricsRecorder: cc.metricsRecorderList, + HealthCheckOptions: balancer.HealthCheckOptions{ + HealthCheckFunc: cc.dopts.healthCheckFunc, + DisableHealthCheckDialOpt: cc.dopts.disableHealthCheck, + ServiceName: func() string { + cfg := cc.healthCheckConfig() + if cfg == nil { + return "" + } + return cfg.ServiceName + }, + }, }, serializer: grpcsync.NewCallbackSerializer(ctx), serializerCancel: cancel, @@ -183,10 +194,11 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer return nil, err } acbw := &acBalancerWrapper{ - ccb: ccb, - ac: ac, - producers: make(map[balancer.ProducerBuilder]*refCountedProducer), - stateListener: opts.StateListener, + ccb: ccb, + ac: ac, + producers: make(map[balancer.ProducerBuilder]*refCountedProducer), + stateListener: opts.StateListener, + connectivityListeners: make(map[balancer.StateListener]bool), } ac.acbw = acbw return acbw, nil @@ -256,8 +268,9 @@ type acBalancerWrapper struct { ccb *ccBalancerWrapper // read-only stateListener func(balancer.SubConnState) - mu sync.Mutex - producers map[balancer.ProducerBuilder]*refCountedProducer + mu sync.Mutex + producers map[balancer.ProducerBuilder]*refCountedProducer + connectivityListeners map[balancer.StateListener]bool } // updateState is invoked by grpc to push a subConn state update to the @@ -275,6 +288,12 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolve setConnectedAddress(&scs, curAddr) } acbw.stateListener(scs) + + acbw.mu.Lock() + defer acbw.mu.Unlock() + for lis := range acbw.connectivityListeners { + lis.OnStateChange(scs) + } }) } @@ -353,3 +372,26 @@ func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) ( } return pData.producer, grpcsync.OnceFunc(unref) } + +func (acbw *acBalancerWrapper) RegisterConnectivityListner(l balancer.StateListener) { + acbw.ccb.serializer.TrySchedule(func(ctx context.Context) { + if ctx.Err() != nil || acbw.ccb.balancer == nil { + return + } + acbw.connectivityListeners[l] = true + acbw.ac.mu.Lock() + defer acbw.ac.mu.Unlock() + l.OnStateChange(balancer.SubConnState{ + ConnectivityState: acbw.ac.state, + }) + }) +} + +func (acbw *acBalancerWrapper) UnregisterConnectivityListner(l balancer.StateListener) { + acbw.ccb.serializer.TrySchedule(func(ctx context.Context) { + if ctx.Err() != nil || acbw.ccb.balancer == nil { + return + } + delete(acbw.connectivityListeners, l) + }) +} diff --git a/health/producer.go b/health/producer.go new file mode 100644 index 000000000000..b8a68eaf84e8 --- /dev/null +++ b/health/producer.go @@ -0,0 +1,144 @@ +package health + +import ( + "context" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/health/producer" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/status" +) + +func init() { + producerBuilderSingleton = &producerBuilder{} + internal.EnableHealthCheckViaProducer = EnableHealtCheck +} + +type producerBuilder struct{} + +var producerBuilderSingleton *producerBuilder + +type subConnStateListener struct { + p *healthServiceProducer +} + +func (l *subConnStateListener) OnStateChange(state balancer.SubConnState) { + l.p.mu.Lock() + defer l.p.mu.Unlock() + prevState := l.p.connectivityState + l.p.connectivityState = state.ConnectivityState + if prevState == state.ConnectivityState || prevState == connectivity.Shutdown { + return + } + if prevState == connectivity.Ready { + // Connection failure, stop health check. + if l.p.stopClientFn != nil { + l.p.stopClientFn() + l.p.stopClientFn = nil + } + } else if state.ConnectivityState == connectivity.Ready && l.p.sender != nil { + l.p.startHealthCheckLocked() + } +} + +// Build constructs and returns a producer and its cleanup function +func (*producerBuilder) Build(cci any) (balancer.Producer, func()) { + p := &healthServiceProducer{ + cc: cci.(grpc.ClientConnInterface), + mu: sync.Mutex{}, + connectivityState: connectivity.Idle, + } + return p, sync.OnceFunc(func() { + p.mu.Lock() + defer p.mu.Unlock() + p.connectivityState = connectivity.Shutdown + if p.stopClientFn != nil { + p.stopClientFn() + p.stopClientFn = nil + } + if p.unregisterConnListener != nil { + p.unregisterConnListener() + p.unregisterConnListener = nil + } + }) +} + +type healthServiceProducer struct { + cc grpc.ClientConnInterface + mu sync.Mutex + connectivityState connectivity.State + subConnStateListener balancer.StateListener + sender producer.Sender + oldSender producer.Sender + unregisterConnListener func() + opts *balancer.HealthCheckOptions + stopClientFn func() +} + +func EnableHealtCheck(opts balancer.HealthCheckOptions, sc balancer.SubConn) func() { + pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton) + p := pr.(*healthServiceProducer) + p.mu.Lock() + defer p.mu.Unlock() + if p.sender != nil || p.connectivityState == connectivity.Shutdown { + return closeFn + } + var closeGenericProducer func() + p.sender, closeGenericProducer = producer.SwapSender(func(scs balancer.SubConnState) { + // Block all the updates. The health check service is supposed to be + // first producer in the chain. + }, sc) + ls := &subConnStateListener{ + p: p, + } + sc.RegisterConnectivityListner(ls) + p.unregisterConnListener = func() { + sc.UnregisterConnectivityListner(ls) + } + p.opts = &opts + return func() { + closeFn() + closeGenericProducer() + } +} + +func (p *healthServiceProducer) startHealthCheckLocked() { + serviceName := p.opts.ServiceName() + if p.opts.DisableHealthCheckDialOpt || !p.opts.EnableHealthCheck || serviceName == "" { + p.sender(balancer.SubConnState{ConnectivityState: connectivity.Ready}) + return + } + if p.opts.HealthCheckFunc == nil { + logger.Error("Health check is requested but health check function is not set.") + p.sender(balancer.SubConnState{ConnectivityState: connectivity.Ready}) + return + } + ctx, cancel := context.WithCancel(context.Background()) + p.stopClientFn = cancel + newStream := func(method string) (any, error) { + return p.cc.NewStream(ctx, &grpc.StreamDesc{ServerStreams: true}, method) + } + + setConnectivityState := func(state connectivity.State, err error) { + p.sender(balancer.SubConnState{ + ConnectivityState: state, + ConnectionError: err, + }) + } + + go func() { + err := p.opts.HealthCheckFunc(ctx, newStream, setConnectivityState, serviceName) + if err == nil { + return + } + if status.Code(err) == codes.Unimplemented { + logger.Error("Subchannel health check is unimplemented at server side, thus health check is disabled\n") + } else { + logger.Errorf("Health checking failed: %v\n", err) + } + }() +} diff --git a/health/producer/generic_producer.go b/health/producer/generic_producer.go new file mode 100644 index 000000000000..3d28a6feae53 --- /dev/null +++ b/health/producer/generic_producer.go @@ -0,0 +1,105 @@ +package producer + +import ( + "context" + "sync" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/grpcsync" +) + +var logger = grpclog.Component("health_producer") + +func init() { + producerBuilderSingleton = &producerBuilder{} +} + +type producerBuilder struct{} +type Sender func(balancer.SubConnState) + +var producerBuilderSingleton *producerBuilder + +// Build constructs and returns a producer and its cleanup function +func (*producerBuilder) Build(cci any) (balancer.Producer, func()) { + ctx, cancel := context.WithCancel(context.Background()) + p := &producer{ + cci: cci, + healthState: balancer.SubConnState{ + ConnectivityState: connectivity.Idle, + }, + serializer: grpcsync.NewCallbackSerializer(ctx), + listeners: make(map[balancer.StateListener]bool), + } + p.sender = func(scs balancer.SubConnState) { + p.serializer.TrySchedule(func(ctx context.Context) { + p.healthState = scs + for lis := range p.listeners { + lis.OnStateChange(scs) + } + }) + } + return p, sync.OnceFunc(func() { + p.serializer.TrySchedule(func(ctx context.Context) { + if len(p.listeners) > 0 { + logger.Errorf("Health Producer closing with %d listeners remaining in list", len(p.listeners)) + } + p.listeners = nil + }) + cancel() + <-p.serializer.Done() + }) +} + +type producer struct { + cci any // grpc.ClientConnInterface + listeners map[balancer.StateListener]bool + opts *balancer.HealthCheckOptions + healthState balancer.SubConnState + serializer *grpcsync.CallbackSerializer + sender Sender + senderSwapped bool +} + +func RegisterListener(l balancer.StateListener, sc balancer.SubConn) func() { + pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton) + p := pr.(*producer) + unregister := func() { + p.unregisterListener(l) + closeFn() + } + p.serializer.TrySchedule(func(ctx context.Context) { + p.listeners[l] = true + if !p.senderSwapped { + l.OnStateChange(balancer.SubConnState{ConnectivityState: connectivity.Ready}) + } else if p.healthState.ConnectivityState != connectivity.Idle { + l.OnStateChange(p.healthState) + } + }) + return unregister +} + +// Adds a Sender to beginning of the chain, gives the next sender in the chain to send +// updates. +func SwapSender(newSender Sender, sc balancer.SubConn) (Sender, func()) { + pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton) + p := pr.(*producer) + senderCh := make(chan Sender, 1) + p.serializer.ScheduleOr(func(ctx context.Context) { + p.senderSwapped = true + oldSender := p.sender + p.sender = newSender + senderCh <- oldSender + }, func() { + close(senderCh) + }) + oldSender := <-senderCh + return oldSender, closeFn +} + +func (p *producer) unregisterListener(l balancer.StateListener) { + p.serializer.TrySchedule(func(ctx context.Context) { + delete(p.listeners, l) + }) +} diff --git a/internal/internal.go b/internal/internal.go index 433e697f184f..dbe5360a885f 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -32,7 +32,8 @@ var ( // WithHealthCheckFunc is set by dialoptions.go WithHealthCheckFunc any // func (HealthChecker) DialOption // HealthCheckFunc is used to provide client-side LB channel health checking - HealthCheckFunc HealthChecker + HealthCheckFunc HealthChecker + EnableHealthCheckViaProducer any // func(balancer.HealthCheckOptions, balancer.SubConn) func() // BalancerUnregister is exported by package balancer to unregister a balancer. BalancerUnregister func(name string) // KeepaliveMinPingTime is the minimum ping interval. This must be 10s by diff --git a/internal/testutils/balancer.go b/internal/testutils/balancer.go index c65be16be4b6..ce116a9bc0d9 100644 --- a/internal/testutils/balancer.go +++ b/internal/testutils/balancer.go @@ -32,6 +32,7 @@ import ( // TestSubConn implements the SubConn interface, to be used in tests. type TestSubConn struct { + balancer.SubConn tcc *BalancerClientConn // the CC that owns this SubConn id string ConnectCh chan struct{} @@ -63,8 +64,8 @@ func (tsc *TestSubConn) Connect() { } // GetOrBuildProducer is a no-op. -func (tsc *TestSubConn) GetOrBuildProducer(balancer.ProducerBuilder) (balancer.Producer, func()) { - return nil, nil +func (tsc *TestSubConn) GetOrBuildProducer(builder balancer.ProducerBuilder) (balancer.Producer, func()) { + return builder.Build(nil) } // UpdateState pushes the state to the listener, if one is registered. @@ -72,7 +73,6 @@ func (tsc *TestSubConn) UpdateState(state balancer.SubConnState) { <-tsc.connectCalled.Done() if tsc.stateListener != nil { tsc.stateListener(state) - return } } diff --git a/test/healthcheck_test.go b/test/healthcheck_test.go index b03c47a31426..a6071665c4e3 100644 --- a/test/healthcheck_test.go +++ b/test/healthcheck_test.go @@ -28,12 +28,15 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/pickfirst" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/health" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" @@ -46,7 +49,42 @@ import ( testpb "google.golang.org/grpc/interop/grpc_testing" ) -var testHealthCheckFunc = internal.HealthCheckFunc +const healthCheckingPetiolePolicyName = "health_checking_petiole_policy" + +var ( + testHealthCheckFunc = internal.HealthCheckFunc + + // healthCheckTestPolicyName is the LB policy used for testing the health check + // service. + healthCheckTestPolicyName = "round_robin" +) + +func init() { + balancer.Register(&healthCheckingPetiolePolicyBuilder{}) + // Until dualstack changes are not implemented and round_robin doesn't + // delegate to pickfirst, we test a fake LB policy that delegates to pickfirst. + // to verify health checking works as expected through pickfirst. + if envconfig.NewPickFirstEnabled { + healthCheckTestPolicyName = healthCheckingPetiolePolicyName + } +} + +type healthCheckingPetiolePolicyBuilder struct{} + +func (b *healthCheckingPetiolePolicyBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + opts.HealthCheckOptions.EnableHealthCheck = true + return &healthCheckingPetiolePolicy{ + balancer.Get(pickfirst.Name).Build(cc, opts), + } +} + +func (b *healthCheckingPetiolePolicyBuilder) Name() string { + return healthCheckingPetiolePolicyName +} + +type healthCheckingPetiolePolicy struct { + balancer.Balancer +} func newTestHealthServer() *testHealthServer { return newTestHealthServerWithWatchFunc(defaultWatchFunc) @@ -204,12 +242,12 @@ func (s) TestHealthCheckWatchStateChange(t *testing.T) { cc, r := setupClient(t, nil) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "foo" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -259,12 +297,12 @@ func (s) TestHealthCheckHealthServerNotRegistered(t *testing.T) { cc, r := setupClient(t, nil) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "foo" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -286,12 +324,12 @@ func (s) TestHealthCheckWithGoAway(t *testing.T) { tc := testgrpc.NewTestServiceClient(cc) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "foo" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -364,12 +402,12 @@ func (s) TestHealthCheckWithConnClose(t *testing.T) { tc := testgrpc.NewTestServiceClient(cc) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "foo" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -412,12 +450,12 @@ func (s) TestHealthCheckWithAddrConnDrain(t *testing.T) { hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper() cc, r := setupClient(t, &clientConfig{testHealthCheckFuncWrapper: testHealthCheckFuncWrapper}) tc := testgrpc.NewTestServiceClient(cc) - sc := parseServiceConfig(t, r, `{ + sc := parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "foo" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName)) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, ServiceConfig: sc, @@ -494,12 +532,12 @@ func (s) TestHealthCheckWithClientConnClose(t *testing.T) { tc := testgrpc.NewTestServiceClient(cc) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "foo" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -561,12 +599,12 @@ func (s) TestHealthCheckWithoutSetConnectivityStateCalledAddrConnShutDown(t *tes // The serviceName "delay" is specially handled at server side, where response will not be sent // back to client immediately upon receiving the request (client should receive no response until // test ends). - sc := parseServiceConfig(t, r, `{ + sc := parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "delay" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName)) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, ServiceConfig: sc, @@ -626,12 +664,12 @@ func (s) TestHealthCheckWithoutSetConnectivityStateCalled(t *testing.T) { // test ends). r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "delay" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName))}) select { case <-hcExitChan: @@ -667,12 +705,12 @@ func testHealthCheckDisableWithDialOption(t *testing.T, addr string) { tc := testgrpc.NewTestServiceClient(cc) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: addr}}, - ServiceConfig: parseServiceConfig(t, r, `{ + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "foo" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName))}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -775,12 +813,12 @@ func (s) TestHealthCheckChannelzCountingCallSuccess(t *testing.T) { _, r := setupClient(t, nil) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "channelzSuccess" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName))}) if err := verifyResultWithDelay(func() (bool, error) { cm, _ := channelz.GetTopChannels(0, 0) @@ -824,12 +862,12 @@ func (s) TestHealthCheckChannelzCountingCallFailure(t *testing.T) { _, r := setupClient(t, nil) r.UpdateState(resolver.State{ Addresses: []resolver.Address{{Addr: lis.Addr().String()}}, - ServiceConfig: parseServiceConfig(t, r, `{ + ServiceConfig: parseServiceConfig(t, r, fmt.Sprintf(`{ "healthCheckConfig": { "serviceName": "channelzFailure" }, - "loadBalancingConfig": [{"round_robin":{}}] -}`)}) + "loadBalancingConfig": [{"%s":{}}] +}`, healthCheckTestPolicyName))}) if err := verifyResultWithDelay(func() (bool, error) { cm, _ := channelz.GetTopChannels(0, 0) @@ -938,7 +976,15 @@ func testHealthCheckSuccess(t *testing.T, e env) { // TestHealthCheckFailure invokes the unary Check() RPC on the health server // with an expired context and expects the RPC to fail. func (s) TestHealthCheckFailure(t *testing.T) { - for _, e := range listTestEnv() { + envs := listTestEnv() + if envconfig.NewPickFirstEnabled { + envs = append([]env{ + {name: "tcp-clear", network: "tcp", balancer: healthCheckTestPolicyName}, + {name: "tcp-tls", network: "tcp", security: "tls", balancer: healthCheckTestPolicyName}, + }, envs...) + + } + for _, e := range envs { testHealthCheckFailure(t, e) } }