Skip to content

Commit

Permalink
test(flow): improve coverage and minor refactoring (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
reugn authored Aug 18, 2024
1 parent af55736 commit cdcb6c0
Show file tree
Hide file tree
Showing 17 changed files with 673 additions and 146 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ jobs:

- name: Upload coverage to Codecov
if: ${{ matrix.go-version == '1.18.x' }}
run: bash <(curl -s https://codecov.io/bash)
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ issues:
- errcheck
- unparam
- prealloc
- funlen
36 changes: 22 additions & 14 deletions flow/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,25 @@ import (
)

// Batch processor breaks a stream of elements into batches based on size or timing.
// When the maximum batch size is reached or the batch time is elapsed, and the current buffer
// is not empty, a new batch will be emitted.
// When the maximum batch size is reached or the batch time is elapsed, and the
// current buffer is not empty, a new batch will be emitted.
// Note: once a batch is sent downstream, the timer will be reset.
// T indicates the incoming element type, and the outgoing element type is []T.
type Batch[T any] struct {
maxBatchSize int
timeInterval time.Duration
in chan any
out chan any
buffer []T
}

// Verify Batch satisfies the Flow interface.
var _ streams.Flow = (*Batch[any])(nil)

// NewBatch returns a new Batch operator using the specified maximum batch size and the
// time interval.
// NewBatch returns a new Batch operator using the specified maximum batch size and
// the time interval.
// T specifies the incoming element type, and the outgoing element type is []T.
//
// NewBatch will panic if the maxBatchSize argument is not positive.
func NewBatch[T any](maxBatchSize int, timeInterval time.Duration) *Batch[T] {
if maxBatchSize < 1 {
Expand All @@ -35,7 +37,10 @@ func NewBatch[T any](maxBatchSize int, timeInterval time.Duration) *Batch[T] {
timeInterval: timeInterval,
in: make(chan any),
out: make(chan any),
buffer: make([]T, 0, maxBatchSize),
}

// start stream processing
go batchFlow.batchStream()

return batchFlow
Expand Down Expand Up @@ -76,34 +81,37 @@ func (b *Batch[T]) batchStream() {
ticker := time.NewTicker(b.timeInterval)
defer ticker.Stop()

batch := make([]T, 0, b.maxBatchSize)
for {
select {
case element, ok := <-b.in:
if ok {
batch = append(batch, element.(T))
b.buffer = append(b.buffer, element.(T))
// dispatch the batch if the maximum batch size has been reached
if len(batch) >= b.maxBatchSize {
b.out <- batch
batch = make([]T, 0, b.maxBatchSize)
if len(b.buffer) >= b.maxBatchSize {
b.flush()
}
// reset the ticker
ticker.Reset(b.timeInterval)
} else {
// send the available buffer elements as a new batch, close the
// output channel and return
if len(batch) > 0 {
b.out <- batch
if len(b.buffer) > 0 {
b.flush()
}
close(b.out)
return
}
case <-ticker.C:
// timeout; dispatch and reset the buffer
if len(batch) > 0 {
b.out <- batch
batch = make([]T, 0, b.maxBatchSize)
if len(b.buffer) > 0 {
b.flush()
}
}
}
}

// flush sends the elements in the buffer downstream and resets the buffer.
func (b *Batch[T]) flush() {
b.out <- b.buffer
b.buffer = make([]T, 0, b.maxBatchSize)
}
43 changes: 37 additions & 6 deletions flow/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,10 @@ func TestBatch(t *testing.T) {
go func() {
source.
Via(batch).
Via(flow.NewMap(retransmitStringSlice, 1)). // test generic return type
To(sink)
}()

var outputValues [][]string
for e := range sink.Out {
outputValues = append(outputValues, e.([]string))
}
outputValues := readSlice[[]string](sink.Out)
fmt.Println(outputValues)

assert.Equal(t, 3, len(outputValues)) // [[a b c d] [e f g] [h]]
Expand All @@ -48,7 +44,42 @@ func TestBatch(t *testing.T) {
assert.Equal(t, []string{"h"}, outputValues[2])
}

func TestBatchInvalidArguments(t *testing.T) {
func TestBatch_Ptr(t *testing.T) {
in := make(chan any)
out := make(chan any)

source := ext.NewChanSource(in)
batch := flow.NewBatch[*string](4, 40*time.Millisecond)
sink := ext.NewChanSink(out)
assert.NotEqual(t, batch.Out(), nil)

inputValues := ptrSlice([]string{"a", "b", "c", "d", "e", "f", "g"})
go func() {
for _, e := range inputValues {
ingestDeferred(e, in, 5*time.Millisecond)
}
}()
go ingestDeferred(ptr("h"), in, 90*time.Millisecond)
go closeDeferred(in, 100*time.Millisecond)

go func() {
source.
Via(batch).
Via(flow.NewPassThrough()). // Via coverage
To(sink)
}()

outputValues := readSlice[[]*string](sink.Out)
fmt.Println(outputValues)

assert.Equal(t, 3, len(outputValues)) // [[a b c d] [e f g] [h]]

assert.Equal(t, ptrSlice([]string{"a", "b", "c", "d"}), outputValues[0])
assert.Equal(t, ptrSlice([]string{"e", "f", "g"}), outputValues[1])
assert.Equal(t, ptrSlice([]string{"h"}), outputValues[2])
}

func TestBatch_InvalidArguments(t *testing.T) {
assert.Panics(t, func() {
flow.NewBatch[string](0, time.Second)
})
Expand Down
72 changes: 72 additions & 0 deletions flow/filter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package flow_test

import (
"testing"

"github.com/reugn/go-streams"
ext "github.com/reugn/go-streams/extension"
"github.com/reugn/go-streams/flow"
"github.com/reugn/go-streams/internal/assert"
)

func TestFilter(t *testing.T) {
tests := []struct {
name string
filterFlow streams.Flow
ptr bool
}{
{
name: "values",
filterFlow: flow.NewFilter(func(e int) bool {
return e%2 != 0
}, 1),
ptr: false,
},
{
name: "pointers",
filterFlow: flow.NewFilter(func(e *int) bool {
return *e%2 != 0
}, 1),
ptr: true,
},
}
input := []int{1, 2, 3, 4, 5}
expected := []int{1, 3, 5}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
in := make(chan any, 5)
out := make(chan any, 5)

source := ext.NewChanSource(in)
sink := ext.NewChanSink(out)

if tt.ptr {
ingestSlice(ptrSlice(input), in)
} else {
ingestSlice(input, in)
}
close(in)

source.
Via(tt.filterFlow).
To(sink)

if tt.ptr {
output := readSlicePtr[int](out)
assert.Equal(t, ptrSlice(expected), output)
} else {
output := readSlice[int](out)
assert.Equal(t, expected, output)
}
})
}
}

func TestFilter_NonPositiveParallelism(t *testing.T) {
assert.Panics(t, func() {
flow.NewFilter(filterNotContainsA, 0)
})
assert.Panics(t, func() {
flow.NewFilter(filterNotContainsA, -1)
})
}
94 changes: 94 additions & 0 deletions flow/flat_map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package flow_test

import (
"strings"
"testing"

"github.com/reugn/go-streams"
ext "github.com/reugn/go-streams/extension"
"github.com/reugn/go-streams/flow"
"github.com/reugn/go-streams/internal/assert"
)

func TestFlatMap(t *testing.T) {
tests := []struct {
name string
flatMapFlow streams.Flow
inPtr bool
outPtr bool
}{
{
name: "val-val",
inPtr: false,
flatMapFlow: flow.NewFlatMap(func(in string) []string {
return []string{in, strings.ToUpper(in)}
}, 1),
outPtr: false,
},
{
name: "ptr-val",
inPtr: true,
flatMapFlow: flow.NewFlatMap(func(in *string) []string {
return []string{*in, strings.ToUpper(*in)}
}, 1),
outPtr: false,
},
{
name: "ptr-ptr",
inPtr: true,
flatMapFlow: flow.NewFlatMap(func(in *string) []*string {
upper := strings.ToUpper(*in)
return []*string{in, &upper}
}, 1),
outPtr: true,
},
{
name: "val-ptr",
inPtr: false,
flatMapFlow: flow.NewFlatMap(func(in string) []*string {
upper := strings.ToUpper(in)
return []*string{&in, &upper}
}, 1),
outPtr: true,
},
}
input := []string{"a", "b", "c"}
expected := []string{"a", "A", "b", "B", "c", "C"}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
in := make(chan any, 3)
out := make(chan any, 6)

source := ext.NewChanSource(in)
sink := ext.NewChanSink(out)

if tt.inPtr {
ingestSlice(ptrSlice(input), in)
} else {
ingestSlice(input, in)
}
close(in)

source.
Via(tt.flatMapFlow).
To(sink)

if tt.outPtr {
output := readSlicePtr[string](out)
assert.Equal(t, ptrSlice(expected), output)
} else {
output := readSlice[string](out)
assert.Equal(t, expected, output)
}
})
}
}

func TestFlatMap_NonPositiveParallelism(t *testing.T) {
assert.Panics(t, func() {
flow.NewFlatMap(addAsterisk, 0)
})
assert.Panics(t, func() {
flow.NewFlatMap(addAsterisk, -1)
})
}
Loading

0 comments on commit cdcb6c0

Please sign in to comment.