Skip to content

Commit

Permalink
Make least_request delegate to pickfirst
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal committed Dec 25, 2024
1 parent 724f450 commit a3a0c34
Showing 1 changed file with 129 additions and 56 deletions.
185 changes: 129 additions & 56 deletions balancer/leastrequest/leastrequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,35 @@ import (
"encoding/json"
"fmt"
rand "math/rand/v2"
"sync"
"sync/atomic"

"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)

// randuint32 is a global to stub out in tests.
var randuint32 = rand.Uint32

// Name is the name of the least request balancer.
const Name = "least_request_experimental"

var logger = grpclog.Component("least-request")
var (
// randuint32 is a global to stub out in tests.
randuint32 = rand.Uint32
endpointShardingLBConfig serviceconfig.LoadBalancingConfig
logger = grpclog.Component("least-request")
)

func init() {
var err error
endpointShardingLBConfig, err = endpointsharding.ParseConfig(json.RawMessage(endpointsharding.PickFirstConfig))
if err != nil {
logger.Fatal(err)
}
balancer.Register(bb{})
}

Expand Down Expand Up @@ -80,104 +92,165 @@ func (bb) Name() string {
}

func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
b := &leastRequestBalancer{scRPCCounts: make(map[balancer.SubConn]*atomic.Int32)}
baseBuilder := base.NewBalancerBuilder(Name, b, base.Config{HealthCheck: true})
b.Balancer = baseBuilder.Build(cc, bOpts)
b := &leastRequestBalancer{
ClientConn: cc,
endpointRPCCounts: resolver.NewEndpointMap(),
choiceCount: 2,
}
b.child = endpointsharding.NewBalancer(b, bOpts)
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b))
b.logger.Infof("Created")
return b
}

type leastRequestBalancer struct {
// Embeds balancer.Balancer because needs to intercept UpdateClientConnState
// to learn about choiceCount.
balancer.Balancer
// Embeds balancer.ClientConn because needs to intercept UpdateState calls
// from the child balancer.
balancer.ClientConn
child balancer.Balancer
logger *internalgrpclog.PrefixLogger

mu sync.Mutex
choiceCount uint32
scRPCCounts map[balancer.SubConn]*atomic.Int32 // Hold onto RPC counts to keep track for subsequent picker updates.
// endpointRPCCounts holds RPC counts to keep track for subsequent picker
// updates.
endpointRPCCounts *resolver.EndpointMap // endpoint -> *atomic.Int32
}

// Close implements balancer.Balancer.
func (lrb *leastRequestBalancer) Close() {
lrb.child.Close()
lrb.endpointRPCCounts = nil
}

func (lrb *leastRequestBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
lrCfg, ok := s.BalancerConfig.(*LBConfig)
// ResolverError implements balancer.Balancer.
func (lrb *leastRequestBalancer) ResolverError(err error) {
lrb.child.ResolverError(err)
}

// UpdateSubConnState implements balancer.Balancer.
func (lrb *leastRequestBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
lrb.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", sc, state)
}

func (lrb *leastRequestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
lrCfg, ok := ccs.BalancerConfig.(*LBConfig)
if !ok {
logger.Errorf("least-request: received config with unexpected type %T: %v", s.BalancerConfig, s.BalancerConfig)
logger.Errorf("least-request: received config with unexpected type %T: %v", ccs.BalancerConfig, ccs.BalancerConfig)
return balancer.ErrBadResolverState
}

lrb.mu.Lock()
lrb.choiceCount = lrCfg.ChoiceCount
return lrb.Balancer.UpdateClientConnState(s)
lrb.mu.Unlock()
// Enable the health listener in pickfirst children for client side health
// checks and outlier detection, if configured.
ccs.ResolverState = pickfirstleaf.EnableHealthListener(ccs.ResolverState)
ccs.BalancerConfig = endpointShardingLBConfig
return lrb.child.UpdateClientConnState(ccs)
}

type scWithRPCCount struct {
sc balancer.SubConn
type pickerWithRPCCount struct {
picker balancer.Picker
numRPCs *atomic.Int32
}

func (lrb *leastRequestBalancer) Build(info base.PickerBuildInfo) balancer.Picker {
if logger.V(2) {
logger.Infof("least-request: Build called with info: %v", info)
func (lrb *leastRequestBalancer) UpdateState(state balancer.State) {
childStates := endpointsharding.ChildStatesFromPicker(state.Picker)
var readyChildren []endpointsharding.ChildState
for _, child := range childStates {
if child.State.ConnectivityState == connectivity.Ready {
readyChildren = append(readyChildren, child)
}
}
if len(info.ReadySCs) == 0 {
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)

// If no ready pickers are present, simply defer to the round robin picker
// from endpoint sharding, which will round robin across the most relevant
// pick first children in the highest precedence connectivity state.
if len(readyChildren) == 0 {
lrb.ClientConn.UpdateState(state)
return
}

for sc := range lrb.scRPCCounts {
if _, ok := info.ReadySCs[sc]; !ok { // If no longer ready, no more need for the ref to count active RPCs.
delete(lrb.scRPCCounts, sc)
}
if logger.V(2) {
lrb.logger.Infof("UpdateState called with ready endpoints: %v", readyChildren)
}

// Reconcile endpoints.
newEndpoints := resolver.NewEndpointMap() // endpoint -> nil
for _, child := range readyChildren {
newEndpoints.Set(child.Endpoint, nil)
}

// Create new refs if needed.
for sc := range info.ReadySCs {
if _, ok := lrb.scRPCCounts[sc]; !ok {
lrb.scRPCCounts[sc] = new(atomic.Int32)
// If endpoints are no longer ready, no need to count their active RPCs.
for _, endpoint := range lrb.endpointRPCCounts.Keys() {
if _, ok := newEndpoints.Get(endpoint); !ok {
lrb.endpointRPCCounts.Delete(endpoint)
}
}

// Copy refs to counters into picker.
scs := make([]scWithRPCCount, 0, len(info.ReadySCs))
for sc := range info.ReadySCs {
scs = append(scs, scWithRPCCount{
sc: sc,
numRPCs: lrb.scRPCCounts[sc], // guaranteed to be present due to algorithm
pickers := make([]pickerWithRPCCount, 0, len(readyChildren))
for _, child := range readyChildren {
var counter *atomic.Int32
if val, ok := lrb.endpointRPCCounts.Get(child.Endpoint); !ok {
// Create new counts if needed.
counter = new(atomic.Int32)
lrb.endpointRPCCounts.Set(child.Endpoint, counter)
} else {
counter = val.(*atomic.Int32)
}
pickers = append(pickers, pickerWithRPCCount{
picker: child.State.Picker,
numRPCs: counter,
})
}

return &picker{
choiceCount: lrb.choiceCount,
subConns: scs,
}
lrb.ClientConn.UpdateState(balancer.State{
Picker: &picker{
choiceCount: lrb.choiceCount,
pickersWithRPCCount: pickers,
},
ConnectivityState: connectivity.Ready,
})
}

type picker struct {
// choiceCount is the number of random SubConns to find the one with
// the least request.
choiceCount uint32
// Built out when receives list of ready RPCs.
subConns []scWithRPCCount
// Built out when receives list of ready child pickers.
pickersWithRPCCount []pickerWithRPCCount
}

func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
var pickedSC *scWithRPCCount
func (p *picker) Pick(pInfo balancer.PickInfo) (balancer.PickResult, error) {
var pickedEndpoint *pickerWithRPCCount
var pickedSCNumRPCs int32
for i := 0; i < int(p.choiceCount); i++ {
index := randuint32() % uint32(len(p.subConns))
sc := p.subConns[index]
n := sc.numRPCs.Load()
if pickedSC == nil || n < pickedSCNumRPCs {
pickedSC = &sc
index := randuint32() % uint32(len(p.pickersWithRPCCount))
child := p.pickersWithRPCCount[index]
n := child.numRPCs.Load()
if pickedEndpoint == nil || n < pickedSCNumRPCs {
pickedEndpoint = &child
pickedSCNumRPCs = n
}
}
result, err := pickedEndpoint.picker.Pick(pInfo)
if err != nil {
return result, err
}
// "The counter for a subchannel should be atomically incremented by one
// after it has been successfully picked by the picker." - A48
pickedSC.numRPCs.Add(1)
pickedEndpoint.numRPCs.Add(1)
// "the picker should add a callback for atomically decrementing the
// subchannel counter once the RPC finishes (regardless of Status code)." -
// A48.
done := func(balancer.DoneInfo) {
pickedSC.numRPCs.Add(-1)
originalDone := result.Done
result.Done = func(info balancer.DoneInfo) {
pickedEndpoint.numRPCs.Add(-1)
if originalDone != nil {
originalDone(info)
}
}
return balancer.PickResult{
SubConn: pickedSC.sc,
Done: done,
}, nil
return result, nil
}

0 comments on commit a3a0c34

Please sign in to comment.