Skip to content

Commit b126417

Browse files
committed
fix race condition in Block/Nursery and always run provided routine
1 parent c6ba6d2 commit b126417

File tree

2 files changed

+15
-45
lines changed

2 files changed

+15
-45
lines changed

nursery.go

+14-11
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@ func catchPanics(routineDone chan<- error) {
5555

5656
// Go implements Nursery.
5757
func (n *nursery) Go(routine func() error) {
58-
n.routinesCount.Add(1)
58+
new := n.routinesCount.Add(1)
59+
if new < 2 {
60+
panic("use of nursery after end of block")
61+
}
62+
5963
if n.limiter == nil {
6064
select {
6165
case n.goRoutine <- routine:
@@ -69,9 +73,6 @@ func (n *nursery) Go(routine func() error) {
6973
case n.limiter <- struct{}{}:
7074
// We are below our limit.
7175
n.goNew(routine)
72-
case <-n.Done():
73-
// Context canceled.
74-
n.routinesCount.Add(-1)
7576
case n.goRoutine <- routine:
7677
// Successfully reused a goroutine.
7778
}
@@ -82,6 +83,8 @@ func (n *nursery) goNew(routine Routine) {
8283
go func() {
8384
defer catchPanics(n.errors)
8485
for r := range n.goRoutine {
86+
// TODO: add option to skip routine if context is canceled.
87+
8588
err := r()
8689
if err != nil {
8790
n.onError(err)
@@ -90,13 +93,8 @@ func (n *nursery) goNew(routine Routine) {
9093
}
9194
}()
9295

93-
select {
94-
case <-n.Done():
95-
// Context canceled.
96-
n.routinesCount.Add(-1)
97-
case n.goRoutine <- routine:
98-
// routine forwarded.
99-
}
96+
// Execute routine.
97+
n.goRoutine <- routine
10098
}
10199

102100
// Block starts a nursery block that returns when all goroutines have returned.
@@ -106,6 +104,9 @@ func (n *nursery) goNew(routine Routine) {
106104
// goroutines to handle context cancellation. Error returned by block closure
107105
// always trigger a context cancellation and is returned if it occurs before a
108106
// default goroutine error handler is called.
107+
// Calling [Nursery].Go() after end of block always panic. Calling [Nursery].Go
108+
// after context is canceled still runs the provided function, you're responsible
109+
// for handling cancellation.
109110
func Block(block func(n Nursery) error, opts ...BlockOption) (err error) {
110111
n := newNursery()
111112
for _, opt := range opts {
@@ -130,6 +131,7 @@ func Block(block func(n Nursery) error, opts ...BlockOption) (err error) {
130131
}
131132

132133
// Start block.
134+
n.routinesCount.Add(1) // Bypass end of block check.
133135
n.Go(func() error {
134136
e := block(n)
135137
if e != nil {
@@ -138,6 +140,7 @@ func Block(block func(n Nursery) error, opts ...BlockOption) (err error) {
138140
err = e
139141
})
140142
}
143+
n.routinesCount.Add(-1) // Restore end of block check.
141144
return nil
142145
})
143146

nursery_test.go

+1-34
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func TestNursery(t *testing.T) {
154154
if panicValue == nil {
155155
t.Fatal("use of nursery after end of block didn't panic")
156156
}
157-
if panicValue.(error).Error() != "send on closed channel" {
157+
if panicValue.(string) != "use of nursery after end of block" {
158158
t.Fatal("use of nursery after end of block didn't panicked with ErrNurseryDone")
159159
}
160160
})
@@ -302,39 +302,6 @@ func TestNursery(t *testing.T) {
302302
}
303303
})
304304

305-
t.Run("LimiterWithContextCancel", func(t *testing.T) {
306-
ctx, cancel := context.WithCancel(context.Background())
307-
308-
blockCh := make(chan struct{})
309-
doneCh := make(chan struct{})
310-
311-
err := Block(func(n Nursery) error {
312-
// Fill the limiter to capacity with a long-running goroutine
313-
n.Go(func() error {
314-
close(blockCh)
315-
<-doneCh
316-
return nil
317-
})
318-
319-
<-blockCh
320-
cancel()
321-
322-
// Try to add another goroutine - this should hit the context.Done() case
323-
// since the limiter is full and the context is canceled
324-
n.Go(func() error {
325-
t.Fatal("this goroutine should not run")
326-
return nil
327-
})
328-
329-
close(doneCh)
330-
return nil
331-
}, WithMaxGoroutines(1), WithContext(ctx))
332-
333-
if err != nil {
334-
t.Error(err)
335-
}
336-
})
337-
338305
t.Run("WithTimeout", func(t *testing.T) {
339306
start := time.Now()
340307
err := Block(func(n Nursery) error {

0 commit comments

Comments
 (0)