diff --git a/balancer/leastrequest/leastrequest.go b/balancer/leastrequest/leastrequest.go index 6dede1a40b70..088c98f81a78 100644 --- a/balancer/leastrequest/leastrequest.go +++ b/balancer/leastrequest/leastrequest.go @@ -23,23 +23,35 @@ import ( "encoding/json" "fmt" rand "math/rand/v2" + "sync" "sync/atomic" "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/balancer/endpointsharding" + "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" + internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) -// randuint32 is a global to stub out in tests. -var randuint32 = rand.Uint32 - // Name is the name of the least request balancer. const Name = "least_request_experimental" -var logger = grpclog.Component("least-request") +var ( + // randuint32 is a global to stub out in tests. + randuint32 = rand.Uint32 + endpointShardingLBConfig serviceconfig.LoadBalancingConfig + logger = grpclog.Component("least-request") +) func init() { + var err error + endpointShardingLBConfig, err = endpointsharding.ParseConfig(json.RawMessage(endpointsharding.PickFirstConfig)) + if err != nil { + logger.Fatal(err) + } balancer.Register(bb{}) } @@ -80,104 +92,166 @@ func (bb) Name() string { } func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { - b := &leastRequestBalancer{scRPCCounts: make(map[balancer.SubConn]*atomic.Int32)} - baseBuilder := base.NewBalancerBuilder(Name, b, base.Config{HealthCheck: true}) - b.Balancer = baseBuilder.Build(cc, bOpts) + b := &leastRequestBalancer{ + ClientConn: cc, + endpointRPCCounts: resolver.NewEndpointMap(), + choiceCount: 2, + } + b.child = endpointsharding.NewBalancer(b, bOpts) + b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b)) + b.logger.Infof("Created") return b } type leastRequestBalancer struct { - // Embeds balancer.Balancer because needs to intercept UpdateClientConnState - // to learn about choiceCount. - balancer.Balancer + // Embeds balancer.ClientConn because needs to intercept UpdateState calls + // from the child balancer. + balancer.ClientConn + child balancer.Balancer + logger *internalgrpclog.PrefixLogger + mu sync.Mutex choiceCount uint32 - scRPCCounts map[balancer.SubConn]*atomic.Int32 // Hold onto RPC counts to keep track for subsequent picker updates. + // endpointRPCCounts holds RPC counts to keep track for subsequent picker + // updates. + endpointRPCCounts *resolver.EndpointMap // endpoint -> *atomic.Int32 +} + +// Close implements balancer.Balancer. +func (lrb *leastRequestBalancer) Close() { + lrb.child.Close() + lrb.child = nil + lrb.endpointRPCCounts = nil } -func (lrb *leastRequestBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - lrCfg, ok := s.BalancerConfig.(*LBConfig) +// ResolverError implements balancer.Balancer. +func (lrb *leastRequestBalancer) ResolverError(err error) { + lrb.child.ResolverError(err) +} + +// UpdateSubConnState implements balancer.Balancer. +func (lrb *leastRequestBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + lrb.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state) +} + +func (lrb *leastRequestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { + lrCfg, ok := ccs.BalancerConfig.(*LBConfig) if !ok { - logger.Errorf("least-request: received config with unexpected type %T: %v", s.BalancerConfig, s.BalancerConfig) + logger.Errorf("least-request: received config with unexpected type %T: %v", ccs.BalancerConfig, ccs.BalancerConfig) return balancer.ErrBadResolverState } + lrb.mu.Lock() lrb.choiceCount = lrCfg.ChoiceCount - return lrb.Balancer.UpdateClientConnState(s) + lrb.mu.Unlock() + // Enable the health listener in pickfirst children for client side health + // checks and outlier detection, if configured. + ccs.ResolverState = pickfirstleaf.EnableHealthListener(ccs.ResolverState) + ccs.BalancerConfig = endpointShardingLBConfig + return lrb.child.UpdateClientConnState(ccs) } -type scWithRPCCount struct { - sc balancer.SubConn +type pickerWithRPCCount struct { + picker balancer.Picker numRPCs *atomic.Int32 } -func (lrb *leastRequestBalancer) Build(info base.PickerBuildInfo) balancer.Picker { - if logger.V(2) { - logger.Infof("least-request: Build called with info: %v", info) +func (lrb *leastRequestBalancer) UpdateState(state balancer.State) { + childStates := endpointsharding.ChildStatesFromPicker(state.Picker) + var readyChildren []endpointsharding.ChildState + for _, child := range childStates { + if child.State.ConnectivityState == connectivity.Ready { + readyChildren = append(readyChildren, child) + } } - if len(info.ReadySCs) == 0 { - return base.NewErrPicker(balancer.ErrNoSubConnAvailable) + + // If no ready pickers are present, simply defer to the round robin picker + // from endpoint sharding, which will round robin across the most relevant + // pick first children in the highest precedence connectivity state. + if len(readyChildren) == 0 { + lrb.ClientConn.UpdateState(state) + return } - for sc := range lrb.scRPCCounts { - if _, ok := info.ReadySCs[sc]; !ok { // If no longer ready, no more need for the ref to count active RPCs. - delete(lrb.scRPCCounts, sc) - } + if logger.V(2) { + lrb.logger.Infof("UpdateState called with ready endpoints: %v", readyChildren) + } + + // Reconcile endpoints. + newEndpoints := resolver.NewEndpointMap() // endpoint -> nil + for _, child := range readyChildren { + newEndpoints.Set(child.Endpoint, nil) } - // Create new refs if needed. - for sc := range info.ReadySCs { - if _, ok := lrb.scRPCCounts[sc]; !ok { - lrb.scRPCCounts[sc] = new(atomic.Int32) + // If endpoints are no longer ready, no need to count their active RPCs. + for _, endpoint := range lrb.endpointRPCCounts.Keys() { + if _, ok := newEndpoints.Get(endpoint); !ok { + lrb.endpointRPCCounts.Delete(endpoint) } } // Copy refs to counters into picker. - scs := make([]scWithRPCCount, 0, len(info.ReadySCs)) - for sc := range info.ReadySCs { - scs = append(scs, scWithRPCCount{ - sc: sc, - numRPCs: lrb.scRPCCounts[sc], // guaranteed to be present due to algorithm + pickers := make([]pickerWithRPCCount, 0, len(readyChildren)) + for _, child := range readyChildren { + var counter *atomic.Int32 + if val, ok := lrb.endpointRPCCounts.Get(child.Endpoint); !ok { + // Create new counts if needed. + counter = new(atomic.Int32) + lrb.endpointRPCCounts.Set(child.Endpoint, counter) + } else { + counter = val.(*atomic.Int32) + } + pickers = append(pickers, pickerWithRPCCount{ + picker: child.State.Picker, + numRPCs: counter, }) } - return &picker{ - choiceCount: lrb.choiceCount, - subConns: scs, - } + lrb.ClientConn.UpdateState(balancer.State{ + Picker: &picker{ + choiceCount: lrb.choiceCount, + pickersWithRPCCount: pickers, + }, + ConnectivityState: connectivity.Ready, + }) } type picker struct { // choiceCount is the number of random SubConns to find the one with // the least request. choiceCount uint32 - // Built out when receives list of ready RPCs. - subConns []scWithRPCCount + // Built out when receives list of ready child pickers. + pickersWithRPCCount []pickerWithRPCCount } -func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) { - var pickedSC *scWithRPCCount +func (p *picker) Pick(pInfo balancer.PickInfo) (balancer.PickResult, error) { + var pickedEndpoint *pickerWithRPCCount var pickedSCNumRPCs int32 for i := 0; i < int(p.choiceCount); i++ { - index := randuint32() % uint32(len(p.subConns)) - sc := p.subConns[index] - n := sc.numRPCs.Load() - if pickedSC == nil || n < pickedSCNumRPCs { - pickedSC = &sc + index := randuint32() % uint32(len(p.pickersWithRPCCount)) + child := p.pickersWithRPCCount[index] + n := child.numRPCs.Load() + if pickedEndpoint == nil || n < pickedSCNumRPCs { + pickedEndpoint = &child pickedSCNumRPCs = n } } + result, err := pickedEndpoint.picker.Pick(pInfo) + if err != nil { + return result, err + } // "The counter for a subchannel should be atomically incremented by one // after it has been successfully picked by the picker." - A48 - pickedSC.numRPCs.Add(1) + pickedEndpoint.numRPCs.Add(1) // "the picker should add a callback for atomically decrementing the // subchannel counter once the RPC finishes (regardless of Status code)." - // A48. - done := func(balancer.DoneInfo) { - pickedSC.numRPCs.Add(-1) + originalDone := result.Done + result.Done = func(info balancer.DoneInfo) { + pickedEndpoint.numRPCs.Add(-1) + if originalDone != nil { + originalDone(info) + } } - return balancer.PickResult{ - SubConn: pickedSC.sc, - Done: done, - }, nil + return result, nil }