diff --git a/picker_wrapper_test.go b/picker_wrapper_test.go index 33eb86f94b6f..88ad912c46b1 100644 --- a/picker_wrapper_test.go +++ b/picker_wrapper_test.go @@ -21,6 +21,7 @@ package grpc import ( "context" "fmt" + "sync" "sync/atomic" "testing" "time" @@ -80,6 +81,8 @@ func (s) TestBlockingPick(t *testing.T) { var finishedCount uint64 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() + wg := sync.WaitGroup{} + wg.Add(goroutineCount) for i := goroutineCount; i > 0; i-- { go func() { if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT { @@ -93,6 +96,8 @@ func (s) TestBlockingPick(t *testing.T) { t.Errorf("finished goroutines count: %v, want 0", c) } bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) + // Wait for all pickers to finish before the context is cancelled. + wg.Wait() } func (s) TestBlockingPickNoSubAvailable(t *testing.T) { @@ -102,6 +107,8 @@ func (s) TestBlockingPickNoSubAvailable(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() // All goroutines should block because picker returns no subConn available. + wg := sync.WaitGroup{} + wg.Add(goroutineCount) for i := goroutineCount; i > 0; i-- { go func() { if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT { @@ -115,6 +122,8 @@ func (s) TestBlockingPickNoSubAvailable(t *testing.T) { t.Errorf("finished goroutines count: %v, want 0", c) } bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) + // Wait for all pickers to finish before the context is cancelled. + wg.Wait() } func (s) TestBlockingPickTransientWaitforready(t *testing.T) { @@ -125,6 +134,8 @@ func (s) TestBlockingPickTransientWaitforready(t *testing.T) { defer cancel() // All goroutines should block because picker returns transientFailure and // picks are not failfast. + wg := sync.WaitGroup{} + wg.Add(goroutineCount) for i := goroutineCount; i > 0; i-- { go func() { if tr, _, err := bp.pick(ctx, false, balancer.PickInfo{}); err != nil || tr != testT { @@ -138,6 +149,8 @@ func (s) TestBlockingPickTransientWaitforready(t *testing.T) { t.Errorf("finished goroutines count: %v, want 0", c) } bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) + // Wait for all pickers to finish before the context is cancelled. + wg.Wait() } func (s) TestBlockingPickSCNotReady(t *testing.T) { @@ -147,12 +160,15 @@ func (s) TestBlockingPickSCNotReady(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() // All goroutines should block because subConn is not ready. + wg := sync.WaitGroup{} + wg.Add(goroutineCount) for i := goroutineCount; i > 0; i-- { go func() { if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT { t.Errorf("bp.pick returned non-nil error: %v", err) } atomic.AddUint64(&finishedCount, 1) + wg.Done() }() } time.Sleep(time.Millisecond) @@ -160,4 +176,6 @@ func (s) TestBlockingPickSCNotReady(t *testing.T) { t.Errorf("finished goroutines count: %v, want 0", c) } bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) + // Wait for all pickers to finish before the context is cancelled. + wg.Wait() }