From 0a5e515349304cd07cb306197815bcb251b99570 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Thu, 5 Sep 2024 15:27:18 +0530 Subject: [PATCH] Send connectivity state to root listener --- .../{generic_producer.go => producer.go} | 52 ++++++++++++++----- health/producer.go | 43 +++++++-------- internal/testutils/balancer.go | 12 +++++ 3 files changed, 70 insertions(+), 37 deletions(-) rename health/genericproducer/{generic_producer.go => producer.go} (61%) diff --git a/health/genericproducer/generic_producer.go b/health/genericproducer/producer.go similarity index 61% rename from health/genericproducer/generic_producer.go rename to health/genericproducer/producer.go index c578ef0a78ea..7755723654c8 100644 --- a/health/genericproducer/generic_producer.go +++ b/health/genericproducer/producer.go @@ -1,3 +1,5 @@ +// Package genericproducer provides a balancer.Producer that is used to publish +// and subscribe to health state updates. package genericproducer import ( @@ -26,7 +28,7 @@ type broadcastingListner struct { } func (l *broadcastingListner) OnStateChange(scs balancer.SubConnState) { - l.p.serializer.TrySchedule(func(ctx context.Context) { + l.p.serializer.TrySchedule(func(_ context.Context) { l.p.healthState = scs for lis := range l.listeners { lis.OnStateChange(scs) @@ -44,17 +46,22 @@ func (*producerBuilder) Build(cci any) (balancer.Producer, func()) { }, serializer: grpcsync.NewCallbackSerializer(ctx), } + p.connectivityListener = &connectivityListener{p: p} p.broadcastingListener = &broadcastingListner{ p: p, listeners: make(map[balancer.StateListener]bool), } p.rootListener = p.broadcastingListener return p, sync.OnceFunc(func() { - p.serializer.TrySchedule(func(ctx context.Context) { + p.serializer.TrySchedule(func(_ context.Context) { if len(p.broadcastingListener.listeners) > 0 { logger.Errorf("Health Producer closing with %d listeners remaining in list", len(p.broadcastingListener.listeners)) } p.broadcastingListener.listeners = nil + if p.sc != nil { + p.sc.UnregisterConnectivityListner(p.connectivityListener) + p.connectivityListener = nil + } }) cancel() <-p.serializer.Done() @@ -63,13 +70,30 @@ func (*producerBuilder) Build(cci any) (balancer.Producer, func()) { type producer struct { cci any // grpc.ClientConnInterface - opts *balancer.HealthCheckOptions healthState balancer.SubConnState serializer *grpcsync.CallbackSerializer rootListener balancer.StateListener broadcastingListener *broadcastingListner + connectivityListener *connectivityListener + sc balancer.SubConn +} + +type connectivityListener struct { + p *producer + connectivityState balancer.SubConnState } +func (l *connectivityListener) OnStateChange(state balancer.SubConnState) { + l.p.serializer.TrySchedule(func(_ context.Context) { + l.connectivityState = state + l.p.rootListener.OnStateChange(state) + }) +} + +// RegisterListener is used by health consumers to start listening for health +// updates. It returns a function to unregister the listener and manage +// ref counting. It must be called by consumers when they no longer required the +// listener. func RegisterListener(l balancer.StateListener, sc balancer.SubConn) func() { pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton) p := pr.(*producer) @@ -77,20 +101,26 @@ func RegisterListener(l balancer.StateListener, sc balancer.SubConn) func() { p.unregisterListener(l) closeFn() } - p.serializer.TrySchedule(func(ctx context.Context) { + p.serializer.TrySchedule(func(_ context.Context) { + if p.sc == nil { + p.sc = sc + sc.RegisterConnectivityListner(p.connectivityListener) + } p.broadcastingListener.listeners[l] = true l.OnStateChange(p.healthState) }) return unregister } -// Adds a Sender to beginning of the chain, gives the next sender in the chain to send -// updates. +// SwapRootListener sets the given listener as the root of the listener chain. +// It returns the previous root of the chain. The producer must process calls +// to the registered listener in a passthrough manner by calling the returned +// listener every time it received an update. func SwapRootListener(newListener balancer.StateListener, sc balancer.SubConn) (balancer.StateListener, func()) { pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton) p := pr.(*producer) senderCh := make(chan balancer.StateListener, 1) - p.serializer.ScheduleOr(func(ctx context.Context) { + p.serializer.ScheduleOr(func(_ context.Context) { oldSender := p.rootListener p.rootListener = newListener senderCh <- oldSender @@ -100,16 +130,14 @@ func SwapRootListener(newListener balancer.StateListener, sc balancer.SubConn) ( oldSender := <-senderCh // Send an update on the root listener to allow the new producer to set // update the state present in listener down the chain if required. - p.serializer.TrySchedule(func(ctx context.Context) { - p.rootListener.OnStateChange(balancer.SubConnState{ - ConnectivityState: connectivity.Ready, - }) + p.serializer.TrySchedule(func(_ context.Context) { + p.rootListener.OnStateChange(p.connectivityListener.connectivityState) }) return oldSender, closeFn } func (p *producer) unregisterListener(l balancer.StateListener) { - p.serializer.TrySchedule(func(ctx context.Context) { + p.serializer.TrySchedule(func(_ context.Context) { delete(p.broadcastingListener.listeners, l) }) } diff --git a/health/producer.go b/health/producer.go index e8ef0d621beb..1a926c366b3d 100644 --- a/health/producer.go +++ b/health/producer.go @@ -22,16 +22,20 @@ type producerBuilder struct{} var producerBuilderSingleton *producerBuilder -type subConnStateListener struct { +type connectivityStateListener struct { p *healthServiceProducer } -func (l *subConnStateListener) OnStateChange(state balancer.SubConnState) { +func (l *connectivityStateListener) OnStateChange(state balancer.SubConnState) { l.p.mu.Lock() defer l.p.mu.Unlock() + defer func() { + // Propogate updates down the listener chain. + l.p.updateHealthStateLocked(l.p.healthState) + }() prevState := l.p.connectivityState l.p.connectivityState = state.ConnectivityState - if prevState == state.ConnectivityState || prevState == connectivity.Shutdown { + if prevState == state.ConnectivityState { return } if prevState == connectivity.Ready { @@ -40,15 +44,11 @@ func (l *subConnStateListener) OnStateChange(state balancer.SubConnState) { l.p.stopClientFn() l.p.stopClientFn = nil } - l.p.running = false - l.p.listener.OnStateChange(balancer.SubConnState{ + l.p.currentAttemptMarker = nil + l.p.updateHealthStateLocked(balancer.SubConnState{ ConnectivityState: connectivity.Idle, }) - } else if state.ConnectivityState == connectivity.Ready && l.p.listener != nil { - l.p.running = true - l.p.listener.OnStateChange(balancer.SubConnState{ - ConnectivityState: connectivity.Connecting, - }) + } else if state.ConnectivityState == connectivity.Ready { l.p.startHealthCheckLocked() } } @@ -83,24 +83,15 @@ type healthServiceProducer struct { mu sync.Mutex connectivityState connectivity.State healthState balancer.SubConnState - subConnStateListener balancer.StateListener listener balancer.StateListener unregisterConnListener func() opts *balancer.HealthCheckOptions stopClientFn func() - running bool -} - -type noOpListener struct { - p *healthServiceProducer -} - -func (l *noOpListener) OnStateChange(_ balancer.SubConnState) { - l.p.mu.Lock() - defer l.p.mu.Unlock() - l.p.listener.OnStateChange(l.p.healthState) + currentAttemptMarker *struct{} } +// EnableHealthCheck enabled the health check service client to perform health +// checks for the subchannel. func EnableHealthCheck(opts balancer.HealthCheckOptions, sc balancer.SubConn) func() { pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton) p := pr.(*healthServiceProducer) @@ -110,8 +101,8 @@ func EnableHealthCheck(opts balancer.HealthCheckOptions, sc balancer.SubConn) fu return closeFn } var closeGenericProducer func() - p.listener, closeGenericProducer = genericproducer.SwapRootListener(&noOpListener{p: p}, sc) - ls := &subConnStateListener{ + p.listener, closeGenericProducer = genericproducer.SwapRootListener(&connectivityStateListener{p: p}, sc) + ls := &connectivityStateListener{ p: p, } sc.RegisterConnectivityListner(ls) @@ -146,11 +137,13 @@ func (p *healthServiceProducer) startHealthCheckLocked() { newStream := func(method string) (any, error) { return p.cc.NewStream(ctx, &grpc.StreamDesc{ServerStreams: true}, method) } + marker := &struct{}{} + p.currentAttemptMarker = marker setConnectivityState := func(state connectivity.State, err error) { p.mu.Lock() defer p.mu.Unlock() - if !p.running { + if p.currentAttemptMarker != marker { return } p.updateHealthStateLocked(balancer.SubConnState{ diff --git a/internal/testutils/balancer.go b/internal/testutils/balancer.go index 3a1b7d31be51..5bdf9b52f8b0 100644 --- a/internal/testutils/balancer.go +++ b/internal/testutils/balancer.go @@ -65,6 +65,18 @@ func NewTestSubConn(id string) *TestSubConn { // UpdateAddresses is a no-op. func (tsc *TestSubConn) UpdateAddresses([]resolver.Address) {} +// RegisterConnectivityListner registers a listener. +func (tsc *TestSubConn) RegisterConnectivityListner(sl balancer.StateListener) { + oldLis := tsc.stateListener + tsc.stateListener = func(state balancer.SubConnState) { + oldLis(state) + sl.OnStateChange(state) + } +} + +// UnregisterConnectivityListner is a no-op. +func (tsc *TestSubConn) UnregisterConnectivityListner(_ balancer.StateListener) {} + // Connect is a no-op. func (tsc *TestSubConn) Connect() { tsc.connectCalled.Fire()