diff --git a/docs/layouts/shortcodes/generated/execution_config_configuration.html b/docs/layouts/shortcodes/generated/execution_config_configuration.html index 770f2f5b52499..710ba45503ca1 100644 --- a/docs/layouts/shortcodes/generated/execution_config_configuration.html +++ b/docs/layouts/shortcodes/generated/execution_config_configuration.html @@ -8,6 +8,18 @@ + +
table.exec.async-agg.buffer-capacity

Streaming + 10 + Integer + The max number of async i/o operations that the async table function can trigger. + + +
table.exec.async-agg.timeout

Streaming + 3 min + Duration + The async timeout for the asynchronous operation to complete. +
table.exec.async-lookup.buffer-capacity

Batch Streaming 100 diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java index 907508c166234..af5af2f90e5ca 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/TypeExtractionUtils.java @@ -34,6 +34,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; import static org.apache.flink.shaded.asm9.org.objectweb.asm.Type.getConstructorDescriptor; import static org.apache.flink.shaded.asm9.org.objectweb.asm.Type.getMethodDescriptor; @@ -373,4 +374,28 @@ public static void validateLambdaType(Class baseClass, Type t) { + "Otherwise the type has to be specified explicitly using type information."); } } + + /** + * Will return true if the type of the given generic class type. + * + * @param clazz The generic class to check against + * @param type The type to be checked + */ + public static boolean isGenericOfClass(Class clazz, Type type) { + Optional parameterized = getParameterizedType(type); + return clazz.equals(type) + || parameterized.isPresent() && clazz.equals(parameterized.get().getRawType()); + } + + /** + * Returns an optional of a ParameterizedType, if that's what the type is. + * + * @param type The type to check + * @return optional which is present if the type is a ParameterizedType + */ + public static Optional getParameterizedType(Type type) { + return Optional.of(type) + .filter(p -> p instanceof ParameterizedType) + .map(ParameterizedType.class::cast); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java index 69d29bb818c5c..9cd221ecd3650 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java @@ -121,7 +121,7 @@ public abstract class AbstractStreamOperator private transient @Nullable MailboxExecutor mailboxExecutor; - private transient @Nullable MailboxWatermarkProcessor watermarkProcessor; + protected transient @Nullable MailboxWatermarkProcessor watermarkProcessor; // ---------------- key/value state ------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/MailboxWatermarkProcessor.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/MailboxWatermarkProcessor.java index fb498f65f07ee..4b9a8a4a4dbfa 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/MailboxWatermarkProcessor.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/MailboxWatermarkProcessor.java @@ -63,17 +63,22 @@ public MailboxWatermarkProcessor( } public void emitWatermarkInsideMailbox(Watermark mark) throws Exception { + emitWatermarkInsideMailbox(mark, output::emitWatermark); + } + + public void emitWatermarkInsideMailbox(Watermark mark, WatermarkEmitter watermarkEmitter) + throws Exception { maxInputWatermark = new Watermark(Math.max(maxInputWatermark.getTimestamp(), mark.getTimestamp())); - emitWatermarkInsideMailbox(); + emitWatermarkInsideMailbox(watermarkEmitter); } - private void emitWatermarkInsideMailbox() throws Exception { + private void emitWatermarkInsideMailbox(WatermarkEmitter watermarkEmitter) throws Exception { // Try to progress min watermark as far as we can. if (internalTimeServiceManager.tryAdvanceWatermark( maxInputWatermark, mailboxExecutor::shouldInterrupt)) { // In case output watermark has fully progressed emit it downstream. - output.emitWatermark(maxInputWatermark); + watermarkEmitter.emitWatermark(maxInputWatermark); } else if (!progressWatermarkScheduled) { progressWatermarkScheduled = true; // We still have work to do, but we need to let other mails to be processed first. @@ -81,7 +86,7 @@ private void emitWatermarkInsideMailbox() throws Exception { MailboxExecutor.MailOptions.deferrable(), () -> { progressWatermarkScheduled = false; - emitWatermarkInsideMailbox(); + emitWatermarkInsideMailbox(watermarkEmitter); }, "emitWatermarkInsideMailbox"); } else { @@ -91,4 +96,12 @@ private void emitWatermarkInsideMailbox() throws Exception { LOG.debug("emitWatermarkInsideMailbox is already scheduled, skipping."); } } + + /** Interface to emit a watermark after all the timers have been fired. */ + @Internal + public interface WatermarkEmitter { + + /** Emit a watermark. */ + void emitWatermark(Watermark watermark) throws Exception; + } } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java index a1847e93b9e77..abce2a9eaf47f 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/ExecutionConfigOptions.java @@ -443,6 +443,25 @@ public class ExecutionConfigOptions { "The max number of async retry attempts to make before task " + "execution is failed."); + // ------------------------------------------------------------------------ + // Bundled Aggregate Options + // ------------------------------------------------------------------------ + @Documentation.TableOption(execMode = Documentation.ExecMode.STREAMING) + public static final ConfigOption TABLE_EXEC_ASYNC_AGG_BUFFER_CAPACITY = + key("table.exec.async-agg.buffer-capacity") + .intType() + .defaultValue(10) + .withDescription( + "The max number of async i/o operations that the async table function can trigger."); + + @Documentation.TableOption(execMode = Documentation.ExecMode.STREAMING) + public static final ConfigOption TABLE_EXEC_ASYNC_AGG_TIMEOUT = + key("table.exec.async-agg.timeout") + .durationType() + .defaultValue(Duration.ofMinutes(3)) + .withDescription( + "The async timeout for the asynchronous operation to complete."); + // ------------------------------------------------------------------------ // MiniBatch Options // ------------------------------------------------------------------------ diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BundledAggregateFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BundledAggregateFunction.java new file mode 100644 index 0000000000000..bf8de3103fd98 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BundledAggregateFunction.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.functions.agg.BundledKeySegment; +import org.apache.flink.table.functions.agg.BundledKeySegmentApplied; + +import java.util.concurrent.CompletableFuture; + +/** The bundled interface to be implemented by {@AggregateFunction}s that may support bundling. */ +@PublicEvolving +public interface BundledAggregateFunction extends FunctionDefinition { + + /** + * Whether the implementor supports bundling. This allows them to programatically decide whether + * to use the bundling or non-bundling interface. + */ + boolean canBundle(); + + /** Whether the implementor supports retraction. */ + default boolean canRetract() { + return false; + } + + default void bundledAccumulateRetract( + CompletableFuture future, BundledKeySegment segment) + throws Exception { + throw new UnsupportedOperationException( + "This aggregate function does not support bundled calls."); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java index 25d7e4ea7cfb7..fd0169169e532 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java @@ -33,6 +33,8 @@ import org.apache.flink.table.functions.SpecializedFunction.ExpressionEvaluator; import org.apache.flink.table.functions.SpecializedFunction.ExpressionEvaluatorFactory; import org.apache.flink.table.functions.SpecializedFunction.SpecializedContext; +import org.apache.flink.table.functions.agg.BundledKeySegment; +import org.apache.flink.table.functions.agg.BundledKeySegmentApplied; import org.apache.flink.table.functions.python.utils.PythonFunctionUtils; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.extraction.ExtractionUtils; @@ -50,11 +52,14 @@ import java.util.Arrays; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.getAllDeclaredMethods; +import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.getParameterizedType; +import static org.apache.flink.api.java.typeutils.TypeExtractionUtils.isGenericOfClass; import static org.apache.flink.util.Preconditions.checkState; /** @@ -81,6 +86,8 @@ public final class UserDefinedFunctionHelper { public static final String AGGREGATE_MERGE = "merge"; + public static final String AGGREGATE_BUNDLED = "bundledAccumulateRetract"; + public static final String TABLE_AGGREGATE_ACCUMULATE = "accumulate"; public static final String TABLE_AGGREGATE_RETRACT = "retract"; @@ -483,9 +490,14 @@ private static void validateImplementationMethods( } else if (AsyncTableFunction.class.isAssignableFrom(functionClass)) { validateImplementationMethod(functionClass, true, false, ASYNC_TABLE_EVAL); } else if (AggregateFunction.class.isAssignableFrom(functionClass)) { - validateImplementationMethod(functionClass, true, false, AGGREGATE_ACCUMULATE); - validateImplementationMethod(functionClass, true, true, AGGREGATE_RETRACT); - validateImplementationMethod(functionClass, true, true, AGGREGATE_MERGE); + if (BundledAggregateFunction.class.isAssignableFrom(functionClass)) { + validateImplementationMethod(functionClass, true, false, AGGREGATE_BUNDLED); + validateBundledImplementationMethod(functionClass, AGGREGATE_BUNDLED); + } else { + validateImplementationMethod(functionClass, true, false, AGGREGATE_ACCUMULATE); + validateImplementationMethod(functionClass, true, true, AGGREGATE_RETRACT); + validateImplementationMethod(functionClass, true, true, AGGREGATE_MERGE); + } } else if (TableAggregateFunction.class.isAssignableFrom(functionClass)) { validateImplementationMethod(functionClass, true, false, TABLE_AGGREGATE_ACCUMULATE); validateImplementationMethod(functionClass, true, true, TABLE_AGGREGATE_RETRACT); @@ -540,6 +552,50 @@ private static void validateImplementationMethod( } } + private static void validateBundledImplementationMethod( + Class clazz, String... methodNameOptions) { + final Set nameSet = new HashSet<>(Arrays.asList(methodNameOptions)); + final List methods = getAllDeclaredMethods(clazz); + for (Method method : methods) { + if (!nameSet.contains(method.getName())) { + continue; + } + + if (!method.getReturnType().equals(Void.TYPE)) { + throw new ValidationException( + String.format( + "Method '%s' of function class '%s' must be void.", + method.getName(), clazz.getName())); + } + + boolean foundSignature = false; + if (method.getParameterCount() == 2) { + Type firstParam = method.getGenericParameterTypes()[0]; + Type secondType = method.getGenericParameterTypes()[1]; + if (isGenericOfClass(CompletableFuture.class, firstParam) + && isGenericOfClass(BundledKeySegment.class, secondType)) { + Optional parameterizedFirst = + getParameterizedType(firstParam); + if (parameterizedFirst.isPresent() + && parameterizedFirst.get().getActualTypeArguments().length > 0) { + firstParam = parameterizedFirst.get().getActualTypeArguments()[0]; + if (BundledKeySegmentApplied.class.equals(firstParam)) { + foundSignature = true; + } + } + } + } + + if (!foundSignature) { + throw new ValidationException( + String.format( + "Method '%s' of function class '%s' must have signature " + + "void bundledAccumulateRetract(CompletableFuture future, BundledKeySegment segment).", + method.getName(), clazz.getName())); + } + } + } + private static void validateAsyncImplementationMethod( Class clazz, String... methodNameOptions) { final Set nameSet = new HashSet<>(Arrays.asList(methodNameOptions)); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/agg/BundledKeySegment.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/agg/BundledKeySegment.java new file mode 100644 index 0000000000000..48a53b3ad64c1 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/agg/BundledKeySegment.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.agg; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.data.RowData; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.List; + +/** One segment of a bundled aggregate call, where all rows in the segment are for the same key. */ +@PublicEvolving +public class BundledKeySegment { + + /** The common key of the segment. */ + private final RowData key; + + /** The rows, where all rows are for the common key. */ + private final List rows; + + /** The accumulator value under the current key. Can be null. */ + private final List accumulators; + + /** + * If set, returns the updated value after each row is applied rather than only the final value. + */ + private final boolean updatedValuesAfterEachRow; + + public BundledKeySegment( + RowData key, + List rows, + @Nullable RowData accumulator, + boolean updatedValuesAfterEachRow) { + this.key = key; + this.rows = rows; + this.accumulators = + accumulator == null + ? Collections.emptyList() + : Collections.singletonList(accumulator); + this.updatedValuesAfterEachRow = updatedValuesAfterEachRow; + } + + public RowData getKey() { + return key; + } + + public List getRows() { + return rows; + } + + @Nullable + public RowData getAccumulator() { + Preconditions.checkState(accumulators.size() <= 1); + return accumulators.isEmpty() ? null : accumulators.get(0); + } + + public List getAccumulatorsToMerge() { + return accumulators; + } + + public boolean getUpdatedValuesAfterEachRow() { + return updatedValuesAfterEachRow; + } + + public static BundledKeySegment of( + RowData key, + List rows, + @Nullable RowData accumulator, + boolean updatedValuesAfterEachRow) { + return new BundledKeySegment(key, rows, accumulator, updatedValuesAfterEachRow); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/agg/BundledKeySegmentApplied.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/agg/BundledKeySegmentApplied.java new file mode 100644 index 0000000000000..e501638284af2 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/agg/BundledKeySegmentApplied.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.agg; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.data.RowData; + +import java.util.Collections; +import java.util.List; + +/** + * The result of applying rows to the provided accumulator(s) in {@link BundledKeySegment}. The + * result is an updated accumulator and values evaluated at the start, end, and potentially after + * each row. + */ +@PublicEvolving +public class BundledKeySegmentApplied { + + /** The accumulator value after all rows are applied. */ + private final RowData accumulator; + + /** The value before any rows are applied. */ + private final RowData startingValue; + + /** The value after all rows are applied. */ + private final RowData finalValue; + + /** The value after each row is applied. */ + private final List updatedValuesAfterEachRow; + + public BundledKeySegmentApplied( + RowData accumulator, + RowData startingValue, + RowData finalValue, + List updatedValuesAfterEachRow) { + this.accumulator = accumulator; + this.startingValue = startingValue; + this.finalValue = finalValue; + this.updatedValuesAfterEachRow = updatedValuesAfterEachRow; + } + + public static BundledKeySegmentApplied of( + RowData accumulator, + RowData startingValue, + RowData finalValue, + List updatedValuesAfterEachRow) { + return new BundledKeySegmentApplied( + accumulator, startingValue, finalValue, updatedValuesAfterEachRow); + } + + public static BundledKeySegmentApplied of(RowData accumulator) { + return new BundledKeySegmentApplied(accumulator, null, null, Collections.emptyList()); + } + + public RowData getAccumulator() { + return accumulator; + } + + public RowData getStartingValue() { + return startingValue; + } + + public RowData getFinalValue() { + return finalValue; + } + + public List getUpdatedValuesAfterEachRow() { + return updatedValuesAfterEachRow; + } + + @Override + public String toString() { + return "{" + accumulator + "," + startingValue + "," + finalValue; + } +} diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/functions/UserDefinedFunctionHelperTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/functions/UserDefinedFunctionHelperTest.java index 2aa2d3a938b25..6402200bc034d 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/functions/UserDefinedFunctionHelperTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/functions/UserDefinedFunctionHelperTest.java @@ -22,7 +22,10 @@ import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.catalog.CatalogFunction; import org.apache.flink.table.catalog.FunctionLanguage; +import org.apache.flink.table.functions.agg.BundledKeySegment; +import org.apache.flink.table.functions.agg.BundledKeySegmentApplied; import org.apache.flink.table.resource.ResourceUri; +import org.apache.flink.types.Row; import org.apache.flink.util.Collector; import org.junit.jupiter.api.Test; @@ -32,6 +35,7 @@ import javax.annotation.Nullable; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -135,22 +139,39 @@ private static List testSpecs() { + PrivateMethodScalarFunction.class.getName() + "' is not public."), TestSpec.forClass(ValidAsyncScalarFunction.class).expectSuccess(), + TestSpec.forClass(ValidAggBundledFunction.class).expectSuccess(), TestSpec.forInstance(new ValidAsyncScalarFunction()).expectSuccess(), + TestSpec.forInstance(new ValidAggBundledFunction()).expectSuccess(), TestSpec.forClass(PrivateAsyncScalarFunction.class) .expectErrorMessage( "Function class '" + PrivateAsyncScalarFunction.class.getName() + "' is not public."), + TestSpec.forClass(PrivateAggBundledFunction.class) + .expectErrorMessage( + "Function class '" + + PrivateAggBundledFunction.class.getName() + + "' is not public."), TestSpec.forClass(MissingImplementationAsyncScalarFunction.class) .expectErrorMessage( "Function class '" + MissingImplementationAsyncScalarFunction.class.getName() + "' does not implement a method named 'eval'."), + TestSpec.forClass(MissingImplementationAggBundledFunction.class) + .expectErrorMessage( + "Function class '" + + MissingImplementationAggBundledFunction.class.getName() + + "' does not implement a method named 'bundledAccumulateRetract'."), TestSpec.forClass(PrivateMethodAsyncScalarFunction.class) .expectErrorMessage( "Method 'eval' of function class '" + PrivateMethodAsyncScalarFunction.class.getName() + "' is not public."), + TestSpec.forClass(PrivateMethodAggBundledFunction.class) + .expectErrorMessage( + "Method 'bundledAccumulateRetract' of function class '" + + PrivateMethodAggBundledFunction.class.getName() + + "' is not public."), TestSpec.forClass(NonVoidAsyncScalarFunction.class) .expectErrorMessage( "Method 'eval' of function class '" @@ -161,6 +182,16 @@ private static List testSpecs() { "Method 'eval' of function class '" + NoFutureAsyncScalarFunction.class.getName() + "' must have a first argument of type java.util.concurrent.CompletableFuture."), + TestSpec.forClass(BadSignatureAggBundledFunction.class) + .expectErrorMessage( + "Method 'bundledAccumulateRetract' of function class '" + + BadSignatureAggBundledFunction.class.getName() + + "' must have signature void bundledAccumulateRetract(CompletableFuture future, BundledKeySegment segment)."), + TestSpec.forClass(BadSignatureAggBundledFunction2.class) + .expectErrorMessage( + "Method 'bundledAccumulateRetract' of function class '" + + BadSignatureAggBundledFunction2.class.getName() + + "' must have signature void bundledAccumulateRetract(CompletableFuture future, BundledKeySegment segment)."), TestSpec.forInstance(new ValidTableAggregateFunction()).expectSuccess(), TestSpec.forInstance(new MissingEmitTableAggregateFunction()) .expectErrorMessage( @@ -295,20 +326,64 @@ public static class ValidAsyncScalarFunction extends AsyncScalarFunction { public void eval(CompletableFuture future, int i) {} } + public abstract static class AggBundledFunctionBase extends AggregateFunction + implements BundledAggregateFunction { + + @Override + public Long getValue(Row accumulator) { + return null; + } + + @Override + public Row createAccumulator() { + return null; + } + + public boolean canBundle() { + return true; + } + } + + /** Valid aggregate bundled function. */ + public static class ValidAggBundledFunction extends AggBundledFunctionBase { + public void bundledAccumulateRetract( + CompletableFuture future, BundledKeySegment segment) + throws Exception {} + } + private static class PrivateAsyncScalarFunction extends AsyncScalarFunction { public void eval(CompletableFuture future, int i) {} } + private static class PrivateAggBundledFunction extends AggBundledFunctionBase { + public void bundledAccumulateRetract( + CompletableFuture future, BundledKeySegment segment) + throws Exception {} + } + /** No implementation method. */ public static class MissingImplementationAsyncScalarFunction extends AsyncScalarFunction { // nothing to do } + /** No implementation method. */ + public static class MissingImplementationAggBundledFunction extends AggBundledFunctionBase { + // nothing to do + } + /** Implementation method is private. */ public static class PrivateMethodAsyncScalarFunction extends AsyncScalarFunction { private void eval(CompletableFuture future, int i) {} } + /** Implementation method is private. */ + public static class PrivateMethodAggBundledFunction extends AggBundledFunctionBase { + private List bundledAccumulateRetract( + List batch) throws Exception { + return null; + } + } + /** Implementation method isn't void. */ public static class NonVoidAsyncScalarFunction extends AsyncScalarFunction { public String eval(CompletableFuture future, int i) { @@ -321,6 +396,20 @@ public static class NoFutureAsyncScalarFunction extends AsyncScalarFunction { public void eval(int i) {} } + /** Implementation method is private. */ + public static class BadSignatureAggBundledFunction extends AggBundledFunctionBase { + public void bundledAccumulateRetract( + Collection future, BundledKeySegment segment) + throws Exception {} + } + + /** Second argument is wrong type. */ + public static class BadSignatureAggBundledFunction2 extends AggBundledFunctionBase { + public void bundledAccumulateRetract( + CompletableFuture future, Integer segment) + throws Exception {} + } + /** Valid table aggregate function. */ public static class ValidTableAggregateFunction extends TableAggregateFunction { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/codegen/agg/BundledResultCombiner.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/codegen/agg/BundledResultCombiner.java new file mode 100644 index 0000000000000..fee50ea54c285 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/codegen/agg/BundledResultCombiner.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.codegen.agg; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.agg.BundledKeySegmentApplied; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SupplierWithException; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +/** + * For each {@link BundledKeySegmentApplied}, it combines with other corresponding bundled and + * non-bundled results, producing a single {@link BundledKeySegmentApplied} of a combined result. + * This allows an operator to utilize a uniform bundled API regardless of how many bundled or + * non-bundled aggregate functions are used. + * + *

For example, imagine the following query: SELECT B1_SUM(a), B2_COUNT(b), SUM(c), COUNT(d) FROM + * T; + * + *

Here, B1_SUM and B2_COUNT are bundled aggregate functions, and SUM and COUNT are non-bundled + * aggregates. The bundled calls will be separate while both the non-bundled are produced using the + * conventional {@link org.apache.flink.table.runtime.generated.AggsHandleFunction}. The result of + * all three are then combined to a single {@link BundledKeySegmentApplied}. + */ +public class BundledResultCombiner implements Serializable { + private static final long serialVersionUID = -3486860542451993040L; + + private final RowType accTypeInfo; + private final RowType valueType; + + /** + * Creates a new {@link Combiner} factory. + * + * @param accTypeInfo the accumulator type information + * @param valueType the value type information + */ + public BundledResultCombiner(RowType accTypeInfo, RowType valueType) { + this.accTypeInfo = accTypeInfo; + this.valueType = valueType; + } + + /** Creates a new {@link Combiner} instance. */ + public Combiner newCombiner() { + return new Combiner(accTypeInfo, valueType); + } + + /** + * Combines all kinds of aggregates into a single result conforming to the bundled interface, as + * if it were not bundled at all. + */ + public static class Combiner + implements Serializable, SupplierWithException { + + private static final long serialVersionUID = -3348403465172218589L; + + private final RowType accTypeInfo; + private final RowType valueType; + + private final List updates = new ArrayList<>(); + + private Combiner(RowType accTypeInfo, RowType valueType) { + this.accTypeInfo = accTypeInfo; + this.valueType = valueType; + } + + /** + * Adds a new update to the combiner. Note that one of the two update types should be + * present. + * + * @param index the index of the update + * @param bundledDataKeySegmentUpdate the bundled data key segment update + * @param nonBundledResult the non-bundled result + * @param shouldIncludeValue whether the value is included in the result (or internal) + * @param isBundled whether the result is bundled + * @param accIndexStart the start index of the accumulator + * @param accIndexEnd the end index of the accumulator + */ + public void add( + final int index, + final Optional> + bundledDataKeySegmentUpdate, + final Optional nonBundledResult, + final boolean shouldIncludeValue, + final boolean isBundled, + final int accIndexStart, + final int accIndexEnd) { + final CompletableFuture updateToCombine; + if (bundledDataKeySegmentUpdate.isPresent()) { + updateToCombine = bundledDataKeySegmentUpdate.get(); + } else { + Preconditions.checkArgument(nonBundledResult.isPresent()); + updateToCombine = + CompletableFuture.completedFuture( + nonBundledResult + .get() + .asBundledDataKeySegmentUpdate( + index, + shouldIncludeValue, + indexList(accIndexStart, accIndexEnd), + accTypeInfo, + valueType)); + } + updates.add(new UpdateMetadata(index, updateToCombine, shouldIncludeValue, isBundled)); + } + + /** + * Combines all the updates added to the combiner. + * + * @return the combined updates + */ + public BundledKeySegmentApplied combine() throws Exception { + BundledKeySegmentApplied result = + BundledResultCombiner.combineUpdates( + accTypeInfo, valueType, updates.toArray(new UpdateMetadata[0])); + updates.clear(); + return result; + } + + @Override + public BundledKeySegmentApplied get() throws Exception { + return combine(); + } + } + + private static BundledKeySegmentApplied combineUpdates( + RowType accTypeInfo, RowType valueType, UpdateMetadata... updates) throws Exception { + List ordered = new ArrayList<>(updates.length); + for (int i = 0; i < updates.length; i++) { + ordered.add(null); + } + for (UpdateMetadata updateMetadata : updates) { + ordered.set(updateMetadata.index, updateMetadata); + } + Map shouldIncludeValues = + ordered.stream().collect(Collectors.toMap(p -> p.index, p -> p.shouldIncludeValue)); + Map isBundled = + ordered.stream().collect(Collectors.toMap(p -> p.index, p -> p.isBundled)); + List allUpdates = + ordered.stream() + .map( + p -> { + // This should only be invoked after the futures have all + // completed, so calling join shouldn't block. + return p.update.join(); + }) + .collect(Collectors.toList()); + return combineSegment(shouldIncludeValues, isBundled, accTypeInfo, valueType, allUpdates); + } + + private static BundledKeySegmentApplied combineSegment( + Map shouldIncludeValues, + Map isBundled, + RowType accumulatorType, + RowType valueType, + List ithEntries) { + List accs = new ArrayList<>(); + List startingValues = new ArrayList<>(); + List finalValues = new ArrayList<>(); + List> updatedValuesAfterEachRow = new ArrayList<>(); + for (int i = 0; i < ithEntries.size(); i++) { + BundledKeySegmentApplied update = ithEntries.get(i); + // Non bundled accumulators are already wrapped in a row to contain them, so should not + // create another layer. + boolean avoidWrappingInRow = !isBundled.get(i); + accs.add( + avoidWrappingInRow + ? update.getAccumulator() + : GenericRowData.of(update.getAccumulator())); + + if (shouldIncludeValues.get(i)) { + RowData startingValue = update.getStartingValue(); + RowData finalValue = update.getFinalValue(); + + startingValues.add(startingValue); + finalValues.add(finalValue); + + if (updatedValuesAfterEachRow.isEmpty()) { + for (int j = 0; j < update.getUpdatedValuesAfterEachRow().size(); j++) { + updatedValuesAfterEachRow.add(new ArrayList<>(ithEntries.size())); + updatedValuesAfterEachRow + .get(j) + .add(update.getUpdatedValuesAfterEachRow().get(j)); + } + } else { + Preconditions.checkState( + updatedValuesAfterEachRow.size() + == update.getUpdatedValuesAfterEachRow().size()); + for (int j = 0; j < update.getUpdatedValuesAfterEachRow().size(); j++) { + updatedValuesAfterEachRow + .get(j) + .add(update.getUpdatedValuesAfterEachRow().get(j)); + } + } + } + } + + final List updatedValuesAfterEachRowFinal = + updatedValuesAfterEachRow.stream() + .map(list -> mergeAllFields(list, valueType)) + .collect(Collectors.toList()); + + return new BundledKeySegmentApplied( + mergeAllFields(accs, accumulatorType), + mergeAllFields(startingValues, valueType), + mergeAllFields(finalValues, valueType), + updatedValuesAfterEachRowFinal); + } + + // Merges all fields from the given list of rows into a single row. + private static GenericRowData mergeAllFields(List rowData, RowType types) { + int size = rowData.stream().mapToInt(RowData::getArity).sum(); + final Object[] fieldByPosition = new Object[size]; + int total = 0; + for (RowData rd : rowData) { + for (int pos = 0; pos < rd.getArity(); pos++) { + final Object value = + RowData.createFieldGetter(types.getTypeAt(total), pos).getFieldOrNull(rd); + fieldByPosition[total++] = value; + } + } + return GenericRowData.of(fieldByPosition); + } + + private static class UpdateMetadata { + + private final int index; + private final CompletableFuture update; + private final boolean shouldIncludeValue; + private final boolean isBundled; + + public UpdateMetadata( + int index, + CompletableFuture update, + boolean shouldIncludeValue, + boolean isBundled) { + this.index = index; + this.update = update; + this.shouldIncludeValue = shouldIncludeValue; + this.isBundled = isBundled; + } + } + + private static List indexList(int start, int end) { + return java.util.stream.IntStream.range(start, end) + .boxed() + .collect(java.util.stream.Collectors.toList()); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/codegen/agg/NonBundledAggregateUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/codegen/agg/NonBundledAggregateUtil.java new file mode 100644 index 0000000000000..e220f4312b054 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/codegen/agg/NonBundledAggregateUtil.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.codegen.agg; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.agg.BundledKeySegment; +import org.apache.flink.table.functions.agg.BundledKeySegmentApplied; +import org.apache.flink.table.runtime.generated.AggsHandleFunction; +import org.apache.flink.table.types.logical.RowType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** Utility class for executing non-bundled aggregate functions. */ +public class NonBundledAggregateUtil { + + /** + * Executes the given non-bundled aggregate function calls for each segment in the list. + * + * @param handle The aggregate function to execute. + * @param segment The segment to execute the function for. + * @return The results of the function calls. + */ + public static NonBundledSegmentResult executeAsBundle( + AggsHandleFunction handle, BundledKeySegment segment) throws Exception { + RowData acc = segment.getAccumulator(); + + if (acc == null) { + acc = handle.createAccumulators(); + } + handle.setAccumulators(acc); + RowData startingValues = handle.getValue(); + List updatedValuesAfterEachRow = new ArrayList<>(); + for (org.apache.flink.table.data.RowData row : segment.getRows()) { + if (org.apache.flink.table.data.util.RowDataUtil.isAccumulateMsg(row)) { + handle.accumulate(row); + } else { + handle.retract(row); + } + if (segment.getUpdatedValuesAfterEachRow()) { + updatedValuesAfterEachRow.add(handle.getValue()); + } + } + + acc = handle.getAccumulators(); + + org.apache.flink.table.data.RowData finalValues = handle.getValue(); + + return new NonBundledSegmentResult( + acc, startingValues, finalValues, updatedValuesAfterEachRow); + } + + /** Result of all non-bundled aggregate function calls for this segment. */ + public static class NonBundledSegmentResult { + private final RowData accumulator; + private final RowData startingValue; + private final RowData finalValue; + private final List updatedValuesAfterEachRow; + + public NonBundledSegmentResult( + RowData accumulator, + RowData startingValue, + RowData finalValue, + List updatedValuesAfterEachRow) { + this.accumulator = accumulator; + this.startingValue = startingValue; + this.finalValue = finalValue; + this.updatedValuesAfterEachRow = updatedValuesAfterEachRow; + } + + /** + * Convert this result to a {@link BundledKeySegmentApplied} for the given call index. This + * extracts the accumulator and values as though the call was done with a bundled call, + * giving a view of them so that they can be combined with the bundled results. + * + * @param index The index of the call. + * @param shouldIncludeValue Whether the value should be included or ignored. + * @param accumulatorFields The fields to include in the view of the accumulator. + * @param accTypeInfo The type information for the accumulator. + * @param valueType The type information for the values. + * @return The bundled data key segment update. + */ + public BundledKeySegmentApplied asBundledDataKeySegmentUpdate( + int index, + boolean shouldIncludeValue, + List accumulatorFields, + RowType accTypeInfo, + RowType valueType) { + return BundledKeySegmentApplied.of( + getAccumulatorViewForFields(accumulator, accumulatorFields, accTypeInfo), + shouldIncludeValue ? getValueForIndex(valueType, startingValue, index) : null, + shouldIncludeValue ? getValueForIndex(valueType, finalValue, index) : null, + shouldIncludeValue + ? updatedValuesAfterEachRow.stream() + .map(v -> getValueForIndex(valueType, v, index)) + .collect(Collectors.toList()) + : Collections.emptyList()); + } + + /** Wraps the accumulator fields in a row. */ + private RowData getAccumulatorViewForFields( + RowData accumulator, List accumulatorFields, RowType accumulatorType) { + return GenericRowData.of( + accumulatorFields.stream() + .mapToInt(v -> v) + .mapToObj( + v -> + RowData.createFieldGetter( + accumulatorType.getTypeAt(v), v) + .getFieldOrNull(accumulator)) + .toArray()); + } + + /** Extracts the value for the given index from the whole value. */ + private static RowData getValueForIndex(RowType valueType, RowData inputValue, int index) { + return GenericRowData.of( + RowData.createFieldGetter(valueType.getTypeAt(index), index) + .getFieldOrNull(inputValue)); + } + + public RowData getAccumulator() { + return accumulator; + } + + public RowData getStartingValue() { + return startingValue; + } + + public RowData getFinalValue() { + return finalValue; + } + + public List getUpdatedValuesAfterEachRow() { + return updatedValuesAfterEachRow; + } + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecBundledGroupAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecBundledGroupAggregate.java new file mode 100644 index 0000000000000..97f1c77e388fd --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecBundledGroupAggregate.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.stream; + +import org.apache.flink.FlinkVersion; +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory; +import org.apache.flink.streaming.api.transformations.OneInputTransformation; +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.planner.codegen.CodeGeneratorContext; +import org.apache.flink.table.planner.codegen.EqualiserCodeGenerator; +import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator; +import org.apache.flink.table.planner.delegation.PlannerBase; +import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeMetadata; +import org.apache.flink.table.planner.plan.nodes.exec.InputProperty; +import org.apache.flink.table.planner.plan.nodes.exec.StateMetadata; +import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil; +import org.apache.flink.table.planner.plan.utils.AggregateInfoList; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.plan.utils.BundledAggUtil; +import org.apache.flink.table.planner.plan.utils.KeySelectorUtil; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; +import org.apache.flink.table.runtime.operators.aggregate.async.BundledAggregateAsyncFunction; +import org.apache.flink.table.runtime.operators.aggregate.async.KeyedAsyncWaitOperatorFactory; +import org.apache.flink.table.runtime.operators.aggregate.async.queue.KeyedAsyncOutputMode; +import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonInclude; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; + +import org.apache.calcite.rel.core.AggregateCall; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Stream {@link ExecNode} which matches along with join a Java/Scala user defined table function. + */ +@ExecNodeMetadata( + name = "stream-exec-bundled-group-aggregate", + version = 1, + consumedOptions = {}, + producedTransformations = + StreamExecBundledGroupAggregate.BUNDLED_GROUP_AGGREGATE_TRANSFORMATION, + minPlanVersion = FlinkVersion.v1_19, + minStateVersion = FlinkVersion.v1_19) +public class StreamExecBundledGroupAggregate extends StreamExecAggregateBase { + + private static final Logger LOG = + LoggerFactory.getLogger(StreamExecBundledGroupAggregate.class); + + public static final String BUNDLED_GROUP_AGGREGATE_TRANSFORMATION = "bundled-group-aggregate"; + + public static final String STATE_NAME = "bundledGroupAggregateState"; + + @JsonProperty(FIELD_NAME_GROUPING) + private final int[] grouping; + + @JsonProperty(FIELD_NAME_AGG_CALLS) + private final AggregateCall[] aggCalls; + + /** Each element indicates whether the corresponding agg call needs `retract` method. */ + @JsonProperty(FIELD_NAME_AGG_CALL_NEED_RETRACTIONS) + private final boolean[] aggCallNeedRetractions; + + /** Whether this node will generate UPDATE_BEFORE messages. */ + @JsonProperty(FIELD_NAME_GENERATE_UPDATE_BEFORE) + private final boolean generateUpdateBefore; + + /** Whether this node consumes retraction messages. */ + @JsonProperty(FIELD_NAME_NEED_RETRACTION) + private final boolean needRetraction; + + @Nullable + @JsonProperty(FIELD_NAME_STATE) + @JsonInclude(JsonInclude.Include.NON_NULL) + private final List stateMetadataList; + + public StreamExecBundledGroupAggregate( + TableConfig tableConfig, + int[] grouping, + AggregateCall[] aggCalls, + boolean[] aggCallNeedRetractions, + boolean generateUpdateBefore, + boolean needRetraction, + @Nullable Long stateTtlFromHint, + InputProperty inputProperty, + RowType outputType, + String description) { + this( + ExecNodeContext.newNodeId(), + ExecNodeContext.newContext(StreamExecGroupAggregate.class), + ExecNodeContext.newPersistedConfig(StreamExecGroupAggregate.class, tableConfig), + grouping, + aggCalls, + aggCallNeedRetractions, + generateUpdateBefore, + needRetraction, + StateMetadata.getOneInputOperatorDefaultMeta( + stateTtlFromHint, tableConfig, STATE_NAME), + Collections.singletonList(inputProperty), + outputType, + description); + } + + @JsonCreator + public StreamExecBundledGroupAggregate( + @JsonProperty(FIELD_NAME_ID) int id, + @JsonProperty(FIELD_NAME_TYPE) ExecNodeContext context, + @JsonProperty(FIELD_NAME_CONFIGURATION) ReadableConfig persistedConfig, + @JsonProperty(FIELD_NAME_GROUPING) int[] grouping, + @JsonProperty(FIELD_NAME_AGG_CALLS) AggregateCall[] aggCalls, + @JsonProperty(FIELD_NAME_AGG_CALL_NEED_RETRACTIONS) boolean[] aggCallNeedRetractions, + @JsonProperty(FIELD_NAME_GENERATE_UPDATE_BEFORE) boolean generateUpdateBefore, + @JsonProperty(FIELD_NAME_NEED_RETRACTION) boolean needRetraction, + @Nullable @JsonProperty(FIELD_NAME_STATE) List stateMetadataList, + @JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List inputProperties, + @JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType, + @JsonProperty(FIELD_NAME_DESCRIPTION) String description) { + super(id, context, persistedConfig, inputProperties, outputType, description); + this.grouping = checkNotNull(grouping); + this.aggCalls = checkNotNull(aggCalls); + this.aggCallNeedRetractions = checkNotNull(aggCallNeedRetractions); + checkArgument(aggCalls.length == aggCallNeedRetractions.length); + this.generateUpdateBefore = generateUpdateBefore; + this.needRetraction = needRetraction; + this.stateMetadataList = stateMetadataList; + } + + private void setJobMetadata(TableConfig tableConfig, LogicalType[] accTypes) { + RowType internalAccType = RowType.of(accTypes); + tableConfig.addJobParameter("internalAccType", internalAccType.asSerializableString()); + } + + @Override + protected Transformation translateToPlanInternal( + PlannerBase planner, ExecNodeConfig config) { + final long stateRetentionTime = + StateMetadata.getStateTtlForOneInputOperator(config, stateMetadataList); + if (grouping.length > 0 && stateRetentionTime < 0) { + LOG.warn( + "No state retention interval configured for a query which accumulates state. " + + "Please provide a query configuration with valid retention interval to prevent excessive " + + "state size. You may specify a retention time of 0 to not clean up the state."); + } + + final ExecEdge inputEdge = getInputEdges().get(0); + final Transformation inputTransform = + (Transformation) inputEdge.translateToPlan(planner); + final RowType inputRowType = (RowType) inputEdge.getOutputType(); + + final AggsHandlerCodeGenerator generator = + new AggsHandlerCodeGenerator( + new CodeGeneratorContext( + config, planner.getFlinkContext().getClassLoader()), + planner.createRelBuilder(), + JavaScalaConversionUtil.toScala(inputRowType.getChildren()), + // TODO: heap state backend do not copy key currently, + // we have to copy input field + // TODO: copy is not need when state backend is rocksdb, + // improve this in future + // TODO: but other operators do not copy this input field..... + true); + + final AggregateInfoList aggInfoList = + AggregateUtil.transformToStreamAggregateInfoList( + planner.getTypeFactory(), + inputRowType, + JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), + aggCallNeedRetractions, + needRetraction, + true, + true); + + generator.needAccumulate().needBundled(); + + if (needRetraction) { + generator.needRetract(); + } + + final GeneratedAggsHandleFunction aggsHandler = + generator.generateAggsHandler("GroupAggsHandler", aggInfoList); + + final LogicalType[] accTypes = + Arrays.stream(aggInfoList.getAccTypes()) + .map(LogicalTypeDataTypeConverter::fromDataTypeToLogicalType) + .toArray(LogicalType[]::new); + final LogicalType[] aggValueTypes = + Arrays.stream(aggInfoList.getActualValueTypes()) + .map(LogicalTypeDataTypeConverter::fromDataTypeToLogicalType) + .toArray(LogicalType[]::new); + final GeneratedRecordEqualiser recordEqualiser = + new EqualiserCodeGenerator( + aggValueTypes, planner.getFlinkContext().getClassLoader()) + .generateRecordEqualiser("GroupAggValueEqualiser"); + final int inputCountIndex = aggInfoList.getIndexOfCountStar(); + + BundledAggregateAsyncFunction asyncFunction = + new BundledAggregateAsyncFunction( + aggsHandler, + recordEqualiser, + accTypes, + inputRowType, + inputCountIndex, + generateUpdateBefore, + stateRetentionTime); + + final OneInputStreamOperatorFactory operator = + new KeyedAsyncWaitOperatorFactory<>( + asyncFunction, + BundledAggUtil.timeout(config), + BundledAggUtil.bufferCapacity(config), + KeyedAsyncOutputMode.ORDERED); + + // partitioned aggregation + final OneInputTransformation transform = + ExecNodeUtil.createOneInputTransformation( + inputTransform, + createTransformationMeta(BUNDLED_GROUP_AGGREGATE_TRANSFORMATION, config), + operator, + InternalTypeInfo.of(getOutputType()), + inputTransform.getParallelism(), + false); + + // set KeyType and Selector for state + final RowDataKeySelector selector = + KeySelectorUtil.getRowDataSelector( + planner.getFlinkContext().getClassLoader(), + grouping, + InternalTypeInfo.of(inputRowType)); + transform.setStateKeySelector(selector); + transform.setStateKeyType(selector.getProducedType()); + + return transform; + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalBundledGroupAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalBundledGroupAggregate.java new file mode 100644 index 0000000000000..d2577ffca494a --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalBundledGroupAggregate.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.physical.stream; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.planner.plan.PartialFinalType; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; +import org.apache.flink.table.planner.plan.nodes.exec.InputProperty; +import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecBundledGroupAggregate; +import org.apache.flink.table.planner.plan.utils.AggregateInfoList; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.plan.utils.ChangelogPlanUtils; +import org.apache.flink.table.planner.plan.utils.RelDescriptionWriterImpl; +import org.apache.flink.table.planner.plan.utils.RelExplainUtil; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; + +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.hint.RelHint; +import org.apache.calcite.rel.type.RelDataType; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.List; +import java.util.Optional; + +import static org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTableConfig; +import static org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory; + +/** + * Stream physical RelNode for {@link org.apache.flink.table.functions.AggregateBundledFunction}. + */ +public class StreamPhysicalBundledGroupAggregate extends StreamPhysicalGroupAggregateBase { + private final RelOptCluster cluster; + private final RelDataType outputRowType; + private final int[] grouping; + private final List aggCalls; + private final PartialFinalType partialFinalType; + private final List hints; + + public StreamPhysicalBundledGroupAggregate( + RelOptCluster cluster, + RelTraitSet traitSet, + RelNode inputRel, + RelDataType outputRowType, + int[] grouping, + List aggCalls, + PartialFinalType partialFinalType, + List hints) { + super( + cluster, + traitSet, + inputRel, + grouping, + JavaScalaConversionUtil.toScala(aggCalls), + hints); + this.cluster = cluster; + this.outputRowType = outputRowType; + this.grouping = grouping; + this.aggCalls = aggCalls; + this.partialFinalType = partialFinalType; + this.hints = hints; + } + + @Override + public RelWriter explainTerms(RelWriter pw) { + RelDataType inputRowType = getInput().getRowType(); + AggregateInfoList aggInfoList = + AggregateUtil.deriveAggregateInfoList( + this, grouping.length, JavaScalaConversionUtil.toScala(aggCalls)); + return super.explainTerms(pw) + .itemIf( + "groupBy", + RelExplainUtil.fieldToString(grouping, inputRowType), + grouping.length > 0) + .itemIf( + "partialFinalType", + partialFinalType, + partialFinalType != PartialFinalType.NONE) + .item( + "select", + RelExplainUtil.streamGroupAggregationToString( + inputRowType, + getRowType(), + aggInfoList, + grouping, + JavaScalaConversionUtil.toScala(Optional.empty()), + false, + false)); + } + + @Override + public RelDataType deriveRowType() { + return outputRowType; + } + + @Override + public RelNode copy(RelTraitSet traitSet, List inputs) { + return new StreamPhysicalBundledGroupAggregate( + cluster, + traitSet, + inputs.get(0), + outputRowType, + grouping, + aggCalls, + partialFinalType, + hints); + } + + @Override + public String getRelDetailedDescription() { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + RelDescriptionWriterImpl relWriter = new RelDescriptionWriterImpl(pw); + this.explain(relWriter); + return sw.toString(); + } + + @Override + public ExecNode translateToExecNode() { + boolean[] aggCallNeedRetractions = + AggregateUtil.deriveAggCallNeedRetractions( + this, grouping.length, JavaScalaConversionUtil.toScala(aggCalls)); + boolean generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this); + boolean needRetraction = !ChangelogPlanUtils.inputInsertOnly(this); + TableConfig tableConfig = unwrapTableConfig(this); + setJobMetadata(tableConfig); + return new StreamExecBundledGroupAggregate( + tableConfig, + grouping, + aggCalls.toArray(new AggregateCall[0]), + aggCallNeedRetractions, + generateUpdateBefore, + needRetraction, + null, + InputProperty.DEFAULT, + FlinkTypeFactory.toLogicalRowType(getRowType()), + getRelDetailedDescription()); + } + + private void setJobMetadata(TableConfig tableConfig) { + RowType inputRowType = + unwrapTypeFactory(getInput()).toLogicalRowType(getInput().getRowType()); + LogicalType[] inputFieldTypes = InternalTypeInfo.of(inputRowType).toRowFieldTypes(); + LogicalType[] keyFieldTypes = new LogicalType[grouping.length]; + for (int i = 0; i < grouping.length; ++i) { + keyFieldTypes[i] = inputFieldTypes[grouping[i]]; + } + RowType keyType = RowType.of(keyFieldTypes); + tableConfig.addJobParameter("keyType", keyType.asSerializableString()); + } + + @Override + public boolean requireWatermark() { + return false; + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalBundledGroupAggregateRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalBundledGroupAggregateRule.java new file mode 100644 index 0000000000000..63d70c3c02988 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalBundledGroupAggregateRule.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.physical.stream; + +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.planner.plan.nodes.FlinkConventions; +import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate; +import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalBundledGroupAggregate; +import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution; +import org.apache.flink.table.planner.plan.utils.BundledAggUtil; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.convert.ConverterRule; +import org.apache.calcite.rel.core.Aggregate; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Rule that converts {@link FlinkLogicalAggregate} to {@link StreamPhysicalBundledGroupAggregate}. + */ +public class StreamPhysicalBundledGroupAggregateRule extends ConverterRule { + + public static final RelOptRule INSTANCE = + new StreamPhysicalBundledGroupAggregateRule( + Config.INSTANCE.withConversion( + FlinkLogicalAggregate.class, + FlinkConventions.LOGICAL(), + FlinkConventions.STREAM_PHYSICAL(), + "StreamPhysicalBundledGroupAggregateRule")); + + public StreamPhysicalBundledGroupAggregateRule(Config config) { + super(config); + } + + @Override + public boolean matches(RelOptRuleCall call) { + FlinkLogicalAggregate agg = call.rel(0); + + if (agg.getGroupType() != Aggregate.Group.SIMPLE) { + throw new TableException("GROUPING SETS are currently not supported."); + } + + return agg.getAggCallList().stream().anyMatch(BundledAggUtil::containsBatchAggCall); + } + + @Override + public @Nullable RelNode convert(RelNode rel) { + FlinkLogicalAggregate agg = (FlinkLogicalAggregate) rel; + FlinkRelDistribution requiredDistribution; + if (agg.getGroupCount() != 0) { + requiredDistribution = FlinkRelDistribution.hash(agg.getGroupSet().asList(), true); + } else { + requiredDistribution = FlinkRelDistribution.SINGLETON(); + } + RelTraitSet requiredTraitSet = + rel.getCluster() + .getPlanner() + .emptyTraitSet() + .replace(requiredDistribution) + .replace(FlinkConventions.STREAM_PHYSICAL()); + RelTraitSet providedTraitSet = + rel.getTraitSet().replace(FlinkConventions.STREAM_PHYSICAL()); + RelNode newInput = RelOptRule.convert(agg.getInput(), requiredTraitSet); + + return new StreamPhysicalBundledGroupAggregate( + rel.getCluster(), + providedTraitSet, + newInput, + rel.getRowType(), + agg.getGroupSet().toArray(), + agg.getAggCallList(), + agg.partialFinalType(), + agg.getHints()); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/BundledAggUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/BundledAggUtil.java new file mode 100644 index 0000000000000..3b0491bd6a0a6 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/BundledAggUtil.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.utils; + +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.table.api.config.ExecutionConfigOptions; +import org.apache.flink.table.functions.BundledAggregateFunction; +import org.apache.flink.table.functions.FunctionKind; +import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction; + +import org.apache.calcite.rel.core.AggregateCall; + +/** Utility class for bundled aggregate functions. */ +public class BundledAggUtil { + public static boolean containsBatchAggCall(AggregateCall aggregateCall) { + if (aggregateCall.getAggregation() instanceof BridgingSqlAggFunction) { + BridgingSqlAggFunction bridgingSqlAggFunction = + (BridgingSqlAggFunction) aggregateCall.getAggregation(); + return bridgingSqlAggFunction.getDefinition().getKind() == FunctionKind.AGGREGATE + && (bridgingSqlAggFunction.getDefinition() instanceof BundledAggregateFunction + && ((BundledAggregateFunction) bridgingSqlAggFunction.getDefinition()) + .canBundle()); + } + return false; + } + + /** Returns the batch size from the config. */ + public static int bufferCapacity(ReadableConfig config) { + int capacity = config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_AGG_BUFFER_CAPACITY); + if (capacity <= 0) { + throw new IllegalArgumentException( + ExecutionConfigOptions.TABLE_EXEC_ASYNC_AGG_BUFFER_CAPACITY + " must be > 0."); + } + return capacity; + } + + /** Returns the max latency from the config. */ + public static long timeout(ReadableConfig config) { + return config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_AGG_TIMEOUT).toMillis(); + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala index 5da8128a21aa1..29e7d345b2cea 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.planner.codegen import org.apache.flink.api.common.functions.RuntimeContext import org.apache.flink.api.common.serialization.SerializerConfigImpl import org.apache.flink.core.memory.MemorySegment +import org.apache.flink.streaming.api.functions.async.ResultFuture import org.apache.flink.table.data._ import org.apache.flink.table.data.binary._ import org.apache.flink.table.data.binary.BinaryRowDataUtil.BYTE_ARRAY_BASE_OFFSET @@ -27,11 +28,16 @@ import org.apache.flink.table.data.util.DataFormatConverters import org.apache.flink.table.data.util.DataFormatConverters.IdentityConverter import org.apache.flink.table.data.utils.JoinedRowData import org.apache.flink.table.functions.UserDefinedFunction +import org.apache.flink.table.functions.agg.{BundledKeySegment, BundledKeySegmentApplied} import org.apache.flink.table.legacy.types.logical.TypeInformationRawType import org.apache.flink.table.planner.codegen.GenerateUtils.{generateInputFieldUnboxing, generateNonNullField} +import org.apache.flink.table.planner.codegen.agg.{BundledResultCombiner, NonBundledAggregateUtil} +import org.apache.flink.table.planner.codegen.agg.NonBundledAggregateUtil.NonBundledSegmentResult import org.apache.flink.table.planner.codegen.calls.BuiltInMethods.BINARY_STRING_DATA_FROM_STRING import org.apache.flink.table.runtime.dataview.StateDataViewStore import org.apache.flink.table.runtime.generated._ +import org.apache.flink.table.runtime.generated.{AggsHandleFunction, GeneratedHashFunction, HashFunction, NamespaceAggsHandleFunction, TableAggsHandleFunction} +import org.apache.flink.table.runtime.operators.aggregate.async.MultiDelegatingAsyncResultFuture import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType import org.apache.flink.table.runtime.typeutils.TypeCheckUtils import org.apache.flink.table.runtime.util.{MurmurHashUtil, TimeWindowUtil} @@ -97,6 +103,18 @@ object CodeGenUtils { val ROW_DATA: String = className[RowData] + val BUNDLED_KEY_SEGMENT: String = className[BundledKeySegment] + + val BUNDLED_KEY_SEGMENT_APPLIED: String = className[BundledKeySegmentApplied] + + val NON_BUNDLED_SEGMENT_UTIL: String = className[NonBundledAggregateUtil] + + val BUNDLED_RESULT_COMBINER: String = className[BundledResultCombiner] + + val RESULT_FUTURE: String = className[ResultFuture[_]] + + val MULTI_DELEGATING_FUTURE: String = className[MultiDelegatingAsyncResultFuture] + val JOINED_ROW: String = className[JoinedRowData] val GENERIC_ROW: String = className[GenericRowData] diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggCodeGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggCodeGen.scala index cbfbd6ff51482..65ba8bad690fa 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggCodeGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggCodeGen.scala @@ -39,6 +39,8 @@ trait AggCodeGen { def retract(generator: ExprCodeGenerator): String + def bundledAccumulateRetract(generator: ExprCodeGenerator): String + def merge(generator: ExprCodeGenerator): String def getValue(generator: ExprCodeGenerator): GeneratedExpression @@ -48,5 +50,6 @@ trait AggCodeGen { needRetract: Boolean = false, needMerge: Boolean = false, needReset: Boolean = false, + needBundled: Boolean = false, needEmitValue: Boolean = false): Unit } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala index deec91f894a0b..28f4671f0f735 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala @@ -21,7 +21,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer import org.apache.flink.table.api.{DataTypes, TableException} import org.apache.flink.table.data.GenericRowData import org.apache.flink.table.expressions._ -import org.apache.flink.table.functions.{DeclarativeAggregateFunction, ImperativeAggregateFunction, TableAggregateFunction, UserDefinedFunctionHelper} +import org.apache.flink.table.functions.{BundledAggregateFunction, DeclarativeAggregateFunction, ImperativeAggregateFunction, TableAggregateFunction, UserDefinedFunctionHelper} import org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector import org.apache.flink.table.planner.JLong import org.apache.flink.table.planner.codegen._ @@ -77,6 +77,7 @@ class AggsHandlerCodeGenerator( /** Aggregates informations */ private var accTypeInfo: RowType = _ private var aggBufferSize: Int = _ + private var accIndicesMap: Map[Int, (Int, Int)] = _ private var mergedAccExternalTypes: Array[DataType] = _ private var mergedAccOffset: Int = 0 @@ -86,6 +87,7 @@ class AggsHandlerCodeGenerator( private var isAccumulateNeeded = false private var isRetractNeeded = false + private var isBundledNeeded = false private var isMergeNeeded = false private var isWindowSizeNeeded = false private var isIncrementalUpdateNeeded = false @@ -169,6 +171,18 @@ class AggsHandlerCodeGenerator( this } + /** + * Tells the generator to generate `bundledAccumulateRetract(..)` method for the + * [[AggsHandleFunction]] and [[NamespaceAggsHandleFunction]]. Default not generate + * `bundledAccumulateRetract(..)` method. + * + * @return + */ + def needBundled(): AggsHandlerCodeGenerator = { + this.isBundledNeeded = true + this + } + /** * Whether to update acc result incrementally. The value is true only for TableAggregateFunction * with emitUpdateWithRetract method implemented. @@ -227,6 +241,10 @@ class AggsHandlerCodeGenerator( mergedAccExternalTypes = aggInfoList.getAccTypes } + if (accIndicesMap == null) { + accIndicesMap = createAccumulatorIndicesMap(aggInfoList) + } + val aggCodeGens = aggInfoList.aggInfos.map { aggInfo => val filterExpr = @@ -244,6 +262,16 @@ class AggsHandlerCodeGenerator( inputFieldTypes, constants, relBuilder) + case _: BundledAggregateFunction => + new BundledImperativeAggCodeGen( + ctx, + aggInfo, + filterExpr, + accIndicesMap(aggInfo.aggIndex), + inputFieldTypes, + constantExprs, + relBuilder, + copyInputField) case _: ImperativeAggregateFunction[_, _] => aggInfo.function match { case tableAggFunc: TableAggregateFunction[_, _] => @@ -360,6 +388,7 @@ class AggsHandlerCodeGenerator( val retractCode = genRetract() val mergeCode = genMerge() val getValueCode = genGetValue() + val bundledCode = genBundledAccumulateRetract(aggInfoList, None) val functionName = newName(ctx, name) @@ -400,6 +429,13 @@ class AggsHandlerCodeGenerator( $retractCode } + @Override + public void bundledAccumulateRetract( + java.util.concurrent.CompletableFuture<$BUNDLED_KEY_SEGMENT_APPLIED> future, + $BUNDLED_KEY_SEGMENT $BATCH_INPUT_TERM) throws Exception { + $bundledCode + } + @Override public void merge($ROW_DATA $MERGED_ACC_TERM) throws Exception { $mergeCode @@ -1055,6 +1091,105 @@ class AggsHandlerCodeGenerator( } } + private def genBundledAccumulateRetract( + aggInfoList: AggregateInfoList, + namespace: Option[Class[_]]): String = { + if (isBundledNeeded) { + // validation check + checkNeededMethods(needBundled = true) + + val methodName = "bundledAccumulateRetract" + ctx.startNewLocalVariableStatement(methodName) + + // bind input1 as inputRow + val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL) + .bindInput(inputType, inputTerm = BATCH_ENTRY_TERM) + + val imperativeBlock = aggActionCodeGens + .map(_.bundledAccumulateRetract(exprGenerator)) + .mkString("\n") + + val ns = namespace.map(c => "ns,").getOrElse("") + + val declarativeBlock = + s""" + | $NON_BUNDLED_SEGMENT_UTIL.NonBundledSegmentResult nonBundledResult = + | $NON_BUNDLED_SEGMENT_UTIL.executeAsBundle( + | this, + | $ns + | $BATCH_INPUT_TERM); + |""".stripMargin + + val combinerFactory = new BundledResultCombiner(accTypeInfo, valueType) + val combinerFactoryTerm = ctx.addReusableObject(combinerFactory, "combinerFactory") + + val combineAdds = aggInfoList.aggInfos + .map( + aggInfo => { + val isBundled = + aggInfo.function.isInstanceOf[BundledAggregateFunction] + val bundledResults = if (isBundled) { + s"""java.util.Optional.of(future${aggInfo.aggIndex})""".stripMargin + } else { + "java.util.Optional.empty()" + } + + val nonBundledResults = if (isBundled) { + "java.util.Optional.empty()" + } else { + "java.util.Optional.of(nonBundledResult)" + } + + val createFuture = if (isBundled) { + s""" + |java.util.concurrent.CompletableFuture future${aggInfo.aggIndex} = df.createAsyncFuture(); + |""".stripMargin + } else { + "" + } + + val (accIndexStart, accIndexEnd) = accIndicesMap(aggInfo.aggIndex) + s""" + | $createFuture + | combiner.add( + | ${aggInfo.aggIndex}, + | $bundledResults, + | $nonBundledResults, + | ${!aggInfoList.countStarInserted || aggInfoList.indexOfCountStar.get != aggInfo.aggIndex}, + | $isBundled, + | $accIndexStart, $accIndexEnd + |); + """.stripMargin + }) + .mkString("\n") + + var totalBundled = 0 + aggInfoList.aggInfos.foreach { + aggInfo => + if (aggInfo.function.isInstanceOf[BundledAggregateFunction]) { + totalBundled += 1 + } + } + + s""" + | ${ctx.reuseLocalVariableCode(methodName)} + | // The combiner takes the results of the declarative block, the async calls, and combines + | // them to be returned by the future. + | $BUNDLED_RESULT_COMBINER.Combiner combiner = $combinerFactoryTerm.newCombiner(); + | $MULTI_DELEGATING_FUTURE df = new $MULTI_DELEGATING_FUTURE(future, $totalBundled); + | df.setResultSupplier(combiner); + |$declarativeBlock + |$combineAdds + | // This is where the async calls are started, so everything must be set on the combiner + | // before this. + |$imperativeBlock + """.stripMargin + } else { + genThrowException( + "This function does not require batch method, but the batch method is called.") + } + } + private def genMerge(): String = { if (isMergeNeeded) { // validation check @@ -1272,6 +1407,7 @@ class AggsHandlerCodeGenerator( needAccumulate: Boolean = false, needRetract: Boolean = false, needMerge: Boolean = false, + needBundled: Boolean = false, needReset: Boolean = false, needEmitValue: Boolean = false): Unit = { // check and validate the needed methods @@ -1293,6 +1429,9 @@ object AggsHandlerCodeGenerator { val MERGED_ACC_TERM = "otherAcc" val ACCUMULATE_INPUT_TERM = "accInput" val RETRACT_INPUT_TERM = "retractInput" + val BATCH_INPUT_TERM = "batchInput" + val BATCH_ENTRY_TERM = "batchEntry" + val BATCH_RETURN_TERM = "batchReturn" val WINDOWS_SIZE = "windowSize" val DISTINCT_KEY_TERM = "distinctKey" @@ -1431,4 +1570,20 @@ object AggsHandlerCodeGenerator { ctx.addReusableExternalSerializer(dataType()) } } + + private def createAccumulatorIndicesMap( + aggInfoList: AggregateInfoList): scala.collection.immutable.Map[Int, (Int, Int)] = { + var accIndex = 0; + var accumulatorIndexesMap: scala.collection.mutable.Map[Int, (Int, Int)] = + scala.collection.mutable.Map[Int, (Int, Int)]() + aggInfoList.aggInfos.foreach { + aggInfo => + { + accumulatorIndexesMap(aggInfo.aggIndex) = + (accIndex, accIndex + aggInfo.externalAccTypes.length) + accIndex += aggInfo.externalAccTypes.length + } + } + scala.collection.immutable.Map(accumulatorIndexesMap.toSeq: _*) + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/BundledImperativeAggCodeGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/BundledImperativeAggCodeGen.scala new file mode 100644 index 0000000000000..f513c0d1d6722 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/BundledImperativeAggCodeGen.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.planner.codegen.agg +import org.apache.flink.table.api.TableException +import org.apache.flink.table.data.RowData +import org.apache.flink.table.expressions.Expression +import org.apache.flink.table.functions.{BundledAggregateFunction, FunctionContext, ImperativeAggregateFunction} +import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, ExprCodeGenerator, GeneratedExpression} +import org.apache.flink.table.planner.codegen.CodeGenUtils._ +import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE} +import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator._ +import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver.toRexInputRef +import org.apache.flink.table.planner.expressions.converter.ExpressionConverter +import org.apache.flink.table.planner.plan.utils.AggregateInfo +import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType +import org.apache.flink.table.types.FieldsDataType +import org.apache.flink.table.types.logical.LogicalType + +import org.apache.calcite.tools.RelBuilder + +import scala.collection.mutable.ArrayBuffer + +class BundledImperativeAggCodeGen( + ctx: CodeGeneratorContext, + aggInfo: AggregateInfo, + filterExpression: Option[Expression], + accIndexes: (Int, Int), + inputTypes: Seq[LogicalType], + constantExprs: Seq[GeneratedExpression], + relBuilder: RelBuilder, + inputFieldCopy: Boolean) + extends AggCodeGen { + val function = aggInfo.function.asInstanceOf[ImperativeAggregateFunction[_, _]] + val functionTerm: String = ctx.addReusableFunction( + function, + classOf[FunctionContext], + Seq(s"$STORE_TERM.getRuntimeContext()")) + val aggIndex: Int = aggInfo.aggIndex + val externalAccType = aggInfo.externalAccTypes(0) + private val internalAccType = fromDataTypeToLogicalType(externalAccType) + + /** + * whether the acc type is an internal type. Currently we only support GenericRowData as internal + * acc type + */ + val isAccTypeInternal: Boolean = + classOf[RowData].isAssignableFrom(externalAccType.getConversionClass) + + private val externalResultType = aggInfo.externalResultType + private val internalResultType = fromDataTypeToLogicalType(externalResultType) + + private val rexNodeGen = new ExpressionConverter(relBuilder) + + override def createAccumulator(generator: ExprCodeGenerator): Seq[GeneratedExpression] = { + Seq(GeneratedExpression("null", NEVER_NULL, NO_CODE, internalAccType, Option.empty)) + } + + override def setAccumulator(generator: ExprCodeGenerator): String = { + "" + } + + override def getAccumulator(generator: ExprCodeGenerator): Seq[GeneratedExpression] = { + Seq(GeneratedExpression("null", NEVER_NULL, NO_CODE, internalAccType, Option.empty)) + } + + override def resetAccumulator(generator: ExprCodeGenerator): String = { + "" + } + + override def setWindowSize(generator: ExprCodeGenerator): String = { + "" + } + + override def accumulate(generator: ExprCodeGenerator): String = { + "" + } + + override def retract(generator: ExprCodeGenerator): String = { + "" + } + + private def bundledParametersCode(generator: ExprCodeGenerator): (String, String) = { + val externalInputTypes = aggInfo.externalArgTypes + var codes: ArrayBuffer[String] = ArrayBuffer.empty[String] + + val inputRowFields = aggInfo.argIndexes.zipWithIndex + .map { + case (f, index) => + if (f >= inputTypes.length) { + // index to constant + val expr = constantExprs(f - inputTypes.length) + genToExternalConverterAll(ctx, externalInputTypes(index), expr) + } else { + // index to input field + val inputRef = { + // called from accumulate + toRexInputRef(relBuilder, f, inputTypes(f)) + } + var inputExpr = generator.generateExpression(inputRef.accept(rexNodeGen)) + if (inputFieldCopy) inputExpr = inputExpr.deepCopy(ctx) + codes += inputExpr.code + genToExternalConverterAll(ctx, externalInputTypes(index), inputExpr) + } + } + .mkString(", ") + + val loop = + s""" + | java.util.List inputRows = new java.util.ArrayList(); + | for ($ROW_DATA $BATCH_ENTRY_TERM : $BATCH_INPUT_TERM.getRows()) { + | ${ctx.reuseInputUnboxingCode(BATCH_ENTRY_TERM)} + | ${ctx.reusePerRecordCode()} + | inputRows.add($GENERIC_ROW.ofKind( + | $BATCH_ENTRY_TERM.getRowKind(), $inputRowFields)); + | } + | + | $BUNDLED_KEY_SEGMENT aggCall${aggInfo.aggIndex} = + | $BUNDLED_KEY_SEGMENT.of( + | $BATCH_INPUT_TERM.getKey(), + | inputRows, + | $BATCH_INPUT_TERM.getAccumulator() == null ? + | null : + | // Grab the segment of the accumulator belonging to this one + | $BATCH_INPUT_TERM.getAccumulator().getRow(${accIndexes._1}, + | ${externalAccType.asInstanceOf[FieldsDataType].getChildren.size()}), + | $BATCH_INPUT_TERM.getUpdatedValuesAfterEachRow()); + |""".stripMargin + + codes += loop + val inputFields = Seq( + s"""future${aggInfo.aggIndex}""".stripMargin, + s"""aggCall${aggInfo.aggIndex}""".stripMargin) + (inputFields.mkString(", "), codes.mkString("\n")) + } + + override def bundledAccumulateRetract(generator: ExprCodeGenerator): String = { + val (parameters, code) = bundledParametersCode(generator) + + val call = + s""" + | $functionTerm.bundledAccumulateRetract($parameters); + """.stripMargin + filterExpression match { + case None => + s""" + |$code + |$call + """.stripMargin + case Some(expr) => + throw new UnsupportedOperationException( + "Filter operations not handled on bundled aggregates yet"); + } + } + + override def merge(generator: ExprCodeGenerator): String = { + "" + } + + override def getValue(generator: ExprCodeGenerator): GeneratedExpression = { + GeneratedExpression("null", NEVER_NULL, NO_CODE, internalResultType, Option.empty) + } + + override def checkNeededMethods( + needAccumulate: Boolean, + needRetract: Boolean, + needMerge: Boolean, + needReset: Boolean, + needBundled: Boolean, + needEmitValue: Boolean): Unit = { + + if (needRetract) { + function match { + case f: BundledAggregateFunction if !f.canRetract => + throw new UnsupportedOperationException( + s"Retract functionality is not implemented for the aggregate function: ${f.getClass.getName}") + case _ => + } + } + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala index 898708d2ec5e2..77b6ea76835ea 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala @@ -180,6 +180,10 @@ class DeclarativeAggCodeGen( } } + override def bundledAccumulateRetract(generator: ExprCodeGenerator): String = { + "" + } + def retract(generator: ExprCodeGenerator): String = { val isDistinctMerge = generator.input1Term.startsWith(DISTINCT_KEY_TERM) val resolvedExprs = function.retractExpressions @@ -290,6 +294,7 @@ class DeclarativeAggCodeGen( needRetract: Boolean = false, needMerge: Boolean = false, needReset: Boolean = false, + needBundled: Boolean = false, needEmitValue: Boolean = false): Unit = { // skip the check for DeclarativeAggregateFunction for now } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala index 234631e995d34..a655660535e8f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala @@ -366,6 +366,7 @@ class DistinctAggCodeGen( needRetract: Boolean, needMerge: Boolean, needReset: Boolean, + needBundled: Boolean = false, needEmitValue: Boolean): Unit = { if (needMerge) { // see merge method for more information @@ -960,4 +961,8 @@ class DistinctAggCodeGen( throw new TableException( "Distinct shouldn't set window size, this is a bug, please file a issue.") } + + override def bundledAccumulateRetract(generator: ExprCodeGenerator): String = { + throw new TableException("Distinct shouldn't call batch, this is a bug, please file a issue.") + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala index 14f5849967d6c..bc2f7eddf5e4d 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.planner.codegen.agg import org.apache.flink.table.data.{GenericRowData, RowData, UpdatableRowData} import org.apache.flink.table.expressions.Expression -import org.apache.flink.table.functions.{FunctionContext, ImperativeAggregateFunction, UserDefinedFunctionHelper} +import org.apache.flink.table.functions.{BundledAggregateFunction, FunctionContext, ImperativeAggregateFunction, UserDefinedFunctionHelper} import org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, ExprCodeGenerator, GeneratedExpression} import org.apache.flink.table.planner.codegen.CodeGenUtils._ @@ -228,6 +228,10 @@ class ImperativeAggCodeGen( } } + override def bundledAccumulateRetract(generator: ExprCodeGenerator): String = { + "" + } + def merge(generator: ExprCodeGenerator): String = { val accIterTerm = s"agg${aggIndex}_acc_iter" ctx.addReusableMember(s"private final $SINGLE_ITERABLE $accIterTerm = new $SINGLE_ITERABLE();") @@ -456,6 +460,7 @@ class ImperativeAggCodeGen( needRetract: Boolean = false, needMerge: Boolean = false, needReset: Boolean = false, + needBundled: Boolean = false, needEmitValue: Boolean = false): Unit = { val functionName = String.valueOf(aggInfo.agg.getAggregation) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala index 0560bcedc5838..c1680d9ea9e8f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala @@ -188,7 +188,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti // ignore required trait from context, because sink is the true root sink.copy(sinkTrait, children).asInstanceOf[StreamPhysicalRel] - case agg: StreamPhysicalGroupAggregate => + case agg: StreamPhysicalGroupAggregateBase => // agg support all changes in input val children = visitChildren(agg, ModifyKindSetTrait.ALL_CHANGES) val inputModifyKindSet = getModifyKindSet(children.head) @@ -334,7 +334,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti createNewNode(over, children, providedTrait, requiredTrait, requester) case _: StreamPhysicalTemporalSort | _: StreamPhysicalIntervalJoin | - _: StreamPhysicalPythonOverAggregate => + _: StreamPhysicalOverAggregateBase | _: StreamPhysicalPythonOverAggregate => // TemporalSort, IntervalJoin only support consuming insert-only // and producing insert-only changes val children = visitChildren(rel, ModifyKindSetTrait.INSERT_ONLY) @@ -592,7 +592,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti } visitSink(sink, sinkRequiredTraits) - case _: StreamPhysicalGroupAggregate | _: StreamPhysicalGroupTableAggregate | + case _: StreamPhysicalGroupAggregateBase | _: StreamPhysicalGroupTableAggregate | _: StreamPhysicalLimit | _: StreamPhysicalPythonGroupAggregate | _: StreamPhysicalPythonGroupTableAggregate | _: StreamPhysicalGroupWindowAggregateBase | _: StreamPhysicalWindowAggregate | _: StreamPhysicalOverAggregate => @@ -605,7 +605,8 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti case _: StreamPhysicalWindowRank | _: StreamPhysicalWindowDeduplicate | _: StreamPhysicalTemporalSort | _: StreamPhysicalMatch | _: StreamPhysicalIntervalJoin | - _: StreamPhysicalPythonOverAggregate | _: StreamPhysicalWindowJoin => + _: StreamPhysicalOverAggregateBase | _: StreamPhysicalPythonOverAggregate | + _: StreamPhysicalWindowJoin => // WindowRank, WindowDeduplicate, Deduplicate, TemporalSort, CEP, // and IntervalJoin, WindowJoin require nothing about UpdateKind. val children = visitChildren(rel, UpdateKindTrait.NONE) @@ -1073,7 +1074,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti val fullDelete = fullDeleteOrNone(childModifyKindSet) visitSink(sink, Seq(fullDelete)) - case _: StreamPhysicalGroupAggregate | _: StreamPhysicalGroupTableAggregate | + case _: StreamPhysicalGroupAggregateBase | _: StreamPhysicalGroupTableAggregate | _: StreamPhysicalLimit | _: StreamPhysicalPythonGroupAggregate | _: StreamPhysicalPythonGroupTableAggregate | _: StreamPhysicalGroupWindowAggregateBase | _: StreamPhysicalWindowAggregate | _: StreamPhysicalSort | _: StreamPhysicalRank | diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala index 74219150ed0fa..dce3807eb2675 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala @@ -454,6 +454,7 @@ object FlinkStreamRuleSets { StreamPhysicalGroupTableAggregateRule.INSTANCE, StreamPhysicalPythonGroupAggregateRule.INSTANCE, StreamPhysicalPythonGroupTableAggregateRule.INSTANCE, + StreamPhysicalBundledGroupAggregateRule.INSTANCE, // over agg StreamPhysicalOverAggregateRule.INSTANCE, StreamPhysicalPythonOverAggregateRule.INSTANCE, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalGroupAggregateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalGroupAggregateRule.scala index 3495ebea06e4b..2c662586b264d 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalGroupAggregateRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalGroupAggregateRule.scala @@ -22,8 +22,8 @@ import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalGroupAggregate +import org.apache.flink.table.planner.plan.utils.{AsyncUtil, BundledAggUtil, WindowUtil} import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate -import org.apache.flink.table.planner.plan.utils.WindowUtil import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.RelNode @@ -48,6 +48,10 @@ class StreamPhysicalGroupAggregateRule(config: Config) extends ConverterRule(con return false } + if (agg.getAggCallList.exists(BundledAggUtil.containsBatchAggCall)) { + return false + } + // check not window aggregate !WindowUtil.isValidWindowAggregate(agg) } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/agg/BundledResultCombinerTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/agg/BundledResultCombinerTest.java new file mode 100644 index 0000000000000..b038ce04a99e2 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/agg/BundledResultCombinerTest.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.codegen.agg; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.functions.agg.BundledKeySegmentApplied; +import org.apache.flink.table.planner.codegen.agg.BundledResultCombiner.Combiner; +import org.apache.flink.table.planner.codegen.agg.NonBundledAggregateUtil.NonBundledSegmentResult; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for {@link BundledResultCombiner}. */ +public class BundledResultCombinerTest { + + private Combiner prepareOneBundledOneNonBundled(boolean implicitCounter) { + // Simulate SUM + Combiner combiner = + new BundledResultCombiner( + RowType.of(RowType.of(new BigIntType()), new BigIntType()), + RowType.of(new BigIntType(), new BigIntType())) + .newCombiner(); + + combiner.add( + 0, + Optional.of(CompletableFuture.completedFuture(createBundledSum())), + Optional.empty(), + true, + true, + 0, + 1); + combiner.add( + 1, + Optional.empty(), + Optional.of(createNonBundledSum(1)), + !implicitCounter, + false, + 1, + 2); + return combiner; + } + + @Test + public void testOneBundledOneNonBundled() throws Exception { + Combiner combiner = prepareOneBundledOneNonBundled(false); + BundledKeySegmentApplied combined = combiner.combine(); + assertThat(combined.getAccumulator()) + .isEqualTo(GenericRowData.of(GenericRowData.of(2L), 3L)); + assertThat(combined.getStartingValue()).isEqualTo(GenericRowData.of(0L, 1L)); + assertThat(combined.getFinalValue()).isEqualTo(GenericRowData.of(2L, 3L)); + assertThat(combined.getUpdatedValuesAfterEachRow()) + .isEqualTo( + Arrays.asList( + GenericRowData.of(0L, 1L), + GenericRowData.of(1L, 2L), + GenericRowData.of(2L, 3L))); + } + + @Test + public void testOneBundledOneNonBundledImplicitCounter() throws Exception { + Combiner combiner = prepareOneBundledOneNonBundled(true); + BundledKeySegmentApplied combined = combiner.combine(); + assertThat(combined.getAccumulator()) + .isEqualTo(GenericRowData.of(GenericRowData.of(2L), 3L)); + assertThat(combined.getStartingValue()).isEqualTo(GenericRowData.of(0L)); + assertThat(combined.getFinalValue()).isEqualTo(GenericRowData.of(2L)); + assertThat(combined.getUpdatedValuesAfterEachRow()) + .isEqualTo( + Arrays.asList( + GenericRowData.of(0L), + GenericRowData.of(1L), + GenericRowData.of(2L))); + } + + @Test + public void testTwoBundledOneNonBundled() throws Exception { + Combiner combiner = + new BundledResultCombiner( + RowType.of( + RowType.of(new BigIntType()), + RowType.of(new VarCharType(10)), + new BigIntType()), + RowType.of(new BigIntType(), new VarCharType(10), new BigIntType())) + .newCombiner(); + + combiner.add( + 0, + Optional.of(CompletableFuture.completedFuture(createBundledSum())), + Optional.empty(), + true, + true, + 0, + 1); + combiner.add( + 1, + Optional.of(CompletableFuture.completedFuture(createBundledStringAppend())), + Optional.empty(), + true, + true, + 1, + 2); + combiner.add(2, Optional.empty(), Optional.of(createNonBundledSum(2)), true, false, 2, 3); + BundledKeySegmentApplied combined = combiner.combine(); + assertThat(combined.getAccumulator()) + .isEqualTo( + GenericRowData.of( + GenericRowData.of(2L), + GenericRowData.of(StringData.fromString("abc")), + 3L)); + assertThat(combined.getStartingValue()) + .isEqualTo(GenericRowData.of(0L, StringData.fromString("a"), 1L)); + assertThat(combined.getFinalValue()) + .isEqualTo(GenericRowData.of(2L, StringData.fromString("abc"), 3L)); + assertThat(combined.getUpdatedValuesAfterEachRow()) + .isEqualTo( + Arrays.asList( + GenericRowData.of(0L, StringData.fromString("a"), 1L), + GenericRowData.of(1L, StringData.fromString("ab"), 2L), + GenericRowData.of(2L, StringData.fromString("abc"), 3L))); + } + + @Test + public void testTwoBundledOneNonBundledDifferentOrder() throws Exception { + Combiner combiner = + new BundledResultCombiner( + RowType.of( + RowType.of(new BigIntType()), + new BigIntType(), + RowType.of(new VarCharType(10))), + RowType.of(new BigIntType(), new BigIntType(), new VarCharType(10))) + .newCombiner(); + + combiner.add( + 0, + Optional.of(CompletableFuture.completedFuture(createBundledSum())), + Optional.empty(), + true, + true, + 0, + 1); + combiner.add(1, Optional.empty(), Optional.of(createNonBundledSum(1)), true, false, 1, 2); + combiner.add( + 2, + Optional.of(CompletableFuture.completedFuture(createBundledStringAppend())), + Optional.empty(), + true, + true, + 2, + 3); + BundledKeySegmentApplied combined = combiner.combine(); + assertThat(combined.getAccumulator()) + .isEqualTo( + GenericRowData.of( + GenericRowData.of(2L), + 3L, + GenericRowData.of(StringData.fromString("abc")))); + assertThat(combined.getStartingValue()) + .isEqualTo(GenericRowData.of(0L, 1L, StringData.fromString("a"))); + assertThat(combined.getFinalValue()) + .isEqualTo(GenericRowData.of(2L, 3L, StringData.fromString("abc"))); + assertThat(combined.getUpdatedValuesAfterEachRow()) + .isEqualTo( + Arrays.asList( + GenericRowData.of(0L, 1L, StringData.fromString("a")), + GenericRowData.of(1L, 2L, StringData.fromString("ab")), + GenericRowData.of(2L, 3L, StringData.fromString("abc")))); + } + + @Test + public void testOneBundledOneNonBundledWithMultiSizeAccumulator() throws Exception { + // Note that the accumulator now has two entries for the non bundled + Combiner combiner = + new BundledResultCombiner( + RowType.of( + RowType.of(new BigIntType()), + new BigIntType(), + new IntType()), + RowType.of(new BigIntType(), new BigIntType())) + .newCombiner(); + + combiner.add( + 0, + Optional.of(CompletableFuture.completedFuture(createBundledSum())), + Optional.empty(), + true, + true, + 0, + 1); + combiner.add( + 1, + Optional.empty(), + Optional.of( + new NonBundledSegmentResult( + GenericRowData.of(null, 2L, 6), + GenericRowData.of(null, 10L), + GenericRowData.of(null, 20L), + Arrays.asList( + GenericRowData.of(null, 10L), + GenericRowData.of(null, 15L), + GenericRowData.of(null, 20L)))), + true, + false, + 1, + 3); + BundledKeySegmentApplied combined = combiner.combine(); + assertThat(combined.getAccumulator()) + .isEqualTo(GenericRowData.of(GenericRowData.of(2L), 2L, 6)); + assertThat(combined.getStartingValue()).isEqualTo(GenericRowData.of(0L, 10L)); + assertThat(combined.getFinalValue()).isEqualTo(GenericRowData.of(2L, 20L)); + assertThat(combined.getUpdatedValuesAfterEachRow()) + .isEqualTo( + Arrays.asList( + GenericRowData.of(0L, 10L), + GenericRowData.of(1L, 15L), + GenericRowData.of(2L, 20L))); + } + + private BundledKeySegmentApplied createBundledSum() { + return new BundledKeySegmentApplied( + GenericRowData.of(2L), + GenericRowData.of(0L), + GenericRowData.of(2L), + Arrays.asList(GenericRowData.of(0L), GenericRowData.of(1L), GenericRowData.of(2L))); + } + + private BundledKeySegmentApplied createBundledStringAppend() { + return new BundledKeySegmentApplied( + GenericRowData.of(StringData.fromString("abc")), + GenericRowData.of(StringData.fromString("a")), + GenericRowData.of(StringData.fromString("abc")), + Arrays.asList( + GenericRowData.of(StringData.fromString("a")), + GenericRowData.of(StringData.fromString("ab")), + GenericRowData.of(StringData.fromString("abc")))); + } + + private NonBundledSegmentResult createNonBundledSum(int index) { + return new NonBundledSegmentResult( + createRowDataWith(index, 3L), + createRowDataWith(index, 1L), + createRowDataWith(index, 3L), + Arrays.asList( + createRowDataWith(index, 1L), + createRowDataWith(index, 2L), + createRowDataWith(index, 3L))); + } + + private GenericRowData createRowDataWith(int index, Object obj) { + Object[] fields = new Object[] {null, null, null}; + fields[index] = obj; + return GenericRowData.of(fields); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/agg/NonBundledAggregateUtilTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/agg/NonBundledAggregateUtilTest.java new file mode 100644 index 0000000000000..ea59884a3e412 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/agg/NonBundledAggregateUtilTest.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.codegen.agg; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.functions.agg.BundledKeySegment; +import org.apache.flink.table.planner.codegen.agg.NonBundledAggregateUtil.NonBundledSegmentResult; +import org.apache.flink.table.runtime.dataview.StateDataViewStore; +import org.apache.flink.table.runtime.generated.AggsHandleFunction; +import org.apache.flink.types.RowKind; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for {@link NonBundledAggregateUtil}. */ +public class NonBundledAggregateUtilTest { + + @Test + public void testAvg() throws Exception { + // This simulates an AVG call as the second value and second and third acc entries. + AggsHandleFunction handle = new AvgAggsHandleFunction(2, 1, 3, 1); + NonBundledSegmentResult result = + NonBundledAggregateUtil.executeAsBundle( + handle, + new BundledKeySegment( + null, + Arrays.asList(GenericRowData.of(2L), GenericRowData.of(4L)), + null, + false)); + assertThat(result.getAccumulator()).isEqualTo(GenericRowData.of(null, 2L, 6L)); + assertThat(result.getStartingValue()).isEqualTo(GenericRowData.of(null, null)); + assertThat(result.getFinalValue()).isEqualTo(GenericRowData.of(null, 3L)); + + result = + NonBundledAggregateUtil.executeAsBundle( + handle, + new BundledKeySegment( + null, + Arrays.asList( + GenericRowData.of(7L), + GenericRowData.of(9L), + GenericRowData.ofKind(RowKind.DELETE, 7L)), + GenericRowData.of(null, 2L, 6L), + false)); + assertThat(result.getAccumulator()).isEqualTo(GenericRowData.of(null, 3L, 15L)); + assertThat(result.getStartingValue()).isEqualTo(GenericRowData.of(null, 3L)); + assertThat(result.getFinalValue()).isEqualTo(GenericRowData.of(null, 5L)); + } + + private static class AvgAggsHandleFunction implements AggsHandleFunction { + + private final int totalValues; + private final int valueIndex; + private final int totalAccFields; + private final int accIndex; + + AvgAggsHandleFunction(int totalValues, int valueIndex, int totalAccFields, int accIndex) { + this.totalValues = totalValues; + this.valueIndex = valueIndex; + this.totalAccFields = totalAccFields; + this.accIndex = accIndex; + } + + private GenericRowData acc; + + @Override + public RowData getValue() throws Exception { + Object[] fields = new Object[totalValues]; + if (!(acc == null || acc.getLong(accIndex) == 0)) { + fields[valueIndex] = acc.getLong(accIndex + 1) / acc.getLong(accIndex); + } + return GenericRowData.of(fields); + } + + @Override + public void setWindowSize(int windowSize) {} + + @Override + public void open(StateDataViewStore store) throws Exception {} + + @Override + public void accumulate(RowData input) throws Exception { + acc.setField(accIndex, acc.getLong(accIndex) + 1); + acc.setField(accIndex + 1, acc.getLong(accIndex + 1) + input.getLong(0)); + } + + @Override + public void retract(RowData input) throws Exception { + acc.setField(accIndex, acc.getLong(accIndex) - 1); + acc.setField(accIndex + 1, acc.getLong(accIndex + 1) - input.getLong(0)); + } + + @Override + public void merge(RowData accumulators) throws Exception {} + + @Override + public void setAccumulators(RowData accumulators) throws Exception { + acc = (GenericRowData) accumulators; + } + + @Override + public void resetAccumulators() throws Exception {} + + @Override + public RowData getAccumulators() throws Exception { + return acc; + } + + @Override + public RowData createAccumulators() throws Exception { + Object[] fields = new Object[totalAccFields]; + fields[accIndex] = 0L; + fields[accIndex + 1] = 0L; + return GenericRowData.of(fields); + } + + @Override + public void cleanup() throws Exception {} + + @Override + public void close() throws Exception {} + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/BundledAggregateITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/BundledAggregateITCase.java new file mode 100644 index 0000000000000..ed8c815dbcec1 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/BundledAggregateITCase.java @@ -0,0 +1,469 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.runtime.stream.table; + +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.annotation.DataTypeHint; +import org.apache.flink.table.annotation.FunctionHint; +import org.apache.flink.table.api.EnvironmentSettings; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.api.TableResult; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.util.RowDataUtil; +import org.apache.flink.table.functions.AggregateFunction; +import org.apache.flink.table.functions.BundledAggregateFunction; +import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.agg.BundledKeySegment; +import org.apache.flink.table.functions.agg.BundledKeySegmentApplied; +import org.apache.flink.table.planner.runtime.utils.StreamingTestBase; +import org.apache.flink.types.Row; +import org.apache.flink.types.RowKind; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests {@link BundledAggregateFunction}. */ +public class BundledAggregateITCase extends StreamingTestBase { + + private TableEnvironment tEnv; + + @BeforeEach + public void before() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + tEnv = StreamTableEnvironment.create(env, EnvironmentSettings.inStreamingMode()); + } + + @Test + public void testGroupByWithRetractions() { + tEnv.createTemporarySystemFunction("func", new SumAggregate()); + final List results = + executeSql( + "select v1 % 2, func(v1) from (select k1, LAST_VALUE(v1) as v1 from (VALUES (1, 1), (2, 2), (5, 5), (2, 6), (1, 3)) AS t (k1, v1) group by k1) group by v1 % 2"); + final List expectedRows = + Arrays.asList( + Row.ofKind(RowKind.INSERT, 1, 1L), + Row.ofKind(RowKind.INSERT, 0, 2L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 1L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 6L), + Row.ofKind(RowKind.DELETE, 0, 2L), + Row.ofKind(RowKind.INSERT, 0, 6L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 6L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 5L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 5L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 8L)); + assertThat(results).containsSequence(expectedRows); + } + + @Test + public void testMissingRetractionSupport() { + tEnv.createTemporarySystemFunction("func", new SumAggregateNoRetraction()); + assertThatThrownBy( + () -> + executeSql( + "select v1 % 2, func(v1) from (select k1, LAST_VALUE(v1) as v1 from (VALUES (1, 1), (2, 2), (5, 5), (2, 6), (1, 3)) AS t (k1, v1) group by k1) group by v1 % 2")) + .hasMessageContaining( + "Retract functionality is not implemented for the aggregate function"); + } + + @Test + public void testNoRetractionsGlobal() { + tEnv.createTemporarySystemFunction("MySum", new SumAggregate()); + final List results = + executeSql( + "select MySum(v1) from (VALUES (1, 1), (2, 2), (5, 5), (2, 6), (1, 3)) AS t (k1, v1)"); + final List expectedRows = + Arrays.asList( + Row.ofKind(RowKind.INSERT, 1L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1L), + Row.ofKind(RowKind.UPDATE_AFTER, 3L), + Row.ofKind(RowKind.UPDATE_BEFORE, 3L), + Row.ofKind(RowKind.UPDATE_AFTER, 8L), + Row.ofKind(RowKind.UPDATE_BEFORE, 8L), + Row.ofKind(RowKind.UPDATE_AFTER, 14L), + Row.ofKind(RowKind.UPDATE_BEFORE, 14L), + Row.ofKind(RowKind.UPDATE_AFTER, 17L)); + assertThat(results).containsSequence(expectedRows); + } + + @Test + public void testNoRetractionsGlobalFilter() { + tEnv.createTemporarySystemFunction("MySum", new SumAggregate()); + assertThatThrownBy( + () -> + executeSql( + "select MySum(v1) FILTER(WHERE v1 % 2 = 0) from (VALUES (1, 1)) AS t (k1, v1)")) + .hasMessageContaining("Filter operations not handled on bundled aggregates yet"); + } + + @Test + public void testGroupByNoRetractions() { + tEnv.createTemporarySystemFunction("MySum", new SumAggregate()); + final List results = + executeSql( + "select v1 % 2, MySum(v1) from (VALUES (1, 1), (2, 2), (5, 5), (2, 6), (1, 3)) AS t (k1, v1) group by v1 % 2"); + final List expectedRows = + Arrays.asList( + Row.ofKind(RowKind.INSERT, 1, 1L), + Row.ofKind(RowKind.INSERT, 0, 2L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 1L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 6L), + Row.ofKind(RowKind.UPDATE_BEFORE, 0, 2L), + Row.ofKind(RowKind.UPDATE_AFTER, 0, 8L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 6L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 9L)); + assertThat(results).containsSequence(expectedRows); + } + + @Test + public void testGroupMultipleNoRetractions() { + tEnv.createTemporarySystemFunction("MySum", new SumAggregate()); + tEnv.createTemporarySystemFunction("MyAvg", new AvgAggregate()); + final List results = + executeSql( + "select v1 % 2, MySum(v1), MyAvg(v1) from (VALUES (1, 1), (2, 2), (5, 5), (2, 6), (1, 3)) AS t (k1, v1) group by v1 % 2"); + final List expectedRows = + Arrays.asList( + Row.ofKind(RowKind.INSERT, 1, 1L, 1L), + Row.ofKind(RowKind.INSERT, 0, 2L, 2L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 1L, 1L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 6L, 3L), + Row.ofKind(RowKind.UPDATE_BEFORE, 0, 2L, 2L), + Row.ofKind(RowKind.UPDATE_AFTER, 0, 8L, 4L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 6L, 3L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 9L, 3L)); + assertThat(results).containsSequence(expectedRows); + } + + @Test + public void testGroupByWithMultipleWithRetractions() { + tEnv.createTemporarySystemFunction("MySum", new SumAggregate()); + tEnv.createTemporarySystemFunction("MyAvg", new AvgAggregate()); + final List results = + executeSql( + "select v1 % 2, MySum(v1), MyAvg(v1) from (select k1, LAST_VALUE(v1) as v1 from (VALUES (1, 1), (2, 2), (5, 5), (2, 6), (1, 3)) AS t (k1, v1) group by k1) group by v1 % 2"); + final List expectedRows = + Arrays.asList( + Row.ofKind(RowKind.INSERT, 1, 1L, 1L), + Row.ofKind(RowKind.INSERT, 0, 2L, 2L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 1L, 1L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 6L, 3L), + Row.ofKind(RowKind.DELETE, 0, 2L, 2L), + Row.ofKind(RowKind.INSERT, 0, 6L, 6L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 6L, 3L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 5L, 5L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 5L, 5L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 8L, 4L)); + assertThat(results).containsSequence(expectedRows); + } + + @Test + public void testGroupMultipleOneSystem() { + tEnv.createTemporarySystemFunction("MySum", new SumAggregate()); + final List results = + executeSql( + "select v1 % 2, MySum(v1), AVG(v1) from (VALUES (1, 1), (2, 2), (5, 5), (2, 6), (1, 3)) AS t (k1, v1) group by v1 % 2"); + final List expectedRows = + Arrays.asList( + Row.ofKind(RowKind.INSERT, 1, 1L, 1), + Row.ofKind(RowKind.INSERT, 0, 2L, 2), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 1L, 1), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 6L, 3), + Row.ofKind(RowKind.UPDATE_BEFORE, 0, 2L, 2), + Row.ofKind(RowKind.UPDATE_AFTER, 0, 8L, 4), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 6L, 3), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 9L, 3)); + assertThat(results).containsSequence(expectedRows); + } + + @Test + public void testGroupMultipleOneSystemDifferentOrder() { + tEnv.createTemporarySystemFunction("MySum", new SumAggregate()); + final List results = + executeSql( + "select v1 % 2, AVG(v1), MySum(v1) from (VALUES (1, 1), (2, 2), (5, 5), (2, 6), (1, 3)) AS t (k1, v1) group by v1 % 2"); + final List expectedRows = + Arrays.asList( + Row.ofKind(RowKind.INSERT, 1, 1, 1L), + Row.ofKind(RowKind.INSERT, 0, 2, 2L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 1, 1L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 3, 6L), + Row.ofKind(RowKind.UPDATE_BEFORE, 0, 2, 2L), + Row.ofKind(RowKind.UPDATE_AFTER, 0, 4, 8L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 3, 6L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 3, 9L)); + assertThat(results).containsSequence(expectedRows); + } + + @Test + public void testGroupSingleBatchMultipleSystem() { + tEnv.createTemporarySystemFunction("MySum", new SumAggregate()); + final List results = + executeSql( + "select v1 % 2, MySum(v1), AVG(v1 + 10), AVG(v1) from (VALUES (1, 1), (2, 2), (5, 5), (2, 6), (1, 3)) AS t (k1, v1) group by v1 % 2"); + final List expectedRows = + Arrays.asList( + Row.ofKind(RowKind.INSERT, 1, 1L, 11, 1), + Row.ofKind(RowKind.INSERT, 0, 2L, 12, 2), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 1L, 11, 1), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 6L, 13, 3), + Row.ofKind(RowKind.UPDATE_BEFORE, 0, 2L, 12, 2), + Row.ofKind(RowKind.UPDATE_AFTER, 0, 8L, 14, 4), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 6L, 13, 3), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 9L, 13, 3)); + assertThat(results).containsSequence(expectedRows); + } + + @Test + public void testGroupMultipleOneConventional() { + tEnv.createTemporarySystemFunction("MySum", new SumAggregate()); + tEnv.createTemporarySystemFunction("MyAvg", new AvgAggregateFunction()); + final List results = + executeSql( + "select v1 % 2, MySum(v1), MyAvg(v1) from (VALUES (1, 1), (2, 2), (5, 5), (2, 6), (1, 3)) AS t (k1, v1) group by v1 % 2"); + final List expectedRows = + Arrays.asList( + Row.ofKind(RowKind.INSERT, 1, 1L, 1L), + Row.ofKind(RowKind.INSERT, 0, 2L, 2L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 1L, 1L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 6L, 3L), + Row.ofKind(RowKind.UPDATE_BEFORE, 0, 2L, 2L), + Row.ofKind(RowKind.UPDATE_AFTER, 0, 8L, 4L), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 6L, 3L), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 9L, 3L)); + assertThat(results).containsSequence(expectedRows); + } + + /** Average bundled aggregate function. */ + @FunctionHint(accumulator = @DataTypeHint("ROW")) + public static class AvgAggregate extends AggregateFunction + implements BundledAggregateFunction { + + private static final long serialVersionUID = 4585229396060575732L; + + @Override + public boolean canBundle() { + return true; + } + + @Override + public boolean canRetract() { + return true; + } + + private Long getAverage(Long sum, Long count) { + if (count == 0) { + return null; + } + return sum / count; + } + + public void bundledAccumulateRetract( + CompletableFuture future, BundledKeySegment segment) + throws Exception { + final GenericRowData acc; + if (segment.getAccumulator() == null) { + acc = GenericRowData.of(0L, 0L); + } else { + acc = (GenericRowData) segment.getAccumulator(); + } + + RowData previousValue = GenericRowData.of(getAverage(acc.getLong(0), acc.getLong(1))); + + List valueUpdates = new ArrayList<>(); + for (RowData row : segment.getRows()) { + if (RowDataUtil.isAccumulateMsg(row)) { + acc.setField(0, acc.getLong(0) + row.getLong(0)); + acc.setField(1, acc.getLong(1) + 1); + } else { + acc.setField(0, acc.getLong(0) - row.getLong(0)); + acc.setField(1, acc.getLong(1) - 1); + } + if (segment.getUpdatedValuesAfterEachRow()) { + valueUpdates.add(GenericRowData.of(getAverage(acc.getLong(0), acc.getLong(1)))); + } + } + + RowData newValue = GenericRowData.of(getAverage(acc.getLong(0), acc.getLong(1))); + future.complete( + new BundledKeySegmentApplied(acc, previousValue, newValue, valueUpdates)); + } + + public void accumulate(Row acc, Long value) { + throw new UnsupportedOperationException(); + } + + @Override + public Long getValue(Row accumulator) { + throw new UnsupportedOperationException(); + } + + @Override + public Row createAccumulator() { + throw new UnsupportedOperationException(); + } + } + + /** Sum bundled aggregate function. */ + @FunctionHint(accumulator = @DataTypeHint("ROW")) + public static class SumAggregate extends AggregateFunction + implements BundledAggregateFunction { + + private static final long serialVersionUID = -31497660659269145L; + + @Override + public boolean canBundle() { + return true; + } + + @Override + public boolean canRetract() { + return true; + } + + public void bundledAccumulateRetract( + CompletableFuture future, BundledKeySegment segment) + throws Exception { + final GenericRowData acc; + if (segment.getAccumulator() == null) { + acc = GenericRowData.of(0L); + } else { + acc = (GenericRowData) segment.getAccumulator(); + } + + RowData previousValue = GenericRowData.of(acc.getLong(0)); + + List valueUpdates = new ArrayList<>(); + for (RowData row : segment.getRows()) { + if (RowDataUtil.isAccumulateMsg(row)) { + acc.setField(0, acc.getLong(0) + row.getLong(0)); + } else { + acc.setField(0, acc.getLong(0) - row.getLong(0)); + } + if (segment.getUpdatedValuesAfterEachRow()) { + valueUpdates.add(GenericRowData.of(acc.getLong(0))); + } + } + + RowData newValue = GenericRowData.of(acc.getLong(0)); + BundledKeySegmentApplied result = + new BundledKeySegmentApplied(acc, previousValue, newValue, valueUpdates); + future.complete(result); + } + + public void accumulate(Row acc, Long value) { + throw new UnsupportedOperationException(); + } + + @Override + public Long getValue(Row accumulator) { + throw new UnsupportedOperationException(); + } + + @Override + public Row createAccumulator() { + throw new UnsupportedOperationException(); + } + } + + /** Sum bundled aggregate function without retraction. */ + public static class SumAggregateNoRetraction extends SumAggregate { + + @Override + public boolean canRetract() { + return false; + } + } + + private List executeSql(String sql) { + TableResult result = tEnv.executeSql(sql); + final List rows = new ArrayList<>(); + result.collect().forEachRemaining(rows::add); + return rows; + } + + /** Normal average aggregate function. */ + @FunctionHint(accumulator = @DataTypeHint("ROW")) + public static class AvgAggregateFunction extends AggregateFunction { + + @Override + public void open(FunctionContext context) { + System.out.println(); + } + + @Override + public Long getValue(Row acc) { + WrapperAvg wrapper = new WrapperAvg(acc); + if (wrapper.getCount() == 0) { + return null; + } + return (wrapper.getSum() / wrapper.getCount()); + } + + @Override + public Row createAccumulator() { + return Row.of(0L, 0L); + } + + public void accumulate(Row acc, Long value) { + WrapperAvg wrapper = new WrapperAvg(acc); + wrapper.addCount(1); + wrapper.addSum(value); + } + + public void retract(Row acc, Long value) { + WrapperAvg wrapper = new WrapperAvg(acc); + wrapper.addCount(-1); + wrapper.addSum(-value); + } + } + + private static class WrapperAvg { + private final Row acc; + + public WrapperAvg(Row acc) { + this.acc = acc; + } + + public long getSum() { + return (Long) acc.getField(0); + } + + public void addSum(long sum) { + acc.setField(0, getSum() + sum); + } + + public long getCount() { + return (Long) acc.getField(1); + } + + public void addCount(long count) { + acc.setField(1, getCount() + count); + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunctionBase.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunctionBase.java index 471af5a0959bf..a18ee77ac4589 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunctionBase.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/generated/AggsHandleFunctionBase.java @@ -22,8 +22,12 @@ import org.apache.flink.table.data.RowData; import org.apache.flink.table.functions.AggregateFunction; import org.apache.flink.table.functions.TableAggregateFunction; +import org.apache.flink.table.functions.agg.BundledKeySegment; +import org.apache.flink.table.functions.agg.BundledKeySegmentApplied; import org.apache.flink.table.runtime.dataview.StateDataViewStore; +import java.util.concurrent.CompletableFuture; + /** * The base class for handling aggregate or table aggregate functions. * @@ -52,6 +56,18 @@ public interface AggsHandleFunctionBase extends Function { */ void retract(RowData input) throws Exception; + /** + * Batches together both accumulate and retract calls to a single batch for improved + * performance. + * + * @param segment The segment + */ + default void bundledAccumulateRetract( + CompletableFuture future, BundledKeySegment segment) + throws Exception { + throw new UnsupportedOperationException(); + } + /** * Merges the other accumulators into current accumulators. * diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/BundledAggregateAsyncFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/BundledAggregateAsyncFunction.java new file mode 100644 index 0000000000000..f62566a2fd55d --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/BundledAggregateAsyncFunction.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.api.common.state.StateTtlConfig; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.utils.JoinedRowData; +import org.apache.flink.table.functions.agg.BundledKeySegment; +import org.apache.flink.table.functions.agg.BundledKeySegmentApplied; +import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore; +import org.apache.flink.table.runtime.generated.AggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.generated.RecordEqualiser; +import org.apache.flink.table.runtime.operators.aggregate.RecordCounter; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.types.RowKind; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.table.data.util.RowDataUtil.isRetractMsg; +import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; + +/** Aggregate Function used for the groupby, which is asynchronous. */ +public class BundledAggregateAsyncFunction + extends KeyedAsyncFunctionCommon { + + private final GeneratedAggsHandleFunction genAggsHandle; + private final GeneratedRecordEqualiser genRecordEqualiser; + private final LogicalType[] accTypes; + private final RowType inputType; + private final boolean generateUpdateBefore; + private final long stateRetentionTime; + private RecordCounter recordCounter; + + // function used to handle all aggregates + private transient AggsHandleFunction function = null; + + // stores the accumulators + private transient ValueState accState = null; + + // function used to equal RowData + private transient RecordEqualiser equaliser = null; + private transient OpenContext ctx; + + public BundledAggregateAsyncFunction( + GeneratedAggsHandleFunction genAggsHandle, + GeneratedRecordEqualiser genRecordEqualiser, + LogicalType[] accTypes, + RowType inputType, + int indexOfCountStar, + boolean generateUpdateBefore, + long stateRetentionTime) { + this.genAggsHandle = genAggsHandle; + this.genRecordEqualiser = genRecordEqualiser; + this.accTypes = accTypes; + this.inputType = inputType; + this.recordCounter = RecordCounter.of(indexOfCountStar); + this.generateUpdateBefore = generateUpdateBefore; + this.stateRetentionTime = stateRetentionTime; + } + + @Override + public void open(OpenContext ctx) throws Exception { + super.open(ctx); + this.ctx = ctx; + // instantiate function + StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime); + function = genAggsHandle.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); + function.open(new PerKeyStateDataViewStore(ctx.getRuntimeContext(), ttlConfig)); + + equaliser = + genRecordEqualiser.newInstance(ctx.getRuntimeContext().getUserCodeClassLoader()); + + InternalTypeInfo accTypeInfo = InternalTypeInfo.ofFields(accTypes); + ValueStateDescriptor accDesc = new ValueStateDescriptor<>("accState", accTypeInfo); + if (ttlConfig.isEnabled()) { + accDesc.enableTimeToLive(ttlConfig); + } + accState = ctx.getRuntimeContext().getState(accDesc); + } + + public void asyncInvokeProtected(RowData input, ResultFuture resultFuture) + throws Exception { + RowData acc = accState.value(); + if (acc == null && isRetractMsg(input)) { + handleResponseForAsyncInvoke( + () -> { + resultFuture.complete(Collections.emptyList()); + }); + return; + } + BundledKeySegment bundledKeySegment = + new BundledKeySegment( + ctx.currentKey(), + Collections.singletonList(input), + accState.value(), + false); + CompletableFuture result = new CompletableFuture<>(); + function.bundledAccumulateRetract(result, bundledKeySegment); + + handleResponseForAsyncInvoke( + result, + resultFuture::completeExceptionally, + bundledKeySegmentApplied -> { + resultFuture.complete( + handleResponse(bundledKeySegment, bundledKeySegmentApplied)); + }); + } + + private Collection handleResponse( + BundledKeySegment keySegment, BundledKeySegmentApplied updatedSegment) + throws Exception { + final boolean isFirstRow = keySegment.getAccumulator() == null; + RowData currentKey = ctx.currentKey(); + + // get previous aggregate result + RowData prevAggValue = updatedSegment.getStartingValue(); + + RowData acc = updatedSegment.getAccumulator(); + + // get new aggregate result + RowData newAggValue = updatedSegment.getFinalValue(); + + List output = new ArrayList<>(); + + if (!recordCounter.recordCountIsZero(acc)) { + // we aggregated at least one record for this key + // update acc to state accState.update(acc); + accState.update(acc); + + // if this was not the first row and we have to emit retractions + if (!isFirstRow) { + if (stateRetentionTime > 0 || !equaliser.equals(prevAggValue, newAggValue)) { + // new row is not same with prev row + if (generateUpdateBefore) { + // prepare UPDATE_BEFORE message for previous row + JoinedRowData resultRow = + new JoinedRowData(RowKind.UPDATE_BEFORE, currentKey, prevAggValue); + output.add(resultRow); + } + // prepare UPDATE_AFTER message for new row + JoinedRowData resultRow = + new JoinedRowData(RowKind.UPDATE_AFTER, currentKey, newAggValue); + output.add(resultRow); + } + // new row is same with prev row, no need to output + } else { + // this is the first, output new result + // prepare INSERT message for new row + JoinedRowData resultRow = + new JoinedRowData(RowKind.INSERT, currentKey, newAggValue); + output.add(resultRow); + } + + } else { + // we retracted the last record for this key + // if this is not first row sent out a DELETE message + if (!isFirstRow) { + // prepare DELETE message for previous row + JoinedRowData resultRow = + new JoinedRowData(RowKind.DELETE, currentKey, prevAggValue); + output.add(resultRow); + } + // and clear all state + accState.clear(); + // cleanup dataview under current key + function.cleanup(); + } + return output; + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/CallbackSequencer.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/CallbackSequencer.java new file mode 100644 index 0000000000000..34202a9acad66 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/CallbackSequencer.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +/** + * Sequences callbacks by registering and then invoking them. + * + * @param Metadata passed along to calls. Useful for passing arguments to callbacks. + * @param Context object which has a timer and mailbox available. + */ +public interface CallbackSequencer { + + /** A callback which can be invoked on a per-key basis. */ + interface Callback { + void callback(long timestamp, D data, C ctx) throws Exception; + } + + void callbackWhenNext(C ctx, long timestamp) throws Exception; + + /** + * Adds the caller to the queue and invokes the callback when they are at the front. + * + * @param ctx The context to store. + * @param timestamp The timestamp associated with the callback + * @param metadata Metadata to be passed along with the callback + */ + void callbackWhenNext(C ctx, long timestamp, D metadata) throws Exception; + + /** + * Invokes the callback for the next waiter. + * + * @param ctx The context + */ + void notifyNextWaiter(C ctx) throws Exception; +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncFunction.java new file mode 100644 index 0000000000000..70c292b910ebd --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncFunction.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.table.runtime.context.ExecutionContext; +import org.apache.flink.util.function.ThrowingRunnable; + +/** An {@link AsyncFunction} that is meant to work on keyed elements. */ +public interface KeyedAsyncFunction extends AsyncFunction { + + /** + * Opens the function. This method is called before any methods are invoked. + * + * @param context The context for this function. + */ + default void open(OpenContext context) throws Exception {} + + /** Closes the function. This method is called after all methods have been invoked. */ + default void close() throws Exception {} + + /** Context passed to {@link #open(OpenContext)}. */ + interface OpenContext extends ExecutionContext { + /** + * Runs the given runnable on the mailbox thread. This is useful to writing single-threaded + * functions so that async results can be processed on the same thread which calls + * asyncInvoke and onTimer. + */ + void runOnMailboxThread(ThrowingRunnable runnable); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncFunctionCommon.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncFunctionCommon.java new file mode 100644 index 0000000000000..67f650acf725f --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncFunctionCommon.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.table.data.RowData; +import org.apache.flink.util.function.ThrowingConsumer; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * A common base class for {@link KeyedAsyncFunction} that provides the basic functionality to + * handle async operations. These include handling async invocations and timers. Namely, it provides + * per-key sequencing of async operations and timers. The subclass should implement the calls + * asyncInvokeProtected and onTimerProtected to handle sequenced invocations of asyncInvoke and + * onTimer. + */ +public abstract class KeyedAsyncFunctionCommon + implements KeyedAsyncFunction { + + /** Context for handling async invocations. */ + private transient OpenContext openContext; + + /** Sequencer for handling async invocations. */ + private transient PerKeyCallbackSequencer>, OpenContext> + asyncInvokeSequencer; + + /** Request id for sequencing asyncInvoke calls in order. */ + private transient long requestId = 0; + + public void open(OpenContext ctx) throws Exception { + this.openContext = ctx; + asyncInvokeSequencer = + new PerKeyCallbackSequencer<>( + (reqId, p, context) -> asyncInvokeProtected(p.getKey(), p.getRight())); + } + + @Override + public void asyncInvoke(IN input, ResultFuture resultFuture) throws Exception { + asyncInvokeSequencer.callbackWhenNext( + openContext, requestId++, Pair.of(input, resultFuture)); + } + + /** + * This method is called for each asyncInvoke call. The subclass should implement this method to + * handle an async operation after every asyncInvoke invocation. Implementers should call {@code + * handleResponseForAsyncInvoke} to handle an async response. + * + * @param input The input row. + * @param resultFuture The future to complete once the async operation is done. + */ + protected void asyncInvokeProtected(IN input, ResultFuture resultFuture) + throws Exception {} + + /** + * Handles the response for an async operation done in {@code asyncInvokeProtected}. This method + * should be called after an async operation is started. It will notify the next waiter in the + * sequence after completion. + * + * @param result The result to run. + */ + public void handleResponseForAsyncInvoke(Runnable result) throws Exception { + try { + result.run(); + } finally { + asyncInvokeSequencer.notifyNextWaiter(openContext); + } + } + + /** + * Handles the response for an async operation done in {@code asyncInvokeProtected}. This method + * should be called after an async operation is started. It will notify the next waiter in the + * sequence after completion. + * + * @param future The future to complete once the async operation is done. + * @param handleError A consumer to handle an error. + * @param handleSuccess A consumer to handle a successful response. + */ + public void handleResponseForAsyncInvoke( + CompletableFuture future, + Consumer handleError, + ThrowingConsumer handleSuccess) { + RowData currentKey = openContext.currentKey(); + future.whenComplete( + (result, t) -> { + if (t != null) { + openContext.runOnMailboxThread( + () -> { + openContext.setCurrentKey(currentKey); + handleError.accept(t); + asyncInvokeSequencer.notifyNextWaiter(openContext); + }); + return; + } + + openContext.runOnMailboxThread( + () -> { + openContext.setCurrentKey(currentKey); + try { + handleSuccess.accept(result); + } catch (Exception e) { + handleError.accept(e); + } finally { + asyncInvokeSequencer.notifyNextWaiter(openContext); + } + }); + }); + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncWaitOperator.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncWaitOperator.java new file mode 100644 index 0000000000000..17016c92303de --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncWaitOperator.java @@ -0,0 +1,537 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.operators.MailboxExecutor; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.streaming.api.functions.async.CollectionSupplier; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.MailboxWatermarkProcessor; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.operators.aggregate.async.queue.KeyedAsyncOutputMode; +import org.apache.flink.table.runtime.operators.aggregate.async.queue.KeyedStreamElementQueue; +import org.apache.flink.table.runtime.operators.aggregate.async.queue.KeyedStreamElementQueueImpl; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.ThrowingConsumer; +import org.apache.flink.util.function.ThrowingRunnable; + +import javax.annotation.Nonnull; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; + +/** + * The {@link KeyedAsyncWaitOperator} allows to asynchronously process incoming stream records. For + * that the operator creates an {@link ResultFuture} which is passed to an {@link + * KeyedAsyncFunction}. Within the async function, the user can complete the async collector + * arbitrarily. Once the async collector has been completed, the result is emitted by the operator's + * emitter to downstream operators. + * + *

The operator offers different output modes depending on the chosen {@link + * KeyedAsyncOutputMode}. In order to give exactly once processing guarantees, the operator stores + * all currently in-flight {@link StreamElement} in per-key state. Upon recovery the recorded set of + * stream elements is replayed. + * + *

Because {@link KeyedAsyncFunction}s can utilize row-based timers, retries are not supported by + * this operator. Moving the current watermark forward prevents timers from being retried. If + * retries are desired, they should be internal to the {@link KeyedAsyncFunction}. + * + *

In case of chaining of this operator, it has to be made sure that the operators in the chain + * are opened tail to head. The reason for this is that an opened {@link KeyedAsyncWaitOperator} + * starts already emitting recovered {@link StreamElement} to downstream operators. + * + * @param Key type for the operator. + * @param Input type for the operator. + * @param Output type for the operator. + */ +@Internal +public class KeyedAsyncWaitOperator + extends AbstractUdfStreamOperator> + implements OneInputStreamOperator, BoundedOneInput { + private static final long serialVersionUID = 1L; + + private static final String STATE_NAME = "_keyed_async_wait_operator_state_"; + + /** Capacity of the stream element queue. */ + private final int capacity; + + /** Output mode for this operator. */ + private final KeyedAsyncOutputMode outputMode; + + /** Timeout for the async collectors. */ + private final long timeout; + + /** {@link TypeSerializer} for inputs while making snapshots. */ + private transient StreamElementSerializer inStreamElementSerializer; + + /** Recovered input stream elements. */ + private transient Map> recoveredStreamElements; + + /** Queue, into which to store the currently in-flight stream elements. */ + private transient KeyedStreamElementQueue queue; + + /** Mailbox executor used to yield while waiting for buffers to empty. */ + private final transient MailboxExecutor mailboxExecutor; + + private transient TimestampedCollector timestampedCollector; + + /** Whether object reuse has been enabled or disabled. */ + private transient boolean isObjectReuseEnabled; + + private transient TimerService timerService; + + public KeyedAsyncWaitOperator( + @Nonnull KeyedAsyncFunction asyncFunction, + long timeout, + int capacity, + @Nonnull KeyedAsyncOutputMode outputMode, + @Nonnull ProcessingTimeService processingTimeService, + @Nonnull MailboxExecutor mailboxExecutor) { + super(asyncFunction); + + Preconditions.checkArgument( + capacity > 0, "The number of concurrent async operation should be greater than 0."); + this.capacity = capacity; + + this.outputMode = Preconditions.checkNotNull(outputMode, "outputMode"); + + this.timeout = timeout; + + this.processingTimeService = Preconditions.checkNotNull(processingTimeService); + + this.mailboxExecutor = mailboxExecutor; + } + + @Override + public void setup( + StreamTask containingTask, + StreamConfig config, + Output> output) { + super.setup(containingTask, config, output); + + this.inStreamElementSerializer = + new StreamElementSerializer<>( + getOperatorConfig().getTypeSerializerIn1(getUserCodeClassloader())); + + switch (outputMode) { + case ORDERED: + queue = KeyedStreamElementQueueImpl.createOrderedQueue(capacity); + break; + default: + throw new IllegalStateException("Unknown async mode: " + outputMode + '.'); + } + + this.timestampedCollector = new TimestampedCollector<>(super.output); + } + + @Override + public void open() throws Exception { + super.open(); + + userFunction.open(new OpenContextImpl(this, getRuntimeContext(), mailboxExecutor)); + + this.isObjectReuseEnabled = getExecutionConfig().isObjectReuseEnabled(); + + if (recoveredStreamElements != null) { + for (Map.Entry> e : recoveredStreamElements.entrySet()) { + List elementList = e.getValue(); + setCurrentKey(e.getKey()); + for (StreamElement element : elementList) { + if (element.isRecord()) { + processElement(element.asRecord()); + } else if (element.isWatermark()) { + processWatermark(element.asWatermark()); + } else if (element.isLatencyMarker()) { + processLatencyMarker(element.asLatencyMarker()); + } else { + throw new IllegalStateException( + "Unknown record type " + + element.getClass() + + " encountered while opening the operator."); + } + } + } + recoveredStreamElements = null; + } + } + + @Override + public void close() throws Exception { + userFunction.close(); + } + + @Override + public void processElement(StreamRecord record) throws Exception { + StreamRecord element; + // copy the element avoid the element is reused + if (isObjectReuseEnabled) { + //noinspection unchecked + element = (StreamRecord) inStreamElementSerializer.copy(record); + } else { + element = record; + } + + // add element first to the queue + final ResultFuture entry = addToWorkQueue(element); + + final ResultHandler resultHandler = new ResultHandler(element, entry); + + // register a timeout for the entry if timeout is configured + if (timeout > 0L) { + resultHandler.registerTimeout(getProcessingTimeService(), timeout); + } + + userFunction.asyncInvoke(element.getValue(), resultHandler); + } + + @Override + public void processWatermark(Watermark mark) throws Exception { + MailboxWatermarkProcessor.WatermarkEmitter addToQueue = + w -> { + addToWorkQueue(mark); + + // watermarks are always completed + // if there is no prior element, we can directly emit them + // this also avoids watermarks being held back until the next element has been + // processed + outputCompletedElement(); + }; + + if (watermarkProcessor != null) { + watermarkProcessor.emitWatermarkInsideMailbox(mark, addToQueue); + } else { + if (getTimeServiceManager().isPresent()) { + getTimeServiceManager().get().advanceWatermark(mark); + } + + addToQueue.emitWatermark(mark); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + + for (Map.Entry> entry : queue.valuesByKey().entrySet()) { + setCurrentKey(entry.getKey()); + + ListState partitionableState = + getKeyedStateStore() + .getListState( + new ListStateDescriptor<>( + STATE_NAME, inStreamElementSerializer)); + + try { + partitionableState.update(entry.getValue()); + } catch (Exception e) { + partitionableState.clear(); + + throw new Exception( + "Could not add stream element queue entries to operator state " + + "backend of operator " + + getOperatorName() + + '.', + e); + } + } + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + List keys = + this.getKeyedStateBackend() + .getKeys(STATE_NAME, VoidNamespace.INSTANCE) + .collect(Collectors.toList()); + recoveredStreamElements = new LinkedHashMap<>(); + for (K key : keys) { + setCurrentKey(key); + ListState state = + context.getKeyedStateStore() + .getListState( + new ListStateDescriptor<>( + STATE_NAME, inStreamElementSerializer)); + + List streamElements = new ArrayList<>(); + for (StreamElement elements : state.get()) { + streamElements.add(elements); + } + recoveredStreamElements.put(key, streamElements); + } + } + + @Override + public void endInput() throws Exception { + // we should finish all in fight delayed retry immediately. + finishInFlightDelayedRetry(); + + // we should wait here for the data in flight to be finished. the reason is that the + // timer not in running will be forbidden to fire after this, so that when the async + // operation is stuck, it results in deadlock due to what the timeout timer is not fired + waitInFlightInputsFinished(); + } + + /** + * Add the given stream element to the operator's stream element queue. This operation blocks + * until the element has been added. + * + *

Between two insertion attempts, this method yields the execution to the mailbox, such that + * events as well as asynchronous results can be processed. + * + * @param streamElement to add to the operator's queue + * @throws InterruptedException if the current thread has been interrupted while yielding to + * mailbox + * @return a handle that allows to set the result of the async computation for the given + * element. + */ + private ResultFuture addToWorkQueue(StreamElement streamElement) + throws InterruptedException { + Optional> queueEntry; + while (!(queueEntry = queue.tryPut((K) getCurrentKey(), streamElement)).isPresent()) { + mailboxExecutor.yield(); + } + return queueEntry.get(); + } + + private void finishInFlightDelayedRetry() throws Exception {} + + private void waitInFlightInputsFinished() throws InterruptedException { + + while (!queue.isEmpty()) { + mailboxExecutor.yield(); + } + } + + /** + * Outputs one completed element. Watermarks are always completed if it's their turn to be + * processed. + * + *

This method will be called from {@link #processWatermark(Watermark)} and from a mail + * processing the result of an async function call. + */ + private void outputCompletedElement() { + if (queue.hasCompletedElements()) { + // emit only one element to not block the mailbox thread unnecessarily + queue.emitCompletedElement(timestampedCollector); + + // if there are more completed elements, emit them with subsequent mails + if (queue.hasCompletedElements()) { + try { + mailboxExecutor.execute( + this::outputCompletedElement, + "AsyncWaitOperator#outputCompletedElement"); + } catch (RejectedExecutionException mailboxClosedException) { + // This exception can only happen if the operator is cancelled which means all + // pending records can be safely ignored since they will be processed one more + // time after recovery. + LOG.debug( + "Attempt to complete element is ignored since the mailbox rejected the execution.", + mailboxClosedException); + } + } + } + } + + /** Utility method to register timeout timer. */ + private ScheduledFuture registerTimer( + ProcessingTimeService processingTimeService, + long timeout, + ThrowingConsumer callback) { + final long timeoutTimestamp = timeout + processingTimeService.getCurrentProcessingTime(); + + return processingTimeService.registerTimer( + timeoutTimestamp, timestamp -> callback.accept(null)); + } + + /** A handler for the results of a specific input record. */ + private class ResultHandler implements ResultFuture { + /** Optional timeout timer used to signal the timeout to the AsyncFunction. */ + private ScheduledFuture timeoutTimer; + + /** Record for which this result handler exists. Used only to report errors. */ + private final StreamRecord inputRecord; + + /** + * The handle received from the queue to update the entry. Should only be used to inject the + * result; exceptions are handled here. + */ + private final ResultFuture resultFuture; + + /** + * A guard against ill-written AsyncFunction. Additional (parallel) invokations of {@link + * #complete(Collection)} or {@link #completeExceptionally(Throwable)} will be ignored. This + * guard also helps for cases where proper results and timeouts happen at the same time. + */ + private final AtomicBoolean completed = new AtomicBoolean(false); + + ResultHandler(StreamRecord inputRecord, ResultFuture resultFuture) { + this.inputRecord = inputRecord; + this.resultFuture = resultFuture; + } + + @Override + public void complete(Collection results) { + + // already completed (exceptionally or with previous complete call from ill-written + // AsyncFunction), so + // ignore additional result + if (!completed.compareAndSet(false, true)) { + return; + } + + processInMailbox(results); + } + + @Override + public void complete(CollectionSupplier runnable) { + // already completed (exceptionally or with previous complete call from ill-written + // AsyncFunction), so ignore additional result + if (!completed.compareAndSet(false, true)) { + return; + } + mailboxExecutor.execute( + () -> { + // If there is an exception, let it bubble up and fail the job. + processResults(runnable.get()); + }, + "ResultHandler#complete"); + } + + private void processInMailbox(Collection results) { + // move further processing into the mailbox thread + mailboxExecutor.execute( + () -> processResults(results), + "Result in AsyncWaitOperator of input %s", + results); + } + + private void processResults(Collection results) { + // Cancel the timer once we've completed the stream record buffer entry. This will + // remove the registered + // timer task + if (timeoutTimer != null) { + // canceling in mailbox thread avoids + // https://issues.apache.org/jira/browse/FLINK-13635 + timeoutTimer.cancel(true); + } + + // update the queue entry with the result + resultFuture.complete(results); + // now output all elements from the queue that have been completed (in the correct + // order) + outputCompletedElement(); + } + + @Override + public void completeExceptionally(Throwable error) { + // already completed, so ignore exception + if (!completed.compareAndSet(false, true)) { + return; + } + + // signal failure through task + getContainingTask() + .getEnvironment() + .failExternally( + new Exception( + "Could not complete the stream element: .", error)); + + // complete with empty result, so that we remove timer and move ahead processing (to + // leave potentially + // blocking section in #addToWorkQueue or #waitInFlightInputsFinished) + processInMailbox(Collections.emptyList()); + } + + private void registerTimeout(ProcessingTimeService processingTimeService, long timeout) { + timeoutTimer = registerTimer(processingTimeService, timeout, t -> timerTriggered()); + } + + private void timerTriggered() throws Exception { + if (!completed.get()) { + userFunction.timeout(inputRecord.getValue(), this); + } + } + } + + private static class OpenContextImpl implements KeyedAsyncFunction.OpenContext { + + private final AbstractStreamOperator operator; + private final RuntimeContext runtimeContext; + private final MailboxExecutor mailboxExecutor; + + public OpenContextImpl( + KeyedAsyncWaitOperator operator, + RuntimeContext runtimeContext, + MailboxExecutor mailboxExecutor) { + this.operator = Preconditions.checkNotNull(operator); + this.runtimeContext = Preconditions.checkNotNull(runtimeContext); + this.mailboxExecutor = mailboxExecutor; + } + + @Override + public void runOnMailboxThread(ThrowingRunnable runnable) { + mailboxExecutor.execute(runnable, "keyedAsyncWaitOperator.runOnMailboxThread"); + } + + @Override + public RowData currentKey() { + return (RowData) operator.getCurrentKey(); + } + + @Override + public void setCurrentKey(RowData key) { + operator.setCurrentKey(key); + } + + @Override + public RuntimeContext getRuntimeContext() { + return runtimeContext; + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncWaitOperatorFactory.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncWaitOperatorFactory.java new file mode 100644 index 0000000000000..45b729d1c6f43 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncWaitOperatorFactory.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory; +import org.apache.flink.streaming.api.operators.ChainingStrategy; +import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.async.AsyncWaitOperator; +import org.apache.flink.streaming.api.operators.legacy.YieldingOperatorFactory; +import org.apache.flink.table.runtime.operators.aggregate.async.queue.KeyedAsyncOutputMode; + +/** + * The factory of {@link KeyedAsyncWaitOperator}. + * + * @param The output type of the operator + */ +public class KeyedAsyncWaitOperatorFactory extends AbstractStreamOperatorFactory + implements OneInputStreamOperatorFactory, YieldingOperatorFactory { + private final KeyedAsyncFunction asyncFunction; + private final long timeout; + private final int capacity; + private final KeyedAsyncOutputMode outputMode; + + public KeyedAsyncWaitOperatorFactory( + KeyedAsyncFunction asyncFunction, + long timeout, + int capacity, + KeyedAsyncOutputMode outputMode) { + this.asyncFunction = asyncFunction; + this.timeout = timeout; + this.capacity = capacity; + this.outputMode = outputMode; + this.chainingStrategy = ChainingStrategy.ALWAYS; + } + + @Override + public > T createStreamOperator( + StreamOperatorParameters parameters) { + KeyedAsyncWaitOperator asyncWaitOperator = + new KeyedAsyncWaitOperator<>( + asyncFunction, + timeout, + capacity, + outputMode, + processingTimeService, + getMailboxExecutor()); + asyncWaitOperator.setup( + parameters.getContainingTask(), + parameters.getStreamConfig(), + parameters.getOutput()); + return (T) asyncWaitOperator; + } + + @Override + public Class getStreamOperatorClass(ClassLoader classLoader) { + return AsyncWaitOperator.class; + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/MultiDelegatingAsyncResultFuture.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/MultiDelegatingAsyncResultFuture.java new file mode 100644 index 0000000000000..6e920d7284e58 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/MultiDelegatingAsyncResultFuture.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SupplierWithException; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; + +/** + * A class which collects multiple async calls and completes when all of them are done, calling back + * on an underlying {@link CompletableFuture}. + */ +public class MultiDelegatingAsyncResultFuture implements BiConsumer { + + /** The future that is completed when all async calls are done. */ + private final CompletableFuture delegatedResultFuture; + + /** The total number of async calls expected. */ + private final int totalAsyncCalls; + + /** The futures of the individual async calls. */ + private List> futures = new ArrayList<>(); + + private SupplierWithException resultSupplier; + + public MultiDelegatingAsyncResultFuture( + CompletableFuture delegatedResultFuture, int totalAsyncCalls) { + this.delegatedResultFuture = delegatedResultFuture; + this.totalAsyncCalls = totalAsyncCalls; + } + + /** Creates a new future for one of the expected async calls, for it to call when it is done. */ + public CompletableFuture createAsyncFuture() { + CompletableFuture future = new CompletableFuture<>(); + futures.add(future); + Preconditions.checkState(futures.size() <= totalAsyncCalls, "Too many async calls."); + + if (futures.size() == totalAsyncCalls) { + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).whenComplete(this); + } + return future; + } + + /** + * Sets the supplier that will be called to produce the final result once the async calls + * complete. + */ + public void setResultSupplier(SupplierWithException resultSupplier) { + this.resultSupplier = resultSupplier; + } + + @Override + public void accept(Object object, Throwable throwable) { + if (throwable != null) { + delegatedResultFuture.completeExceptionally(throwable); + } else { + try { + delegatedResultFuture.complete(resultSupplier.get()); + } catch (Throwable t) { + delegatedResultFuture.completeExceptionally(t); + } + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/PerKeyCallbackSequencer.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/PerKeyCallbackSequencer.java new file mode 100644 index 0000000000000..5602356ba7e6d --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/PerKeyCallbackSequencer.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.table.data.RowData; + +import java.util.HashMap; +import java.util.Map; +import java.util.PriorityQueue; + +/** + * Sequences callbacks on a per-key basis, so that different keys are allowed to pass concurrently, + * but same keys are sequenced serially. + * + * @param Metadata passed along to calls. Useful for passing arguments to callbacks. + * @param Context object which has a timer and mailbox available. + */ +public class PerKeyCallbackSequencer + implements CallbackSequencer { + private final Callback callback; + private final Map> waitList; + + public PerKeyCallbackSequencer(Callback callback) { + this.callback = callback; + this.waitList = new HashMap<>(); + } + + @Override + public void callbackWhenNext(C ctx, long timestamp) throws Exception { + callbackWhenNext(ctx, timestamp, null); + } + + @Override + public void callbackWhenNext(C ctx, long timestamp, D metadata) throws Exception { + PriorityQueue result = + waitList.compute( + ctx.currentKey(), + (k, v) -> { + if (v == null) { + v = new PriorityQueue<>(); + } + v.add(new Data(timestamp, metadata)); + return v; + }); + if (result.size() == 1) { + runNextWaiter(ctx, result); + } + } + + @Override + public void notifyNextWaiter(C ctx) throws Exception { + PriorityQueue pq = waitList.get(ctx.currentKey()); + pq.poll(); + if (!pq.isEmpty()) { + runNextWaiter(ctx, pq); + } else { + waitList.remove(ctx.currentKey()); + } + } + + private void runNextWaiter(C ctx, PriorityQueue pq) throws Exception { + if (!pq.isEmpty()) { + RowData key = ctx.currentKey(); + Data data = pq.peek(); + + ctx.setCurrentKey(key); + callback.callback(data.timestamp, data.data, ctx); + } + } + + private class Data implements Comparable { + private final long timestamp; + private final D data; + + public Data(long timestamp, D data) { + this.timestamp = timestamp; + this.data = data; + } + + @Override + public int compareTo(PerKeyCallbackSequencer.Data o) { + return Long.compare(timestamp, o.timestamp); + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedAsyncOutputMode.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedAsyncOutputMode.java new file mode 100644 index 0000000000000..668af5ebafee8 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedAsyncOutputMode.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async.queue; + +import org.apache.flink.table.runtime.operators.aggregate.async.KeyedAsyncWaitOperator; + +/** The mode for outputting results in the {@link KeyedAsyncWaitOperator}. */ +public enum KeyedAsyncOutputMode { + /** + * Outputs results in the order in which elements were processed, regardless of when they + * asynchronously complete. + */ + ORDERED +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueue.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueue.java new file mode 100644 index 0000000000000..45a767accbb2e --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueue.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async.queue; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.table.runtime.operators.aggregate.async.KeyedAsyncWaitOperator; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** Interface for stream element queues for the {@link KeyedAsyncWaitOperator}. */ +@Internal +public interface KeyedStreamElementQueue { + + /** + * Tries to put the given element in the queue. This operation succeeds if the queue has + * capacity left and fails if the queue is full. + * + *

This method returns a handle to the inserted element that allows to set the result of the + * computation. + * + * @param streamElement the element to be inserted. + * @return A handle to the element if successful or {@link Optional#empty()} otherwise. + */ + Optional> tryPut(K key, StreamElement streamElement); + + /** + * Emits one completed element from the head of this queue into the given output. + * + *

Will not emit any element if no element has been completed (check {@link + * #hasCompletedElements()} before entering any critical section). + * + * @param output the output into which to emit + */ + void emitCompletedElement(TimestampedCollector output); + + /** + * Checks if there is at least one completed head element. + * + * @return True if there is a completed head element. + */ + boolean hasCompletedElements(); + + /** + * Returns the Map of {@link StreamElement} currently contained in this queue for checkpointing, + * grouped by key. + * + *

This includes all non-emitted, completed and non-completed elements. + * + * @return Map of currently contained {@link StreamElement}. + */ + Map> valuesByKey(); + + /** + * True if the queue is empty; otherwise false. + * + * @return True if the queue is empty; otherwise false. + */ + boolean isEmpty(); + + /** + * Return the size of the queue. + * + * @return The number of elements contained in this queue. + */ + int size(); +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueueEntry.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueueEntry.java new file mode 100644 index 0000000000000..6b0d66bb1be99 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueueEntry.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async.queue; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.streaming.api.functions.async.CollectionSupplier; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; + +import javax.annotation.Nonnull; + +/** + * An entry for the {@link KeyedStreamElementQueue}. The stream element queue entry stores the + * {@link StreamElement} for which the stream element queue entry has been instantiated. + * Furthermore, it allows to set the result of a completed entry through {@link ResultFuture}. + */ +@Internal +public interface KeyedStreamElementQueueEntry extends ResultFuture { + + /** + * True if the stream element queue entry has been completed; otherwise false. + * + * @return True if the stream element queue entry has been completed; otherwise false. + */ + boolean isDone(); + + /** + * Emits the results associated with this queue entry. + * + * @param output the output into which to emit. + */ + void emitResult(TimestampedCollector output); + + /** + * The input element for this queue entry, for which the calculation is performed + * asynchronously. + * + * @return the input element. + */ + @Nonnull + StreamElement getInputElement(); + + /** Not supported. Exceptions must be handled in the AsyncWaitOperator. */ + @Override + default void completeExceptionally(Throwable error) { + throw new UnsupportedOperationException( + "This result future should only be used to set completed results."); + } + + default void complete(CollectionSupplier supplier) { + throw new UnsupportedOperationException(); + } + + /** Returns the key associated with the entry. */ + K getKey(); +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueueImpl.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueueImpl.java new file mode 100644 index 0000000000000..fa876f5b1036a --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueueImpl.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async.queue; + +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nonnull; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TreeSet; +import java.util.function.Supplier; + +/** + * Uses completion time to determine when to output results. Note that watermarks are output + * immediately after being added since it's assumed that any timers have been processed and outputs + * made by the time the watermark is added to this queue and potentially outputted. + * + * @param + */ +public class KeyedStreamElementQueueImpl implements KeyedStreamElementQueue { + + private static final Logger LOG = LoggerFactory.getLogger(KeyedStreamElementQueueImpl.class); + + /** Queue for the inserted KeyedStreamElementQueueEntries. */ + private final TreeSet> queue; + + /** Maximum active request capacity of the queue. */ + private final int activeCapacity; + + /** Flag indicating whether to wait for watermarks before emitting elements. */ + private final boolean waitForWatermark; + + /** Counter for the creation order of the StreamElementQueueEntries. */ + private long creationCount = 0; + + /** Number of active (in-flight) requests in the queue. */ + private int active = 0; + + /** Number of watermarks in the queue. */ + private int watermarksInQueue = 0; + + private KeyedStreamElementQueueImpl( + int activeCapacity, + Comparator> comparator, + boolean waitForWatermark) { + this.activeCapacity = activeCapacity; + this.queue = new TreeSet<>(comparator); + this.waitForWatermark = waitForWatermark; + } + + /** + * Creates a new {@link KeyedStreamElementQueue} that uses the order of the input elements to + * determine the order of output. + * + * @param activeCapacity The maximum number of in-flight requests that can be made at once. + * @param The type of the key. + * @param The type of the output elements. + * @return A new {@link KeyedStreamElementQueue} that uses the order of the input elements to + * determine the order of the elements. + */ + public static KeyedStreamElementQueue createOrderedQueue(int activeCapacity) { + return new KeyedStreamElementQueueImpl<>( + activeCapacity, + Comparator.comparingLong(CreationOrderEntry::getCreationOrder), + false); + } + + @Override + public boolean hasCompletedElements() { + return !queue.isEmpty() + && queue.first().isDone() + && (!waitForWatermark || watermarksInQueue > 0); + } + + @Override + public void emitCompletedElement(TimestampedCollector output) { + if (hasCompletedElements()) { + KeyedStreamElementQueueEntry element = queue.pollFirst(); + if (element != null) { + if (element.getInputElement().isWatermark()) { + watermarksInQueue--; + } + element.emitResult(output); + } + } + } + + @Override + public Map> valuesByKey() { + Map> map = new HashMap<>(this.queue.size()); + for (KeyedStreamElementQueueEntry e : queue) { + if (e.getInputElement().isWatermark()) { + continue; + } + List list; + if (!map.containsKey(e.getKey())) { + list = new ArrayList<>(); + map.put(e.getKey(), list); + } else { + list = map.get(e.getKey()); + } + list.add(e.getInputElement()); + } + return map; + } + + @Override + public boolean isEmpty() { + return queue.isEmpty(); + } + + @Override + public int size() { + return queue.size(); + } + + @Override + public Optional> tryPut(K key, StreamElement streamElement) { + // Watermarks are always put into the queue and don't count towards the capacity. + if (active < activeCapacity || streamElement.isWatermark()) { + CreationOrderEntry queueEntry = createEntry(key, streamElement); + queue.add(queueEntry); + + LOG.debug( + "Put element into ordered stream element queue. New filling degree " + + "({}/{}).", + queue.size(), + activeCapacity); + + if (streamElement.isRecord()) { + active++; + } else { + watermarksInQueue++; + } + return Optional.of(queueEntry); + } else { + LOG.debug( + "Failed to put element into ordered stream element queue because it " + + "was full ({}/{}).", + queue.size(), + activeCapacity); + + return Optional.empty(); + } + } + + private CreationOrderEntry createEntry(K key, StreamElement streamElement) { + if (streamElement.isRecord()) { + return new CompletionTimeStreamElementEntry<>( + key, (StreamRecord) streamElement, () -> creationCount++, () -> active--); + } + if (streamElement.isWatermark()) { + return new CompletionTimeWatermarkEntry<>( + (Watermark) streamElement, () -> creationCount++); + } + throw new UnsupportedOperationException("Cannot enqueue "); + } + + static class CompletionTimeStreamElementEntry implements CreationOrderEntry { + + private final K key; + + private final StreamRecord inputRecord; + + private Collection completedElements; + + private final long creationOrder; + + private final Runnable onCompletion; + + CompletionTimeStreamElementEntry( + K key, + StreamRecord inputRecord, + Supplier creationOrderSupplier, + Runnable onCompletion) { + this.key = key; + this.inputRecord = inputRecord; + this.creationOrder = creationOrderSupplier.get(); + this.onCompletion = onCompletion; + } + + @Override + public boolean isDone() { + return completedElements != null; + } + + @Nonnull + @Override + public StreamRecord getInputElement() { + return inputRecord; + } + + @Override + public void emitResult(TimestampedCollector output) { + output.setTimestamp(inputRecord); + for (OUT r : completedElements) { + output.collect(r); + } + } + + @Override + public void complete(Collection result) { + this.completedElements = Preconditions.checkNotNull(result); + onCompletion.run(); + } + + @Override + public K getKey() { + return key; + } + + @Override + public long getCreationOrder() { + return creationOrder; + } + } + + static class CompletionTimeWatermarkEntry implements CreationOrderEntry { + + @Nonnull private final Watermark watermark; + + private final long creationOrder; + + public CompletionTimeWatermarkEntry( + Watermark watermark, Supplier creationOrderSupplier) { + this.watermark = Preconditions.checkNotNull(watermark); + this.creationOrder = creationOrderSupplier.get(); + } + + @Override + public long getCreationOrder() { + return creationOrder; + } + + @Override + public void emitResult(TimestampedCollector output) { + output.emitWatermark(watermark); + } + + @Nonnull + @Override + public Watermark getInputElement() { + return watermark; + } + + @Override + public K getKey() { + return null; + } + + @Override + public boolean isDone() { + return true; + } + + @Override + public void complete(Collection result) { + throw new IllegalStateException("Cannot complete a watermark."); + } + } + + private interface CreationOrderEntry extends KeyedStreamElementQueueEntry { + long getCreationOrder(); + } +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncFunctionCommonTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncFunctionCommonTest.java new file mode 100644 index 0000000000000..0d50e379b8515 --- /dev/null +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncFunctionCommonTest.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.streaming.api.functions.async.CollectionSupplier; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; + +import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link KeyedAsyncFunctionCommon}. */ +public class KeyedAsyncFunctionCommonTest { + + private static final RowData KEY1 = GenericRowData.of(1); + + private static final RowData KEY2 = GenericRowData.of(2); + + private TestOpenContext openContext; + + @BeforeEach + public void setUp() { + openContext = new TestOpenContext(); + openContext.setCurrentKey(KEY1); + } + + @Test + public void testOneCallback() throws Exception { + TestKeyedAsyncFunctionCommon function = new TestKeyedAsyncFunctionCommon(); + function.open(openContext); + TestResultFuture resultFuture = new TestResultFuture(); + doCall(function, 123, resultFuture); + function.results.get(0).complete(321); + assertThat(resultFuture.collection).containsExactly(321); + } + + @Test + public void testMultipleCallbacks() throws Exception { + TestKeyedAsyncFunctionCommon function = new TestKeyedAsyncFunctionCommon(); + function.open(openContext); + TestResultFuture resultFuture1 = new TestResultFuture(); + doCall(function, 123, resultFuture1); + TestResultFuture resultFuture2 = new TestResultFuture(); + doCall(function, 456, resultFuture2); + TestResultFuture resultFuture3 = new TestResultFuture(); + doCall(function, 789, resultFuture3); + // Each one synchronously adds the next in line + assertThat(function.results).hasSize(1); + function.results.get(0).complete(321); + assertThat(function.results).hasSize(2); + function.results.get(1).complete(654); + assertThat(function.results).hasSize(3); + function.results.get(2).complete(987); + assertThat(resultFuture1.collection).containsExactly(321); + assertThat(resultFuture2.collection).containsExactly(654); + assertThat(resultFuture3.collection).containsExactly(987); + } + + @Test + public void testMultipleCallbacksMultipleKeys() throws Exception { + TestKeyedAsyncFunctionCommon function = new TestKeyedAsyncFunctionCommon(); + function.open(openContext); + openContext.setCurrentKey(KEY1); + TestResultFuture resultFuture11 = new TestResultFuture(); + doCall(function, 123, resultFuture11); + openContext.setCurrentKey(KEY2); + TestResultFuture resultFuture21 = new TestResultFuture(); + doCall(function, 333, resultFuture21); + + openContext.setCurrentKey(KEY1); + TestResultFuture resultFuture12 = new TestResultFuture(); + doCall(function, 456, resultFuture12); + openContext.setCurrentKey(KEY2); + TestResultFuture resultFuture22 = new TestResultFuture(); + doCall(function, 444, resultFuture22); + + openContext.setCurrentKey(KEY1); + TestResultFuture resultFuture13 = new TestResultFuture(); + doCall(function, 789, resultFuture13); + openContext.setCurrentKey(KEY2); + TestResultFuture resultFuture23 = new TestResultFuture(); + doCall(function, 555, resultFuture23); + + // Each one synchronously adds the next in line + assertThat(function.results).hasSize(2); + function.results.get(0).complete(321); + function.results.get(1).complete(3333); + assertThat(function.results).hasSize(4); + function.results.get(2).complete(654); + function.results.get(3).complete(4444); + assertThat(function.results).hasSize(6); + function.results.get(4).complete(987); + function.results.get(5).complete(5555); + assertThat(resultFuture11.collection).containsExactly(321); + assertThat(resultFuture21.collection).containsExactly(3333); + assertThat(resultFuture12.collection).containsExactly(654); + assertThat(resultFuture22.collection).containsExactly(4444); + assertThat(resultFuture13.collection).containsExactly(987); + assertThat(resultFuture23.collection).containsExactly(5555); + } + + @Test + public void testError() throws Exception { + TestKeyedAsyncFunctionCommon function = new TestKeyedAsyncFunctionCommon(); + function.open(openContext); + TestResultFuture resultFuture1 = new TestResultFuture(); + doCall(function, 123, resultFuture1); + TestResultFuture resultFuture2 = new TestResultFuture(); + doCall(function, 456, resultFuture2); + TestResultFuture resultFuture3 = new TestResultFuture(); + doCall(function, 789, resultFuture3); + // Each one synchronously adds the next in line + assertThat(function.results).hasSize(1); + function.results.get(0).complete(321); + assertThat(function.results).hasSize(2); + function.results.get(1).completeExceptionally(new RuntimeException("Error!")); + assertThat(function.results).hasSize(3); + function.results.get(2).complete(987); + assertThat(resultFuture1.collection).containsExactly(321); + assertThat(resultFuture2.throwable).hasMessageContaining("Error!"); + assertThat(resultFuture3.collection).containsExactly(987); + } + + private void doCall( + TestKeyedAsyncFunctionCommon function, int input, ResultFuture resultFuture) + throws Exception { + function.asyncInvoke(input, resultFuture); + } + + private static class TestKeyedAsyncFunctionCommon + extends KeyedAsyncFunctionCommon { + + private List> results = new ArrayList<>(); + + @Override + public void asyncInvokeProtected(Integer input, ResultFuture resultFuture) + throws Exception { + // Fake doing an rpc + CompletableFuture future = new CompletableFuture<>(); + results.add(future); + handleResponseForAsyncInvoke( + future, + resultFuture::completeExceptionally, + r -> { + resultFuture.complete(ImmutableList.of(r)); + }); + } + } + + private static class TestResultFuture implements ResultFuture { + + private Collection collection; + private Throwable throwable; + + @Override + public void complete(Collection collection) { + this.collection = collection; + } + + @Override + public void completeExceptionally(Throwable throwable) { + this.throwable = throwable; + } + + @Override + public void complete(CollectionSupplier supplier) { + try { + this.collection = supplier.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncWaitOperatorTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncWaitOperatorTest.java new file mode 100644 index 0000000000000..408bfce58afd2 --- /dev/null +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/KeyedAsyncWaitOperatorTest.java @@ -0,0 +1,1245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple1; +import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.io.network.api.CheckpointBarrier; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.testutils.ExpectedTestException; +import org.apache.flink.runtime.operators.testutils.MockEnvironment; +import org.apache.flink.runtime.state.TestTaskStateManager; +import org.apache.flink.streaming.api.datastream.AsyncDataStream; +import org.apache.flink.streaming.api.functions.async.AsyncFunction; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory; +import org.apache.flink.streaming.api.operators.async.queue.StreamElementQueue; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask; +import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness; +import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness; +import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.TestHarnessUtil; +import org.apache.flink.table.runtime.operators.aggregate.async.queue.KeyedAsyncOutputMode; +import org.apache.flink.testutils.junit.SharedObjects; +import org.apache.flink.testutils.junit.SharedReference; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.TestLogger; + +import org.apache.flink.shaded.guava33.com.google.common.collect.Lists; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link KeyedAsyncWaitOperator}. These test that: + * + *

    + *
  • Process StreamRecords and Watermarks in ORDERED mode + *
  • Process StreamRecords and Watermarks in ROW_TIME mode + *
  • Snapshot state and restore state + *
+ */ +public class KeyedAsyncWaitOperatorTest extends TestLogger { + + private static final Logger LOG = LoggerFactory.getLogger(KeyedAsyncWaitOperatorTest.class); + + private static final long TIMEOUT = 1000L; + + @Rule public Timeout timeoutRule = new Timeout(100, TimeUnit.SECONDS); + @Rule public final SharedObjects sharedObjects = SharedObjects.create(); + + private abstract static class MyAbstractKeyedAsyncFunction + implements KeyedAsyncFunction { + private static final long serialVersionUID = 8522411971886428444L; + + private static final long TERMINATION_TIMEOUT = 5000L; + private static final int THREAD_POOL_SIZE = 10; + + static ExecutorService executorService; + static int counter = 0; + + @Override + public void open(OpenContext openContext) throws Exception { + + synchronized (MyAbstractKeyedAsyncFunction.class) { + if (counter == 0) { + executorService = Executors.newFixedThreadPool(THREAD_POOL_SIZE); + } + + ++counter; + } + } + + @Override + public void close() throws Exception { + freeExecutor(); + } + + private void freeExecutor() { + synchronized (MyAbstractKeyedAsyncFunction.class) { + --counter; + + if (counter == 0) { + executorService.shutdown(); + + try { + if (!executorService.awaitTermination( + TERMINATION_TIMEOUT, TimeUnit.MILLISECONDS)) { + executorService.shutdownNow(); + } + } catch (InterruptedException interrupted) { + executorService.shutdownNow(); + + Thread.currentThread().interrupt(); + } + } + } + } + } + + private static class MyAsyncFunction extends MyAbstractKeyedAsyncFunction { + private static final long serialVersionUID = -1504699677704123889L; + + @Override + public void asyncInvoke(final Integer input, final ResultFuture resultFuture) + throws Exception { + executorService.submit( + new Runnable() { + @Override + public void run() { + resultFuture.complete(Collections.singletonList(input * 2)); + } + }); + } + } + + /** + * A special {@link AsyncFunction} without issuing {@link ResultFuture#complete} until the latch + * counts to zero. {@link ResultFuture#complete} until the latch counts to zero. This function + * is used in the testStateSnapshotAndRestore, ensuring that {@link StreamElement} can stay in + * the {@link StreamElementQueue} to be snapshotted while checkpointing. + */ + private static class LazyAsyncFunction extends MyAsyncFunction { + private static final long serialVersionUID = 3537791752703154670L; + + private static CountDownLatch latch; + + public LazyAsyncFunction() { + latch = new CountDownLatch(1); + } + + @Override + public void asyncInvoke(final Integer input, final ResultFuture resultFuture) + throws Exception { + executorService.submit( + new Runnable() { + @Override + public void run() { + try { + latch.await(); + } catch (InterruptedException e) { + // do nothing + } + + resultFuture.complete(Collections.singletonList(input)); + } + }); + } + + public static void countDown() { + latch.countDown(); + } + } + + private static class LazyAsyncFunctionWithRunning extends MyAsyncFunction { + private static final long serialVersionUID = 3537791752703154670L; + private static final AtomicInteger nextToFinish = new AtomicInteger(-1); + private static ConcurrentHashMap running = + new ConcurrentHashMap<>(); + private static ConcurrentHashMap complete = + new ConcurrentHashMap<>(); + private static AtomicInteger numActive = new AtomicInteger(0); + + public LazyAsyncFunctionWithRunning() {} + + @Override + public void asyncInvoke(final Integer input, final ResultFuture resultFuture) + throws Exception { + executorService.submit( + new Runnable() { + @Override + public void run() { + try { + trigger(input, running); + numActive.incrementAndGet(); + synchronized (nextToFinish) { + while (nextToFinish.get() != input) { + nextToFinish.wait(); + } + nextToFinish.set(-1); + nextToFinish.notifyAll(); + } + numActive.decrementAndGet(); + trigger(input, complete); + } catch (InterruptedException e) { + LOG.error("Error while running async function.", e); + } + resultFuture.complete(Collections.singletonList(input)); + } + }); + } + + public static void release(int toRelease) { + synchronized (nextToFinish) { + while (nextToFinish.get() != -1) { + try { + nextToFinish.wait(); + } catch (InterruptedException e) { + LOG.error("Error while waiting to release.", e); + } + } + nextToFinish.set(toRelease); + nextToFinish.notifyAll(); + } + } + + public static boolean complete(Integer input) { + return value(input, complete); + } + + public static boolean running(Integer input) { + return value(input, running); + } + + private static void trigger(int key, ConcurrentHashMap map) { + map.compute( + key, + (k, v) -> { + if (v != null) { + v.set(true); + return v; + } else { + return new AtomicBoolean(true); + } + }); + } + + private static boolean value(int key, ConcurrentHashMap map) { + return map.compute( + key, + (k, v) -> { + if (v != null) { + return v; + } else { + return new AtomicBoolean(false); + } + }) + .get(); + } + + public static long numActive() { + return numActive.get(); + } + } + + private static class InputReusedAsyncFunction + extends MyAbstractKeyedAsyncFunction> { + + private static final long serialVersionUID = 8627909616410487720L; + + @Override + public void asyncInvoke(Tuple1 input, ResultFuture resultFuture) + throws Exception { + executorService.submit( + new Runnable() { + @Override + public void run() { + resultFuture.complete(Collections.singletonList(input.f0 * 2)); + } + }); + } + } + + /** + * A special {@link LazyAsyncFunction} for timeout handling. Complete the result future with 3 + * times the input when the timeout occurred. + */ + private static class IgnoreTimeoutLazyAsyncFunction extends LazyAsyncFunction { + private static final long serialVersionUID = 1428714561365346128L; + + @Override + public void timeout(Integer input, ResultFuture resultFuture) throws Exception { + resultFuture.complete(Collections.singletonList(input * 3)); + } + } + + /** Completes input at half the TIMEOUT and registers timeouts. */ + private static class TimeoutAfterCompletionTestFunction + implements AsyncFunction { + static final AtomicBoolean TIMED_OUT = new AtomicBoolean(false); + static final CountDownLatch COMPLETION_TRIGGER = new CountDownLatch(1); + + @Override + public void asyncInvoke(Integer input, ResultFuture resultFuture) { + ForkJoinPool.commonPool() + .submit( + () -> { + COMPLETION_TRIGGER.await(); + resultFuture.complete(Collections.singletonList(input)); + return null; + }); + } + + @Override + public void timeout(Integer input, ResultFuture resultFuture) { + TIMED_OUT.set(true); + } + } + + /** A {@link Comparator} to compare {@link StreamRecord} while sorting them. */ + private class StreamRecordComparator implements Comparator { + @Override + public int compare(Object o1, Object o2) { + if (o1 instanceof Watermark || o2 instanceof Watermark) { + return 0; + } else { + StreamRecord sr0 = (StreamRecord) o1; + StreamRecord sr1 = (StreamRecord) o2; + + return Long.compare(sr0.getTimestamp(), sr1.getTimestamp()); + } + } + } + + /** Test the KeyedAsyncWaitOperator with unordered mode and event time. */ + @Test + public void testOrdered() throws Exception { + testOutputOrdering(KeyedAsyncOutputMode.ORDERED); + } + + private void testOutputOrdering(KeyedAsyncOutputMode mode) throws Exception { + final OneInputStreamOperatorTestHarness testHarness = + createTestHarness(new MyAsyncFunction(), TIMEOUT, 4, mode); + + final long initialTime = 0L; + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + testHarness.open(); + + synchronized (testHarness.getCheckpointLock()) { + testHarness.processElement(new StreamRecord<>(1, initialTime + 3)); + testHarness.processElement(new StreamRecord<>(2, initialTime + 1)); + testHarness.processElement(new StreamRecord<>(3, initialTime + 2)); + testHarness.processWatermark(new Watermark(initialTime + 3)); + testHarness.processElement(new StreamRecord<>(4, initialTime + 5)); + testHarness.processElement(new StreamRecord<>(5, initialTime + 4)); + testHarness.processWatermark(new Watermark(initialTime + 5)); + } + + // wait until all async collectors in the buffer have been emitted out. + synchronized (testHarness.getCheckpointLock()) { + testHarness.endInput(); + testHarness.close(); + } + + expectedOutput.add(new StreamRecord<>(2, initialTime + 3)); + expectedOutput.add(new StreamRecord<>(4, initialTime + 1)); + expectedOutput.add(new StreamRecord<>(6, initialTime + 2)); + expectedOutput.add(new Watermark(initialTime + 3)); + expectedOutput.add(new StreamRecord<>(8, initialTime + 5)); + expectedOutput.add(new StreamRecord<>(10, initialTime + 4)); + expectedOutput.add(new Watermark(initialTime + 5)); + + if (KeyedAsyncOutputMode.ORDERED == mode) { + TestHarnessUtil.assertOutputEquals( + "Output with watermark was not correct.", + expectedOutput, + testHarness.getOutput()); + } else { + Object[] jobOutputQueue = testHarness.getOutput().toArray(); + + Assert.assertEquals( + "Watermark should be at index 3", + new Watermark(initialTime + 3), + jobOutputQueue[3]); + + TestHarnessUtil.assertOutputEqualsSorted( + "Output for StreamRecords does not match", + expectedOutput, + testHarness.getOutput(), + new StreamRecordComparator()); + } + } + + /** Test the KeyedAsyncWaitOperator with ordered mode. */ + @Test + public void testLimitedActiveOrdered() throws Exception { + testLimitedActive(KeyedAsyncOutputMode.ORDERED); + } + + private void testLimitedActive(KeyedAsyncOutputMode mode) throws Exception { + StreamTaskMailboxTestHarnessBuilder builder = + new StreamTaskMailboxTestHarnessBuilder<>( + OneInputStreamTask::new, BasicTypeInfo.INT_TYPE_INFO) + .addInput(BasicTypeInfo.INT_TYPE_INFO); + try (StreamTaskMailboxTestHarness testHarness = + builder.setupOutputForSingletonOperatorChain( + new KeyedAsyncWaitOperatorFactory<>( + new LazyAsyncFunctionWithRunning(), + // We do a lot of waiting in this test, so give a longer + // timeout. + 5 * TIMEOUT, + 2, + mode)) + .modifyStreamConfig( + streamConfig -> { + streamConfig.setStateKeySerializer(IntSerializer.INSTANCE); + streamConfig.setStatePartitioner(0, new EvenOddKeySelector()); + }) + .build()) { + final long initialTime = 0L; + + testHarness.processElement(new StreamRecord<>(1, initialTime + 2)); + testHarness.processElement(new StreamRecord<>(2, initialTime + 1)); + + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.running(1)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.running(2)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.numActive() == 2); + + LazyAsyncFunctionWithRunning.release(1); + + testHarness.processElement(new Watermark(initialTime + 2)); + testHarness.processElement(new StreamRecord<>(3, initialTime + 4)); + + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.complete(1)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.running(2)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.running(3)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.numActive() == 2); + + LazyAsyncFunctionWithRunning.release(2); + + testHarness.processElement(new StreamRecord<>(4, initialTime + 3)); + + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.complete(2)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.running(3)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.running(4)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.numActive() == 2); + + LazyAsyncFunctionWithRunning.release(3); + testHarness.processElement(new Watermark(initialTime + 4)); + testHarness.processElement(new StreamRecord<>(5, initialTime + 5)); + + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.complete(3)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.running(4)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.running(5)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.numActive() == 2); + + LazyAsyncFunctionWithRunning.release(4); + testHarness.processElement(new Watermark(initialTime + 5)); + + LazyAsyncFunctionWithRunning.release(5); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.complete(4)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.complete(5)); + testHarness.processUntil(() -> LazyAsyncFunctionWithRunning.numActive() == 0); + + testHarness.processUntil(() -> testHarness.getOutput().size() == 8); + } + } + + static class EvenOddKeySelector implements KeySelector { + private static final long serialVersionUID = -1927524994684581374L; + + @Override + public Integer getKey(Integer value) throws Exception { + return value % 2; + } + } + + static class TupleEventOddKeySelector implements KeySelector, Integer> { + private static final long serialVersionUID = -1927524994684581374L; + + @Override + public Integer getKey(Tuple1 value) throws Exception { + return value.f0 % 2; + } + } + + @Test + public void testStateSnapshotAndRestore() throws Exception { + final OneInputStreamTaskTestHarness testHarness = + new OneInputStreamTaskTestHarness<>( + OneInputStreamTask::new, + 1, + 1, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO); + + testHarness.setupOutputForSingletonOperatorChain(); + + KeyedAsyncWaitOperatorFactory factory = + new KeyedAsyncWaitOperatorFactory<>( + new LazyAsyncFunction(), TIMEOUT, 4, KeyedAsyncOutputMode.ORDERED); + + final StreamConfig streamConfig = testHarness.getStreamConfig(); + OperatorID operatorID = new OperatorID(42L, 4711L); + streamConfig.setStreamOperatorFactory(factory); + streamConfig.setOperatorID(operatorID); + streamConfig.setStateKeySerializer(IntSerializer.INSTANCE); + streamConfig.setStatePartitioner(0, new EvenOddKeySelector()); + + final TestTaskStateManager taskStateManagerMock = testHarness.getTaskStateManager(); + + testHarness.invoke(); + testHarness.waitForTaskRunning(); + + final OneInputStreamTask task = testHarness.getTask(); + + final long initialTime = 0L; + + testHarness.processElement(new StreamRecord<>(1, initialTime + 1)); + testHarness.processElement(new StreamRecord<>(2, initialTime + 2)); + testHarness.processElement(new StreamRecord<>(3, initialTime + 3)); + testHarness.processElement(new StreamRecord<>(4, initialTime + 4)); + + testHarness.waitForInputProcessing(); + + final long checkpointId = 1L; + final long checkpointTimestamp = 1L; + + final CheckpointMetaData checkpointMetaData = + new CheckpointMetaData(checkpointId, checkpointTimestamp); + + task.triggerCheckpointAsync( + checkpointMetaData, CheckpointOptions.forCheckpointWithDefaultLocation()); + + taskStateManagerMock.getWaitForReportLatch().await(); + + assertEquals(checkpointId, taskStateManagerMock.getReportedCheckpointId()); + + LazyAsyncFunction.countDown(); + + testHarness.endInput(); + testHarness.waitForTaskCompletion(); + + // set the keyed state from previous attempt into the restored one + TaskStateSnapshot subtaskStates = taskStateManagerMock.getLastJobManagerTaskStateSnapshot(); + + final OneInputStreamTaskTestHarness restoredTaskHarness = + new OneInputStreamTaskTestHarness<>( + OneInputStreamTask::new, + BasicTypeInfo.INT_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO); + + restoredTaskHarness.setTaskStateSnapshot(checkpointId, subtaskStates); + restoredTaskHarness.setupOutputForSingletonOperatorChain(); + + KeyedAsyncWaitOperatorFactory restoredOperator = + new KeyedAsyncWaitOperatorFactory<>( + new MyAsyncFunction(), TIMEOUT, 6, KeyedAsyncOutputMode.ORDERED); + + restoredTaskHarness.getStreamConfig().setStreamOperatorFactory(restoredOperator); + restoredTaskHarness.getStreamConfig().setOperatorID(operatorID); + restoredTaskHarness.getStreamConfig().setStateKeySerializer(IntSerializer.INSTANCE); + restoredTaskHarness.getStreamConfig().setStatePartitioner(0, new EvenOddKeySelector()); + + restoredTaskHarness.invoke(); + restoredTaskHarness.waitForTaskRunning(); + + final OneInputStreamTask restoredTask = restoredTaskHarness.getTask(); + + restoredTaskHarness.processElement(new StreamRecord<>(5, initialTime + 5)); + restoredTaskHarness.processElement(new StreamRecord<>(6, initialTime + 6)); + restoredTaskHarness.processElement(new StreamRecord<>(7, initialTime + 7)); + + // trigger the checkpoint while processing stream elements + restoredTask + .triggerCheckpointAsync( + new CheckpointMetaData(checkpointId, checkpointTimestamp), + CheckpointOptions.forCheckpointWithDefaultLocation()) + .get(); + + restoredTaskHarness.processElement(new StreamRecord<>(8, initialTime + 8)); + + restoredTaskHarness.endInput(); + restoredTaskHarness.waitForTaskCompletion(); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + // Note that the restore only keeps order within the same key, so the output may not be + // identical to the original input order. + // 1, 3 Restored + expectedOutput.add(new StreamRecord<>(2, initialTime + 1)); + expectedOutput.add(new StreamRecord<>(6, initialTime + 3)); + // 2, 4 Restored + expectedOutput.add(new StreamRecord<>(4, initialTime + 2)); + expectedOutput.add(new StreamRecord<>(8, initialTime + 4)); + // New elements + expectedOutput.add(new StreamRecord<>(10, initialTime + 5)); + expectedOutput.add(new StreamRecord<>(12, initialTime + 6)); + expectedOutput.add(new StreamRecord<>(14, initialTime + 7)); + expectedOutput.add(new StreamRecord<>(16, initialTime + 8)); + + // remove CheckpointBarrier which is not expected + restoredTaskHarness.getOutput().removeIf(record -> record instanceof CheckpointBarrier); + + TestHarnessUtil.assertOutputEquals( + "StateAndRestored Test Output was not correct.", + expectedOutput, + restoredTaskHarness.getOutput()); + } + + @SuppressWarnings("rawtypes") + @Test + public void testObjectReused() throws Exception { + TypeSerializer[] fieldSerializers = new TypeSerializer[] {IntSerializer.INSTANCE}; + TupleSerializer inputSerializer = + new TupleSerializer<>(Tuple1.class, fieldSerializers); + KeyedAsyncWaitOperatorFactory, Integer> factory = + new KeyedAsyncWaitOperatorFactory<>( + new InputReusedAsyncFunction(), TIMEOUT, 4, KeyedAsyncOutputMode.ORDERED); + + //noinspection unchecked + final OneInputStreamOperatorTestHarness, Integer> testHarness = + new OneInputStreamOperatorTestHarness(factory, inputSerializer); + // enable object reuse + testHarness.getExecutionConfig().enableObjectReuse(); + testHarness.getStreamConfig().setStateKeySerializer(IntSerializer.INSTANCE); + testHarness.getStreamConfig().setStatePartitioner(0, new TupleEventOddKeySelector()); + testHarness.getStreamConfig().serializeAllConfigs(); + + final long initialTime = 0L; + Tuple1 reusedTuple = new Tuple1<>(); + StreamRecord> reusedRecord = new StreamRecord<>(reusedTuple, -1L); + + testHarness.setup(); + testHarness.open(); + + synchronized (testHarness.getCheckpointLock()) { + reusedTuple.setFields(1); + reusedRecord.setTimestamp(initialTime + 1); + testHarness.processElement(reusedRecord); + + reusedTuple.setFields(2); + reusedRecord.setTimestamp(initialTime + 2); + testHarness.processElement(reusedRecord); + + reusedTuple.setFields(3); + reusedRecord.setTimestamp(initialTime + 3); + testHarness.processElement(reusedRecord); + + reusedTuple.setFields(4); + reusedRecord.setTimestamp(initialTime + 4); + testHarness.processElement(reusedRecord); + } + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + expectedOutput.add(new StreamRecord<>(2, initialTime + 1)); + expectedOutput.add(new StreamRecord<>(4, initialTime + 2)); + expectedOutput.add(new StreamRecord<>(6, initialTime + 3)); + expectedOutput.add(new StreamRecord<>(8, initialTime + 4)); + + synchronized (testHarness.getCheckpointLock()) { + testHarness.endInput(); + testHarness.close(); + } + + TestHarnessUtil.assertOutputEquals( + "StateAndRestoredWithObjectReuse Test Output was not correct.", + expectedOutput, + testHarness.getOutput()); + } + + @Test + public void testAsyncTimeoutFailure() throws Exception { + testAsyncTimeout( + new LazyAsyncFunction(), + Optional.of(TimeoutException.class), + new StreamRecord<>(2, 5L)); + } + + @Test + public void testAsyncTimeoutIgnore() throws Exception { + testAsyncTimeout( + new IgnoreTimeoutLazyAsyncFunction(), + Optional.empty(), + new StreamRecord<>(3, 0L), + new StreamRecord<>(2, 5L)); + } + + private void testAsyncTimeout( + LazyAsyncFunction lazyAsyncFunction, + Optional> expectedException, + StreamRecord... expectedRecords) + throws Exception { + final long timeout = 10L; + + final OneInputStreamOperatorTestHarness testHarness = + createTestHarness(lazyAsyncFunction, timeout, 2, KeyedAsyncOutputMode.ORDERED); + + final MockEnvironment mockEnvironment = testHarness.getEnvironment(); + mockEnvironment.setExpectedExternalFailureCause(Throwable.class); + + final long initialTime = 0L; + final ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + + testHarness.open(); + + testHarness.setProcessingTime(initialTime); + + synchronized (testHarness.getCheckpointLock()) { + testHarness.processElement(new StreamRecord<>(1, initialTime)); + testHarness.setProcessingTime(initialTime + 5L); + testHarness.processElement(new StreamRecord<>(2, initialTime + 5L)); + } + + // trigger the timeout of the first stream record + testHarness.setProcessingTime(initialTime + timeout + 1L); + + // allow the second async stream record to be processed + lazyAsyncFunction.countDown(); + + // wait until all async collectors in the buffer have been emitted out. + synchronized (testHarness.getCheckpointLock()) { + testHarness.endInput(); + testHarness.close(); + } + + expectedOutput.addAll(Arrays.asList(expectedRecords)); + + TestHarnessUtil.assertOutputEquals( + "Output with watermark was not correct.", expectedOutput, testHarness.getOutput()); + + if (expectedException.isPresent()) { + assertTrue(mockEnvironment.getActualExternalFailureCause().isPresent()); + assertTrue( + ExceptionUtils.findThrowable( + mockEnvironment.getActualExternalFailureCause().get(), + expectedException.get()) + .isPresent()); + } + } + + /** + * FLINK-5652 Tests that registered timers are properly canceled upon completion of a {@link + * StreamElement} in order to avoid resource leaks because TriggerTasks hold a reference on the + * StreamRecordQueueEntry. + */ + @Test + public void testTimeoutCleanup() throws Exception { + OneInputStreamOperatorTestHarness harness = + createTestHarness(new MyAsyncFunction(), TIMEOUT, 1, KeyedAsyncOutputMode.ORDERED); + + harness.open(); + + synchronized (harness.getCheckpointLock()) { + harness.processElement(42, 1L); + } + + synchronized (harness.getCheckpointLock()) { + harness.endInput(); + harness.close(); + } + + // check that we actually outputted the result of the single input + assertEquals( + Arrays.asList(new StreamRecord(42 * 2, 1L)), new ArrayList<>(harness.getOutput())); + + // check that we have cancelled our registered timeout + assertEquals(0, harness.getProcessingTimeService().getNumActiveTimers()); + } + + /** + * Checks if timeout has been called after the element has been completed within the timeout. + * + * @see FLINK-22573 + */ + @Test + public void testTimeoutAfterComplete() throws Exception { + StreamTaskMailboxTestHarnessBuilder builder = + new StreamTaskMailboxTestHarnessBuilder<>( + OneInputStreamTask::new, BasicTypeInfo.INT_TYPE_INFO) + .addInput(BasicTypeInfo.INT_TYPE_INFO); + try (StreamTaskMailboxTestHarness harness = + builder.setupOutputForSingletonOperatorChain( + new AsyncWaitOperatorFactory<>( + new TimeoutAfterCompletionTestFunction(), + TIMEOUT, + 1, + AsyncDataStream.OutputMode.UNORDERED)) + .build()) { + harness.processElement(new StreamRecord<>(1)); + // add a timer after AsyncIO added its timer to verify that AsyncIO timer is processed + ScheduledFuture testTimer = + harness.getTimerService() + .registerTimer( + harness.getTimerService().getCurrentProcessingTime() + TIMEOUT, + ts -> {}); + // trigger the regular completion in AsyncIO + TimeoutAfterCompletionTestFunction.COMPLETION_TRIGGER.countDown(); + // wait until all timers have been processed + testTimer.get(); + // handle normal completion call outputting the element in mailbox thread + harness.processAll(); + assertEquals( + Collections.singleton(new StreamRecord<>(1)), + new HashSet<>(harness.getOutput())); + assertFalse("no timeout expected", TimeoutAfterCompletionTestFunction.TIMED_OUT.get()); + } + } + + /** + * FLINK-6435 + * + *

Tests that a user exception triggers the completion of a StreamElementQueueEntry and does + * not wait to until another StreamElementQueueEntry is properly completed before it is + * collected. + */ + @Test + public void testOrderedWaitUserExceptionHandling() throws Exception { + testUserExceptionHandling(KeyedAsyncOutputMode.ORDERED); + } + + private void testUserExceptionHandling(KeyedAsyncOutputMode outputMode) throws Exception { + OneInputStreamOperatorTestHarness harness = + createTestHarness(new UserExceptionAsyncFunction(), TIMEOUT, 2, outputMode); + + harness.getEnvironment().setExpectedExternalFailureCause(Throwable.class); + harness.open(); + + synchronized (harness.getCheckpointLock()) { + harness.processElement(1, 1L); + harness.processWatermark(new Watermark(1)); + } + + synchronized (harness.getCheckpointLock()) { + harness.endInput(); + harness.close(); + } + + assertTrue(harness.getEnvironment().getActualExternalFailureCause().isPresent()); + } + + /** AsyncFunction which completes the result with an {@link Exception}. */ + private static class UserExceptionAsyncFunction + implements KeyedAsyncFunction { + + private static final long serialVersionUID = 6326568632967110990L; + + @Override + public void asyncInvoke(Integer input, ResultFuture resultFuture) + throws Exception { + resultFuture.completeExceptionally(new Exception("Test exception")); + } + + @Override + public void open(OpenContext context) throws Exception {} + } + + /** + * FLINK-6435 + * + *

Tests that timeout exceptions are properly handled in ordered output mode. The proper + * handling means that a StreamElementQueueEntry is completed in case of a timeout exception. + */ + @Test + public void testOrderedWaitTimeoutHandling() throws Exception { + testTimeoutExceptionHandling(KeyedAsyncOutputMode.ORDERED); + } + + private void testTimeoutExceptionHandling(KeyedAsyncOutputMode outputMode) throws Exception { + OneInputStreamOperatorTestHarness harness = + createTestHarness(new NoOpAsyncFunction<>(), 10L, 2, outputMode); + + harness.getEnvironment().setExpectedExternalFailureCause(Throwable.class); + harness.open(); + + synchronized (harness.getCheckpointLock()) { + harness.processElement(1, 1L); + } + + harness.setProcessingTime(10L); + + synchronized (harness.getCheckpointLock()) { + harness.close(); + } + } + + /** + * Tests that the AsyncWaitOperator can restart if checkpointed queue was full. + * + *

See FLINK-7949 + */ + @Test(timeout = 10000) + public void testRestartWithFullQueue() throws Exception { + final int capacity = 10; + + // 1. create the snapshot which contains capacity + 1 elements + final CompletableFuture trigger = new CompletableFuture<>(); + + final OneInputStreamOperatorTestHarness snapshotHarness = + createTestHarness( + new ControllableAsyncFunction<>( + trigger), // the NoOpAsyncFunction is like a blocking function + 1000L, + capacity, + KeyedAsyncOutputMode.ORDERED); + + snapshotHarness.open(); + + final OperatorSubtaskState snapshot; + + final ArrayList expectedOutput = new ArrayList<>(capacity); + final ArrayList odds = new ArrayList<>(capacity / 2); + final ArrayList evens = new ArrayList<>(capacity / 2); + + try { + synchronized (snapshotHarness.getCheckpointLock()) { + for (int i = 0; i < capacity; i++) { + snapshotHarness.processElement(i, 0L); + if (i % 2 == 0) { + evens.add(i); + } else { + odds.add(i); + } + } + } + + expectedOutput.addAll(odds); + expectedOutput.addAll(evens); + + synchronized (snapshotHarness.getCheckpointLock()) { + // execute the snapshot within the checkpoint lock, because then it is guaranteed + // that the lastElementWriter has written the exceeding element + snapshot = snapshotHarness.snapshot(0L, 0L); + } + + // trigger the computation to make the close call finish + trigger.complete(null); + } finally { + synchronized (snapshotHarness.getCheckpointLock()) { + snapshotHarness.close(); + } + } + + // 2. restore the snapshot and check that we complete + final OneInputStreamOperatorTestHarness recoverHarness = + createTestHarness( + new ControllableAsyncFunction<>(CompletableFuture.completedFuture(null)), + 1000L, + capacity, + KeyedAsyncOutputMode.ORDERED); + + recoverHarness.initializeState(snapshot); + + synchronized (recoverHarness.getCheckpointLock()) { + recoverHarness.open(); + } + + synchronized (recoverHarness.getCheckpointLock()) { + recoverHarness.endInput(); + recoverHarness.close(); + } + + final ConcurrentLinkedQueue output = recoverHarness.getOutput(); + + final List outputElements = + output.stream() + .map(r -> ((StreamRecord) r).getValue()) + .collect(Collectors.toList()); + + assertThat(outputElements, Matchers.equalTo(expectedOutput)); + } + + @Test + public void testIgnoreAsyncOperatorRecordsOnDrain() throws Exception { + // given: Async wait operator which are able to collect result futures. + StreamTaskMailboxTestHarnessBuilder builder = + new StreamTaskMailboxTestHarnessBuilder<>( + OneInputStreamTask::new, BasicTypeInfo.INT_TYPE_INFO) + .addInput(BasicTypeInfo.INT_TYPE_INFO); + SharedReference>> resultFutures = sharedObjects.add(new ArrayList<>()); + try (StreamTaskMailboxTestHarness harness = + builder.setupOutputForSingletonOperatorChain( + new KeyedAsyncWaitOperatorFactory<>( + new CollectableFuturesAsyncFunction<>(resultFutures), + TIMEOUT, + 5, + KeyedAsyncOutputMode.ORDERED)) + .modifyStreamConfig( + streamConfig -> { + streamConfig.setStateKeySerializer(IntSerializer.INSTANCE); + streamConfig.setStatePartitioner(0, new EvenOddKeySelector()); + }) + .build()) { + + // when: Processing at least two elements in reverse order to keep completed queue not + // empty. + harness.processElement(new StreamRecord<>(1)); + harness.processElement(new StreamRecord<>(2)); + + for (ResultFuture resultFuture : Lists.reverse(resultFutures.get())) { + resultFuture.complete(Collections.emptyList()); + } + + // then: All records from async operator should be ignored during drain since they will + // be processed on recovery. + harness.finishProcessing(); + assertTrue(harness.getOutput().isEmpty()); + } + } + + @Test + public void testProcessingTimeWithMailboxThreadOrdered() throws Exception { + testProcessingTimeWithCallThread(KeyedAsyncOutputMode.ORDERED); + } + + @Test + public void testProcessingTimeWithMailboxThreadError() throws Exception { + StreamTaskMailboxTestHarnessBuilder builder = + new StreamTaskMailboxTestHarnessBuilder<>( + OneInputStreamTask::new, BasicTypeInfo.INT_TYPE_INFO) + .addInput(BasicTypeInfo.INT_TYPE_INFO); + try (StreamTaskMailboxTestHarness testHarness = + builder.setupOutputForSingletonOperatorChain( + new KeyedAsyncWaitOperatorFactory<>( + new CallThreadAsyncFunctionError(), + TIMEOUT, + 4, + KeyedAsyncOutputMode.ORDERED)) + .modifyStreamConfig( + streamConfig -> { + streamConfig.setStateKeySerializer(IntSerializer.INSTANCE); + streamConfig.setStatePartitioner(0, new EvenOddKeySelector()); + }) + .build()) { + final long initialTime = 0L; + AtomicReference error = new AtomicReference<>(); + testHarness.getStreamMockEnvironment().setExternalExceptionHandler(error::set); + + // Sometimes, processElement invoke the async function immediately, so we should catch + // any exception. + try { + testHarness.processElement(new StreamRecord<>(1, initialTime + 1)); + while (error.get() == null) { + testHarness.processAll(); + } + } catch (Exception e) { + // This simulates a mailbox failure failing the job + error.set(e); + } + + ExceptionUtils.assertThrowable(error.get(), ExpectedTestException.class); + + testHarness.endInput(); + } + } + + private void testProcessingTimeWithCallThread(KeyedAsyncOutputMode mode) throws Exception { + StreamTaskMailboxTestHarnessBuilder builder = + new StreamTaskMailboxTestHarnessBuilder<>( + OneInputStreamTask::new, BasicTypeInfo.INT_TYPE_INFO) + .addInput(BasicTypeInfo.INT_TYPE_INFO); + try (StreamTaskMailboxTestHarness testHarness = + builder.setupOutputForSingletonOperatorChain( + new KeyedAsyncWaitOperatorFactory<>( + new CallThreadAsyncFunction(), TIMEOUT, 4, mode)) + .modifyStreamConfig( + streamConfig -> { + streamConfig.setStateKeySerializer(IntSerializer.INSTANCE); + streamConfig.setStatePartitioner(0, new EvenOddKeySelector()); + }) + .build()) { + + final long initialTime = 0L; + final Queue expectedOutput = new ArrayDeque<>(); + + testHarness.processElement(new StreamRecord<>(1, initialTime + 3)); + testHarness.processElement(new StreamRecord<>(2, initialTime + 1)); + testHarness.processElement(new StreamRecord<>(3, initialTime + 2)); + testHarness.processElement(new Watermark(initialTime + 3)); + + expectedOutput.add(new StreamRecord<>(2, initialTime + 3)); + expectedOutput.add(new StreamRecord<>(4, initialTime + 1)); + expectedOutput.add(new StreamRecord<>(6, initialTime + 2)); + expectedOutput.add(new Watermark(initialTime + 3)); + + while (testHarness.getOutput().size() < expectedOutput.size()) { + testHarness.processAll(); + } + + TestHarnessUtil.assertOutputEquals( + "ORDERED Output was not correct.", expectedOutput, testHarness.getOutput()); + + testHarness.endInput(); + } + } + + private static class CollectableFuturesAsyncFunction + implements KeyedAsyncFunction { + + private static final long serialVersionUID = -4214078239227288637L; + + private final SharedReference>> resultFutures; + + private CollectableFuturesAsyncFunction( + SharedReference>> resultFutures) { + this.resultFutures = resultFutures; + } + + @Override + public void asyncInvoke(IN input, ResultFuture resultFuture) throws Exception { + resultFutures.get().add(resultFuture); + } + } + + private static class ControllableAsyncFunction implements KeyedAsyncFunction { + + private static final long serialVersionUID = -4214078239267288636L; + + private transient CompletableFuture trigger; + + private ControllableAsyncFunction(CompletableFuture trigger) { + this.trigger = Preconditions.checkNotNull(trigger); + } + + @Override + public void asyncInvoke(IN input, ResultFuture resultFuture) throws Exception { + trigger.thenAccept(v -> resultFuture.complete(Collections.singleton(input))); + } + } + + private static class NoOpAsyncFunction implements KeyedAsyncFunction { + private static final long serialVersionUID = -3060481953330480694L; + + @Override + public void asyncInvoke(IN input, ResultFuture resultFuture) throws Exception { + // no op + } + } + + private static OneInputStreamOperatorTestHarness createTestHarness( + KeyedAsyncFunction function, + long timeout, + int capacity, + KeyedAsyncOutputMode outputMode) + throws Exception { + + OneInputStreamOperatorTestHarness result = + new OneInputStreamOperatorTestHarness<>( + new KeyedAsyncWaitOperatorFactory<>( + function, timeout, capacity, outputMode), + IntSerializer.INSTANCE); + result.getStreamConfig().setStateKeySerializer(IntSerializer.INSTANCE); + result.getStreamConfig().setStatePartitioner(0, new EvenOddKeySelector()); + result.getStreamConfig().serializeAllConfigs(); + return result; + } + + private static OneInputStreamOperatorTestHarness createTestHarnessWithRetry( + KeyedAsyncFunction function, + long timeout, + int capacity, + KeyedAsyncOutputMode outputMode) + throws Exception { + + return new OneInputStreamOperatorTestHarness<>( + new KeyedAsyncWaitOperatorFactory<>(function, timeout, capacity, outputMode), + IntSerializer.INSTANCE); + } + + private static class CallThreadAsyncFunction extends MyAbstractKeyedAsyncFunction { + private static final long serialVersionUID = -1504699677704123889L; + + @Override + public void asyncInvoke(final Integer input, final ResultFuture resultFuture) + throws Exception { + final Thread callThread = Thread.currentThread(); + executorService.submit( + () -> + resultFuture.complete( + () -> { + assertEquals(callThread, Thread.currentThread()); + return Collections.singletonList(input * 2); + })); + } + } + + private static class CallThreadAsyncFunctionError + extends MyAbstractKeyedAsyncFunction { + private static final long serialVersionUID = -1504699677704123889L; + + @Override + public void asyncInvoke(final Integer input, final ResultFuture resultFuture) + throws Exception { + executorService.submit( + () -> + resultFuture.complete( + () -> { + throw new ExpectedTestException(); + })); + } + } +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/PerKeyCallbackSequencerTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/PerKeyCallbackSequencerTest.java new file mode 100644 index 0000000000000..bc368738f0394 --- /dev/null +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/PerKeyCallbackSequencerTest.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; + +import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableList; +import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableMap; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link PerKeyCallbackSequencer}. */ +public class PerKeyCallbackSequencerTest { + + private static final RowData KEY1 = GenericRowData.of(1); + private static final RowData KEY2 = GenericRowData.of(2); + + private Map> callbackTimestamps = new HashMap<>(); + private TestOpenContext openContext; + + @BeforeEach + public void setUp() { + openContext = new TestOpenContext(); + callbackTimestamps.clear(); + } + + @Test + public void testOneCallback() throws Exception { + PerKeyCallbackSequencer sequencer = + new PerKeyCallbackSequencer<>(this::callback); + openContext.setCurrentKey(KEY1); + sequencer.callbackWhenNext(openContext, 123); + assertThat(callbackTimestamps) + .containsExactlyInAnyOrderEntriesOf(ImmutableMap.of(KEY1, ImmutableList.of(123L))); + sequencer.notifyNextWaiter(openContext); + } + + @Test + public void testManyCallbacksOneKey() throws Exception { + PerKeyCallbackSequencer sequencer = + new PerKeyCallbackSequencer<>(this::callback); + openContext.setCurrentKey(KEY1); + sequencer.callbackWhenNext(openContext, 123); + sequencer.callbackWhenNext(openContext, 345); + sequencer.callbackWhenNext(openContext, 567); + assertThat(callbackTimestamps) + .containsExactlyInAnyOrderEntriesOf(ImmutableMap.of(KEY1, ImmutableList.of(123L))); + sequencer.notifyNextWaiter(openContext); + assertThat(callbackTimestamps) + .containsExactlyInAnyOrderEntriesOf( + ImmutableMap.of(KEY1, ImmutableList.of(123L, 345L))); + sequencer.notifyNextWaiter(openContext); + assertThat(callbackTimestamps) + .containsExactlyInAnyOrderEntriesOf( + ImmutableMap.of(KEY1, ImmutableList.of(123L, 345L, 567L))); + } + + @Test + public void testManyCallbacksTwoKeys() throws Exception { + PerKeyCallbackSequencer sequencer = + new PerKeyCallbackSequencer<>(this::callback); + openContext.setCurrentKey(KEY1); + sequencer.callbackWhenNext(openContext, 123); + openContext.setCurrentKey(KEY2); + sequencer.callbackWhenNext(openContext, 222); + openContext.setCurrentKey(KEY1); + sequencer.callbackWhenNext(openContext, 345); + openContext.setCurrentKey(KEY2); + sequencer.callbackWhenNext(openContext, 333); + openContext.setCurrentKey(KEY1); + sequencer.callbackWhenNext(openContext, 567); + openContext.setCurrentKey(KEY2); + sequencer.callbackWhenNext(openContext, 444); + assertThat(callbackTimestamps) + .containsExactlyInAnyOrderEntriesOf( + ImmutableMap.of( + KEY1, ImmutableList.of(123L), KEY2, ImmutableList.of(222L))); + openContext.setCurrentKey(KEY1); + sequencer.notifyNextWaiter(openContext); + openContext.setCurrentKey(KEY2); + sequencer.notifyNextWaiter(openContext); + assertThat(callbackTimestamps) + .containsExactlyInAnyOrderEntriesOf( + ImmutableMap.of( + KEY1, + ImmutableList.of(123L, 345L), + KEY2, + ImmutableList.of(222L, 333L))); + openContext.setCurrentKey(KEY1); + sequencer.notifyNextWaiter(openContext); + openContext.setCurrentKey(KEY2); + sequencer.notifyNextWaiter(openContext); + assertThat(callbackTimestamps) + .containsExactlyInAnyOrderEntriesOf( + ImmutableMap.of( + KEY1, + ImmutableList.of(123L, 345L, 567L), + KEY2, + ImmutableList.of(222L, 333L, 444L))); + } + + public void callback(long timestamp, Object data, KeyedAsyncFunction.OpenContext ctx) + throws Exception { + callbackTimestamps.compute( + ctx.currentKey(), + (rowData, longs) -> { + if (longs == null) { + longs = new ArrayList<>(); + } + longs.add(timestamp); + return longs; + }); + } +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/TestOpenContext.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/TestOpenContext.java new file mode 100644 index 0000000000000..5f8481d4c87f9 --- /dev/null +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/TestOpenContext.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.table.data.RowData; +import org.apache.flink.util.function.ThrowingRunnable; + +/** Test implementation of {@link KeyedAsyncFunction.OpenContext}. */ +public class TestOpenContext implements KeyedAsyncFunction.OpenContext { + + private RowData currentKey; + + public TestOpenContext() {} + + @Override + public RowData currentKey() { + return currentKey; + } + + @Override + public void setCurrentKey(RowData key) { + this.currentKey = key; + } + + @Override + public RuntimeContext getRuntimeContext() { + return null; + } + + @Override + public void runOnMailboxThread(ThrowingRunnable runnable) { + try { + runnable.run(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueueImplTest.java b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueueImplTest.java new file mode 100644 index 0000000000000..91e794735ec25 --- /dev/null +++ b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/aggregate/async/queue/KeyedStreamElementQueueImplTest.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async.queue; + +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.CollectorOutput; + +import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableMap; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** Tests {@link KeyedStreamElementQueueImpl}. */ +public class KeyedStreamElementQueueImplTest { + + @Parameterized.Parameters + public static Collection outputModes() { + return Arrays.asList(KeyedAsyncOutputMode.ORDERED); + } + + private KeyedStreamElementQueue createStreamElementQueue( + KeyedAsyncOutputMode outputMode, int capacity) { + switch (outputMode) { + case ORDERED: + return KeyedStreamElementQueueImpl.createOrderedQueue(capacity); + default: + throw new IllegalStateException("Unknown output mode: " + outputMode); + } + } + + @ParameterizedTest(name = "outputMode = {0}") + @MethodSource("outputModes") + public void testPut(KeyedAsyncOutputMode outputMode) { + KeyedStreamElementQueue queue = createStreamElementQueue(outputMode, 2); + + Watermark watermark = new Watermark(0L); + StreamRecord streamRecord = new StreamRecord<>(42, 1L); + StreamRecord streamRecord2 = new StreamRecord<>(43, 2L); + StreamRecord streamRecord3 = new StreamRecord<>(44, 3L); + Watermark watermark2 = new Watermark(3L); + + // add two elements to reach capacity + assertTrue(queue.tryPut(0, watermark).isPresent()); + assertTrue(queue.tryPut(0, streamRecord).isPresent()); + assertTrue(queue.tryPut(0, streamRecord2).isPresent()); + + assertEquals(3, queue.size()); + + // queue full, cannot add new element + assertFalse(queue.tryPut(0, streamRecord3).isPresent()); + + // queue full, can add watermark + assertTrue(queue.tryPut(0, watermark2).isPresent()); + + // check if expected values are returned (for checkpointing) + assertEquals( + ImmutableMap.of(0, Arrays.asList(streamRecord, streamRecord2)), + queue.valuesByKey()); + } + + @ParameterizedTest(name = "outputMode = {0}") + @MethodSource("outputModes") + public void testPutManyKeys(KeyedAsyncOutputMode outputMode) { + KeyedStreamElementQueue queue = createStreamElementQueue(outputMode, 5); + + Watermark watermark = new Watermark(0L); + StreamRecord streamRecord = new StreamRecord<>(42, 1L); + StreamRecord streamRecord2 = new StreamRecord<>(45, 2L); + StreamRecord streamRecord3 = new StreamRecord<>(46, 3L); + StreamRecord streamRecord4 = new StreamRecord<>(47, 4L); + StreamRecord streamRecord5 = new StreamRecord<>(48, 5L); + StreamRecord streamRecord6 = new StreamRecord<>(49, 6L); + + // add two elements to reach capacity + assertTrue(queue.tryPut(0, watermark).isPresent()); + assertTrue(queue.tryPut(0, streamRecord).isPresent()); + assertTrue(queue.tryPut(1, streamRecord2).isPresent()); + assertTrue(queue.tryPut(0, streamRecord3).isPresent()); + assertTrue(queue.tryPut(1, streamRecord4).isPresent()); + assertTrue(queue.tryPut(1, streamRecord5).isPresent()); + + assertEquals(6, queue.size()); + + // queue full, cannot add new element + assertFalse(queue.tryPut(0, streamRecord6).isPresent()); + + // check if expected values are returned (for checkpointing) + assertEquals( + ImmutableMap.of( + 0, + Arrays.asList(streamRecord, streamRecord3), + 1, + Arrays.asList(streamRecord2, streamRecord4, streamRecord5)), + queue.valuesByKey()); + } + + @ParameterizedTest(name = "outputMode = {0}") + @MethodSource("outputModes") + public void testPop(KeyedAsyncOutputMode outputMode) { + KeyedStreamElementQueue queue = createStreamElementQueue(outputMode, 2); + + // add two elements to reach capacity + putSuccessfully(queue, 0, new Watermark(0L)); + ResultFuture recordResult = putSuccessfully(queue, 0, new StreamRecord<>(42, 1L)); + + assertEquals(2, queue.size()); + + // remove completed elements (watermarks are always completed) + assertEquals(Arrays.asList(new Watermark(0L)), popCompleted(queue)); + assertEquals(1, queue.size()); + + // now complete the stream record + recordResult.complete(Collections.singleton(43)); + + putSuccessfully(queue, 0, new Watermark(1L)); + assertEquals( + Arrays.asList(new StreamRecord<>(43, 1L), new Watermark(1L)), popCompleted(queue)); + assertEquals(0, queue.size()); + assertTrue(queue.isEmpty()); + } + + /** Tests that a put operation fails if the queue is full. */ + @ParameterizedTest(name = "outputMode = {0}") + @MethodSource("outputModes") + public void testPutOnFull(KeyedAsyncOutputMode outputMode) throws Exception { + final KeyedStreamElementQueue queue = + createStreamElementQueue(outputMode, 2); + + // fill up queue + ResultFuture resultFuture = putSuccessfully(queue, 0, new StreamRecord<>(42, 0L)); + putSuccessfully(queue, 0, new Watermark(0L)); + ResultFuture resultFuture2 = putSuccessfully(queue, 0, new StreamRecord<>(43, 1L)); + assertEquals(3, queue.size()); + + // cannot add more + putUnsuccessfully(queue, 0, new StreamRecord<>(43, 1L)); + + // popping the completed element frees the queue again + resultFuture.complete(Collections.singleton(42 * 42)); + resultFuture2.complete(Collections.singleton(43 * 43)); + + // Output last watermark so row time can complete too + putSuccessfully(queue, 0, new Watermark(1L)); + assertEquals( + Arrays.asList( + new StreamRecord(42 * 42, 0L), + new Watermark(0L), + new StreamRecord(43 * 43, 1L), + new Watermark(1L)), + popCompleted(queue)); + + // now the put operation should complete + putSuccessfully(queue, 0, new StreamRecord<>(43, 1L)); + } + + /** Tests two adjacent watermarks can be processed successfully. */ + @ParameterizedTest(name = "outputMode = {0}") + @MethodSource("outputModes") + public void testWatermarkOnly(KeyedAsyncOutputMode outputMode) { + final KeyedStreamElementQueue queue = + createStreamElementQueue(outputMode, 2); + + putSuccessfully(queue, 0, new Watermark(2L)); + putSuccessfully(queue, 0, new Watermark(5L)); + + Assert.assertEquals(2, queue.size()); + Assert.assertFalse(queue.isEmpty()); + + Assert.assertEquals( + Arrays.asList(new Watermark(2L), new Watermark(5L)), popCompleted(queue)); + Assert.assertEquals(0, queue.size()); + Assert.assertTrue(queue.isEmpty()); + Assert.assertEquals(Collections.emptyList(), popCompleted(queue)); + } + + @Test + public void testCompletionOrderOrdered() { + final KeyedStreamElementQueue queue = + createStreamElementQueue(KeyedAsyncOutputMode.ORDERED, 4); + + ResultFuture entry1 = putSuccessfully(queue, 0, new StreamRecord<>(1, 0L)); + ResultFuture entry2 = putSuccessfully(queue, 0, new StreamRecord<>(2, 1L)); + putSuccessfully(queue, 0, new Watermark(2L)); + ResultFuture entry4 = putSuccessfully(queue, 0, new StreamRecord<>(3, 3L)); + + Assert.assertEquals(Collections.emptyList(), popCompleted(queue)); + Assert.assertEquals(4, queue.size()); + Assert.assertFalse(queue.isEmpty()); + + entry2.complete(Collections.singleton(11)); + entry4.complete(Collections.singleton(13)); + + Assert.assertEquals(Collections.emptyList(), popCompleted(queue)); + Assert.assertEquals(4, queue.size()); + Assert.assertFalse(queue.isEmpty()); + + entry1.complete(Collections.singleton(10)); + + List expected = + Arrays.asList( + new StreamRecord<>(10, 0L), + new StreamRecord<>(11, 1L), + new Watermark(2L), + new StreamRecord<>(13, 3L)); + Assert.assertEquals(expected, popCompleted(queue)); + Assert.assertEquals(0, queue.size()); + Assert.assertTrue(queue.isEmpty()); + } + + static ResultFuture putSuccessfully( + KeyedStreamElementQueue queue, + Integer key, + StreamElement streamElement) { + Optional> resultFuture = queue.tryPut(key, streamElement); + assertTrue(resultFuture.isPresent()); + return resultFuture.get(); + } + + static void putUnsuccessfully( + KeyedStreamElementQueue queue, + Integer key, + StreamElement streamElement) { + Optional> resultFuture = queue.tryPut(key, streamElement); + assertFalse(resultFuture.isPresent()); + } + + /** + * Pops all completed elements from the head of this queue. + * + * @return Completed elements or empty list if none exists. + */ + static List popCompleted(KeyedStreamElementQueue queue) { + final List completed = new ArrayList<>(); + TimestampedCollector collector = + new TimestampedCollector<>(new CollectorOutput<>(completed)); + while (queue.hasCompletedElements()) { + queue.emitCompletedElement(collector); + } + collector.close(); + return completed; + } +}