-
Notifications
You must be signed in to change notification settings - Fork 157
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(flow): add keyed flow implementation (#141)
- Loading branch information
Showing
3 changed files
with
305 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
package flow | ||
|
||
import ( | ||
"sync" | ||
|
||
"github.com/reugn/go-streams" | ||
) | ||
|
||
// Keyed represents a flow where stream elements are partitioned by key | ||
// using a provided key selector function. | ||
type Keyed[K comparable, V any] struct { | ||
keySelector func(V) K | ||
keyedFlows map[K]streams.Flow | ||
operators []func() streams.Flow | ||
in chan any | ||
out chan any | ||
} | ||
|
||
// Verify Keyed satisfies the Flow interface. | ||
var _ streams.Flow = (*Keyed[int, any])(nil) | ||
|
||
// NewKeyed returns a new Keyed operator, which takes a stream and splits it | ||
// into multiple streams based on the keys extracted from the elements using | ||
// the keySelector function. | ||
// | ||
// Each of these individual streams is then transformed by the provided chain | ||
// of operators, and the results are sent to the output channel. | ||
// | ||
// If no operators are provided, NewKeyed will panic. | ||
func NewKeyed[K comparable, V any]( | ||
keySelector func(V) K, operators ...func() streams.Flow, | ||
) *Keyed[K, V] { | ||
if len(operators) == 0 { | ||
panic("at least one operator supplier is required") | ||
} | ||
keyedFlow := &Keyed[K, V]{ | ||
keySelector: keySelector, | ||
keyedFlows: make(map[K]streams.Flow), | ||
operators: operators, | ||
in: make(chan any), | ||
out: make(chan any), | ||
} | ||
|
||
// start stream processing | ||
go keyedFlow.stream() | ||
|
||
return keyedFlow | ||
} | ||
|
||
// stream routes incoming elements to keyed workflows and consolidates | ||
// the results into the output channel of the Keyed flow. | ||
func (k *Keyed[K, V]) stream() { | ||
var wg sync.WaitGroup | ||
for element := range k.in { | ||
// extract element's key using the selector | ||
key := k.keySelector(element.(V)) | ||
// retrieve the keyed flow for the key | ||
keyedFlow := k.getKeyedFlow(key, &wg) | ||
// send the element downstream | ||
keyedFlow.In() <- element | ||
} | ||
|
||
// close all keyed streams | ||
for _, keyedFlow := range k.keyedFlows { | ||
close(keyedFlow.In()) | ||
} | ||
|
||
// wait for all keyed streams to complete | ||
wg.Wait() | ||
close(k.out) | ||
} | ||
|
||
// Via streams data to a specified Flow and returns it. | ||
func (k *Keyed[K, V]) Via(flow streams.Flow) streams.Flow { | ||
go k.transmit(flow) | ||
return flow | ||
} | ||
|
||
// To streams data to a specified Sink. | ||
func (k *Keyed[K, V]) To(sink streams.Sink) { | ||
k.transmit(sink) | ||
} | ||
|
||
// Out returns the output channel of the Keyed operator. | ||
func (k *Keyed[K, V]) Out() <-chan any { | ||
return k.out | ||
} | ||
|
||
// In returns the input channel of the Keyed operator. | ||
func (k *Keyed[K, V]) In() chan<- any { | ||
return k.in | ||
} | ||
|
||
// transmit submits keyed elements to the next Inlet. | ||
func (k *Keyed[K, V]) transmit(inlet streams.Inlet) { | ||
for keyed := range k.out { | ||
inlet.In() <- keyed | ||
} | ||
close(inlet.In()) | ||
} | ||
|
||
// getKeyedFlow retrieves a keyed workflow associated with the provided key. | ||
// If the workflow has not yet been initiated, it will be created and a | ||
// goroutine will be launched to handle the stream. | ||
func (k *Keyed[K, V]) getKeyedFlow(key K, wg *sync.WaitGroup) streams.Flow { | ||
// try to retrieve the keyed flow from the map | ||
keyedWorkflow, ok := k.keyedFlows[key] | ||
if !ok { // this is the first element for the key | ||
wg.Add(1) | ||
// build the workflow | ||
keyedWorkflow = k.operators[0]() | ||
workflowTail := keyedWorkflow | ||
for _, operatorFactory := range k.operators[1:] { | ||
workflowTail = workflowTail.Via(operatorFactory()) | ||
} | ||
// start processing incoming stream elements | ||
go func() { | ||
defer wg.Done() | ||
for e := range workflowTail.Out() { | ||
k.out <- e | ||
} | ||
}() | ||
// associate the key with the workflow | ||
k.keyedFlows[key] = keyedWorkflow | ||
} | ||
return keyedWorkflow | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
package flow_test | ||
|
||
import ( | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
"github.com/reugn/go-streams" | ||
ext "github.com/reugn/go-streams/extension" | ||
"github.com/reugn/go-streams/flow" | ||
"github.com/reugn/go-streams/internal/assert" | ||
) | ||
|
||
type keyedElement struct { | ||
key int | ||
data string | ||
} | ||
|
||
func (e *keyedElement) String() string { | ||
return e.data | ||
} | ||
|
||
func newKeyedElement(key int) *keyedElement { | ||
return &keyedElement{ | ||
key: key, | ||
data: fmt.Sprint(key), | ||
} | ||
} | ||
|
||
func TestKeyed(t *testing.T) { | ||
in := make(chan any, 30) | ||
out := make(chan any, 20) | ||
|
||
source := ext.NewChanSource(in) | ||
sink := ext.NewChanSink(out) | ||
keyed := flow.NewKeyed(func(e keyedElement) int { return e.key }, | ||
func() streams.Flow { | ||
return flow.NewBatch[keyedElement](4, 10*time.Millisecond) | ||
}) | ||
|
||
inputValues := values(makeElements(30)) | ||
ingestSlice(inputValues, in) | ||
close(in) | ||
|
||
go func() { | ||
source. | ||
Via(keyed). | ||
To(sink) | ||
}() | ||
|
||
outputValues := readSlice[[]keyedElement](sink.Out) | ||
|
||
assert.Equal(t, 20, len(outputValues)) | ||
|
||
var sum int | ||
for _, batch := range outputValues { | ||
for _, v := range batch { | ||
sum += v.key | ||
} | ||
} | ||
assert.Equal(t, 292, sum) | ||
} | ||
|
||
func TestKeyed_Ptr(t *testing.T) { | ||
in := make(chan any, 30) | ||
out := make(chan any, 20) | ||
|
||
source := ext.NewChanSource(in) | ||
sink := ext.NewChanSink(out) | ||
keyed := flow.NewKeyed(func(e *keyedElement) int { return e.key }, | ||
func() streams.Flow { | ||
return flow.NewBatch[*keyedElement](4, 10*time.Millisecond) | ||
}) | ||
assert.NotEqual(t, keyed.Out(), nil) | ||
|
||
inputValues := makeElements(30) | ||
ingestSlice(inputValues, in) | ||
close(in) | ||
|
||
go func() { | ||
source. | ||
Via(keyed). | ||
Via(flow.NewPassThrough()). // Via coverage | ||
To(sink) | ||
}() | ||
|
||
outputValues := readSlice[[]*keyedElement](sink.Out) | ||
fmt.Println(outputValues) | ||
|
||
assert.Equal(t, 20, len(outputValues)) | ||
|
||
var sum int | ||
for _, batch := range outputValues { | ||
for _, v := range batch { | ||
sum += v.key | ||
} | ||
} | ||
assert.Equal(t, 292, sum) | ||
} | ||
|
||
func TestKeyed_MultipleOperators(t *testing.T) { | ||
in := make(chan any, 30) | ||
out := make(chan any, 20) | ||
|
||
source := ext.NewChanSource(in) | ||
sink := ext.NewChanSink(out) | ||
|
||
keyedFlow := flow.NewKeyed(func(e *keyedElement) int { return e.key }, | ||
func() streams.Flow { | ||
return flow.NewBatch[*keyedElement](4, 10*time.Millisecond) | ||
}, | ||
func() streams.Flow { | ||
return flow.NewMap(func(b []*keyedElement) int { | ||
var sum int | ||
for _, v := range b { | ||
sum += v.key | ||
} | ||
return sum | ||
}, 1) | ||
}) | ||
collectFlow := flow.NewTumblingWindow[int](25 * time.Millisecond) | ||
sumFlow := flow.NewMap(sum, 1) | ||
|
||
inputValues := makeElements(30) | ||
go ingestSlice(inputValues, in) | ||
go closeDeferred(in, 10*time.Millisecond) | ||
|
||
go func() { | ||
source. | ||
Via(keyedFlow).Via(collectFlow).Via(sumFlow). | ||
To(sink) | ||
}() | ||
|
||
outputValues := readSlice[int](sink.Out) | ||
|
||
assert.Equal(t, 1, len(outputValues)) | ||
assert.Equal(t, 292, outputValues[0]) | ||
} | ||
|
||
func TestKeyed_InvalidArguments(t *testing.T) { | ||
assert.Panics(t, func() { | ||
flow.NewKeyed(func(e *keyedElement) int { return e.key }) | ||
}) | ||
} | ||
|
||
func makeElements(n int) []*keyedElement { | ||
elements := make([]*keyedElement, n) | ||
denominators := []int{3, 7, 10} | ||
outer: | ||
for i := range elements { | ||
for _, d := range denominators { | ||
if i%d == 0 { | ||
elements[i] = newKeyedElement(d) | ||
continue outer | ||
} | ||
} | ||
elements[i] = newKeyedElement(i) | ||
} | ||
fmt.Println(elements) | ||
return elements | ||
} | ||
|
||
func values[T any](pointers []*T) []T { | ||
values := make([]T, len(pointers)) | ||
for i, ptr := range pointers { | ||
values[i] = *ptr | ||
} | ||
return values | ||
} | ||
|
||
func sum(values []int) int { | ||
var result int | ||
for _, v := range values { | ||
result += v | ||
} | ||
return result | ||
} |