Skip to content

Commit f3f3d65

Browse files
committed
[FLINK-35854][table] Add LiteralAggFunction
1 parent 2d38713 commit f3f3d65

File tree

2 files changed

+239
-0
lines changed

2 files changed

+239
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.functions.aggfunctions;
20+
21+
import org.apache.flink.table.api.DataTypes;
22+
import org.apache.flink.table.expressions.Expression;
23+
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
24+
import org.apache.flink.table.functions.DeclarativeAggregateFunction;
25+
import org.apache.flink.table.types.DataType;
26+
import org.apache.flink.table.types.logical.DecimalType;
27+
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
28+
import org.apache.flink.table.types.logical.TimeType;
29+
import org.apache.flink.table.types.logical.TimestampType;
30+
31+
import org.apache.calcite.rex.RexLiteral;
32+
33+
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
34+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
35+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
36+
37+
public class LiteralAggFunction extends DeclarativeAggregateFunction {
38+
39+
private final RexLiteral rexLiteral;
40+
41+
public LiteralAggFunction(RexLiteral rexLiteral) {
42+
this.rexLiteral = rexLiteral;
43+
}
44+
45+
@Override
46+
public int operandCount() {
47+
return 0;
48+
}
49+
50+
@Override
51+
public UnresolvedReferenceExpression[] aggBufferAttributes() {
52+
return new UnresolvedReferenceExpression[] {unresolvedRef("literalAgg")};
53+
}
54+
55+
@Override
56+
public DataType[] getAggBufferTypes() {
57+
return new DataType[] {getResultType()};
58+
}
59+
60+
@Override
61+
public DataType getResultType() {
62+
return DataTypes.BOOLEAN();
63+
}
64+
65+
@Override
66+
public Expression[] initialValuesExpressions() {
67+
return new Expression[] {/* min= */ nullOf(getResultType())};
68+
}
69+
70+
@Override
71+
public Expression[] accumulateExpressions() {
72+
return new Expression[] {literal(rexLiteral.getValue(), DataTypes.BOOLEAN())};
73+
}
74+
75+
@Override
76+
public Expression[] retractExpressions() {
77+
// See MaxAggFunction.retractExpressions
78+
return new Expression[] {literal(rexLiteral.getValue(), DataTypes.BOOLEAN())};
79+
}
80+
81+
@Override
82+
public Expression[] mergeExpressions() {
83+
return new Expression[] {literal(rexLiteral.getValue(), DataTypes.BOOLEAN())};
84+
}
85+
86+
@Override
87+
public Expression getValueExpression() {
88+
return literal(rexLiteral.getValue(), DataTypes.BOOLEAN());
89+
}
90+
91+
/** Built-in Int Min aggregate function. */
92+
public static class IntMinAggFunction extends MinAggFunction {
93+
94+
@Override
95+
public DataType getResultType() {
96+
return DataTypes.INT();
97+
}
98+
}
99+
100+
/** Built-in Byte Min aggregate function. */
101+
public static class ByteMinAggFunction extends MinAggFunction {
102+
@Override
103+
public DataType getResultType() {
104+
return DataTypes.TINYINT();
105+
}
106+
}
107+
108+
/** Built-in Short Min aggregate function. */
109+
public static class ShortMinAggFunction extends MinAggFunction {
110+
@Override
111+
public DataType getResultType() {
112+
return DataTypes.SMALLINT();
113+
}
114+
}
115+
116+
/** Built-in Long Min aggregate function. */
117+
public static class LongMinAggFunction extends MinAggFunction {
118+
@Override
119+
public DataType getResultType() {
120+
return DataTypes.BIGINT();
121+
}
122+
}
123+
124+
/** Built-in Float Min aggregate function. */
125+
public static class FloatMinAggFunction extends MinAggFunction {
126+
@Override
127+
public DataType getResultType() {
128+
return DataTypes.FLOAT();
129+
}
130+
}
131+
132+
/** Built-in Double Min aggregate function. */
133+
public static class DoubleMinAggFunction extends MinAggFunction {
134+
@Override
135+
public DataType getResultType() {
136+
return DataTypes.DOUBLE();
137+
}
138+
}
139+
140+
/** Built-in Decimal Min aggregate function. */
141+
public static class DecimalMinAggFunction extends MinAggFunction {
142+
private final DataType resultType;
143+
144+
public DecimalMinAggFunction(DecimalType decimalType) {
145+
this.resultType = DataTypes.DECIMAL(decimalType.getPrecision(), decimalType.getScale());
146+
}
147+
148+
@Override
149+
public DataType getResultType() {
150+
return resultType;
151+
}
152+
}
153+
154+
/** Built-in Boolean Min aggregate function. */
155+
public static class BooleanLiteralAggFunction extends LiteralAggFunction {
156+
157+
public BooleanLiteralAggFunction(RexLiteral rexLiteral) {
158+
super(rexLiteral);
159+
}
160+
161+
@Override
162+
public DataType getResultType() {
163+
return DataTypes.BOOLEAN();
164+
}
165+
}
166+
167+
/** Built-in String Min aggregate function. */
168+
public static class StringMinAggFunction extends MinAggFunction {
169+
@Override
170+
public DataType getResultType() {
171+
return DataTypes.STRING();
172+
}
173+
}
174+
175+
/** Built-in Date Min aggregate function. */
176+
public static class DateMinAggFunction extends MinAggFunction {
177+
@Override
178+
public DataType getResultType() {
179+
return DataTypes.DATE();
180+
}
181+
}
182+
183+
/** Built-in Time Min aggregate function. */
184+
public static class TimeMinAggFunction extends MinAggFunction {
185+
@Override
186+
public DataType getResultType() {
187+
return DataTypes.TIME(TimeType.DEFAULT_PRECISION);
188+
}
189+
}
190+
191+
/** Built-in Timestamp Min aggregate function. */
192+
public static class TimestampMinAggFunction extends MinAggFunction {
193+
194+
private final TimestampType type;
195+
196+
public TimestampMinAggFunction(TimestampType type) {
197+
this.type = type;
198+
}
199+
200+
@Override
201+
public DataType getResultType() {
202+
return DataTypes.TIMESTAMP(type.getPrecision());
203+
}
204+
}
205+
206+
/** Built-in TimestampLtz Min aggregate function. */
207+
public static class TimestampLtzMinAggFunction extends MinAggFunction {
208+
209+
private final LocalZonedTimestampType type;
210+
211+
public TimestampLtzMinAggFunction(LocalZonedTimestampType type) {
212+
this.type = type;
213+
}
214+
215+
@Override
216+
public DataType getResultType() {
217+
return DataTypes.TIMESTAMP_LTZ(type.getPrecision());
218+
}
219+
}
220+
}

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ import org.apache.flink.table.runtime.functions.aggregate.PercentileAggFunction.
3131
import org.apache.flink.table.types.logical._
3232
import org.apache.flink.table.types.logical.LogicalTypeRoot._
3333

34+
import org.apache.calcite.rel.`type`.RelDataType
3435
import org.apache.calcite.rel.core.AggregateCall
36+
import org.apache.calcite.rex.{RexLiteral, RexNode}
37+
import org.apache.calcite.sql.`type`.SqlTypeName
3538
import org.apache.calcite.sql.{SqlAggFunction, SqlJsonConstructorNullClause, SqlKind, SqlRankFunction}
3639
import org.apache.calcite.sql.fun._
3740

@@ -158,6 +161,9 @@ class AggFunctionFactory(
158161
val onNull = fn.asInstanceOf[SqlJsonArrayAggAggFunction].getNullClause
159162
new JsonArrayAggFunction(argTypes, onNull == SqlJsonConstructorNullClause.ABSENT_ON_NULL)
160163

164+
case a: SqlAggFunction if a.getKind == SqlKind.LITERAL_AGG =>
165+
createLiteralAggFunction(call.getType, call.rexList.get(0))
166+
161167
case udagg: AggSqlFunction =>
162168
// Can not touch the literals, Calcite make them in previous RelNode.
163169
// In here, all inputs are input refs.
@@ -278,6 +284,19 @@ class AggFunctionFactory(
278284
}
279285
}
280286

287+
private def createLiteralAggFunction(
288+
relDataType: RelDataType,
289+
rexNode: RexNode): UserDefinedFunction = {
290+
relDataType.getSqlTypeName match {
291+
case SqlTypeName.BOOLEAN =>
292+
new LiteralAggFunction(rexNode.asInstanceOf[RexLiteral])
293+
case t =>
294+
throw new TableException(
295+
s"Min aggregate function does not support type: ''$t''.\n" +
296+
s"Please re-check the data type.")
297+
}
298+
}
299+
281300
private def createMinAggFunction(
282301
argTypes: Array[LogicalType],
283302
index: Int): UserDefinedFunction = {

0 commit comments

Comments
 (0)