diff --git a/README.md b/README.md index 75ad314..e39d91d 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ Implemented Flows ([flow](https://github.com/reugn/go-streams/tree/master/flow) * SlidingWindow * TumblingWindow * SessionWindow +* Keyed Supported Connectors: * Go channels diff --git a/flow/keyed.go b/flow/keyed.go new file mode 100644 index 0000000..6d74854 --- /dev/null +++ b/flow/keyed.go @@ -0,0 +1,127 @@ +package flow + +import ( + "sync" + + "github.com/reugn/go-streams" +) + +// Keyed represents a flow where stream elements are partitioned by key +// using a provided key selector function. +type Keyed[K comparable, V any] struct { + keySelector func(V) K + keyedFlows map[K]streams.Flow + operators []func() streams.Flow + in chan any + out chan any +} + +// Verify Keyed satisfies the Flow interface. +var _ streams.Flow = (*Keyed[int, any])(nil) + +// NewKeyed returns a new Keyed operator, which takes a stream and splits it +// into multiple streams based on the keys extracted from the elements using +// the keySelector function. +// +// Each of these individual streams is then transformed by the provided chain +// of operators, and the results are sent to the output channel. +// +// If no operators are provided, NewKeyed will panic. +func NewKeyed[K comparable, V any]( + keySelector func(V) K, operators ...func() streams.Flow, +) *Keyed[K, V] { + if len(operators) == 0 { + panic("at least one operator supplier is required") + } + keyedFlow := &Keyed[K, V]{ + keySelector: keySelector, + keyedFlows: make(map[K]streams.Flow), + operators: operators, + in: make(chan any), + out: make(chan any), + } + + // start stream processing + go keyedFlow.stream() + + return keyedFlow +} + +// stream routes incoming elements to keyed workflows and consolidates +// the results into the output channel of the Keyed flow. +func (k *Keyed[K, V]) stream() { + var wg sync.WaitGroup + for element := range k.in { + // extract element's key using the selector + key := k.keySelector(element.(V)) + // retrieve the keyed flow for the key + keyedFlow := k.getKeyedFlow(key, &wg) + // send the element downstream + keyedFlow.In() <- element + } + + // close all keyed streams + for _, keyedFlow := range k.keyedFlows { + close(keyedFlow.In()) + } + + // wait for all keyed streams to complete + wg.Wait() + close(k.out) +} + +// Via streams data to a specified Flow and returns it. +func (k *Keyed[K, V]) Via(flow streams.Flow) streams.Flow { + go k.transmit(flow) + return flow +} + +// To streams data to a specified Sink. +func (k *Keyed[K, V]) To(sink streams.Sink) { + k.transmit(sink) +} + +// Out returns the output channel of the Keyed operator. +func (k *Keyed[K, V]) Out() <-chan any { + return k.out +} + +// In returns the input channel of the Keyed operator. +func (k *Keyed[K, V]) In() chan<- any { + return k.in +} + +// transmit submits keyed elements to the next Inlet. +func (k *Keyed[K, V]) transmit(inlet streams.Inlet) { + for keyed := range k.out { + inlet.In() <- keyed + } + close(inlet.In()) +} + +// getKeyedFlow retrieves a keyed workflow associated with the provided key. +// If the workflow has not yet been initiated, it will be created and a +// goroutine will be launched to handle the stream. +func (k *Keyed[K, V]) getKeyedFlow(key K, wg *sync.WaitGroup) streams.Flow { + // try to retrieve the keyed flow from the map + keyedWorkflow, ok := k.keyedFlows[key] + if !ok { // this is the first element for the key + wg.Add(1) + // build the workflow + keyedWorkflow = k.operators[0]() + workflowTail := keyedWorkflow + for _, operatorFactory := range k.operators[1:] { + workflowTail = workflowTail.Via(operatorFactory()) + } + // start processing incoming stream elements + go func() { + defer wg.Done() + for e := range workflowTail.Out() { + k.out <- e + } + }() + // associate the key with the workflow + k.keyedFlows[key] = keyedWorkflow + } + return keyedWorkflow +} diff --git a/flow/keyed_test.go b/flow/keyed_test.go new file mode 100644 index 0000000..66c6bc6 --- /dev/null +++ b/flow/keyed_test.go @@ -0,0 +1,177 @@ +package flow_test + +import ( + "fmt" + "testing" + "time" + + "github.com/reugn/go-streams" + ext "github.com/reugn/go-streams/extension" + "github.com/reugn/go-streams/flow" + "github.com/reugn/go-streams/internal/assert" +) + +type keyedElement struct { + key int + data string +} + +func (e *keyedElement) String() string { + return e.data +} + +func newKeyedElement(key int) *keyedElement { + return &keyedElement{ + key: key, + data: fmt.Sprint(key), + } +} + +func TestKeyed(t *testing.T) { + in := make(chan any, 30) + out := make(chan any, 20) + + source := ext.NewChanSource(in) + sink := ext.NewChanSink(out) + keyed := flow.NewKeyed(func(e keyedElement) int { return e.key }, + func() streams.Flow { + return flow.NewBatch[keyedElement](4, 10*time.Millisecond) + }) + + inputValues := values(makeElements(30)) + ingestSlice(inputValues, in) + close(in) + + go func() { + source. + Via(keyed). + To(sink) + }() + + outputValues := readSlice[[]keyedElement](sink.Out) + + assert.Equal(t, 20, len(outputValues)) + + var sum int + for _, batch := range outputValues { + for _, v := range batch { + sum += v.key + } + } + assert.Equal(t, 292, sum) +} + +func TestKeyed_Ptr(t *testing.T) { + in := make(chan any, 30) + out := make(chan any, 20) + + source := ext.NewChanSource(in) + sink := ext.NewChanSink(out) + keyed := flow.NewKeyed(func(e *keyedElement) int { return e.key }, + func() streams.Flow { + return flow.NewBatch[*keyedElement](4, 10*time.Millisecond) + }) + assert.NotEqual(t, keyed.Out(), nil) + + inputValues := makeElements(30) + ingestSlice(inputValues, in) + close(in) + + go func() { + source. + Via(keyed). + Via(flow.NewPassThrough()). // Via coverage + To(sink) + }() + + outputValues := readSlice[[]*keyedElement](sink.Out) + fmt.Println(outputValues) + + assert.Equal(t, 20, len(outputValues)) + + var sum int + for _, batch := range outputValues { + for _, v := range batch { + sum += v.key + } + } + assert.Equal(t, 292, sum) +} + +func TestKeyed_MultipleOperators(t *testing.T) { + in := make(chan any, 30) + out := make(chan any, 20) + + source := ext.NewChanSource(in) + sink := ext.NewChanSink(out) + + keyedFlow := flow.NewKeyed(func(e *keyedElement) int { return e.key }, + func() streams.Flow { + return flow.NewBatch[*keyedElement](4, 10*time.Millisecond) + }, + func() streams.Flow { + return flow.NewMap(func(b []*keyedElement) int { + var sum int + for _, v := range b { + sum += v.key + } + return sum + }, 1) + }) + collectFlow := flow.NewTumblingWindow[int](25 * time.Millisecond) + sumFlow := flow.NewMap(sum, 1) + + inputValues := makeElements(30) + go ingestSlice(inputValues, in) + go closeDeferred(in, 10*time.Millisecond) + + go func() { + source. + Via(keyedFlow).Via(collectFlow).Via(sumFlow). + To(sink) + }() + + outputValues := readSlice[int](sink.Out) + + assert.Equal(t, 1, len(outputValues)) + assert.Equal(t, 292, outputValues[0]) +} + +func TestKeyed_InvalidArguments(t *testing.T) { + assert.Panics(t, func() { + flow.NewKeyed(func(e *keyedElement) int { return e.key }) + }) +} + +func makeElements(n int) []*keyedElement { + elements := make([]*keyedElement, n) + denominators := []int{3, 7, 10} +outer: + for i := range elements { + for _, d := range denominators { + if i%d == 0 { + elements[i] = newKeyedElement(d) + continue outer + } + } + elements[i] = newKeyedElement(i) + } + fmt.Println(elements) + return elements +} + +func values[T any](pointers []*T) []T { + values := make([]T, len(pointers)) + for i, ptr := range pointers { + values[i] = *ptr + } + return values +} + +func sum(values []int) int { + var result int + for _, v := range values { + result += v + } + return result +}