Skip to content

Commit 58ee40e

Browse files
committed
simplify design
1 parent c3efa49 commit 58ee40e

File tree

3 files changed

+75
-88
lines changed

3 files changed

+75
-88
lines changed

nursery.go

+69-81
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"runtime/debug"
8+
"sync/atomic"
89
)
910

1011
var (
@@ -21,76 +22,23 @@ type Nursery interface {
2122

2223
type nursery struct {
2324
context.Context
24-
cancel func()
25-
panics chan any
26-
errors chan error
27-
28-
maxRoutines int
29-
routineDone chan error
30-
31-
goRoutine chan func() error
32-
onError func(error)
25+
cancel func()
26+
onError func(error)
27+
errors chan error
28+
limiter limiter
29+
goRoutine chan func() error
30+
routinesCount atomic.Int32
3331
}
3432

35-
func newNursery(ctx context.Context) *nursery {
36-
ctx, cancel := context.WithCancel(ctx)
33+
func newNursery() *nursery {
3734
n := &nursery{
38-
Context: ctx,
39-
cancel: cancel,
40-
panics: make(chan any),
35+
Context: nil,
36+
cancel: nil,
4137
errors: make(chan error),
38+
limiter: nil,
4239
goRoutine: make(chan func() error),
4340
}
4441

45-
// Event loop.
46-
go func() {
47-
done := false
48-
routinesCount := 0
49-
routineDone := make(chan error)
50-
for !done {
51-
handleRoutineDone := func(routineValue error) {
52-
routinesCount--
53-
if gpanic, isPanic := routineValue.(GoroutinePanic); isPanic {
54-
// Cancel all routines.
55-
n.cancel()
56-
n.panics <- gpanic
57-
} else if routineValue != nil {
58-
n.errors <- routineValue
59-
}
60-
if routinesCount == 0 {
61-
close(routineDone)
62-
close(n.panics)
63-
close(n.errors)
64-
n.cancel()
65-
done = true
66-
}
67-
}
68-
69-
// We can spawn routine.
70-
if routinesCount < n.maxRoutines || n.maxRoutines <= 0 {
71-
select {
72-
case routine := <-n.goRoutine:
73-
routinesCount++
74-
go func() {
75-
defer catchPanics(routineDone)
76-
err := routine()
77-
if err != nil && n.onError != nil {
78-
n.onError(err)
79-
}
80-
routineDone <- err
81-
}()
82-
83-
case routineValue := <-routineDone:
84-
handleRoutineDone(routineValue)
85-
}
86-
} else {
87-
// We can't spawn routine.
88-
routineValue := <-routineDone
89-
handleRoutineDone(routineValue)
90-
}
91-
}
92-
}()
93-
9442
return n
9543
}
9644

@@ -116,36 +64,74 @@ func (n *nursery) mustNotBeDone() {
11664
func (n *nursery) Go(routine func() error) {
11765
n.mustNotBeDone()
11866

67+
n.routinesCount.Add(1)
68+
if n.limiter == nil {
69+
select {
70+
case n.goRoutine <- routine:
71+
// Successfully reused a goroutine.
72+
default:
73+
// No goroutine available, spawn a new one.
74+
n.goNew(routine)
75+
}
76+
} else {
77+
select {
78+
case n.limiter <- struct{}{}:
79+
// We are below our limit.
80+
n.goNew(routine)
81+
case n.goRoutine <- routine:
82+
// Successfully reused a goroutine.
83+
}
84+
}
85+
}
86+
87+
func (n *nursery) goNew(routine func() error) {
88+
go func() {
89+
defer catchPanics(n.errors)
90+
for {
91+
select {
92+
case <-n.Done():
93+
// Nursery is done, we can free this goroutine.
94+
return
95+
case r := <-n.goRoutine:
96+
n.errors <- r()
97+
}
98+
}
99+
}()
100+
119101
n.goRoutine <- routine
120102
}
121103

122-
// Block starts a nursery block that returns when all goroutines have
123-
// returned. If a goroutine panic, context is canceled and panic is immediately
124-
// forwarded without waiting for other goroutines to handle context cancellation.
125-
// Errors returned by goroutines are joined and returned at the end of the block.
104+
// Block starts a nursery block that returns when all goroutines have returned.
105+
// If a goroutine panic, context is canceled and panic is immediately forwarded
106+
// without waiting for other goroutines to handle context cancellation. Errors
107+
// returned by goroutines are joined and returned at the end of the block.
126108
func Block(block func(n Nursery) error, opts ...BlockOption) (err error) {
127-
n := newNursery(context.Background())
128-
109+
n := newNursery()
129110
for _, opt := range opts {
130111
opt(n)
131112
}
132113

114+
// Default context.
115+
if n.Context == nil {
116+
n.Context, n.cancel = context.WithCancel(context.Background())
117+
}
118+
defer n.cancel()
119+
120+
// Start block.
133121
n.Go(func() error {
134-
block(n)
135-
return nil
122+
return block(n)
136123
})
137124

138-
// Wait for all routine to be done.
139-
loop:
125+
// Event loop.
140126
for {
141-
select {
142-
case panicValue := <-n.panics:
143-
if panicValue != nil {
144-
panic(panicValue)
145-
}
146-
break loop
147-
case e := <-n.errors:
148-
err = errors.Join(err, e)
127+
e := <-n.errors
128+
if panicValue, isPanic := e.(GoroutinePanic); isPanic {
129+
panic(panicValue)
130+
}
131+
err = errors.Join(err, e)
132+
count := n.routinesCount.Add(-1)
133+
if count == 0 {
134+
break
149135
}
150136
}
151137

@@ -176,3 +162,5 @@ func (gp GoroutinePanic) Unwrap() error {
176162

177163
return nil
178164
}
165+
166+
type limiter chan struct{}

nursery_test.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ func TestNursery(t *testing.T) {
142142
ctx, cancel := context.WithCancel(context.Background())
143143
Block(func(n Nursery) error {
144144
n.Go(func() error {
145-
time.Sleep(10 * time.Millisecond)
146-
cancel()
145+
time.AfterFunc(10*time.Millisecond, cancel)
147146
return nil
148147
})
149148

options.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ type BlockOption func(cfg *nursery)
1212
// the given one.
1313
func WithContext(ctx context.Context) BlockOption {
1414
return func(n *nursery) {
15-
n.Context = ctx
15+
n.Context, n.cancel = context.WithCancel(ctx)
1616
}
1717
}
1818

@@ -40,12 +40,12 @@ func WithErrorHandler(handler func(error)) BlockOption {
4040
}
4141
}
4242

43-
// WithCancelOnError returns a nursery block option that
43+
// WithCancelOnError returns a nursery block option that sets error handler to
44+
// inner context cancel function.
4445
func WithCancelOnError() BlockOption {
4546
return func(n *nursery) {
46-
cancel := n.cancel
4747
n.onError = func(err error) {
48-
cancel()
48+
n.cancel()
4949
}
5050
}
5151
}
@@ -60,6 +60,6 @@ func WithMaxGoroutines(max int) BlockOption {
6060
}
6161

6262
// +1 because block function is a routine also.
63-
n.maxRoutines = max + 1
63+
n.limiter = make(chan struct{}, max+1)
6464
}
6565
}

0 commit comments

Comments
 (0)