Skip to content

Commit c6ba6d2

Browse files
committed
fix race conditions and add IsDone utils
1 parent cbcdc98 commit c6ba6d2

File tree

2 files changed

+45
-68
lines changed

2 files changed

+45
-68
lines changed

utils.go

+16-12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package conc
33
import (
44
"context"
55
"iter"
6+
"sync"
67
"sync/atomic"
78
"time"
89
)
@@ -16,6 +17,16 @@ func Sleep(ctx context.Context, d time.Duration) {
1617
}
1718
}
1819

20+
// IsDone returns whether provided context is done.
21+
func IsDone(ctx context.Context) bool {
22+
select {
23+
case <-ctx.Done():
24+
return true
25+
default:
26+
return false
27+
}
28+
}
29+
1930
type Job[T any] func(context.Context) (T, error)
2031

2132
// All executes all jobs in separate goroutines and stores each result in
@@ -123,25 +134,18 @@ func doMap[T any, V any](input []T, results []V, f func(context.Context, T) (V,
123134
// Map2 applies f to each key, value pair of input and returns a new slice containing
124135
// mapped results.
125136
func Map2[K comparable, V any](input map[K]V, f func(context.Context, K, V) (K, V, error), opts ...BlockOption) (map[K]V, error) {
126-
results := make(map[K]V)
127-
err := doMap2(input, results, f, opts...)
128-
return results, err
129-
}
130-
131-
// MapInPlace2 applies f to each key, value pair of input and returns modified map.
132-
func Map2InPlace[K comparable, V any](input map[K]V, f func(context.Context, K, V) (K, V, error), opts ...BlockOption) (map[K]V, error) {
133-
err := doMap2(input, input, f, opts...)
134-
return input, err
135-
}
137+
var mu sync.Mutex
136138

137-
func doMap2[K comparable, V any](input map[K]V, results map[K]V, f func(context.Context, K, V) (K, V, error), opts ...BlockOption) error {
138-
return Block(func(n Nursery) error {
139+
results := make(map[K]V)
140+
return results, Block(func(n Nursery) error {
139141
for k, v := range input {
140142
key := k
141143
value := v
142144
n.Go(func() error {
143145
newK, newV, err := f(n, key, value)
146+
mu.Lock()
144147
results[newK] = newV
148+
mu.Unlock()
145149
return err
146150
})
147151
}

utils_test.go

+29-56
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import (
44
"context"
55
"errors"
66
"iter"
7-
"maps"
87
"reflect"
8+
"sync"
99
"testing"
1010
"time"
1111
)
@@ -36,6 +36,19 @@ func TestSleep(t *testing.T) {
3636
})
3737
}
3838

39+
func TestIsDone(t *testing.T) {
40+
ctx, cancel := context.WithCancel(context.Background())
41+
if IsDone(ctx) {
42+
t.Fatal("expected false, got true")
43+
}
44+
45+
cancel()
46+
47+
if !IsDone(ctx) {
48+
t.Fatal("expected true, got false")
49+
}
50+
}
51+
3952
func TestAll(t *testing.T) {
4053
t.Run("NoError", func(t *testing.T) {
4154
jobs := []Job[int]{
@@ -58,9 +71,15 @@ func TestAll(t *testing.T) {
5871
t.Run("JobError", func(t *testing.T) {
5972
expectedErr := errors.New("test error")
6073
jobs := []Job[int]{
61-
func(ctx context.Context) (int, error) { return 1, nil },
62-
func(ctx context.Context) (int, error) { return 0, expectedErr },
63-
func(ctx context.Context) (int, error) { return 3, nil },
74+
func(ctx context.Context) (int, error) {
75+
return 1, nil
76+
},
77+
func(ctx context.Context) (int, error) {
78+
return 0, expectedErr
79+
},
80+
func(ctx context.Context) (int, error) {
81+
return 3, nil
82+
},
6483
}
6584

6685
_, err := All(jobs)
@@ -171,7 +190,7 @@ func TestRange(t *testing.T) {
171190
func TestRange2(t *testing.T) {
172191
t.Run("NoJobError", func(t *testing.T) {
173192
items := map[string]int{"a": 1, "b": 2, "c": 3}
174-
processed := make(map[string]bool)
193+
processed := sync.Map{}
175194

176195
seq := iter.Seq2[string, int](
177196
func(yield func(string, int) bool) {
@@ -186,7 +205,7 @@ func TestRange2(t *testing.T) {
186205
err := Range2(
187206
seq,
188207
func(ctx context.Context, k string, v int) error {
189-
processed[k] = true
208+
processed.Store(k, true)
190209
return nil
191210
},
192211
)
@@ -195,9 +214,10 @@ func TestRange2(t *testing.T) {
195214
t.Errorf("unexpected error: %v", err)
196215
}
197216

198-
expected := map[string]bool{"a": true, "b": true, "c": true}
199-
if !reflect.DeepEqual(processed, expected) {
200-
t.Errorf("got %v, want %v", processed, expected)
217+
for _, k := range []string{"a", "b", "c"} {
218+
if b, ok := processed.Load(k); !ok || b.(bool) == false {
219+
t.Errorf("key %v not processed", k)
220+
}
201221
}
202222
})
203223

@@ -351,50 +371,3 @@ func TestMap2(t *testing.T) {
351371
}
352372
})
353373
}
354-
355-
func TestMap2InPlace(t *testing.T) {
356-
t.Run("NoJobError", func(t *testing.T) {
357-
input := map[string]int{"a": 1, "b": 2, "c": 3}
358-
originalInput := maps.Clone(input)
359-
360-
results, err := Map2InPlace(
361-
input,
362-
func(ctx context.Context, k string, v int) (string, int, error) {
363-
return k, v * 2, nil
364-
},
365-
)
366-
367-
if err != nil {
368-
t.Errorf("unexpected error: %v", err)
369-
}
370-
371-
expected := map[string]int{"a": 2, "b": 4, "c": 6}
372-
if !reflect.DeepEqual(results, expected) {
373-
t.Errorf("got %v, want %v", results, expected)
374-
}
375-
376-
// Verify the input map was modified
377-
if reflect.DeepEqual(input, originalInput) {
378-
t.Error("Map2InPlace did not modify the input map")
379-
}
380-
if !reflect.DeepEqual(input, expected) {
381-
t.Error("Map2InPlace did not correctly modify the input map")
382-
}
383-
})
384-
385-
t.Run("JobError", func(t *testing.T) {
386-
expectedErr := errors.New("test error")
387-
input := map[string]int{"a": 1, "b": 2, "c": 3}
388-
389-
_, err := Map2InPlace(
390-
input,
391-
func(ctx context.Context, k string, v int) (string, int, error) {
392-
return k, v, expectedErr
393-
},
394-
)
395-
396-
if err == nil {
397-
t.Error("expected error, got nil")
398-
}
399-
})
400-
}

0 commit comments

Comments
 (0)