diff --git a/balancer/endpointsharding/endpointsharding.go b/balancer/endpointsharding/endpointsharding.go index df5b5abe73cc..149ef6ed86b0 100644 --- a/balancer/endpointsharding/endpointsharding.go +++ b/balancer/endpointsharding/endpointsharding.go @@ -116,7 +116,13 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState bal = child.(*balancerWrapper) } else { bal = &balancerWrapper{ - childState: ChildState{Endpoint: endpoint}, + childState: ChildState{ + Endpoint: endpoint, + State: balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable), + }, + }, ClientConn: es.cc, es: es, } diff --git a/balancer/endpointsharding/endpointsharding_test.go b/balancer/endpointsharding/endpointsharding_test.go index 6b23063b5d9c..f189970434c2 100644 --- a/balancer/endpointsharding/endpointsharding_test.go +++ b/balancer/endpointsharding/endpointsharding_test.go @@ -28,19 +28,24 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils/roundrobin" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/serviceconfig" + "google.golang.org/grpc/status" testgrpc "google.golang.org/grpc/interop/grpc_testing" ) +const defaultShortTestTimeout = 100 * time.Millisecond + type s struct { grpctest.Tester } @@ -49,16 +54,11 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } -var gracefulSwitchPickFirst serviceconfig.LoadBalancingConfig +var childLBConfig serviceconfig.LoadBalancingConfig var logger = grpclog.Component("endpoint-sharding-test") func init() { - var err error - gracefulSwitchPickFirst, err = ParseConfig(json.RawMessage(PickFirstConfig)) - if err != nil { - logger.Fatal(err) - } balancer.Register(fakePetioleBuilder{}) } @@ -99,7 +99,7 @@ func (fp *fakePetiole) UpdateClientConnState(state balancer.ClientConnState) err } return fp.Balancer.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: gracefulSwitchPickFirst, + BalancerConfig: childLBConfig, ResolverState: state.ResolverState, }) } @@ -124,6 +124,11 @@ func (fp *fakePetiole) UpdateState(state balancer.State) { // It also verifies the petiole has access to the raw child state in case it // wants to implement a custom picker. func (s) TestEndpointShardingBasic(t *testing.T) { + var parseErr error + childLBConfig, parseErr = ParseConfig(json.RawMessage(PickFirstConfig)) + if parseErr != nil { + t.Fatalf("Failed to parse child LB config: %v", parseErr) + } backend1 := stubserver.StartTestService(t, nil) defer backend1.Stop() backend2 := stubserver.StartTestService(t, nil) @@ -157,3 +162,54 @@ func (s) TestEndpointShardingBasic(t *testing.T) { t.Fatalf("error in expected round robin: %v", err) } } + +// TestEndpointShardingStuckConnecting verifies that the endpointsharding policy +// handles child polcies that haven't given a picker update correctly and doesn't +// panic. +func (s) TestEndpointShardingStuckConnecting(t *testing.T) { + childPolicyName := t.Name() + stub.Register(childPolicyName, stub.BalancerFuncs{ + UpdateClientConnState: func(_ *stub.BalancerData, ccs balancer.ClientConnState) error { + t.Logf("Ignoring resolver update to remain in CONNECTING: %v", ccs) + return nil + }, + }) + childLbJSON := json.RawMessage(fmt.Sprintf(`[{%q: {}}]`, childPolicyName)) + var parseErr error + childLBConfig, parseErr = ParseConfig(childLbJSON) + if parseErr != nil { + t.Fatalf("Failed to parse child LB config: %v", parseErr) + } + backend1 := stubserver.StartTestService(t, nil) + defer backend1.Stop() + backend2 := stubserver.StartTestService(t, nil) + defer backend2.Stop() + + mr := manual.NewBuilderWithScheme("e2e-test") + defer mr.Close() + + json := `{"loadBalancingConfig": [{"fake_petiole":{}}]}` + sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json) + mr.InitialState(resolver.State{ + Endpoints: []resolver.Endpoint{ + {Addresses: []resolver.Address{{Addr: backend1.Address}}}, + {Addresses: []resolver.Address{{Addr: backend2.Address}}}, + }, + ServiceConfig: sc, + }) + + cc, err := grpc.Dial(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + log.Fatalf("Failed to dial: %v", err) + } + defer cc.Close() + ctx, cancel := context.WithTimeout(context.Background(), defaultShortTestTimeout) + defer cancel() + client := testgrpc.NewTestServiceClient(cc) + + // Even though the child LB policy hasn't given an picker updates, it is + // assumted that it's in CONNECTING state. + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); status.Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = %s, want %s", status.Code(err), codes.DeadlineExceeded) + } +}