diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java
index 22c8f1e3b62f9..6327af906e443 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java
@@ -685,6 +685,7 @@ public RelOptPredicateList inferPredicates(boolean includeEqualityInference) {
case SEMI:
case INNER:
case LEFT:
+ case ANTI:
infer(
leftChildPredicates,
allExprs,
@@ -762,6 +763,7 @@ public RelOptPredicateList inferPredicates(boolean includeEqualityInference) {
leftInferredPredicates,
rightInferredPredicates);
case LEFT:
+ case ANTI:
return RelOptPredicateList.of(
rexBuilder,
RelOptUtil.conjunctions(leftChildPredicates),
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java
index 430ead73b0cce..3f59e33aa8c2a 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java
@@ -76,7 +76,7 @@
* because of current Calcite way of inferring constants from IS NOT DISTINCT FROM clashes with
* filter push down.
*
- * Lines 397 ~ 399, Use Calcite 1.32.0 behavior for {@link RexUtil#gatherConstraints(Class,
+ *
Lines 399 ~ 401, Use Calcite 1.32.0 behavior for {@link RexUtil#gatherConstraints(Class,
* RexNode, Map, Set, RexBuilder)}.
*/
public class RexUtil {
@@ -870,8 +870,7 @@ public Void visitCall(RexCall call) {
}
/**
- * Returns whether a given tree contains any input references (both {@link RexInputRef} or
- * {@link RexTableArgCall}).
+ * Returns whether a given tree contains any {link RexInputRef} nodes.
*
* @param node a RexNode tree
*/
@@ -3000,7 +2999,8 @@ public RexNode visitCall(RexCall call) {
if (simplifiedNode.getType().equals(call.getType())) {
return simplifiedNode;
}
- return simplify.rexBuilder.makeCast(call.getType(), simplifiedNode, matchNullability);
+ return simplify.rexBuilder.makeCast(
+ call.getType(), simplifiedNode, matchNullability, false);
}
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/SqlFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/SqlFunction.java
index 7b4d7003c9cd4..1e0a918fbb1cf 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/SqlFunction.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/SqlFunction.java
@@ -277,6 +277,16 @@ private RelDataType deriveType(
getKind(),
validator.getCatalogReader().nameMatcher(),
false);
+
+ // If the call already has an operator and its syntax is SPECIAL, it must
+ // have been created intentionally by the parser.
+ if (function == null
+ && call.getOperator().getSyntax() == SqlSyntax.SPECIAL
+ && call.getOperator() instanceof SqlFunction
+ && validator.getOperatorTable().getOperatorList().contains(call.getOperator())) {
+ function = (SqlFunction) call.getOperator();
+ }
+
try {
// if we have a match on function name and parameter count, but
// couldn't find a function with a COLUMN_LIST type, retry, but
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/SqlUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/SqlUtil.java
index 814cf606650c2..a5bd65a0cb85b 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/SqlUtil.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/SqlUtil.java
@@ -34,6 +34,7 @@
import org.apache.calcite.runtime.CalciteContextException;
import org.apache.calcite.runtime.CalciteException;
import org.apache.calcite.runtime.Resources;
+import org.apache.calcite.sql.fun.SqlInOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlOperandMetadata;
@@ -42,6 +43,7 @@
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.util.SqlBasicVisitor;
+import org.apache.calcite.sql.util.SqlVisitor;
import org.apache.calcite.sql.validate.SqlNameMatcher;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.util.BarfingInvocationHandler;
@@ -1197,6 +1199,48 @@ private static SqlNode createBalancedCall(
return op.createCall(pos, leftNode, rightNode);
}
+ /**
+ * Returns whether a given node contains a {@link SqlInOperator}.
+ *
+ * @param node AST tree
+ */
+ public static boolean containsIn(SqlNode node) {
+ final Predicate callPredicate =
+ call -> call.getOperator() instanceof SqlInOperator;
+ return containsCall(node, callPredicate);
+ }
+
+ /**
+ * Returns whether an AST tree contains a call to an aggregate function.
+ *
+ * @param node AST tree
+ */
+ public static boolean containsAgg(SqlNode node) {
+ final Predicate callPredicate = call -> call.getOperator().isAggregator();
+ return containsCall(node, callPredicate);
+ }
+
+ /** Returns whether an AST tree contains a call that matches a given predicate. */
+ private static boolean containsCall(SqlNode node, Predicate callPredicate) {
+ try {
+ SqlVisitor visitor =
+ new SqlBasicVisitor() {
+ @Override
+ public Void visit(SqlCall call) {
+ if (callPredicate.test(call)) {
+ throw new Util.FoundOne(call);
+ }
+ return super.visit(call);
+ }
+ };
+ node.accept(visitor);
+ return false;
+ } catch (Util.FoundOne e) {
+ Util.swallow(e, null);
+ return true;
+ }
+ }
+
// ~ Inner Classes ----------------------------------------------------------
/**
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlCastFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlCastFunction.java
index 636b71364df44..3d24f2fecb999 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlCastFunction.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlCastFunction.java
@@ -22,6 +22,7 @@
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.SetMultimap;
import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeFamily;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
@@ -36,17 +37,26 @@
import org.apache.calcite.sql.SqlSyntax;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.SqlWriter;
+import org.apache.calcite.sql.type.FlinkSqlTypeMappingRule;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
+import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlMonotonicity;
-import org.apache.calcite.sql.validate.SqlValidatorImpl;
import java.text.Collator;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Objects;
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Objects.requireNonNull;
+import static org.apache.calcite.sql.type.SqlTypeUtil.isArray;
+import static org.apache.calcite.sql.type.SqlTypeUtil.isCollection;
+import static org.apache.calcite.sql.type.SqlTypeUtil.isMap;
+import static org.apache.calcite.sql.type.SqlTypeUtil.isRow;
import static org.apache.calcite.util.Static.RESOURCE;
/**
@@ -88,29 +98,112 @@ public class SqlCastFunction extends SqlFunction {
// ~ Constructors -----------------------------------------------------------
public SqlCastFunction() {
- super("CAST", SqlKind.CAST, null, InferTypes.FIRST_KNOWN, null, SqlFunctionCategory.SYSTEM);
+ this(SqlKind.CAST.toString(), SqlKind.CAST);
+ }
+
+ public SqlCastFunction(String name, SqlKind kind) {
+ super(
+ name,
+ kind,
+ returnTypeInference(kind == SqlKind.SAFE_CAST),
+ InferTypes.FIRST_KNOWN,
+ null,
+ SqlFunctionCategory.SYSTEM);
+ checkArgument(kind == SqlKind.CAST || kind == SqlKind.SAFE_CAST, kind);
}
// ~ Methods ----------------------------------------------------------------
- @Override
- public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
- assert opBinding.getOperandCount() == 2;
- RelDataType ret = opBinding.getOperandType(1);
- RelDataType firstType = opBinding.getOperandType(0);
- ret = opBinding.getTypeFactory().createTypeWithNullability(ret, firstType.isNullable());
- if (opBinding instanceof SqlCallBinding) {
- SqlCallBinding callBinding = (SqlCallBinding) opBinding;
- SqlNode operand0 = callBinding.operand(0);
-
- // dynamic parameters and null constants need their types assigned
- // to them using the type they are casted to.
- if (SqlUtil.isNullLiteral(operand0, false) || (operand0 instanceof SqlDynamicParam)) {
- final SqlValidatorImpl validator = (SqlValidatorImpl) callBinding.getValidator();
- validator.setValidatedNodeType(operand0, ret);
+ static SqlReturnTypeInference returnTypeInference(boolean safe) {
+ return opBinding -> {
+ assert opBinding.getOperandCount() == 2;
+ final RelDataType ret =
+ deriveType(
+ opBinding.getTypeFactory(),
+ opBinding.getOperandType(0),
+ opBinding.getOperandType(1),
+ safe);
+
+ if (opBinding instanceof SqlCallBinding) {
+ final SqlCallBinding callBinding = (SqlCallBinding) opBinding;
+ SqlNode operand0 = callBinding.operand(0);
+
+ // dynamic parameters and null constants need their types assigned
+ // to them using the type they are casted to.
+ if (SqlUtil.isNullLiteral(operand0, false) || operand0 instanceof SqlDynamicParam) {
+ callBinding.getValidator().setValidatedNodeType(operand0, ret);
+ }
}
+ return ret;
+ };
+ }
+
+ /** Derives the type of "CAST(expression AS targetType)". */
+ public static RelDataType deriveType(
+ RelDataTypeFactory typeFactory,
+ RelDataType expressionType,
+ RelDataType targetType,
+ boolean safe) {
+ return typeFactory.createTypeWithNullability(
+ targetType, expressionType.isNullable() || safe);
+ }
+
+ private static RelDataType createTypeWithNullabilityFromExpr(
+ RelDataTypeFactory typeFactory,
+ RelDataType expressionType,
+ RelDataType targetType,
+ boolean safe) {
+ boolean isNullable = expressionType.isNullable() || safe;
+
+ if (isCollection(expressionType)) {
+ RelDataType expressionElementType = expressionType.getComponentType();
+ RelDataType targetElementType = targetType.getComponentType();
+ requireNonNull(expressionElementType, () -> "componentType of " + expressionType);
+ requireNonNull(targetElementType, () -> "componentType of " + targetType);
+ RelDataType newElementType =
+ createTypeWithNullabilityFromExpr(
+ typeFactory, expressionElementType, targetElementType, safe);
+ return isArray(targetType)
+ ? SqlTypeUtil.createArrayType(typeFactory, newElementType, isNullable)
+ : SqlTypeUtil.createMultisetType(typeFactory, newElementType, isNullable);
+ }
+
+ if (isRow(expressionType)) {
+ final int fieldCount = expressionType.getFieldCount();
+ final List typeList = new ArrayList<>(fieldCount);
+ for (int i = 0; i < fieldCount; ++i) {
+ RelDataType expressionElementType = expressionType.getFieldList().get(i).getType();
+ RelDataType targetElementType = targetType.getFieldList().get(i).getType();
+ typeList.add(
+ createTypeWithNullabilityFromExpr(
+ typeFactory, expressionElementType, targetElementType, safe));
+ }
+ return typeFactory.createTypeWithNullability(
+ typeFactory.createStructType(typeList, targetType.getFieldNames()), isNullable);
}
- return ret;
+
+ if (isMap(expressionType)) {
+ RelDataType expressionKeyType =
+ requireNonNull(
+ expressionType.getKeyType(), () -> "keyType of " + expressionType);
+ RelDataType expressionValueType =
+ requireNonNull(
+ expressionType.getValueType(), () -> "valueType of " + expressionType);
+ RelDataType targetKeyType =
+ requireNonNull(targetType.getKeyType(), () -> "keyType of " + targetType);
+ RelDataType targetValueType =
+ requireNonNull(targetType.getValueType(), () -> "valueType of " + targetType);
+
+ RelDataType keyType =
+ createTypeWithNullabilityFromExpr(
+ typeFactory, expressionKeyType, targetKeyType, safe);
+ RelDataType valueType =
+ createTypeWithNullabilityFromExpr(
+ typeFactory, expressionValueType, targetValueType, safe);
+ SqlTypeUtil.createMapType(typeFactory, keyType, valueType, isNullable);
+ }
+
+ return typeFactory.createTypeWithNullability(targetType, isNullable);
}
@Override
@@ -175,7 +268,8 @@ private boolean canCastFrom(RelDataType toType, RelDataType fromType) {
FlinkTypeFactory.toLogicalType(fromType),
FlinkTypeFactory.toLogicalType(toType));
default:
- return SqlTypeUtil.canCastFrom(toType, fromType, true);
+ return SqlTypeUtil.canCastFrom(
+ toType, fromType, FlinkSqlTypeMappingRule.instance());
}
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlTimestampAddFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlTimestampAddFunction.java
deleted file mode 100644
index 97325c64c3a11..0000000000000
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlTimestampAddFunction.java
+++ /dev/null
@@ -1,156 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to you under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.calcite.sql.fun;
-
-import org.apache.calcite.avatica.util.TimeUnit;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeFactory;
-import org.apache.calcite.sql.SqlCall;
-import org.apache.calcite.sql.SqlFunction;
-import org.apache.calcite.sql.SqlFunctionCategory;
-import org.apache.calcite.sql.SqlIntervalQualifier;
-import org.apache.calcite.sql.SqlKind;
-import org.apache.calcite.sql.type.OperandTypes;
-import org.apache.calcite.sql.type.SqlReturnTypeInference;
-import org.apache.calcite.sql.type.SqlTypeFamily;
-import org.apache.calcite.sql.type.SqlTypeName;
-import org.apache.calcite.sql.type.SqlTypeTransforms;
-import org.apache.calcite.sql.validate.SqlValidator;
-import org.apache.calcite.sql.validate.SqlValidatorScope;
-import org.checkerframework.checker.nullness.qual.Nullable;
-
-import static org.apache.calcite.util.Util.first;
-
-/**
- * The TIMESTAMPADD
function, which adds an interval to a datetime (TIMESTAMP, TIME or
- * DATE).
- *
- * The SQL syntax is
- *
- *
- *
- * TIMESTAMPADD(timestamp interval, quantity,
- * datetime)
- *
- *
- *
- * The interval time unit can one of the following literals:
- *
- *
- * - NANOSECOND (and synonym SQL_TSI_FRAC_SECOND)
- *
- MICROSECOND (and synonyms SQL_TSI_MICROSECOND, FRAC_SECOND)
- *
- SECOND (and synonym SQL_TSI_SECOND)
- *
- MINUTE (and synonym SQL_TSI_MINUTE)
- *
- HOUR (and synonym SQL_TSI_HOUR)
- *
- DAY (and synonym SQL_TSI_DAY)
- *
- WEEK (and synonym SQL_TSI_WEEK)
- *
- MONTH (and synonym SQL_TSI_MONTH)
- *
- QUARTER (and synonym SQL_TSI_QUARTER)
- *
- YEAR (and synonym SQL_TSI_YEAR)
- *
- *
- * Returns modified datetime.
- *
- *
This class was copied over from Calcite to fix the return type deduction issue on timestamp
- * with local time zone type (CALCITE-4698).
- */
-public class SqlTimestampAddFunction extends SqlFunction {
-
- private static final int MILLISECOND_PRECISION = 3;
- private static final int MICROSECOND_PRECISION = 6;
-
- private static final SqlReturnTypeInference RETURN_TYPE_INFERENCE =
- opBinding ->
- deduceType(
- opBinding.getTypeFactory(),
- opBinding.getOperandLiteralValue(0, TimeUnit.class),
- opBinding.getOperandType(2));
-
- // BEGIN FLINK MODIFICATION
- // Reason: this method is changed to deduce return type on timestamp with local time zone
- // correctly
- // Whole class should be removed after CALCITE-4698 is fixed
- public static RelDataType deduceType(
- RelDataTypeFactory typeFactory,
- @Nullable TimeUnit timeUnit,
- RelDataType operandType1,
- RelDataType operandType2) {
- final RelDataType type = deduceType(typeFactory, timeUnit, operandType2);
- return typeFactory.createTypeWithNullability(
- type, operandType1.isNullable() || operandType2.isNullable());
- }
-
- static RelDataType deduceType(
- RelDataTypeFactory typeFactory, @Nullable TimeUnit timeUnit, RelDataType datetimeType) {
- final TimeUnit timeUnit2 = first(timeUnit, TimeUnit.EPOCH);
- SqlTypeName typeName = datetimeType.getSqlTypeName();
- switch (timeUnit2) {
- case MILLISECOND:
- return typeFactory.createSqlType(
- typeName, Math.max(MILLISECOND_PRECISION, datetimeType.getPrecision()));
- case MICROSECOND:
- return typeFactory.createSqlType(
- typeName, Math.max(MICROSECOND_PRECISION, datetimeType.getPrecision()));
- case HOUR:
- case MINUTE:
- case SECOND:
- if (datetimeType.getFamily() == SqlTypeFamily.TIME) {
- return datetimeType;
- } else if (datetimeType.getFamily() == SqlTypeFamily.TIMESTAMP) {
- return typeFactory.createSqlType(typeName, datetimeType.getPrecision());
- } else {
- return typeFactory.createSqlType(SqlTypeName.TIMESTAMP);
- }
- default:
- return datetimeType;
- }
- }
-
- @Override
- public void validateCall(
- SqlCall call,
- SqlValidator validator,
- SqlValidatorScope scope,
- SqlValidatorScope operandScope) {
- super.validateCall(call, validator, scope, operandScope);
-
- // This is either a time unit or a time frame:
- //
- // * In "TIMESTAMPADD(YEAR, 2, x)" operand 0 is a SqlIntervalQualifier
- // with startUnit = YEAR and timeFrameName = null.
- //
- // * In "TIMESTAMPADD(MINUTE15, 2, x) operand 0 is a SqlIntervalQualifier
- // with startUnit = EPOCH and timeFrameName = 'MINUTE15'.
- //
- // If the latter, check that timeFrameName is valid.
- validator.validateTimeFrame((SqlIntervalQualifier) call.getOperandList().get(0));
- }
-
- // END FLINK MODIFICATION
-
- /** Creates a SqlTimestampAddFunction. */
- SqlTimestampAddFunction(String name) {
- super(
- name,
- SqlKind.TIMESTAMP_ADD,
- RETURN_TYPE_INFERENCE.andThen(SqlTypeTransforms.TO_NULLABLE),
- null,
- OperandTypes.family(
- SqlTypeFamily.ANY, SqlTypeFamily.INTEGER, SqlTypeFamily.DATETIME),
- SqlFunctionCategory.TIMEDATE);
- }
-}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/CompositeSingleOperandTypeChecker.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/CompositeSingleOperandTypeChecker.java
deleted file mode 100644
index d7214ea876502..0000000000000
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/CompositeSingleOperandTypeChecker.java
+++ /dev/null
@@ -1,117 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to you under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.calcite.sql.type;
-
-import com.google.common.collect.ImmutableList;
-import org.apache.calcite.sql.SqlCallBinding;
-import org.apache.calcite.sql.SqlNode;
-import org.apache.calcite.util.Util;
-import org.checkerframework.checker.nullness.qual.Nullable;
-
-/**
- * Default implementation of {@link org.apache.calcite.sql.type.CompositeOperandTypeChecker}, the
- * class was copied over because of current Calcite issue CALCITE-5380.
- *
- *
Lines 73 ~ 78, 101 ~ 107
- */
-public class CompositeSingleOperandTypeChecker extends CompositeOperandTypeChecker
- implements SqlSingleOperandTypeChecker {
-
- // ~ Constructors -----------------------------------------------------------
-
- /**
- * Creates a CompositeSingleOperandTypeChecker. Outside this package, use {@link
- * SqlSingleOperandTypeChecker#and(SqlSingleOperandTypeChecker)}, {@link OperandTypes#and},
- * {@link OperandTypes#or} and similar.
- */
- CompositeSingleOperandTypeChecker(
- CompositeOperandTypeChecker.Composition composition,
- ImmutableList extends SqlSingleOperandTypeChecker> allowedRules,
- @Nullable String allowedSignatures) {
- super(composition, allowedRules, allowedSignatures, null, null);
- }
-
- // ~ Methods ----------------------------------------------------------------
-
- @SuppressWarnings("unchecked")
- @Override
- public ImmutableList extends SqlSingleOperandTypeChecker> getRules() {
- return (ImmutableList extends SqlSingleOperandTypeChecker>) allowedRules;
- }
-
- @Override
- public boolean checkSingleOperandType(
- SqlCallBinding callBinding, SqlNode node, int iFormalOperand, boolean throwOnFailure) {
- assert allowedRules.size() >= 1;
-
- final ImmutableList extends SqlSingleOperandTypeChecker> rules = getRules();
- if (composition == Composition.SEQUENCE) {
- return rules.get(iFormalOperand)
- .checkSingleOperandType(callBinding, node, 0, throwOnFailure);
- }
-
- int typeErrorCount = 0;
-
- boolean throwOnAndFailure = (composition == Composition.AND) && throwOnFailure;
-
- for (SqlSingleOperandTypeChecker rule : rules) {
- if (!rule.checkSingleOperandType(
- // FLINK MODIFICATION BEGIN
- callBinding,
- node,
- rule.getClass() == FamilyOperandTypeChecker.class ? 0 : iFormalOperand,
- throwOnAndFailure)) {
- // FLINK MODIFICATION END
- typeErrorCount++;
- }
- }
-
- boolean ret;
- switch (composition) {
- case AND:
- ret = typeErrorCount == 0;
- break;
- case OR:
- ret = typeErrorCount < allowedRules.size();
- break;
- default:
- // should never come here
- throw Util.unexpected(composition);
- }
-
- if (!ret && throwOnFailure) {
- // In the case of a composite OR, we want to throw an error
- // describing in more detail what the problem was, hence doing the
- // loop again.
- for (SqlSingleOperandTypeChecker rule : rules) {
- // FLINK MODIFICATION BEGIN
- rule.checkSingleOperandType(
- callBinding,
- node,
- rule.getClass() == FamilyOperandTypeChecker.class ? 0 : iFormalOperand,
- true);
- // FLINK MODIFICATION END
- }
-
- // If no exception thrown, just throw a generic validation signature
- // error.
- throw callBinding.newValidationSignatureError();
- }
-
- return ret;
- }
-}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/FlinkSqlTypeMappingRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/FlinkSqlTypeMappingRule.java
new file mode 100644
index 0000000000000..1c51c2d8c5b43
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/FlinkSqlTypeMappingRule.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.calcite.sql.type;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+
+/** Rules that determine whether a type is castable from another type. */
+public class FlinkSqlTypeMappingRule implements SqlTypeMappingRule {
+ private static final FlinkSqlTypeMappingRule INSTANCE;
+
+ private final Map> map;
+
+ private FlinkSqlTypeMappingRule(Map> map) {
+ this.map = ImmutableMap.copyOf(map);
+ }
+
+ public static FlinkSqlTypeMappingRule instance() {
+ return Objects.requireNonNull(FLINK_THREAD_PROVIDERS.get(), "flinkThreadProviders");
+ }
+
+ public static FlinkSqlTypeMappingRule instance(
+ Map> map) {
+ return new FlinkSqlTypeMappingRule(map);
+ }
+
+ public Map> getTypeMapping() {
+ return this.map;
+ }
+
+ static {
+ SqlTypeMappingRules.Builder coerceRules = SqlTypeMappingRules.builder();
+ coerceRules.addAll(SqlTypeCoercionRule.lenientInstance().getTypeMapping());
+ Map> map =
+ SqlTypeCoercionRule.lenientInstance().getTypeMapping();
+ Set rule = new HashSet<>();
+ rule.add(SqlTypeName.TINYINT);
+ rule.add(SqlTypeName.SMALLINT);
+ rule.add(SqlTypeName.INTEGER);
+ rule.add(SqlTypeName.BIGINT);
+ rule.add(SqlTypeName.DECIMAL);
+ rule.add(SqlTypeName.FLOAT);
+ rule.add(SqlTypeName.REAL);
+ rule.add(SqlTypeName.DOUBLE);
+ rule.add(SqlTypeName.CHAR);
+ rule.add(SqlTypeName.VARCHAR);
+ rule.add(SqlTypeName.BOOLEAN);
+ rule.add(SqlTypeName.TIMESTAMP);
+ rule.add(SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE);
+ coerceRules.add(SqlTypeName.FLOAT, rule);
+ coerceRules.add(SqlTypeName.DOUBLE, rule);
+ coerceRules.add(SqlTypeName.DECIMAL, rule);
+ INSTANCE = new FlinkSqlTypeMappingRule(coerceRules.map);
+ }
+
+ public static final ThreadLocal<@Nullable FlinkSqlTypeMappingRule> FLINK_THREAD_PROVIDERS =
+ ThreadLocal.withInitial(() -> INSTANCE);
+}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java
index a5f42aaeb1385..14ed16d2e9f77 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java
@@ -39,8 +39,8 @@
*
*
* - Should be removed after fixing CALCITE-6342: Lines 100-102
- *
- Should be removed after fixing CALCITE-6342: Lines 482-494
- *
- Should be removed after fix of FLINK-31350: Lines 561-573.
+ *
- Should be removed after fixing CALCITE-6342: Lines 484-496
+ *
- Should be removed after fix of FLINK-31350: Lines 563-575.
*
*/
public class SqlTypeFactoryImpl extends RelDataTypeFactoryImpl {
@@ -452,7 +452,7 @@ private static void assertBasic(SqlTypeName typeName) {
if (types.size() > (i + 1)) {
RelDataType type1 = types.get(i + 1);
if (SqlTypeUtil.isDatetime(type1)) {
- resultType = type1;
+ resultType = leastRestrictiveIntervalDatetimeType(type1, type);
return createTypeWithNullability(
resultType, nullCount > 0 || nullableCount > 0);
}
@@ -472,8 +472,10 @@ private static void assertBasic(SqlTypeName typeName) {
// datetime +/- interval (or integer) = datetime
if (types.size() > (i + 1)) {
RelDataType type1 = types.get(i + 1);
- if (SqlTypeUtil.isInterval(type1) || SqlTypeUtil.isIntType(type1)) {
- resultType = type;
+ final boolean isInterval1 = SqlTypeUtil.isInterval(type1);
+ final boolean isInt1 = SqlTypeUtil.isIntType(type1);
+ if (isInterval1 || isInt1) {
+ resultType = leastRestrictiveIntervalDatetimeType(type, type1);
return createTypeWithNullability(
resultType, nullCount > 0 || nullableCount > 0);
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
index ccae916bc64e8..2716293cb03be 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
@@ -18,7 +18,6 @@
import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding;
-import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
@@ -44,6 +43,7 @@
import org.apache.calcite.runtime.CalciteContextException;
import org.apache.calcite.runtime.CalciteException;
import org.apache.calcite.runtime.Feature;
+import org.apache.calcite.runtime.PairList;
import org.apache.calcite.runtime.Resources;
import org.apache.calcite.schema.ColumnStrategy;
import org.apache.calcite.schema.Table;
@@ -160,25 +160,22 @@
import static org.apache.calcite.sql.validate.SqlNonNullableAccessors.getCondition;
import static org.apache.calcite.sql.validate.SqlNonNullableAccessors.getTable;
import static org.apache.calcite.util.Static.RESOURCE;
+import static org.apache.calcite.util.Util.first;
/**
* Default implementation of {@link SqlValidator}, the class was copied over because of
* CALCITE-4554.
*
- * Lines 200 ~ 203, Flink improves error message for functions without appropriate arguments in
+ *
Lines 197 ~ 200, Flink improves error message for functions without appropriate arguments in
* handleUnresolvedFunction.
*
- *
Lines 2000 ~ 2020, Flink improves error message for functions without appropriate arguments in
+ *
Lines 2012 ~ 2032, Flink improves error message for functions without appropriate arguments in
* handleUnresolvedFunction at {@link SqlValidatorImpl#handleUnresolvedFunction}.
*
- *
Lines 3814 ~ 3818, 6458 ~ 6464 Flink improves Optimize the retrieval of sub-operands in
+ *
Lines 3840 ~ 3844, 6511 ~ 6517 Flink improves Optimize the retrieval of sub-operands in
* SqlCall when using NamedParameters at {@link SqlValidatorImpl#checkRollUp}.
*
- *
Lines 5196 ~ 5209, Flink enables TIMESTAMP and TIMESTAMP_LTZ for system time period
- * specification type at {@link org.apache.calcite.sql.validate.SqlValidatorImpl#validateSnapshot}.
- *
- *
Lines 5553 ~ 5559, Flink enables TIMESTAMP and TIMESTAMP_LTZ for first orderBy column in
- * matchRecognize at {@link SqlValidatorImpl#validateMatchRecognize}.
+ *
Lines 5246 ~ 5252, FLINK-24352 Add null check for temporal table check on SqlSnapshot.
*/
public class SqlValidatorImpl implements SqlValidatorWithHints {
// ~ Static fields/initializers ---------------------------------------------
@@ -284,7 +281,7 @@ public class SqlValidatorImpl implements SqlValidatorWithHints {
new SqlValidatorImpl.ValidationErrorFunction();
// TypeCoercion instance used for implicit type coercion.
- private TypeCoercion typeCoercion;
+ private final TypeCoercion typeCoercion;
// ~ Constructors -----------------------------------------------------------
@@ -323,14 +320,15 @@ protected SqlValidatorImpl(
TypeCoercion typeCoercion = config.typeCoercionFactory().create(typeFactory, this);
this.typeCoercion = typeCoercion;
- if (config.conformance().allowCoercionStringToArray()) {
- SqlTypeCoercionRule rules =
+ if (config.conformance().allowLenientCoercion()) {
+ final SqlTypeCoercionRule rules =
requireNonNull(
config.typeCoercionRules() != null
? config.typeCoercionRules()
- : SqlTypeCoercionRule.THREAD_PROVIDERS.get());
+ : SqlTypeCoercionRule.THREAD_PROVIDERS.get(),
+ "rules");
- ImmutableSet arrayMapping =
+ final ImmutableSet arrayMapping =
ImmutableSet.builder()
.addAll(
rules.getTypeMapping()
@@ -340,11 +338,11 @@ protected SqlValidatorImpl(
.build();
Map> mapping =
- new HashMap(rules.getTypeMapping());
+ new HashMap<>(rules.getTypeMapping());
mapping.replace(SqlTypeName.ARRAY, arrayMapping);
- rules = SqlTypeCoercionRule.instance(mapping);
+ SqlTypeCoercionRule rules2 = SqlTypeCoercionRule.instance(mapping);
- SqlTypeCoercionRule.THREAD_PROVIDERS.set(rules);
+ SqlTypeCoercionRule.THREAD_PROVIDERS.set(rules2);
} else if (config.typeCoercionRules() != null) {
SqlTypeCoercionRule.THREAD_PROVIDERS.set(config.typeCoercionRules());
}
@@ -388,14 +386,13 @@ public TimeFrameSet getTimeFrameSet() {
public SqlNodeList expandStar(
SqlNodeList selectList, SqlSelect select, boolean includeSystemVars) {
final List list = new ArrayList<>();
- final List> types = new ArrayList<>();
- for (int i = 0; i < selectList.size(); i++) {
- final SqlNode selectItem = selectList.get(i);
+ final PairList types = PairList.of();
+ for (SqlNode selectItem : selectList) {
final RelDataType originalType = getValidatedNodeTypeIfKnown(selectItem);
expandSelectItem(
selectItem,
select,
- Util.first(originalType, unknownType),
+ first(originalType, unknownType),
list,
catalogReader.nameMatcher().createSet(),
types,
@@ -405,7 +402,6 @@ public SqlNodeList expandStar(
return new SqlNodeList(list, SqlParserPos.ZERO);
}
- // implement SqlValidator
@Override
public void declareCursor(SqlSelect select, SqlValidatorScope parentScope) {
cursorSet.add(select);
@@ -414,33 +410,30 @@ public void declareCursor(SqlSelect select, SqlValidatorScope parentScope) {
// the position of the cursor relative to other cursors in that call
FunctionParamInfo funcParamInfo = requireNonNull(functionCallStack.peek(), "functionCall");
Map cursorMap = funcParamInfo.cursorPosToSelectMap;
- int numCursors = cursorMap.size();
- cursorMap.put(numCursors, select);
+ final int cursorCount = cursorMap.size();
+ cursorMap.put(cursorCount, select);
// create a namespace associated with the result of the select
// that is the argument to the cursor constructor; register it
// with a scope corresponding to the cursor
- SelectScope cursorScope = new SelectScope(parentScope, null, select);
+ SelectScope cursorScope = new SelectScope(parentScope, getEmptyScope(), select);
clauseScopes.put(IdPair.of(select, Clause.CURSOR), cursorScope);
final SelectNamespace selectNs = createSelectNamespace(select, select);
final String alias = SqlValidatorUtil.alias(select, nextGeneratedId++);
registerNamespace(cursorScope, alias, selectNs, false);
}
- // implement SqlValidator
@Override
public void pushFunctionCall() {
FunctionParamInfo funcInfo = new FunctionParamInfo();
functionCallStack.push(funcInfo);
}
- // implement SqlValidator
@Override
public void popFunctionCall() {
functionCallStack.pop();
}
- // implement SqlValidator
@Override
public @Nullable String getParentCursor(String columnListParamName) {
FunctionParamInfo funcParamInfo = requireNonNull(functionCallStack.peek(), "functionCall");
@@ -466,8 +459,8 @@ private boolean expandSelectItem(
RelDataType targetType,
List selectItems,
Set aliases,
- List> fields,
- final boolean includeSystemVars) {
+ PairList fields,
+ boolean includeSystemVars) {
final SelectScope scope = (SelectScope) getWhereScope(select);
if (expandStar(selectItems, aliases, fields, includeSystemVars, scope, selectItem)) {
return true;
@@ -510,12 +503,12 @@ private boolean expandSelectItem(
type = requireNonNull(selectScope.nullifyType(stripAs(expanded), type));
}
setValidatedNodeType(expanded, type);
- fields.add(Pair.of(alias, type));
+ fields.add(alias, type);
return false;
}
private static SqlNode expandExprFromJoin(
- SqlJoin join, SqlIdentifier identifier, @Nullable SelectScope scope) {
+ SqlJoin join, SqlIdentifier identifier, SelectScope scope) {
if (join.getConditionType() != JoinConditionType.USING) {
return identifier;
}
@@ -534,15 +527,11 @@ private static SqlNode expandExprFromJoin(
}
assert qualifiedNode.size() == 2;
- final SqlNode finalNode =
- SqlStdOperatorTable.AS.createCall(
- SqlParserPos.ZERO,
- SqlStdOperatorTable.COALESCE.createCall(
- SqlParserPos.ZERO,
- qualifiedNode.get(0),
- qualifiedNode.get(1)),
- new SqlIdentifier(name, SqlParserPos.ZERO));
- return finalNode;
+ return SqlStdOperatorTable.AS.createCall(
+ SqlParserPos.ZERO,
+ SqlStdOperatorTable.COALESCE.createCall(
+ SqlParserPos.ZERO, qualifiedNode.get(0), qualifiedNode.get(1)),
+ new SqlIdentifier(name, SqlParserPos.ZERO));
}
}
@@ -588,7 +577,7 @@ private List deriveNaturalJoinColumnList(SqlJoin join) {
private static SqlNode expandCommonColumn(
SqlSelect sqlSelect,
SqlNode selectItem,
- @Nullable SelectScope scope,
+ SelectScope scope,
SqlValidatorImpl validator) {
if (!(selectItem instanceof SqlIdentifier)) {
return selectItem;
@@ -611,17 +600,13 @@ private static SqlNode expandCommonColumn(
}
private static void validateQualifiedCommonColumn(
- SqlJoin join,
- SqlIdentifier identifier,
- @Nullable SelectScope scope,
- SqlValidatorImpl validator) {
+ SqlJoin join, SqlIdentifier identifier, SelectScope scope, SqlValidatorImpl validator) {
List names = validator.usingNames(join);
if (names == null) {
// Not USING or NATURAL.
return;
}
- requireNonNull(scope, "scope");
// First we should make sure that the first component is the table name.
// Then check whether the qualified identifier contains common column.
for (ScopeChild child : scope.children) {
@@ -645,7 +630,7 @@ private static void validateQualifiedCommonColumn(
private boolean expandStar(
List selectItems,
Set aliases,
- List> fields,
+ PairList fields,
boolean includeSystemVars,
SelectScope scope,
SqlNode node) {
@@ -659,6 +644,11 @@ private boolean expandStar(
final SqlParserPos startPosition = identifier.getParserPosition();
switch (identifier.names.size()) {
case 1:
+ SqlNode from = scope.getNode().getFrom();
+ if (from == null) {
+ throw newValidationError(identifier, RESOURCE.selectStarRequiresFrom());
+ }
+
boolean hasDynamicStruct = false;
for (ScopeChild child : scope.children) {
final int before = fields.size();
@@ -675,8 +665,8 @@ private boolean expandStar(
addToSelectList(
selectItems, aliases, fields, exp, scope, includeSystemVars);
} else {
- final SqlNode from = SqlNonNullableAccessors.getNode(child);
- final SqlValidatorNamespace fromNs = getNamespaceOrThrow(from, scope);
+ final SqlNode from2 = SqlNonNullableAccessors.getNode(child);
+ final SqlValidatorNamespace fromNs = getNamespaceOrThrow(from2, scope);
final RelDataType rowType = fromNs.getRowType();
for (RelDataTypeField field : rowType.getFieldList()) {
String columnName = field.getName();
@@ -706,9 +696,8 @@ private boolean expandStar(
if (!type.isNullable()) {
fields.set(
i,
- Pair.of(
- entry.getKey(),
- typeFactory.createTypeWithNullability(type, true)));
+ entry.getKey(),
+ typeFactory.createTypeWithNullability(type, true));
}
}
}
@@ -716,11 +705,10 @@ private boolean expandStar(
// If NATURAL JOIN or USING is present, move key fields to the front of
// the list, per standard SQL. Disabled if there are dynamic fields.
if (!hasDynamicStruct || Bug.CALCITE_2400_FIXED) {
- SqlNode from =
- requireNonNull(
- scope.getNode().getFrom(),
- () -> "getFrom for " + scope.getNode());
- new Permute(from, 0).permute(selectItems, fields);
+ // If some fields before star identifier,
+ // we should move offset.
+ int offset = calculatePermuteOffset(selectItems);
+ new Permute(from, offset).permute(selectItems, fields);
}
return true;
@@ -767,6 +755,17 @@ private boolean expandStar(
}
}
+ private static int calculatePermuteOffset(List selectItems) {
+ for (int i = 0; i < selectItems.size(); i++) {
+ SqlNode selectItem = selectItems.get(i);
+ SqlNode col = SqlUtil.stripAs(selectItem);
+ if (col.getKind() == SqlKind.IDENTIFIER && selectItem.getKind() != SqlKind.AS) {
+ return i;
+ }
+ }
+ return 0;
+ }
+
private SqlNode maybeCast(SqlNode node, RelDataType currentType, RelDataType desiredType) {
return SqlTypeUtil.equalSansNullability(typeFactory, currentType, desiredType)
? node
@@ -777,7 +776,7 @@ private SqlNode maybeCast(SqlNode node, RelDataType currentType, RelDataType des
private boolean addOrExpandField(
List selectItems,
Set aliases,
- List> fields,
+ PairList fields,
boolean includeSystemVars,
SelectScope scope,
SqlIdentifier id,
@@ -846,7 +845,7 @@ public List lookupHints(SqlNode topNode, SqlParserPos pos) {
*/
void lookupSelectHints(SqlSelect select, SqlParserPos pos, Collection hintList) {
IdInfo info = idPositions.get(pos.toString());
- if ((info == null) || (info.scope == null)) {
+ if (info == null) {
SqlNode fromNode = select.getFrom();
final SqlValidatorScope fromScope = getFromScope(select);
lookupFromHints(fromNode, fromScope, pos, hintList);
@@ -866,7 +865,7 @@ private void lookupSelectHints(
private void lookupFromHints(
@Nullable SqlNode node,
- @Nullable SqlValidatorScope scope,
+ SqlValidatorScope scope,
SqlParserPos pos,
Collection hintList) {
if (node == null) {
@@ -903,7 +902,7 @@ private void lookupFromHints(
private void lookupJoinHints(
SqlJoin join,
- @Nullable SqlValidatorScope scope,
+ SqlValidatorScope scope,
SqlParserPos pos,
Collection hintList) {
SqlNode left = join.getLeft();
@@ -1062,8 +1061,7 @@ private SqlNode validateScopedExpression(SqlNode topNode, SqlValidatorScope scop
}
@Override
- public void validateQuery(
- SqlNode node, @Nullable SqlValidatorScope scope, RelDataType targetRowType) {
+ public void validateQuery(SqlNode node, SqlValidatorScope scope, RelDataType targetRowType) {
final SqlValidatorNamespace ns = getNamespaceOrThrow(node, scope);
if (node.getKind() == SqlKind.TABLESAMPLE) {
List operands = ((SqlCall) node).getOperandList();
@@ -1108,7 +1106,7 @@ protected void validateNamespace(
}
}
- @VisibleForTesting
+ @Override
public SqlValidatorScope getEmptyScope() {
return new EmptyScope(this);
}
@@ -1159,8 +1157,8 @@ public SqlValidatorScope getGroupScope(SqlSelect select) {
}
@Override
- public @Nullable SqlValidatorScope getFromScope(SqlSelect select) {
- return scopes.get(select);
+ public SqlValidatorScope getFromScope(SqlSelect select) {
+ return requireNonNull(scopes.get(select), () -> "no scope for " + select);
}
@Override
@@ -1174,8 +1172,8 @@ public SqlValidatorScope getMatchRecognizeScope(SqlMatchRecognize node) {
}
@Override
- public @Nullable SqlValidatorScope getJoinScope(SqlNode node) {
- return scopes.get(stripAs(node));
+ public SqlValidatorScope getJoinScope(SqlNode node) {
+ return requireNonNull(scopes.get(stripAs(node)), () -> "scope for " + node);
}
@Override
@@ -1183,12 +1181,17 @@ public SqlValidatorScope getOverScope(SqlNode node) {
return getScopeOrThrow(node);
}
+ @Override
+ public SqlValidatorScope getWithScope(SqlNode withItem) {
+ assert withItem.getKind() == SqlKind.WITH_ITEM;
+ return getScopeOrThrow(withItem);
+ }
+
private SqlValidatorScope getScopeOrThrow(SqlNode node) {
return requireNonNull(scopes.get(node), () -> "scope for " + node);
}
- private @Nullable SqlValidatorNamespace getNamespace(
- SqlNode node, @Nullable SqlValidatorScope scope) {
+ private @Nullable SqlValidatorNamespace getNamespace(SqlNode node, SqlValidatorScope scope) {
if (node instanceof SqlIdentifier && scope instanceof DelegatingScope) {
final SqlIdentifier id = (SqlIdentifier) node;
final DelegatingScope idScope = (DelegatingScope) ((DelegatingScope) scope).getParent();
@@ -1282,7 +1285,7 @@ SqlValidatorNamespace getNamespaceOrThrow(SqlNode node) {
* @see #getNamespace(SqlNode)
*/
@API(since = "1.27", status = API.Status.INTERNAL)
- SqlValidatorNamespace getNamespaceOrThrow(SqlNode node, @Nullable SqlValidatorScope scope) {
+ SqlValidatorNamespace getNamespaceOrThrow(SqlNode node, SqlValidatorScope scope) {
return requireNonNull(
getNamespace(node, scope), () -> "namespace for " + node + ", scope " + scope);
}
@@ -2151,7 +2154,7 @@ protected void addToSelectList(
if (!Objects.equals(alias, uniqueAlias)) {
exp = SqlValidatorUtil.addAlias(exp, uniqueAlias);
}
- fieldList.add(Pair.of(uniqueAlias, deriveType(scope, exp)));
+ ((PairList) fieldList).add(uniqueAlias, deriveType(scope, exp));
list.add(exp);
}
@@ -2285,7 +2288,7 @@ protected void registerNamespace(
/**
* Registers scopes and namespaces implied a relational expression in the FROM clause.
*
- * {@code parentScope} and {@code usingScope} are often the same. They differ when the
+ *
{@code parentScope0} and {@code usingScope} are often the same. They differ when the
* namespace are not visible within the parent. (Example needed.)
*
*
Likewise, {@code enclosingNode} and {@code node} are often the same. {@code enclosingNode}
@@ -2293,7 +2296,7 @@ protected void registerNamespace(
* AS alias) or a table sample clause are stripped away to get {@code node}. Both are
* recorded in the namespace.
*
- * @param parentScope Parent scope which this scope turns to in order to resolve objects
+ * @param parentScope0 Parent scope that this scope turns to in order to resolve objects
* @param usingScope Scope whose child list this scope should add itself to
* @param register Whether to register this scope as a child of {@code usingScope}
* @param node Node which namespace is based on
@@ -2308,7 +2311,7 @@ protected void registerNamespace(
* @return registered node, usually the same as {@code node}
*/
private SqlNode registerFrom(
- SqlValidatorScope parentScope,
+ SqlValidatorScope parentScope0,
SqlValidatorScope usingScope,
boolean register,
final SqlNode node,
@@ -2365,19 +2368,22 @@ private SqlNode registerFrom(
}
}
+ final SqlValidatorScope parentScope;
if (lateral) {
SqlValidatorScope s = usingScope;
while (s instanceof JoinScope) {
s = ((JoinScope) s).getUsingScope();
}
final SqlNode node2 = s != null ? s.getNode() : node;
- final TableScope tableScope = new TableScope(parentScope, node2);
+ final TableScope tableScope = new TableScope(parentScope0, node2);
if (usingScope instanceof ListScope) {
for (ScopeChild child : ((ListScope) usingScope).children) {
tableScope.addChild(child.namespace, child.name, child.nullable);
}
}
parentScope = tableScope;
+ } else {
+ parentScope = parentScope0;
}
SqlCall call;
@@ -2394,7 +2400,8 @@ private SqlNode registerFrom(
final boolean needAliasNamespace =
call.operandCount() > 2
|| expr.getKind() == SqlKind.VALUES
- || expr.getKind() == SqlKind.UNNEST;
+ || expr.getKind() == SqlKind.UNNEST
+ || expr.getKind() == SqlKind.COLLECTION_TABLE;
newExpr =
registerFrom(
parentScope,
@@ -2520,6 +2527,8 @@ private SqlNode registerFrom(
if (newRight != right) {
join.setRight(newRight);
}
+ scopes.putIfAbsent(stripAs(join.getRight()), parentScope);
+ scopes.putIfAbsent(stripAs(join.getLeft()), parentScope);
registerSubQueries(joinScope, join.getCondition());
final JoinNamespace joinNamespace = new JoinNamespace(this, join);
registerNamespace(null, null, joinNamespace, forceNullable);
@@ -2785,8 +2794,7 @@ private void registerQuery(
final SqlSelect select = (SqlSelect) node;
final SelectNamespace selectNs = createSelectNamespace(select, enclosingNode);
registerNamespace(usingScope, alias, selectNs, forceNullable);
- final SqlValidatorScope windowParentScope =
- (usingScope != null) ? usingScope : parentScope;
+ final SqlValidatorScope windowParentScope = first(usingScope, parentScope);
SelectScope selectScope = new SelectScope(parentScope, windowParentScope, select);
scopes.put(select, selectScope);
@@ -3046,6 +3054,7 @@ private void registerWith(
boolean checkUpdate) {
final WithNamespace withNamespace = new WithNamespace(this, with, enclosingNode);
registerNamespace(usingScope, alias, withNamespace, forceNullable);
+ scopes.put(with, parentScope);
SqlValidatorScope scope = parentScope;
for (SqlNode withItem_ : with.withList) {
@@ -3355,6 +3364,7 @@ public TimeFrame validateTimeFrame(SqlIntervalQualifier qualifier) {
* @param scope Scope
*/
protected void validateFrom(SqlNode node, RelDataType targetRowType, SqlValidatorScope scope) {
+ requireNonNull(scope, "scope");
requireNonNull(targetRowType, "targetRowType");
switch (node.getKind()) {
case AS:
@@ -3688,8 +3698,7 @@ protected void validateSelect(SqlSelect select, RelDataType targetRowType) {
}
// Make sure that items in FROM clause have distinct aliases.
- final SelectScope fromScope =
- (SelectScope) requireNonNull(getFromScope(select), () -> "fromScope for " + select);
+ final SelectScope fromScope = (SelectScope) getFromScope(select);
List<@Nullable String> names = fromScope.getChildNames();
if (!catalogReader.nameMatcher().isCaseSensitive()) {
//noinspection RedundantTypeArguments
@@ -3822,6 +3831,11 @@ private void checkRollUp(
// we stripped the field access. Recurse to this method, the DOT's operand
// can be another SqlCall, or an SqlIdentifier.
checkRollUp(grandParent, parent, stripDot, scope, contextClause);
+ } else if (stripDot.getKind() == SqlKind.CONVERT
+ || stripDot.getKind() == SqlKind.TRANSLATE) {
+ // only need to check operand[0] for CONVERT or TRANSLATE
+ SqlNode child = ((SqlCall) stripDot).getOperandList().get(0);
+ checkRollUp(parent, current, child, scope, contextClause);
} else {
// ----- FLINK MODIFICATION BEGIN -----
SqlCall call = (SqlCall) stripDot;
@@ -4122,8 +4136,7 @@ protected void validateWindowClause(SqlSelect select) {
return;
}
- final SelectScope windowScope =
- (SelectScope) requireNonNull(getFromScope(select), () -> "fromScope for " + select);
+ final SelectScope windowScope = (SelectScope) getFromScope(select);
// 1. ensure window names are simple
// 2. ensure they are unique within this scope
@@ -4249,12 +4262,6 @@ public void validateSequenceValue(SqlValidatorScope scope, SqlIdentifier id) {
throw newValidationError(id, RESOURCE.notASequence(id.toString()));
}
- @Override
- public @Nullable SqlValidatorScope getWithScope(SqlNode withItem) {
- assert withItem.getKind() == SqlKind.WITH_ITEM;
- return scopes.get(withItem);
- }
-
@Override
public TypeCoercion getTypeCoercion() {
assert config.typeCoercionEnabled();
@@ -4539,7 +4546,7 @@ protected RelDataType validateSelectList(
final SqlValidatorScope selectScope = getSelectScope(select);
final List expandedSelectItems = new ArrayList<>();
final Set aliases = new HashSet<>();
- final List> fieldList = new ArrayList<>();
+ final PairList fieldList = PairList.of();
for (SqlNode selectItem : selectItems) {
if (selectItem instanceof SqlSelect) {
@@ -4635,7 +4642,7 @@ private void handleScalarSubQuery(
SqlSelect selectItem,
List expandedSelectItems,
Set aliasList,
- List> fieldList) {
+ PairList fieldList) {
// A scalar sub-query only has one output column.
if (1 != SqlNonNullableAccessors.getSelectList(selectItem).size()) {
throw newValidationError(selectItem, RESOURCE.onlyScalarSubQueryAllowed());
@@ -4660,7 +4667,7 @@ private void handleScalarSubQuery(
RelDataType nodeType = rec.getFieldList().get(0).getType();
nodeType = typeFactory.createTypeWithNullability(nodeType, true);
- fieldList.add(Pair.of(alias, nodeType));
+ fieldList.add(alias, nodeType);
}
/**
@@ -4679,13 +4686,10 @@ protected RelDataType createTargetRowType(
return baseRowType;
}
List targetFields = baseRowType.getFieldList();
- final List> fields = new ArrayList<>();
+ final PairList fields = PairList.of();
if (append) {
for (RelDataTypeField targetField : targetFields) {
- fields.add(
- Pair.of(
- SqlUtil.deriveAliasFromOrdinal(fields.size()),
- targetField.getType()));
+ fields.add(SqlUtil.deriveAliasFromOrdinal(fields.size()), targetField.getType());
}
}
final Set assignedFields = new HashSet<>();
@@ -4733,6 +4737,7 @@ public void validateInsert(SqlInsert insert) {
validateSelect(sqlSelect, targetRowType);
} else {
final SqlValidatorScope scope = scopes.get(source);
+ requireNonNull(scope, "scope");
validateQuery(source, scope, targetRowType);
}
@@ -5232,14 +5237,13 @@ private void validateSnapshot(
SqlSnapshot snapshot = (SqlSnapshot) node;
SqlNode period = snapshot.getPeriod();
RelDataType dataType = deriveType(requireNonNull(scope, "scope"), period);
- // ----- FLINK MODIFICATION BEGIN -----
- if (!(dataType.getSqlTypeName() == SqlTypeName.TIMESTAMP
- || dataType.getSqlTypeName() == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE)) {
+ if (!SqlTypeUtil.isTimestamp(dataType)) {
throw newValidationError(
period,
Static.RESOURCE.illegalExpressionForTemporal(
dataType.getSqlTypeName().getName()));
}
+ // ----- FLINK MODIFICATION BEGIN -----
if (ns instanceof IdentifierNamespace && ns.resolve() instanceof WithItemNamespace) {
// If the snapshot is used over a CTE, then we don't have a concrete underlying
// table to operate on. This will be rechecked later in the planner rules.
@@ -5587,13 +5591,9 @@ public void validateMatchRecognize(SqlCall call) {
(SqlIdentifier) requireNonNull(firstOrderByColumn, "firstOrderByColumn");
}
RelDataType firstOrderByColumnType = deriveType(scope, identifier);
- // ----- FLINK MODIFICATION BEGIN -----
- if (!(firstOrderByColumnType.getSqlTypeName() == SqlTypeName.TIMESTAMP
- || firstOrderByColumnType.getSqlTypeName()
- == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE)) {
+ if (!SqlTypeUtil.isTimestamp(firstOrderByColumnType)) {
throw newValidationError(interval, RESOURCE.firstColumnOfOrderByMustBeTimestamp());
}
- // ----- FLINK MODIFICATION END -----
SqlNode expand = expand(interval, scope);
RelDataType type = deriveType(scope, expand);
@@ -5632,13 +5632,14 @@ public void validateMatchRecognize(SqlCall call) {
}
}
- List> measureColumns =
+ PairList measureColumns =
validateMeasure(matchRecognize, scope, allRows);
- for (Map.Entry c : measureColumns) {
- if (!typeBuilder.nameExists(c.getKey())) {
- typeBuilder.add(c.getKey(), c.getValue());
- }
- }
+ measureColumns.forEach(
+ (name, type) -> {
+ if (!typeBuilder.nameExists(name)) {
+ typeBuilder.add(name, type);
+ }
+ });
final RelDataType rowType = typeBuilder.build();
if (matchRecognize.getMeasureList().size() == 0) {
@@ -5648,12 +5649,12 @@ public void validateMatchRecognize(SqlCall call) {
}
}
- private List> validateMeasure(
+ private PairList validateMeasure(
SqlMatchRecognize mr, MatchRecognizeScope scope, boolean allRows) {
final List aliases = new ArrayList<>();
final List sqlNodes = new ArrayList<>();
final SqlNodeList measures = mr.getMeasureList();
- final List> fields = new ArrayList<>();
+ final PairList fields = PairList.of();
for (SqlNode measure : measures) {
assert measure instanceof SqlCall;
@@ -5668,7 +5669,7 @@ private List> validateMeasure(
final RelDataType type = deriveType(scope, expand);
setValidatedNodeType(measure, type);
- fields.add(Pair.of(alias, type));
+ fields.add(alias, type);
sqlNodes.add(
SqlStdOperatorTable.AS.createCall(
SqlParserPos.ZERO,
@@ -5757,8 +5758,7 @@ private static String alias(SqlNode item) {
}
public void validatePivot(SqlPivot pivot) {
- final PivotScope scope =
- requireNonNull((PivotScope) getJoinScope(pivot), () -> "joinScope for " + pivot);
+ final PivotScope scope = (PivotScope) getJoinScope(pivot);
final PivotNamespace ns = getNamespaceOrThrow(pivot).unwrap(PivotNamespace.class);
assert ns.rowType == null;
@@ -5774,12 +5774,12 @@ public void validatePivot(SqlPivot pivot) {
// an aggregate or as an axis.
// Aggregates, e.g. "PIVOT (sum(x) AS sum_x, count(*) AS c)"
- final List> aggNames = new ArrayList<>();
+ final PairList<@Nullable String, RelDataType> aggNames = PairList.of();
pivot.forEachAgg(
(alias, call) -> {
call.validate(this, scope);
final RelDataType type = deriveType(scope, call);
- aggNames.add(Pair.of(alias, type));
+ aggNames.add(alias, type);
if (!(call instanceof SqlCall)
|| !(((SqlCall) call).getOperator() instanceof SqlAggFunction)) {
throw newValidationError(call, RESOURCE.pivotAggMalformed());
@@ -5837,8 +5837,7 @@ public void validatePivot(SqlPivot pivot) {
subNode)),
true);
});
- Pair.forEach(
- aggNames,
+ aggNames.forEach(
(aggAlias, aggType) ->
typeBuilder.add(
aggAlias == null ? alias : alias + "_" + aggAlias,
@@ -5850,8 +5849,7 @@ public void validatePivot(SqlPivot pivot) {
}
public void validateUnpivot(SqlUnpivot unpivot) {
- final UnpivotScope scope =
- (UnpivotScope) requireNonNull(getJoinScope(unpivot), () -> "scope for " + unpivot);
+ final UnpivotScope scope = (UnpivotScope) getJoinScope(unpivot);
final UnpivotNamespace ns = getNamespaceOrThrow(unpivot).unwrap(UnpivotNamespace.class);
assert ns.rowType == null;
@@ -5903,7 +5901,7 @@ public void validateUnpivot(SqlUnpivot unpivot) {
columnNames.addAll(unusedColumnNames);
// Gather the name and type of each measure.
- final List> measureNameTypes = new ArrayList<>();
+ final PairList measureNameTypes = PairList.of();
Ord.forEach(
unpivot.measureList,
(measure, i) -> {
@@ -5928,7 +5926,7 @@ public void validateUnpivot(SqlUnpivot unpivot) {
if (!columnNames.add(measureName)) {
throw newValidationError(measure, RESOURCE.unpivotDuplicate(measureName));
}
- measureNameTypes.add(Pair.of(measureName, type));
+ measureNameTypes.add(measureName, type);
});
// Gather the name and type of each axis.
@@ -5942,7 +5940,7 @@ public void validateUnpivot(SqlUnpivot unpivot) {
// The type of 'job' is derived as the least restrictive type of the values
// ('CLERK', 'ANALYST'), namely VARCHAR(7). The derived type of 'deptno' is
// the type of values (10, 20), namely INTEGER.
- final List> axisNameTypes = new ArrayList<>();
+ final PairList axisNameTypes = PairList.of();
Ord.forEach(
unpivot.axisList,
(axis, i) -> {
@@ -5966,7 +5964,7 @@ public void validateUnpivot(SqlUnpivot unpivot) {
if (!columnNames.add(axisName)) {
throw newValidationError(axis, RESOURCE.unpivotDuplicate(axisName));
}
- axisNameTypes.add(Pair.of(axisName, type));
+ axisNameTypes.add(axisName, type);
});
// Columns that have been seen as arguments to aggregates or as axes
@@ -6081,26 +6079,39 @@ public void validateAggregateParams(
throw new AssertionError(op);
}
+ // Because there are two forms of the PERCENTILE_CONT/PERCENTILE_DISC functions,
+ // they are distinguished by their operand count and then validated accordingly.
+ // For example, the standard single operand form requires group order while the
+ // 2-operand form allows for null treatment and requires an OVER() clause.
if (op.isPercentile()) {
- assert op.requiresGroupOrder() == Optionality.MANDATORY;
- assert orderList != null;
-
- // Validate that percentile function have a single ORDER BY expression
- if (orderList.size() != 1) {
- throw newValidationError(orderList, RESOURCE.orderByRequiresOneKey(op.getName()));
- }
-
- // Validate that the ORDER BY field is of NUMERIC type
- SqlNode node = orderList.get(0);
- assert node != null;
-
- final RelDataType type = deriveType(scope, node);
- final @Nullable SqlTypeFamily family = type.getSqlTypeName().getFamily();
- if (family == null || family.allowableDifferenceTypes().isEmpty()) {
- throw newValidationError(
- orderList,
- RESOURCE.unsupportedTypeInOrderBy(
- type.getSqlTypeName().getName(), op.getName()));
+ switch (aggCall.operandCount()) {
+ case 1:
+ assert op.requiresGroupOrder() == Optionality.MANDATORY;
+ assert orderList != null;
+ // Validate that percentile function have a single ORDER BY expression
+ if (orderList.size() != 1) {
+ throw newValidationError(
+ orderList, RESOURCE.orderByRequiresOneKey(op.getName()));
+ }
+ // Validate that the ORDER BY field is of NUMERIC type
+ SqlNode node = orderList.get(0);
+ assert node != null;
+ final RelDataType type = deriveType(scope, node);
+ final @Nullable SqlTypeFamily family = type.getSqlTypeName().getFamily();
+ if (family == null || family.allowableDifferenceTypes().isEmpty()) {
+ throw newValidationError(
+ orderList,
+ RESOURCE.unsupportedTypeInOrderBy(
+ type.getSqlTypeName().getName(), op.getName()));
+ }
+ break;
+ case 2:
+ assert op.allowsNullTreatment();
+ assert op.requiresOver();
+ assert op.requiresGroupOrder() == Optionality.FORBIDDEN;
+ break;
+ default:
+ throw newValidationError(aggCall, RESOURCE.percentileFunctionsArgumentLimit());
}
}
}
@@ -6834,7 +6845,7 @@ static class ExtendedExpander extends Expander {
final boolean replaceAliases = clause.shouldReplaceAliases(validator.config);
if (!replaceAliases) {
- final SelectScope scope = validator.getRawSelectScope(select);
+ final SelectScope scope = validator.getRawSelectScopeNonNull(select);
SqlNode node = expandCommonColumn(select, id, scope, validator);
if (node != id) {
return node;
@@ -7291,8 +7302,10 @@ private class Permute {
final List sources;
final RelDataType rowType;
final boolean trivial;
+ final int offset;
Permute(SqlNode from, int offset) {
+ this.offset = offset;
switch (from.getKind()) {
case JOIN:
final SqlJoin join = (SqlJoin) from;
@@ -7357,16 +7370,17 @@ private RelDataTypeField field(String name) {
}
/** Moves fields according to the permutation. */
- public void permute(
- List selectItems, List> fields) {
+ void permute(List selectItems, PairList fields) {
if (trivial) {
return;
}
final List oldSelectItems = ImmutableList.copyOf(selectItems);
selectItems.clear();
- final List> oldFields = ImmutableList.copyOf(fields);
+ selectItems.addAll(oldSelectItems.subList(0, offset));
+ final PairList oldFields = fields.immutable();
fields.clear();
+ fields.addAll(oldFields.subList(0, offset));
for (ImmutableIntList source : sources) {
final int p0 = source.get(0);
Map.Entry field = oldFields.get(p0);
@@ -7399,7 +7413,7 @@ public void permute(
new SqlIdentifier(name, SqlParserPos.ZERO));
type = typeFactory.createTypeWithNullability(type2, nullable);
}
- fields.add(Pair.of(name, type));
+ fields.add(name, type);
selectItems.add(selectItem);
}
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/AggConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/AggConverter.java
new file mode 100644
index 0000000000000..849d222f6a845
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/AggConverter.java
@@ -0,0 +1,623 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.sql2rel;
+
+import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.calcite.linq4j.Ord;
+import org.apache.calcite.rel.RelCollation;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.RelFieldCollation;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.runtime.PairList;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlCall;
+import org.apache.calcite.sql.SqlDataTypeSpec;
+import org.apache.calcite.sql.SqlDynamicParam;
+import org.apache.calcite.sql.SqlIdentifier;
+import org.apache.calcite.sql.SqlIntervalQualifier;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlLiteral;
+import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlNodeList;
+import org.apache.calcite.sql.SqlSelectKeyword;
+import org.apache.calcite.sql.SqlUtil;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.parser.SqlParserPos;
+import org.apache.calcite.sql.util.SqlVisitor;
+import org.apache.calcite.sql.validate.AggregatingSelectScope;
+import org.apache.calcite.sql.validate.SqlValidator;
+import org.apache.calcite.sql.validate.SqlValidatorUtil;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Litmus;
+import org.apache.calcite.util.Util;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import static org.apache.calcite.linq4j.Nullness.castNonNull;
+
+/**
+ * FLINK modifications are at lines
+ *
+ *
+ * - Added in FLINK-34057, FLINK-34058, FLINK-34312: Lines 452 ~ 469
+ *
+ */
+class AggConverter implements SqlVisitor {
+ private final SqlToRelConverter.Blackboard bb;
+ private final Map nameMap;
+
+ /** The group-by expressions, in {@link SqlNode} format. */
+ final SqlNodeList groupExprs = new SqlNodeList(SqlParserPos.ZERO);
+
+ /** The auxiliary group-by expressions. */
+ private final Map> auxiliaryGroupExprs = new HashMap<>();
+
+ /** Measure expressions, in {@link SqlNode} format. */
+ private final SqlNodeList measureExprs = new SqlNodeList(SqlParserPos.ZERO);
+
+ /**
+ * Input expressions for the group columns and aggregates, in {@link RexNode} format. The first
+ * elements of the list correspond to the elements in {@link #groupExprs}; the remaining
+ * elements are for aggregates. The right field of each pair is the name of the expression,
+ * where the expressions are simple mappings to input fields.
+ */
+ final PairList convertedInputExprs = PairList.of();
+
+ /**
+ * Expressions to be evaluated as rows are being placed into the aggregate's hash table. This is
+ * when group functions such as TUMBLE cause rows to be expanded.
+ */
+ final List aggCalls = new ArrayList<>();
+
+ private final Map aggMapping = new HashMap<>();
+ private final Map aggCallMapping = new HashMap<>();
+ private final SqlValidator validator;
+ private final AggregatingSelectScope scope;
+
+ /** Whether we are directly inside a windowed aggregate. */
+ boolean inOver = false;
+
+ /** Creates an AggConverter. */
+ private AggConverter(SqlToRelConverter.Blackboard bb, ImmutableMap nameMap) {
+ this(bb, nameMap, null, null);
+ }
+
+ private AggConverter(
+ SqlToRelConverter.Blackboard bb,
+ ImmutableMap nameMap,
+ SqlValidator validator,
+ AggregatingSelectScope scope) {
+ this.bb = bb;
+ this.nameMap = nameMap;
+ this.validator = validator;
+ this.scope = scope;
+ }
+
+ /**
+ * Creates an AggConverter for a pivot query.
+ *
+ * @param bb Blackboard
+ */
+ static AggConverter create(SqlToRelConverter.Blackboard bb) {
+ return new AggConverter(bb, ImmutableMap.of());
+ }
+
+ /**
+ * Creates an AggConverter.
+ *
+ * The {@code aggregatingSelectScope} parameter provides enough context to name aggregate
+ * calls which are top-level select list items.
+ *
+ * @param bb Blackboard
+ * @param scope Scope of a SELECT that has a GROUP BY
+ */
+ static AggConverter create(
+ SqlToRelConverter.Blackboard bb, AggregatingSelectScope scope, SqlValidator validator) {
+ // Collect all expressions used in the select list so that aggregate
+ // calls can be named correctly.
+ final Map nameMap = new HashMap<>();
+ Ord.forEach(
+ scope.getNode().getSelectList(),
+ (selectItem, i) -> {
+ final String name;
+ if (SqlUtil.isCallTo(selectItem, SqlStdOperatorTable.AS)) {
+ final SqlCall call = (SqlCall) selectItem;
+ selectItem = call.operand(0);
+ name = call.operand(1).toString();
+ } else {
+ name = SqlValidatorUtil.alias(selectItem, i);
+ }
+ nameMap.put(selectItem.toString(), name);
+ });
+
+ final AggregatingSelectScope.Resolved resolved = scope.resolved.get();
+ return new AggConverter(bb, ImmutableMap.copyOf(nameMap), validator, scope) {
+ @Override
+ AggregatingSelectScope.Resolved getResolved() {
+ return resolved;
+ }
+ };
+ }
+
+ int addGroupExpr(SqlNode expr) {
+ int ref = lookupGroupExpr(expr);
+ if (ref >= 0) {
+ return ref;
+ }
+ final int index = groupExprs.size();
+ groupExprs.add(expr);
+ String name = nameMap.get(expr.toString());
+ RexNode convExpr = bb.convertExpression(expr);
+ addExpr(convExpr, name);
+
+ if (expr instanceof SqlCall) {
+ SqlCall call = (SqlCall) expr;
+ SqlStdOperatorTable.convertGroupToAuxiliaryCalls(
+ call, (node, converter) -> addAuxiliaryGroupExpr(node, index, converter));
+ }
+
+ return index;
+ }
+
+ void addAuxiliaryGroupExpr(SqlNode node, int index, AuxiliaryConverter converter) {
+ for (SqlNode node2 : auxiliaryGroupExprs.keySet()) {
+ if (node2.equalsDeep(node, Litmus.IGNORE)) {
+ return;
+ }
+ }
+ auxiliaryGroupExprs.put(node, Ord.of(index, converter));
+ }
+
+ boolean addMeasureExpr(SqlNode expr) {
+ if (isMeasureExpr(expr)) {
+ return false; // already present
+ }
+ measureExprs.add(expr);
+ String name = nameMap.get(expr.toString());
+ RexNode convExpr = bb.convertExpression(expr);
+ addExpr(convExpr, name);
+ return true;
+ }
+
+ /**
+ * Adds an expression, deducing an appropriate name if possible.
+ *
+ * @param expr Expression
+ * @param name Suggested name
+ */
+ private void addExpr(RexNode expr, @Nullable String name) {
+ if (name == null && expr instanceof RexInputRef) {
+ final int i = ((RexInputRef) expr).getIndex();
+ name = bb.root().getRowType().getFieldList().get(i).getName();
+ }
+ if (convertedInputExprs.rightList().contains(name)) {
+ // In case like 'SELECT ... GROUP BY x, y, x', don't add
+ // name 'x' twice.
+ name = null;
+ }
+ convertedInputExprs.add(expr, name);
+ }
+
+ @Override
+ public Void visit(SqlIdentifier id) {
+ return null;
+ }
+
+ @Override
+ public Void visit(SqlNodeList nodeList) {
+ nodeList.forEach(this::visitNode);
+ return null;
+ }
+
+ @Override
+ public Void visit(SqlLiteral lit) {
+ return null;
+ }
+
+ @Override
+ public Void visit(SqlDataTypeSpec type) {
+ return null;
+ }
+
+ @Override
+ public Void visit(SqlDynamicParam param) {
+ return null;
+ }
+
+ @Override
+ public Void visit(SqlIntervalQualifier intervalQualifier) {
+ return null;
+ }
+
+ @Override
+ public Void visit(SqlCall call) {
+ switch (call.getKind()) {
+ case FILTER:
+ case IGNORE_NULLS:
+ case RESPECT_NULLS:
+ case WITHIN_DISTINCT:
+ case WITHIN_GROUP:
+ translateAgg(call);
+ return null;
+ case SELECT:
+ // rchen 2006-10-17:
+ // for now do not detect aggregates in sub-queries.
+ return null;
+ default:
+ break;
+ }
+ final boolean prevInOver = inOver;
+ // Ignore window aggregates and ranking functions (associated with OVER
+ // operator). However, do not ignore nested window aggregates.
+ if (call.getOperator().getKind() == SqlKind.OVER) {
+ // Track aggregate nesting levels only within an OVER operator.
+ List operandList = call.getOperandList();
+ assert operandList.size() == 2;
+
+ // Ignore the top level window aggregates and ranking functions
+ // positioned as the first operand of a OVER operator
+ inOver = true;
+ operandList.get(0).accept(this);
+
+ // Normal translation for the second operand of a OVER operator
+ inOver = false;
+ operandList.get(1).accept(this);
+ return null;
+ }
+
+ // Do not translate the top level window aggregate. Only do so for
+ // nested aggregates, if present
+ if (call.getOperator().isAggregator()) {
+ if (inOver) {
+ // Add the parent aggregate level before visiting its children
+ inOver = false;
+ } else {
+ // We're beyond the one ignored level
+ translateAgg(call);
+ return null;
+ }
+ }
+ for (SqlNode operand : call.getOperandList()) {
+ // Operands are occasionally null, e.g. switched CASE arg 0.
+ if (operand != null) {
+ operand.accept(this);
+ }
+ }
+ // Remove the parent aggregate level after visiting its children
+ inOver = prevInOver;
+ return null;
+ }
+
+ private void translateAgg(SqlCall call) {
+ translateAgg(call, null, null, null, false, call);
+ }
+
+ private void translateAgg(
+ SqlCall call,
+ @Nullable SqlNode filter,
+ @Nullable SqlNodeList distinctList,
+ @Nullable SqlNodeList orderList,
+ boolean ignoreNulls,
+ SqlCall outerCall) {
+ assert bb.agg == this;
+ final RexBuilder rexBuilder = bb.getRexBuilder();
+ final List operands = call.getOperandList();
+ final SqlParserPos pos = call.getParserPosition();
+ final SqlCall call2;
+ final List operands2;
+ switch (call.getKind()) {
+ case FILTER:
+ assert filter == null;
+ translateAgg(
+ call.operand(0),
+ call.operand(1),
+ distinctList,
+ orderList,
+ ignoreNulls,
+ outerCall);
+ return;
+ case WITHIN_DISTINCT:
+ assert orderList == null;
+ translateAgg(
+ call.operand(0),
+ filter,
+ call.operand(1),
+ orderList,
+ ignoreNulls,
+ outerCall);
+ return;
+ case WITHIN_GROUP:
+ assert orderList == null;
+ translateAgg(
+ call.operand(0),
+ filter,
+ distinctList,
+ call.operand(1),
+ ignoreNulls,
+ outerCall);
+ return;
+ case IGNORE_NULLS:
+ ignoreNulls = true;
+ // fall through
+ case RESPECT_NULLS:
+ translateAgg(
+ call.operand(0), filter, distinctList, orderList, ignoreNulls, outerCall);
+ return;
+
+ case COUNTIF:
+ // COUNTIF(b) ==> COUNT(*) FILTER (WHERE b)
+ // COUNTIF(b) FILTER (WHERE b2) ==> COUNT(*) FILTER (WHERE b2 AND b)
+ call2 = SqlStdOperatorTable.COUNT.createCall(pos, SqlIdentifier.star(pos));
+ final SqlNode filter2 = SqlUtil.andExpressions(filter, call.operand(0));
+ translateAgg(call2, filter2, distinctList, orderList, ignoreNulls, outerCall);
+ return;
+
+ case STRING_AGG:
+ // Translate "STRING_AGG(s, sep ORDER BY x, y)"
+ // as if it were "LISTAGG(s, sep) WITHIN GROUP (ORDER BY x, y)";
+ // and "STRING_AGG(s, sep)" as "LISTAGG(s, sep)".
+ if (!operands.isEmpty() && Util.last(operands) instanceof SqlNodeList) {
+ orderList = (SqlNodeList) Util.last(operands);
+ operands2 = Util.skipLast(operands);
+ } else {
+ operands2 = operands;
+ }
+ call2 =
+ SqlStdOperatorTable.LISTAGG.createCall(
+ call.getFunctionQuantifier(), pos, operands2);
+ translateAgg(call2, filter, distinctList, orderList, ignoreNulls, outerCall);
+ return;
+
+ case GROUP_CONCAT:
+ // Translate "GROUP_CONCAT(s ORDER BY x, y SEPARATOR ',')"
+ // as if it were "LISTAGG(s, ',') WITHIN GROUP (ORDER BY x, y)".
+ // To do this, build a list of operands without ORDER BY with with sep.
+ operands2 = new ArrayList<>(operands);
+ final SqlNode separator;
+ if (!operands2.isEmpty() && Util.last(operands2).getKind() == SqlKind.SEPARATOR) {
+ final SqlCall sepCall = (SqlCall) operands2.remove(operands.size() - 1);
+ separator = sepCall.operand(0);
+ } else {
+ separator = null;
+ }
+
+ if (!operands2.isEmpty() && Util.last(operands2) instanceof SqlNodeList) {
+ orderList = (SqlNodeList) operands2.remove(operands2.size() - 1);
+ }
+
+ if (separator != null) {
+ operands2.add(separator);
+ }
+
+ call2 =
+ SqlStdOperatorTable.LISTAGG.createCall(
+ call.getFunctionQuantifier(), pos, operands2);
+ translateAgg(call2, filter, distinctList, orderList, ignoreNulls, outerCall);
+ return;
+
+ case ARRAY_AGG:
+ case ARRAY_CONCAT_AGG:
+ // Translate "ARRAY_AGG(s ORDER BY x, y)"
+ // as if it were "ARRAY_AGG(s) WITHIN GROUP (ORDER BY x, y)";
+ // similarly "ARRAY_CONCAT_AGG".
+ if (!operands.isEmpty() && Util.last(operands) instanceof SqlNodeList) {
+ orderList = (SqlNodeList) Util.last(operands);
+ call2 =
+ call.getOperator()
+ .createCall(
+ call.getFunctionQuantifier(),
+ pos,
+ Util.skipLast(operands));
+ translateAgg(call2, filter, distinctList, orderList, ignoreNulls, outerCall);
+ return;
+ }
+ // "ARRAY_AGG" and "ARRAY_CONCAT_AGG" without "ORDER BY"
+ // are handled normally; fall through.
+
+ default:
+ break;
+ }
+ final List args = new ArrayList<>();
+ int filterArg = -1;
+ final ImmutableBitSet distinctKeys;
+ try {
+ // switch out of agg mode
+ bb.agg = null;
+ // ----- FLINK MODIFICATION BEGIN -----
+ FlinkSqlCallBinding binding = new FlinkSqlCallBinding(validator, scope, call);
+ List sqlNodes = binding.operands();
+ for (int i = 0; i < sqlNodes.size(); i++) {
+ SqlNode operand = sqlNodes.get(i);
+ // special case for COUNT(*): delete the *
+ if (operand instanceof SqlIdentifier) {
+ SqlIdentifier id = (SqlIdentifier) operand;
+ if (id.isStar()) {
+ assert call.operandCount() == 1;
+ assert args.isEmpty();
+ break;
+ }
+ }
+ RexNode convertedExpr = bb.convertExpression(operand);
+ args.add(lookupOrCreateGroupExpr(convertedExpr));
+ }
+ // ----- FLINK MODIFICATION END -----
+
+ if (filter != null) {
+ RexNode convertedExpr = bb.convertExpression(filter);
+ if (convertedExpr.getType().isNullable()) {
+ convertedExpr = rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE, convertedExpr);
+ }
+ filterArg = lookupOrCreateGroupExpr(convertedExpr);
+ }
+
+ if (distinctList == null) {
+ distinctKeys = null;
+ } else {
+ final ImmutableBitSet.Builder distinctBuilder = ImmutableBitSet.builder();
+ for (SqlNode distinct : distinctList) {
+ RexNode e = bb.convertExpression(distinct);
+ distinctBuilder.set(lookupOrCreateGroupExpr(e));
+ }
+ distinctKeys = distinctBuilder.build();
+ }
+ } finally {
+ // switch back into agg mode
+ bb.agg = this;
+ }
+
+ SqlAggFunction aggFunction = (SqlAggFunction) call.getOperator();
+ final RelDataType type = bb.getValidator().deriveType(bb.scope, call);
+ boolean distinct = false;
+ SqlLiteral quantifier = call.getFunctionQuantifier();
+ if ((null != quantifier) && (quantifier.getValue() == SqlSelectKeyword.DISTINCT)) {
+ distinct = true;
+ }
+ boolean approximate = false;
+ if (aggFunction == SqlStdOperatorTable.APPROX_COUNT_DISTINCT) {
+ aggFunction = SqlStdOperatorTable.COUNT;
+ distinct = true;
+ approximate = true;
+ }
+ final RelCollation collation;
+ if (orderList == null || orderList.size() == 0) {
+ collation = RelCollations.EMPTY;
+ } else {
+ try {
+ // switch out of agg mode
+ bb.agg = null;
+ collation =
+ RelCollations.of(
+ orderList.stream()
+ .map(
+ order ->
+ bb.convertSortExpression(
+ order,
+ RelFieldCollation.Direction
+ .ASCENDING,
+ RelFieldCollation.NullDirection
+ .UNSPECIFIED,
+ this::sortToFieldCollation))
+ .collect(Collectors.toList()));
+ } finally {
+ // switch back into agg mode
+ bb.agg = this;
+ }
+ }
+ final AggregateCall aggCall =
+ AggregateCall.create(
+ aggFunction,
+ distinct,
+ approximate,
+ ignoreNulls,
+ ImmutableList.of(),
+ args,
+ filterArg,
+ distinctKeys,
+ collation,
+ type,
+ nameMap.get(outerCall.toString()));
+ RexNode rex =
+ rexBuilder.addAggCall(
+ aggCall,
+ groupExprs.size(),
+ aggCalls,
+ aggCallMapping,
+ i -> convertedInputExprs.leftList().get(i).getType().isNullable());
+ aggMapping.put(outerCall, rex);
+ }
+
+ private RelFieldCollation sortToFieldCollation(
+ SqlNode expr,
+ RelFieldCollation.Direction direction,
+ RelFieldCollation.NullDirection nullDirection) {
+ final RexNode node = bb.convertExpression(expr);
+ final int fieldIndex = lookupOrCreateGroupExpr(node);
+ if (nullDirection == RelFieldCollation.NullDirection.UNSPECIFIED) {
+ nullDirection = direction.defaultNullDirection();
+ }
+ return new RelFieldCollation(fieldIndex, direction, nullDirection);
+ }
+
+ private int lookupOrCreateGroupExpr(RexNode expr) {
+ int index = 0;
+ for (RexNode convertedInputExpr : convertedInputExprs.leftList()) {
+ if (expr.equals(convertedInputExpr)) {
+ return index;
+ }
+ ++index;
+ }
+
+ // not found -- add it
+ addExpr(expr, null);
+ return index;
+ }
+
+ /**
+ * If an expression is structurally identical to one of the group-by expressions, returns a
+ * reference to the expression, otherwise returns null.
+ */
+ int lookupGroupExpr(SqlNode expr) {
+ return SqlUtil.indexOfDeep(groupExprs, expr, Litmus.IGNORE);
+ }
+
+ boolean isMeasureExpr(SqlNode expr) {
+ return SqlUtil.indexOfDeep(measureExprs, expr, Litmus.IGNORE) >= 0;
+ }
+
+ @Nullable RexNode lookupMeasure(SqlNode expr) {
+ return aggMapping.get(expr);
+ }
+
+ @Nullable RexNode lookupAggregates(SqlCall call) {
+ // assert call.getOperator().isAggregator();
+ assert bb.agg == this;
+
+ for (Map.Entry> e : auxiliaryGroupExprs.entrySet()) {
+ if (call.equalsDeep(e.getKey(), Litmus.IGNORE)) {
+ AuxiliaryConverter converter = e.getValue().e;
+ final RexBuilder rexBuilder = bb.getRexBuilder();
+ final int groupOrdinal = e.getValue().i;
+ return converter.convert(
+ rexBuilder,
+ convertedInputExprs.leftList().get(groupOrdinal),
+ rexBuilder.makeInputRef(castNonNull(bb.root), groupOrdinal));
+ }
+ }
+
+ return aggMapping.get(call);
+ }
+
+ /**
+ * Returns the resolved. Valid only if this AggConverter was created via {@link
+ * #create(SqlToRelConverter.Blackboard, AggregatingSelectScope)}.
+ */
+ AggregatingSelectScope.Resolved getResolved() {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
index 0cf6f30723ded..212daeadc7407 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -18,7 +18,6 @@
import org.apache.flink.table.planner.plan.rules.logical.FlinkFilterProjectTransposeRule;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
@@ -83,6 +82,7 @@
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
+import org.apache.calcite.runtime.PairList;
import org.apache.calcite.sql.SqlExplainFormat;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlFunction;
@@ -131,9 +131,9 @@
* TODO:
*
*
- * - Was changed within FLINK-29280, FLINK-28682, FLINK-35804: Line 218 ~ 225, Line 273 ~ 288
- *
- Should be removed after fix of FLINK-29540: Line 293 ~ 299
- *
- Should be removed after fix of FLINK-29540: Line 311 ~ 317
+ *
- Was changed within FLINK-29280, FLINK-28682, FLINK-35804: Line 222 ~ 229, Line 277 ~ 292
+ *
- Should be removed after fix of FLINK-29540: Line 297 ~ 303
+ *
- Should be removed after fix of FLINK-29540: Line 315 ~ 321
*
*/
public class RelDecorrelator implements ReflectiveVisitor {
@@ -551,7 +551,7 @@ protected RexNode removeCorrelationExpr(
// Project projects the original expressions,
// plus any correlated variables the input wants to pass along.
- final List> projects = new ArrayList<>();
+ final PairList projects = PairList.of();
List newInputOutput = newInput.getRowType().getFieldList();
@@ -572,7 +572,7 @@ protected RexNode removeCorrelationExpr(
// add mapping of group keys.
outputMap.put(idx, newPos);
int newInputPos = requireNonNull(frame.oldToNewOutputs.get(idx));
- projects.add(RexInputRef.of2(newInputPos, newInputOutput));
+ RexInputRef.add2(projects, newInputPos, newInputOutput);
mapNewInputToProjOutputs.put(newInputPos, newPos);
newPos++;
}
@@ -585,7 +585,7 @@ protected RexNode removeCorrelationExpr(
// Now add the corVars from the input, starting from
// position oldGroupKeyCount.
for (Map.Entry entry : frame.corDefOutputs.entrySet()) {
- projects.add(RexInputRef.of2(entry.getValue(), newInputOutput));
+ RexInputRef.add2(projects, entry.getValue(), newInputOutput);
corDefOutputs.put(entry.getKey(), newPos);
mapNewInputToProjOutputs.put(entry.getValue(), newPos);
@@ -597,7 +597,7 @@ protected RexNode removeCorrelationExpr(
final int newGroupKeyCount = newPos;
for (int i = 0; i < newInputOutput.size(); i++) {
if (!mapNewInputToProjOutputs.containsKey(i)) {
- projects.add(RexInputRef.of2(i, newInputOutput));
+ RexInputRef.add2(projects, i, newInputOutput);
mapNewInputToProjOutputs.put(i, newPos);
newPos++;
}
@@ -610,7 +610,7 @@ protected RexNode removeCorrelationExpr(
RelNode newProject =
relBuilder
.push(newInput)
- .projectNamed(Pair.left(projects), Pair.right(projects), true)
+ .projectNamed(projects.leftList(), projects.rightList(), true)
.build();
// update mappings:
@@ -788,7 +788,7 @@ private static void shiftMapping(Map mapping, int startIndex,
// Project projects the original expressions,
// plus any correlated variables the input wants to pass along.
- final List> projects = new ArrayList<>();
+ final PairList projects = PairList.of();
// If this Project has correlated reference, create value generator
// and produce the correlated variables in the new output.
@@ -802,20 +802,19 @@ private static void shiftMapping(Map mapping, int startIndex,
for (newPos = 0; newPos < oldProjects.size(); newPos++) {
projects.add(
newPos,
- Pair.of(
- decorrelateExpr(
- requireNonNull(currentRel, "currentRel"),
- map,
- cm,
- oldProjects.get(newPos)),
- relOutput.get(newPos).getName()));
+ decorrelateExpr(
+ requireNonNull(currentRel, "currentRel"),
+ map,
+ cm,
+ oldProjects.get(newPos)),
+ relOutput.get(newPos).getName());
mapOldToNewOutputs.put(newPos, newPos);
}
// Project any correlated variables the input wants to pass along.
final NavigableMap corDefOutputs = new TreeMap<>();
for (Map.Entry entry : frame.corDefOutputs.entrySet()) {
- projects.add(RexInputRef.of2(entry.getValue(), frame.r.getRowType().getFieldList()));
+ RexInputRef.add2(projects, entry.getValue(), frame.r.getRowType().getFieldList());
corDefOutputs.put(entry.getKey(), newPos);
newPos++;
}
@@ -823,7 +822,7 @@ private static void shiftMapping(Map mapping, int startIndex,
RelNode newProject =
relBuilder
.push(frame.r)
- .projectNamed(Pair.left(projects), Pair.right(projects), true)
+ .projectNamed(projects.leftList(), projects.rightList(), true)
.build();
return register(rel, newProject, mapOldToNewOutputs, corDefOutputs);
@@ -1456,14 +1455,14 @@ private RelNode projectJoinOutputWithNullability(
true));
// now create the new project
- List> newProjExprs = new ArrayList<>();
+ final PairList newProjExprs = PairList.of();
// project everything from the LHS and then those from the original
// projRel
List leftInputFields = left.getRowType().getFieldList();
for (int i = 0; i < leftInputFields.size(); i++) {
- newProjExprs.add(RexInputRef.of2(i, leftInputFields));
+ RexInputRef.add2(newProjExprs, i, leftInputFields);
}
// Marked where the projected expr is coming from so that the types will
@@ -1476,12 +1475,12 @@ private RelNode projectJoinOutputWithNullability(
removeCorrelationExpr(
pair.left, projectPulledAboveLeftCorrelator, nullIndicator);
- newProjExprs.add(Pair.of(newProjExpr, pair.right));
+ newProjExprs.add(newProjExpr, pair.right);
}
return relBuilder
.push(join)
- .projectNamed(Pair.left(newProjExprs), Pair.right(newProjExprs), true)
+ .projectNamed(newProjExprs.leftList(), newProjExprs.rightList(), true)
.build();
}
@@ -1500,14 +1499,14 @@ private RelNode aggregateCorrelatorOutput(
final JoinRelType joinType = correlate.getJoinType();
// now create the new project
- final List> newProjects = new ArrayList<>();
+ final PairList newProjects = PairList.of();
// Project everything from the LHS and then those from the original
// project
final List leftInputFields = left.getRowType().getFieldList();
for (int i = 0; i < leftInputFields.size(); i++) {
- newProjects.add(RexInputRef.of2(i, leftInputFields));
+ RexInputRef.add2(newProjects, i, leftInputFields);
}
// Marked where the projected expr is coming from so that the types will
@@ -1518,12 +1517,12 @@ private RelNode aggregateCorrelatorOutput(
for (Pair pair : project.getNamedProjects()) {
RexNode newProjExpr =
removeCorrelationExpr(pair.left, projectPulledAboveLeftCorrelator, isCount);
- newProjects.add(Pair.of(newProjExpr, pair.right));
+ newProjects.add(newProjExpr, pair.right);
}
return relBuilder
.push(correlate)
- .projectNamed(Pair.left(newProjects), Pair.right(newProjects), true)
+ .projectNamed(newProjects.leftList(), newProjects.rightList(), true)
.build();
}
@@ -1599,20 +1598,19 @@ private void removeCorVarFromTree(Correlate correlate) {
* @return the new Project
*/
private RelNode createProjectWithAdditionalExprs(
- RelNode input, List> additionalExprs) {
+ RelNode input, PairList additionalExprs) {
final List fieldList = input.getRowType().getFieldList();
- List> projects = new ArrayList<>();
+ PairList projects = PairList.of();
Ord.forEach(
fieldList,
(field, i) ->
projects.add(
- Pair.of(
- relBuilder.getRexBuilder().makeInputRef(field.getType(), i),
- field.getName())));
+ relBuilder.getRexBuilder().makeInputRef(field.getType(), i),
+ field.getName()));
projects.addAll(additionalExprs);
return relBuilder
.push(input)
- .projectNamed(Pair.left(projects), Pair.right(projects), true)
+ .projectNamed(projects.leftList(), projects.rightList(), true)
.build();
}
@@ -1650,11 +1648,7 @@ static boolean allLessThan(Collection integers, int limit, Litmus ret)
}
private static RelNode stripHep(RelNode rel) {
- if (rel instanceof HepRelVertex) {
- HepRelVertex hepRelVertex = (HepRelVertex) rel;
- rel = hepRelVertex.getCurrentRel();
- }
- return rel;
+ return rel instanceof HepRelVertex ? rel.stripped() : rel;
}
// ~ Inner Classes ----------------------------------------------------------
@@ -1884,20 +1878,32 @@ public RexNode visitCall(final RexCall call) {
/**
* Rule to remove an Aggregate with SINGLE_VALUE. For cases like:
*
- * Aggregate(SINGLE_VALUE) Project(single expression) Aggregate
+ *
{@code
+ * Aggregate(SINGLE_VALUE)
+ * Project(single expression)
+ * Aggregate
+ * }
*
- * For instance (subtree taken from TPCH query 17):
+ *
For instance, the following subtree from TPCH query 17:
*
- *
LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])
- * LogicalProject(EXPR$0=[*(0.2:DECIMAL(2, 1), $0)]) LogicalAggregate(group=[{}],
- * agg#0=[AVG($0)]) LogicalProject(L_QUANTITY=[$4]) LogicalFilter(condition=[=($1,
- * $cor0.P_PARTKEY)]) LogicalTableScan(table=[[TPCH_01, LINEITEM]])
+ *
{@code
+ * LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])
+ * LogicalProject(EXPR$0=[*(0.2:DECIMAL(2, 1), $0)])
+ * LogicalAggregate(group=[{}], agg#0=[AVG($0)])
+ * LogicalProject(L_QUANTITY=[$4])
+ * LogicalFilter(condition=[=($1, $cor0.P_PARTKEY)])
+ * LogicalTableScan(table=[[TPCH_01, LINEITEM]])
+ * }
*
- * Will be converted into:
+ *
will be converted into:
*
- *
LogicalProject($f0=[*(0.2:DECIMAL(2, 1), $0)]) LogicalAggregate(group=[{}],
- * agg#0=[AVG($0)]) LogicalProject(L_QUANTITY=[$4]) LogicalFilter(condition=[=($1,
- * $cor0.P_PARTKEY)]) LogicalTableScan(table=[[TPCH_01, LINEITEM]])
+ *
{@code
+ * LogicalProject($f0=[*(0.2:DECIMAL(2, 1), $0)])
+ * LogicalAggregate(group=[{}], agg#0=[AVG($0)])
+ * LogicalProject(L_QUANTITY=[$4])
+ * LogicalFilter(condition=[=($1, $cor0.P_PARTKEY)])
+ * LogicalTableScan(table=[[TPCH_01, LINEITEM]])
+ * }
*/
public static final class RemoveSingleAggregateRule
extends RelRule {
@@ -1951,14 +1957,10 @@ public void onMatch(RelOptRuleCall call) {
// ensure we keep the same type after removing the SINGLE_VALUE Aggregate
final RelBuilder relBuilder = call.builder();
- final RelDataType singleAggType =
- singleAggregate.getRowType().getFieldList().get(0).getType();
- final RexNode oldProjectExp = projExprs.get(0);
- final RexNode newProjectExp =
- singleAggType.equals(oldProjectExp.getType())
- ? oldProjectExp
- : relBuilder.getRexBuilder().makeCast(singleAggType, oldProjectExp);
- relBuilder.push(aggregate).project(newProjectExp);
+ relBuilder
+ .push(aggregate)
+ .project(project.getAliasedProjects(relBuilder))
+ .convert(singleAggregate.getRowType(), false);
call.transformTo(relBuilder.build());
}
@@ -2071,7 +2073,7 @@ public void onMatch(RelOptRuleCall call) {
right = filter.getInput();
assert right instanceof HepRelVertex;
- right = ((HepRelVertex) right).getCurrentRel();
+ right = right.stripped();
// check filter input contains no correlation
if (RelOptUtil.getVariablesUsed(right).size() > 0) {
@@ -2156,9 +2158,7 @@ public void onMatch(RelOptRuleCall call) {
// make the new Project to provide a null indicator
right =
d.createProjectWithAdditionalExprs(
- right,
- ImmutableList.of(
- Pair.of(d.relBuilder.literal(true), "nullIndicator")));
+ right, PairList.of(d.relBuilder.literal(true), "nullIndicator"));
// make the new aggRel
right = RelOptUtil.createSingleValueAggRel(cluster, right);
@@ -2316,7 +2316,7 @@ public void onMatch(RelOptRuleCall call) {
right = filter.getInput();
assert right instanceof HepRelVertex;
- right = ((HepRelVertex) right).getCurrentRel();
+ right = right.stripped();
// check filter input contains no correlation
if (RelOptUtil.getVariablesUsed(right).size() > 0) {
@@ -2485,9 +2485,7 @@ public void onMatch(RelOptRuleCall call) {
right =
d.createProjectWithAdditionalExprs(
- right,
- ImmutableList.of(
- Pair.of(rexBuilder.makeLiteral(true), "nullIndicator")));
+ right, PairList.of(rexBuilder.makeLiteral(true), "nullIndicator"));
Join join = (Join) d.relBuilder.push(left).push(right).join(joinType, joinCond).build();
@@ -2656,15 +2654,15 @@ public void onMatch(RelOptRuleCall call) {
aggregate = call.rel(2);
// Create identity projection
- final List> projects = new ArrayList<>();
+ final PairList projects = PairList.of();
final List fields = aggregate.getRowType().getFieldList();
for (int i = 0; i < fields.size(); i++) {
- projects.add(RexInputRef.of2(projects.size(), fields));
+ RexInputRef.add2(projects, projects.size(), fields);
}
final RelBuilder relBuilder = call.builder();
relBuilder
.push(aggregate)
- .projectNamed(Pair.left(projects), Pair.right(projects), true);
+ .projectNamed(projects.leftList(), projects.rightList(), true);
aggOutputProject = (Project) relBuilder.build();
}
onMatch2(call, correlate, left, aggOutputProject, aggregate);
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
index 1226f6c7030a2..06fcc071f51e3 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
@@ -106,6 +106,7 @@
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowBounds;
+import org.apache.calcite.runtime.PairList;
import org.apache.calcite.schema.ColumnStrategy;
import org.apache.calcite.schema.ModifiableTable;
import org.apache.calcite.schema.ModifiableView;
@@ -203,7 +204,6 @@
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.time.ZoneId;
-import java.util.AbstractList;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.BitSet;
@@ -218,6 +218,7 @@
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
+import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;
@@ -227,6 +228,7 @@
import static java.util.Objects.requireNonNull;
import static org.apache.calcite.linq4j.Nullness.castNonNull;
import static org.apache.calcite.runtime.FlatLists.append;
+import static org.apache.calcite.sql.SqlUtil.containsIn;
import static org.apache.calcite.sql.SqlUtil.stripAs;
import static org.apache.calcite.util.Static.RESOURCE;
import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -240,19 +242,18 @@
* FLINK modifications are at lines
*
*
- * - Added in FLINK-29081, FLINK-28682, FLINK-33395: Lines 670 ~ 687
- *
- Added in Flink-24024: Lines 1463 ~ 1469
- *
- Added in Flink-24024: Lines 1483 ~ 1522
- *
- Added in Flink-37269: Lines 2239 ~ 2261
- *
- Added in FLINK-28682: Lines 2372 ~ 2389
- *
- Added in FLINK-28682: Lines 2426 ~ 2454
- *
- Added in FLINK-32474: Lines 2507 ~ 2509
- *
- Added in FLINK-32474: Lines 2513 ~ 2515
- *
- Added in FLINK-32474: Lines 2526 ~ 2528
- *
- Added in FLINK-32474: Lines 2934 ~ 2945
- *
- Added in FLINK-32474: Lines 3046 ~ 3080
- *
- Added in FLINK-34312: Lines 5827 ~ 5838
- *
- Added in FLINK-34057, FLINK-34058, FLINK-34312: Lines 6285 ~ 6303
+ *
- Added in FLINK-29081, FLINK-28682, FLINK-33395: Lines 673 ~ 690
+ *
- Added in Flink-24024: Lines 1440 ~ 1446
+ *
- Added in Flink-24024: Lines 1460 ~ 1499
+ *
- Added in Flink-37269: Lines 2237 ~ 2259
+ *
- Added in FLINK-28682: Lines 2370 ~ 2387
+ *
- Added in FLINK-28682: Lines 2424 ~ 2452
+ *
- Added in FLINK-32474: Lines 2504 ~ 2506
+ *
- Added in FLINK-32474: Lines 2510 ~ 2512
+ *
- Added in FLINK-32474: Lines 2523 ~ 2525
+ *
- Added in FLINK-32474: Lines 2929 ~ 2941
+ *
- Added in FLINK-32474: Lines 3042 ~ 3076
+ *
- Added in FLINK-34312: Lines 5805 ~ 5816
*
*
* In official extension point (i.e. {@link #convertExtendedExpression(SqlNode, Blackboard)}):
@@ -734,9 +735,7 @@ public RelNode convertSelect(SqlSelect select, boolean top) {
/** Factory method for creating translation workspace. */
protected Blackboard createBlackboard(
- @Nullable SqlValidatorScope scope,
- @Nullable Map nameToNodeMap,
- boolean top) {
+ SqlValidatorScope scope, @Nullable Map nameToNodeMap, boolean top) {
return new Blackboard(scope, nameToNodeMap, top);
}
@@ -864,19 +863,19 @@ private void distinctify(Blackboard bb, boolean checkForDupExprs) {
final Map squished = new HashMap<>();
final List fields = rel.getRowType().getFieldList();
- final List> newProjects = new ArrayList<>();
+ final PairList newProjects = PairList.of();
for (int i = 0; i < fields.size(); i++) {
if (origins.get(i) == i) {
squished.put(i, newProjects.size());
- newProjects.add(RexInputRef.of2(i, fields));
+ RexInputRef.add2(newProjects, i, fields);
}
}
rel =
LogicalProject.create(
rel,
ImmutableList.of(),
- Pair.left(newProjects),
- Pair.right(newProjects),
+ newProjects.leftList(),
+ newProjects.rightList(),
project.getVariablesSet());
bb.root = rel;
distinctify(bb, false);
@@ -884,22 +883,21 @@ private void distinctify(Blackboard bb, boolean checkForDupExprs) {
// Create the expressions to reverse the mapping.
// Project($0, $1, $0, $2).
- final List> undoProjects = new ArrayList<>();
+ final PairList undoProjects = PairList.of();
for (int i = 0; i < fields.size(); i++) {
final int origin = origins.get(i);
RelDataTypeField field = fields.get(i);
undoProjects.add(
- Pair.of(
- new RexInputRef(castNonNull(squished.get(origin)), field.getType()),
- field.getName()));
+ new RexInputRef(castNonNull(squished.get(origin)), field.getType()),
+ field.getName());
}
rel =
LogicalProject.create(
rel,
ImmutableList.of(),
- Pair.left(undoProjects),
- Pair.right(undoProjects),
+ undoProjects.leftList(),
+ undoProjects.rightList(),
ImmutableSet.of());
bb.setRoot(rel, false);
@@ -990,31 +988,6 @@ private boolean removeSortInSubQuery(boolean top) {
return config.isRemoveSortInSubQuery() && !top;
}
- /**
- * Returns whether a given node contains a {@link SqlInOperator}.
- *
- * @param node a RexNode tree
- */
- private static boolean containsInOperator(SqlNode node) {
- try {
- SqlVisitor visitor =
- new SqlBasicVisitor() {
- @Override
- public Void visit(SqlCall call) {
- if (call.getOperator() instanceof SqlInOperator) {
- throw new Util.FoundOne(call);
- }
- return super.visit(call);
- }
- };
- node.accept(visitor);
- return false;
- } catch (Util.FoundOne e) {
- Util.swallow(e, null);
- return true;
- }
- }
-
/**
* Push down all the NOT logical operators into any IN/NOT IN operators.
*
@@ -1023,7 +996,7 @@ public Void visit(SqlCall call) {
* @return the transformed SqlNode representation with NOT pushed down.
*/
private static SqlNode pushDownNotForIn(SqlValidatorScope scope, SqlNode sqlNode) {
- if (!(sqlNode instanceof SqlCall) || !containsInOperator(sqlNode)) {
+ if (!(sqlNode instanceof SqlCall) || !containsIn(sqlNode)) {
return sqlNode;
}
final SqlCall sqlCall = (SqlCall) sqlNode;
@@ -1147,7 +1120,7 @@ private void convertWhere(final Blackboard bb, final @Nullable SqlNode where) {
if (where == null) {
return;
}
- SqlNode newWhere = pushDownNotForIn(bb.scope(), where);
+ SqlNode newWhere = pushDownNotForIn(bb.scope, where);
replaceSubQueries(bb, newWhere, RelOptUtil.Logic.UNKNOWN_AS_FALSE);
final RexNode convertedWhere = bb.convertExpression(newWhere);
final RexNode convertedWhere2 = RexUtil.removeNullabilityCast(typeFactory, convertedWhere);
@@ -1325,6 +1298,7 @@ private void substituteSubQuery(Blackboard bb, SubQuery subQuery) {
false,
false,
ImmutableList.of(),
+ ImmutableList.of(),
-1,
null,
RelCollations.EMPTY,
@@ -1335,6 +1309,7 @@ private void substituteSubQuery(Blackboard bb, SubQuery subQuery) {
false,
false,
false,
+ ImmutableList.of(),
args,
-1,
null,
@@ -1392,7 +1367,7 @@ private void substituteSubQuery(Blackboard bb, SubQuery subQuery) {
final SqlValidatorScope seekScope =
(query instanceof SqlSelect)
? validator().getSelectScope((SqlSelect) query)
- : null;
+ : validator().getEmptyScope();
final Blackboard seekBb = createBlackboard(seekScope, null, false);
final RelNode seekRel = convertQueryOrInList(seekBb, query, null);
requireNonNull(seekRel, () -> "seekRel is null for query " + query);
@@ -1535,7 +1510,7 @@ private void substituteSubQueryOfSetSemanticsInputTable(Blackboard bb, SubQuery
private ImmutableBitSet buildPartitionKeys(Blackboard bb, SqlNodeList partitionList) {
final ImmutableBitSet.Builder partitionKeys = ImmutableBitSet.builder();
for (SqlNode partition : partitionList) {
- validator().deriveType(bb.scope(), partition);
+ validator().deriveType(bb.scope, partition);
RexNode e = bb.convertExpression(partition);
partitionKeys.set(parseFieldIdx(e));
}
@@ -1909,7 +1884,9 @@ private RelOptUtil.Exists convertExists(
boolean notIn,
@Nullable RelDataType targetDataType) {
final SqlValidatorScope seekScope =
- (seek instanceof SqlSelect) ? validator().getSelectScope((SqlSelect) seek) : null;
+ (seek instanceof SqlSelect)
+ ? validator().getSelectScope((SqlSelect) seek)
+ : validator().getEmptyScope();
final Blackboard seekBb = createBlackboard(seekScope, null, false);
RelNode seekRel = convertQueryOrInList(seekBb, seek, targetDataType);
requireNonNull(seekRel, () -> "seekRel is null for query " + seek);
@@ -2181,6 +2158,26 @@ private void findSubQueries(
default:
break;
}
+ if (node instanceof SqlBasicCall
+ && ((SqlCall) node).getOperator() instanceof SqlQuantifyOperator
+ && ((SqlQuantifyOperator) ((SqlCall) node).getOperator())
+ .tryDeriveTypeForCollection(
+ bb.getValidator(), bb.scope, (SqlCall) node)
+ != null) {
+ findSubQueries(
+ bb,
+ ((SqlCall) node).operand(0),
+ logic,
+ registerOnlyScalarSubQueries,
+ clause);
+ findSubQueries(
+ bb,
+ ((SqlCall) node).operand(1),
+ logic,
+ registerOnlyScalarSubQueries,
+ clause);
+ break;
+ }
bb.registerSubQuery(node, logic, clause);
break;
default:
@@ -2278,7 +2275,7 @@ private RexNode convertOver(Blackboard bb, SqlNode node) {
}
SqlNode windowOrRef = call.operand(1);
- final SqlWindow window = validator().resolveWindow(windowOrRef, bb.scope());
+ final SqlWindow window = validator().resolveWindow(windowOrRef, bb.scope);
SqlNode sqlLowerBound = window.getLowerBound();
SqlNode sqlUpperBound = window.getUpperBound();
@@ -2309,7 +2306,7 @@ private RexNode convertOver(Blackboard bb, SqlNode node) {
final SqlNodeList partitionList = window.getPartitionList();
final ImmutableList.Builder partitionKeys = ImmutableList.builder();
for (SqlNode partition : partitionList) {
- validator().deriveType(bb.scope(), partition);
+ validator().deriveType(bb.scope, partition);
partitionKeys.add(bb.convertExpression(partition));
}
final RexNode lowerBound =
@@ -2320,7 +2317,7 @@ private RexNode convertOver(Blackboard bb, SqlNode node) {
// A logical range requires an ORDER BY clause. Use the implicit
// ordering of this relation. There must be one, otherwise it would
// have failed validation.
- orderList = bb.scope().getOrderList();
+ orderList = bb.scope.getOrderList();
if (orderList == null) {
throw new AssertionError("Relation should have sort key for implicit ORDER BY");
}
@@ -2775,7 +2772,6 @@ public RexNode visit(SqlLiteral literal) {
protected void convertPivot(Blackboard bb, SqlPivot pivot) {
final SqlValidatorScope scope = validator().getJoinScope(pivot);
-
final Blackboard pivotBb = createBlackboard(scope, null, false);
// Convert input
@@ -2786,7 +2782,7 @@ protected void convertPivot(Blackboard bb, SqlPivot pivot) {
relBuilder.push(input);
// Gather fields.
- final AggConverter aggConverter = new AggConverter(pivotBb, (AggregatingSelectScope) null);
+ final AggConverter aggConverter = AggConverter.create(pivotBb);
final Set usedColumnNames = pivot.usedColumnNames();
// 1. Gather group keys.
@@ -2814,7 +2810,8 @@ protected void convertPivot(Blackboard bb, SqlPivot pivot) {
// Project the fields that we will need.
relBuilder.project(
- Pair.left(aggConverter.getPreExprs()), Pair.right(aggConverter.getPreExprs()));
+ aggConverter.convertedInputExprs.leftList(),
+ aggConverter.convertedInputExprs.rightList());
// Build expressions.
@@ -2866,7 +2863,6 @@ protected void convertPivot(Blackboard bb, SqlPivot pivot) {
protected void convertUnpivot(Blackboard bb, SqlUnpivot unpivot) {
final SqlValidatorScope scope = validator().getJoinScope(unpivot);
-
final Blackboard unpivotBb = createBlackboard(scope, null, false);
// Convert input
@@ -2985,7 +2981,7 @@ protected void convertCollectionTable(Blackboard bb, SqlCall call) {
// Expand table macro if possible. It's more efficient than
// LogicalTableFunctionScan.
final SqlCallBinding callBinding =
- new SqlCallBinding(bb.scope().getValidator(), bb.scope, call);
+ new SqlCallBinding(bb.scope.getValidator(), bb.scope, call);
if (operator instanceof SqlUserDefinedTableMacro) {
final SqlUserDefinedTableMacro udf = (SqlUserDefinedTableMacro) operator;
final TranslatableTable table = udf.getTable(callBinding);
@@ -3187,21 +3183,19 @@ protected RelNode createJoin(
mapCorrelToDeferred.get(correlName),
() -> "correlation variable is not found: " + correlName);
RexFieldAccess fieldAccess = lookup.getFieldAccess(correlName);
- String originalRelName = lookup.getOriginalRelName();
String originalFieldName = fieldAccess.getField().getName();
final SqlNameMatcher nameMatcher = bb.getValidator().getCatalogReader().nameMatcher();
final SqlValidatorScope.ResolvedImpl resolved = new SqlValidatorScope.ResolvedImpl();
- lookup.bb
- .scope()
- .resolve(ImmutableList.of(originalRelName), nameMatcher, false, resolved);
+ lookup.bb.scope.resolve(
+ ImmutableList.of(lookup.originalRelName), nameMatcher, false, resolved);
assert resolved.count() == 1;
final SqlValidatorScope.Resolve resolve = resolved.only();
final SqlValidatorNamespace foundNs = resolve.namespace;
final RelDataType rowType = resolve.rowType();
final int childNamespaceIndex = resolve.path.steps().get(0).i;
final SqlValidatorScope ancestorScope = resolve.scope;
- boolean correlInCurrentScope = bb.scope().isWithin(ancestorScope);
+ boolean correlInCurrentScope = bb.scope.isWithin(ancestorScope);
if (!correlInCurrentScope) {
continue;
@@ -3258,7 +3252,7 @@ protected RelNode createJoin(
// correl not grouped
throw new AssertionError(
"Identifier '"
- + originalRelName
+ + lookup.originalRelName
+ "."
+ originalFieldName
+ "' is not a group expr");
@@ -3304,14 +3298,13 @@ private boolean isSubQueryNonCorrelated(RelNode subq, Blackboard bb) {
requireNonNull(
mapCorrelToDeferred.get(correlName),
() -> "correlation variable is not found: " + correlName);
- String originalRelName = lookup.getOriginalRelName();
+ String originalRelName = lookup.originalRelName;
final SqlNameMatcher nameMatcher =
- lookup.bb.scope().getValidator().getCatalogReader().nameMatcher();
+ lookup.bb.scope.getValidator().getCatalogReader().nameMatcher();
final SqlValidatorScope.ResolvedImpl resolved = new SqlValidatorScope.ResolvedImpl();
- lookup.bb
- .scope()
- .resolve(ImmutableList.of(originalRelName), nameMatcher, false, resolved);
+ lookup.bb.scope.resolve(
+ ImmutableList.of(originalRelName), nameMatcher, false, resolved);
SqlValidatorScope ancestorScope = resolved.only().scope;
@@ -3345,15 +3338,12 @@ private void convertJoin(Blackboard bb, SqlJoin join) {
SqlValidator validator = validator();
final SqlValidatorScope scope = validator.getJoinScope(join);
final Blackboard fromBlackboard = createBlackboard(scope, null, false);
+
SqlNode left = join.getLeft();
SqlNode right = join.getRight();
- final SqlValidatorScope leftScope =
- Util.first(
- validator.getJoinScope(left), ((DelegatingScope) bb.scope()).getParent());
+ final SqlValidatorScope leftScope = validator.getJoinScope(left);
final Blackboard leftBlackboard = createBlackboard(leftScope, null, false);
- final SqlValidatorScope rightScope =
- Util.first(
- validator.getJoinScope(right), ((DelegatingScope) bb.scope()).getParent());
+ final SqlValidatorScope rightScope = validator.getJoinScope(right);
final Blackboard rightBlackboard = createBlackboard(rightScope, null, false);
convertFrom(leftBlackboard, left);
final RelNode leftRel = requireNonNull(leftBlackboard.root, "leftBlackboard.root");
@@ -3539,12 +3529,16 @@ private static JoinRelType convertJoinType(JoinType joinType) {
* @param orderExprList Additional expressions needed to implement ORDER BY
*/
protected void convertAgg(Blackboard bb, SqlSelect select, List orderExprList) {
- assert bb.root != null : "precondition: child != null";
+ requireNonNull(bb.root, "bb.root");
SqlNodeList groupList = select.getGroup();
SqlNodeList selectList = select.getSelectList();
SqlNode having = select.getHaving();
- final AggConverter aggConverter = new AggConverter(bb, select);
+ final AggConverter aggConverter =
+ AggConverter.create(
+ bb,
+ (AggregatingSelectScope) validator().getSelectScope(select),
+ validator());
createAggImpl(bb, aggConverter, selectList, groupList, having, orderExprList);
}
@@ -3587,15 +3581,16 @@ protected final void createAggImpl(
// Calcite allows expressions, not just column references in
// group by list. This is not SQL 2003 compliant, but hey.
- final AggregatingSelectScope scope =
- requireNonNull(aggConverter.aggregatingSelectScope, "aggregatingSelectScope");
- final AggregatingSelectScope.Resolved r = scope.resolved.get();
- for (SqlNode groupExpr : r.groupExprList) {
- aggConverter.addGroupExpr(groupExpr);
+ final AggregatingSelectScope.Resolved r = aggConverter.getResolved();
+ for (SqlNode e : r.groupExprList) {
+ aggConverter.addGroupExpr(e);
+ }
+ for (SqlNode e : r.measureExprList) {
+ aggConverter.addMeasureExpr(e);
}
final RexNode havingExpr;
- final List> projects = new ArrayList<>();
+ final PairList projects = PairList.of();
try {
checkArgument(bb.agg == null, "already in agg mode");
@@ -3617,14 +3612,15 @@ protected final void createAggImpl(
}
// compute inputs to the aggregator
- List> preExprs = aggConverter.getPreExprs();
-
- if (preExprs.size() == 0) {
+ final PairList preExprs;
+ if (aggConverter.convertedInputExprs.isEmpty()) {
// Special case for COUNT(*), where we can end up with no inputs
// at all. The rest of the system doesn't like 0-tuples, so we
// select a dummy constant here.
final RexNode zero = rexBuilder.makeExactLiteral(BigDecimal.ZERO);
- preExprs = ImmutableList.of(Pair.of(zero, null));
+ preExprs = PairList.of(zero, null);
+ } else {
+ preExprs = aggConverter.convertedInputExprs;
}
final RelNode inputRel = bb.root();
@@ -3633,7 +3629,7 @@ protected final void createAggImpl(
bb.setRoot(
relBuilder
.push(inputRel)
- .projectNamed(Pair.left(preExprs), Pair.right(preExprs), false)
+ .projectNamed(preExprs.leftList(), preExprs.rightList(), false)
.build(),
false);
bb.mapRootRelToFieldProjection.put(bb.root(), r.groupExprProjection);
@@ -3645,20 +3641,19 @@ protected final void createAggImpl(
// Tell bb which of group columns are sorted.
bb.columnMonotonicities.clear();
for (SqlNode groupItem : groupList) {
- bb.columnMonotonicities.add(bb.scope().getMonotonicity(groupItem));
+ bb.columnMonotonicities.add(bb.scope.getMonotonicity(groupItem));
}
// Add the aggregator
bb.setRoot(
- createAggregate(
- bb, r.groupSet, r.groupSets.asList(), aggConverter.getAggCalls()),
+ createAggregate(bb, r.groupSet, r.groupSets.asList(), aggConverter.aggCalls),
false);
bb.mapRootRelToFieldProjection.put(bb.root(), r.groupExprProjection);
// Replace sub-queries in having here and modify having to use
// the replaced expressions
if (having != null) {
- SqlNode newHaving = pushDownNotForIn(bb.scope(), having);
+ SqlNode newHaving = pushDownNotForIn(bb.scope, having);
replaceSubQueries(bb, newHaving, RelOptUtil.Logic.UNKNOWN_AS_FALSE);
havingExpr = bb.convertExpression(newHaving);
} else {
@@ -3694,16 +3689,14 @@ protected final void createAggImpl(
int sysFieldCount = selectList.size() - names.size();
for (SqlNode expr : selectList) {
projects.add(
- Pair.of(
- bb.convertExpression(expr),
- k < sysFieldCount
- ? SqlValidatorUtil.alias(expr, k++)
- : names.get(k++ - sysFieldCount)));
+ bb.convertExpression(expr),
+ k < sysFieldCount
+ ? SqlValidatorUtil.alias(expr, k++)
+ : names.get(k++ - sysFieldCount));
}
for (SqlNode expr : orderExprList) {
- projects.add(
- Pair.of(bb.convertExpression(expr), SqlValidatorUtil.alias(expr, k++)));
+ projects.add(bb.convertExpression(expr), SqlValidatorUtil.alias(expr, k++));
}
} finally {
bb.agg = null;
@@ -3711,18 +3704,16 @@ protected final void createAggImpl(
// implement HAVING (we have already checked that it is non-trivial)
relBuilder.push(bb.root());
- if (havingExpr != null) {
- relBuilder.filter(havingExpr);
- }
+ relBuilder.filter(havingExpr);
// implement the SELECT list
- relBuilder.project(Pair.left(projects), Pair.right(projects)).rename(Pair.right(projects));
+ relBuilder.project(projects.leftList(), projects.rightList()).rename(projects.rightList());
bb.setRoot(relBuilder.build(), false);
// Tell bb which of group columns are sorted.
bb.columnMonotonicities.clear();
for (SqlNode selectItem : selectList) {
- bb.columnMonotonicities.add(bb.scope().getMonotonicity(selectItem));
+ bb.columnMonotonicities.add(bb.scope.getMonotonicity(selectItem));
}
}
@@ -4066,19 +4057,18 @@ private RelNode createSource(
// filter.
final RexNode constraint = modifiableView.getConstraint(rexBuilder, delegateRowType);
RelOptUtil.inferViewPredicates(projectMap, filters, constraint);
- final List> projects = new ArrayList<>();
+ final PairList projects = PairList.of();
for (RelDataTypeField field : delegateRowType.getFieldList()) {
RexNode node = projectMap.get(field.getIndex());
if (node == null) {
node = rexBuilder.makeNullLiteral(field.getType());
}
- projects.add(
- Pair.of(rexBuilder.ensureType(field.getType(), node, false), field.getName()));
+ projects.add(rexBuilder.ensureType(field.getType(), node, false), field.getName());
}
return relBuilder
.push(source)
- .projectNamed(Pair.left(projects), Pair.right(projects), false)
+ .projectNamed(projects.leftList(), projects.rightList(), false)
.filter(filters)
.build();
}
@@ -4237,7 +4227,7 @@ private Blackboard createInsertBlackboard(
nameToNodeMap.put(targetColumnName, rexBuilder.makeFieldAccess(sourceRef, j++));
}
}
- return createBlackboard(null, nameToNodeMap, false);
+ return createBlackboard(validator().getEmptyScope(), nameToNodeMap, false);
}
private static InitializerExpressionFactory getInitializerFactory(
@@ -4498,12 +4488,7 @@ private RexNode convertIdentifier(Blackboard bb, SqlIdentifier identifier) {
pv = identifier.names.get(0);
}
- final SqlQualified qualified;
- if (bb.scope != null) {
- qualified = bb.scope.fullyQualify(identifier);
- } else {
- qualified = SqlQualified.create(null, 1, null, identifier);
- }
+ final SqlQualified qualified = bb.scope.fullyQualify(identifier);
final Pair> e0 =
bb.lookupExp(qualified);
RexNode e = e0.left;
@@ -4606,7 +4591,7 @@ private RelNode convertMultisets(final List operands, Blackboard bb) {
} else {
usedBb =
createBlackboard(
- new ListScope(bb.scope()) {
+ new ListScope(bb.scope) {
@Override
public SqlNode getNode() {
return call;
@@ -4952,16 +4937,15 @@ private void convertValuesImpl(
SqlCall rowConstructor = (SqlCall) rowConstructor1;
Blackboard tmpBb = createBlackboard(bb.scope, null, false);
replaceSubQueries(tmpBb, rowConstructor, RelOptUtil.Logic.TRUE_FALSE_UNKNOWN);
- final List> exps = new ArrayList<>();
+ final PairList exps = PairList.of();
Ord.forEach(
rowConstructor.getOperandList(),
(operand, i) ->
exps.add(
- Pair.of(
- tmpBb.convertExpression(operand),
- SqlValidatorUtil.alias(operand, i))));
+ tmpBb.convertExpression(operand),
+ SqlValidatorUtil.alias(operand, i)));
RelNode in = (null == tmpBb.root) ? LogicalValues.createOneRow(cluster) : tmpBb.root;
- relBuilder.push(in).project(Pair.left(exps), Pair.right(exps));
+ relBuilder.push(in).project(exps.leftList(), exps.rightList());
}
bb.setRoot(relBuilder.union(true, values.getOperandList().size()).build(), true);
@@ -4999,7 +4983,7 @@ R convert(
/** Workspace for translating an individual SELECT statement (or sub-SELECT). */
protected class Blackboard implements SqlRexContext, SqlVisitor, InitializerContext {
/** Collection of {@link RelNode} objects which correspond to a SELECT statement. */
- public final @Nullable SqlValidatorScope scope;
+ public final SqlValidatorScope scope;
private final @Nullable Map nameToNodeMap;
public @Nullable RelNode root;
@@ -5054,7 +5038,7 @@ protected Blackboard(
@Nullable SqlValidatorScope scope,
@Nullable Map nameToNodeMap,
boolean top) {
- this.scope = scope;
+ this.scope = requireNonNull(scope, "scope");
this.nameToNodeMap = nameToNodeMap;
this.top = top;
}
@@ -5063,8 +5047,9 @@ public RelNode root() {
return requireNonNull(root, "root");
}
+ @Deprecated // to be removed before 2.0
public SqlValidatorScope scope() {
- return requireNonNull(scope, "scope");
+ return scope;
}
public void setPatternVarRef(boolean isVarRef) {
@@ -5149,23 +5134,12 @@ public RexNode register(
assert leftKeyCount == rightFieldLength - 1;
final int rexRangeRefLength = leftKeyCount + rightFieldLength;
- RelDataType returnType =
- typeFactory.createStructType(
- new AbstractList>() {
- @Override
- public Map.Entry get(int index) {
- return join.getRowType()
- .getFieldList()
- .get(origLeftInputCount + index);
- }
-
- @Override
- public int size() {
- return rexRangeRefLength;
- }
- });
-
- return rexBuilder.makeRangeReference(returnType, origLeftInputCount, false);
+ final RelDataTypeFactory.Builder builder = typeFactory.builder();
+ for (int i = 0; i < rexRangeRefLength; i++) {
+ builder.add(join.getRowType().getFieldList().get(origLeftInputCount + i));
+ }
+
+ return rexBuilder.makeRangeReference(builder.build(), origLeftInputCount, false);
} else {
return rexBuilder.makeRangeReference(
rel.getRowType(), leftFieldCount, joinType.generatesNullsOnRight());
@@ -5259,9 +5233,9 @@ void setRoot(List inputs) {
return Pair.of(node, null);
}
final SqlNameMatcher nameMatcher =
- scope().getValidator().getCatalogReader().nameMatcher();
+ scope.getValidator().getCatalogReader().nameMatcher();
final SqlValidatorScope.ResolvedImpl resolved = new SqlValidatorScope.ResolvedImpl();
- scope().resolve(qualified.prefix(), nameMatcher, false, resolved);
+ scope.resolve(qualified.prefix(), nameMatcher, false, resolved);
if (resolved.count() != 1) {
throw new AssertionError(
"no unique expression found for "
@@ -5280,7 +5254,6 @@ void setRoot(List inputs) {
if ((inputs != null) && !isParent) {
final LookupContext rels = new LookupContext(this, inputs, systemFieldList.size());
final RexNode node = lookup(resolve.path.steps().get(0).i, rels);
- assert node != null;
return Pair.of(
node,
(e, fieldName) -> {
@@ -5344,8 +5317,9 @@ void setRoot(List inputs) {
* from-list is {@code offset}.
*/
RexNode lookup(int offset, LookupContext lookupContext) {
- Pair pair = lookupContext.findRel(offset);
- return rexBuilder.makeRangeReference(pair.left.getRowType(), pair.right, false);
+ Map.Entry pair = lookupContext.findRel(offset);
+ return rexBuilder.makeRangeReference(
+ pair.getKey().getRowType(), pair.getValue(), false);
}
@Nullable RelDataTypeField getRootField(RexInputRef inputRef) {
@@ -5368,13 +5342,13 @@ public void flatten(
List rels,
int systemFieldCount,
int[] start,
- List> relOffsetList) {
+ BiConsumer relOffsetList) {
for (RelNode rel : rels) {
if (leaves.containsKey(rel)) {
- relOffsetList.add(Pair.of(rel, start[0]));
+ relOffsetList.accept(rel, start[0]);
start[0] += leaves.get(rel);
} else if (rel instanceof LogicalMatch) {
- relOffsetList.add(Pair.of(rel, start[0]));
+ relOffsetList.accept(rel, start[0]);
start[0] += rel.getRowType().getFieldCount();
} else {
if (rel instanceof LogicalJoin || rel instanceof LogicalAggregate) {
@@ -5426,7 +5400,7 @@ public RexNode convertExpression(SqlNode expr) {
// GROUP BY clause, return a reference to the field.
AggConverter agg = this.agg;
if (agg != null) {
- final SqlNode expandedGroupExpr = validator().expand(expr, scope());
+ final SqlNode expandedGroupExpr = validator().expand(expr, scope);
final int ref = agg.lookupGroupExpr(expandedGroupExpr);
if (ref >= 0) {
return rexBuilder.makeInputRef(root(), ref);
@@ -5563,10 +5537,14 @@ public RexNode convertExpression(SqlNode expr) {
case CURSOR:
case IN:
case NOT_IN:
- subQuery = requireNonNull(getSubQuery(expr, null));
+ subQuery = getSubQuery(expr, null);
+ if (subQuery == null && (kind == SqlKind.SOME || kind == SqlKind.ALL)) {
+ break;
+ }
+ assert subQuery != null;
rex = requireNonNull(subQuery.expr);
return StandardConvertletTable.castToValidatedType(
- expr, rex, validator(), rexBuilder);
+ expr, rex, validator(), rexBuilder, false);
case SELECT:
case EXISTS:
@@ -5874,23 +5852,17 @@ private static SqlQuantifyOperator negate(SqlQuantifyOperator operator) {
/** Deferred lookup. */
private static class DeferredLookup {
- Blackboard bb;
- String originalRelName;
+ final Blackboard bb;
+ final String originalRelName;
DeferredLookup(Blackboard bb, String originalRelName) {
this.bb = bb;
this.originalRelName = originalRelName;
}
- public RexFieldAccess getFieldAccess(CorrelationId name) {
- return (RexFieldAccess)
- requireNonNull(
- bb.mapCorrelateToRex.get(name),
- () -> "Correlation " + name + " is not found");
- }
-
- public String getOriginalRelName() {
- return originalRelName;
+ RexFieldAccess getFieldAccess(CorrelationId name) {
+ return requireNonNull(
+ bb.mapCorrelateToRex.get(name), () -> "Correlation " + name + " is not found");
}
}
@@ -5911,556 +5883,9 @@ public RexNode convertSubQuery(
}
}
- /**
- * Converts expressions to aggregates.
- *
- * Consider the expression
- *
- *
- *
- * {@code SELECT deptno, SUM(2 * sal) FROM emp GROUP BY deptno}
- *
- *
- *
- * Then:
- *
- *
- * - groupExprs = {SqlIdentifier(deptno)}
- *
- convertedInputExprs = {RexInputRef(deptno), 2 * RefInputRef(sal)}
- *
- inputRefs = {RefInputRef(#0), RexInputRef(#1)}
- *
- aggCalls = {AggCall(SUM, {1})}
- *
- */
- protected class AggConverter implements SqlVisitor {
- private final Blackboard bb;
- public final @Nullable AggregatingSelectScope aggregatingSelectScope;
-
- private final Map nameMap = new HashMap<>();
-
- /** The group-by expressions, in {@link SqlNode} format. */
- private final SqlNodeList groupExprs = new SqlNodeList(SqlParserPos.ZERO);
-
- /** The auxiliary group-by expressions. */
- private final Map> auxiliaryGroupExprs = new HashMap<>();
-
- /**
- * Input expressions for the group columns and aggregates, in {@link RexNode} format. The
- * first elements of the list correspond to the elements in {@link #groupExprs}; the
- * remaining elements are for aggregates. The right field of each pair is the name of the
- * expression, where the expressions are simple mappings to input fields.
- */
- private final List> convertedInputExprs = new ArrayList<>();
-
- /**
- * Expressions to be evaluated as rows are being placed into the aggregate's hash table.
- * This is when group functions such as TUMBLE cause rows to be expanded.
- */
- private final List aggCalls = new ArrayList<>();
-
- private final Map aggMapping = new HashMap<>();
- private final Map aggCallMapping = new HashMap<>();
-
- /** Whether we are directly inside a windowed aggregate. */
- private boolean inOver = false;
-
- AggConverter(Blackboard bb, @Nullable AggregatingSelectScope aggregatingSelectScope) {
- this.bb = bb;
- this.aggregatingSelectScope = aggregatingSelectScope;
- }
-
- /**
- * Creates an AggConverter.
- *
- * The select
parameter provides enough context to name aggregate calls
- * which are top-level select list items.
- *
- * @param bb Blackboard
- * @param select Query being translated; provides context to give
- */
- public AggConverter(Blackboard bb, SqlSelect select) {
- this(bb, (AggregatingSelectScope) bb.getValidator().getSelectScope(select));
-
- // Collect all expressions used in the select list so that aggregate
- // calls can be named correctly.
- final SqlNodeList selectList = select.getSelectList();
- for (int i = 0; i < selectList.size(); i++) {
- SqlNode selectItem = selectList.get(i);
- String name = null;
- if (SqlUtil.isCallTo(selectItem, SqlStdOperatorTable.AS)) {
- final SqlCall call = (SqlCall) selectItem;
- selectItem = call.operand(0);
- name = call.operand(1).toString();
- }
- if (name == null) {
- name = SqlValidatorUtil.alias(selectItem, i);
- }
- nameMap.put(selectItem.toString(), name);
- }
- }
-
- public int addGroupExpr(SqlNode expr) {
- int ref = lookupGroupExpr(expr);
- if (ref >= 0) {
- return ref;
- }
- final int index = groupExprs.size();
- groupExprs.add(expr);
- String name = nameMap.get(expr.toString());
- RexNode convExpr = bb.convertExpression(expr);
- addExpr(convExpr, name);
-
- if (expr instanceof SqlCall) {
- SqlCall call = (SqlCall) expr;
- for (Pair p :
- SqlStdOperatorTable.convertGroupToAuxiliaryCalls(call)) {
- addAuxiliaryGroupExpr(p.left, index, p.right);
- }
- }
-
- return index;
- }
-
- void addAuxiliaryGroupExpr(SqlNode node, int index, AuxiliaryConverter converter) {
- for (SqlNode node2 : auxiliaryGroupExprs.keySet()) {
- if (node2.equalsDeep(node, Litmus.IGNORE)) {
- return;
- }
- }
- auxiliaryGroupExprs.put(node, Ord.of(index, converter));
- }
-
- /**
- * Adds an expression, deducing an appropriate name if possible.
- *
- * @param expr Expression
- * @param name Suggested name
- */
- private void addExpr(RexNode expr, @Nullable String name) {
- if ((name == null) && (expr instanceof RexInputRef)) {
- final int i = ((RexInputRef) expr).getIndex();
- name = bb.root().getRowType().getFieldList().get(i).getName();
- }
- if (Pair.right(convertedInputExprs).contains(name)) {
- // In case like 'SELECT ... GROUP BY x, y, x', don't add
- // name 'x' twice.
- name = null;
- }
- convertedInputExprs.add(Pair.of(expr, name));
- }
-
- @Override
- public Void visit(SqlIdentifier id) {
- return null;
- }
-
- @Override
- public Void visit(SqlNodeList nodeList) {
- for (int i = 0; i < nodeList.size(); i++) {
- nodeList.get(i).accept(this);
- }
- return null;
- }
-
- @Override
- public Void visit(SqlLiteral lit) {
- return null;
- }
-
- @Override
- public Void visit(SqlDataTypeSpec type) {
- return null;
- }
-
- @Override
- public Void visit(SqlDynamicParam param) {
- return null;
- }
-
- @Override
- public Void visit(SqlIntervalQualifier intervalQualifier) {
- return null;
- }
-
- @Override
- public Void visit(SqlCall call) {
- switch (call.getKind()) {
- case FILTER:
- case IGNORE_NULLS:
- case RESPECT_NULLS:
- case WITHIN_DISTINCT:
- case WITHIN_GROUP:
- translateAgg(call);
- return null;
- case SELECT:
- // rchen 2006-10-17:
- // for now do not detect aggregates in sub-queries.
- return null;
- default:
- break;
- }
- final boolean prevInOver = inOver;
- // Ignore window aggregates and ranking functions (associated with OVER
- // operator). However, do not ignore nested window aggregates.
- if (call.getOperator().getKind() == SqlKind.OVER) {
- // Track aggregate nesting levels only within an OVER operator.
- List operandList = call.getOperandList();
- assert operandList.size() == 2;
-
- // Ignore the top level window aggregates and ranking functions
- // positioned as the first operand of a OVER operator
- inOver = true;
- operandList.get(0).accept(this);
-
- // Normal translation for the second operand of a OVER operator
- inOver = false;
- operandList.get(1).accept(this);
- return null;
- }
-
- // Do not translate the top level window aggregate. Only do so for
- // nested aggregates, if present
- if (call.getOperator().isAggregator()) {
- if (inOver) {
- // Add the parent aggregate level before visiting its children
- inOver = false;
- } else {
- // We're beyond the one ignored level
- translateAgg(call);
- return null;
- }
- }
- for (SqlNode operand : call.getOperandList()) {
- // Operands are occasionally null, e.g. switched CASE arg 0.
- if (operand != null) {
- operand.accept(this);
- }
- }
- // Remove the parent aggregate level after visiting its children
- inOver = prevInOver;
- return null;
- }
-
- private void translateAgg(SqlCall call) {
- translateAgg(call, null, null, null, false, call);
- }
-
- private void translateAgg(
- SqlCall call,
- @Nullable SqlNode filter,
- @Nullable SqlNodeList distinctList,
- @Nullable SqlNodeList orderList,
- boolean ignoreNulls,
- SqlCall outerCall) {
- assert bb.agg == this;
- assert outerCall != null;
- final List operands = call.getOperandList();
- final SqlParserPos pos = call.getParserPosition();
- final SqlCall call2;
- switch (call.getKind()) {
- case FILTER:
- assert filter == null;
- translateAgg(
- call.operand(0),
- call.operand(1),
- distinctList,
- orderList,
- ignoreNulls,
- outerCall);
- return;
- case WITHIN_DISTINCT:
- assert orderList == null;
- translateAgg(
- call.operand(0),
- filter,
- call.operand(1),
- orderList,
- ignoreNulls,
- outerCall);
- return;
- case WITHIN_GROUP:
- assert orderList == null;
- translateAgg(
- call.operand(0),
- filter,
- distinctList,
- call.operand(1),
- ignoreNulls,
- outerCall);
- return;
- case IGNORE_NULLS:
- ignoreNulls = true;
- // fall through
- case RESPECT_NULLS:
- translateAgg(
- call.operand(0),
- filter,
- distinctList,
- orderList,
- ignoreNulls,
- outerCall);
- return;
-
- case COUNTIF:
- // COUNTIF(b) ==> COUNT(*) FILTER (WHERE b)
- // COUNTIF(b) FILTER (WHERE b2) ==> COUNT(*) FILTER (WHERE b2 AND b)
- call2 = SqlStdOperatorTable.COUNT.createCall(pos, SqlIdentifier.star(pos));
- final SqlNode filter2 = SqlUtil.andExpressions(filter, call.operand(0));
- translateAgg(call2, filter2, distinctList, orderList, ignoreNulls, outerCall);
- return;
-
- case STRING_AGG:
- // Translate "STRING_AGG(s, sep ORDER BY x, y)"
- // as if it were "LISTAGG(s, sep) WITHIN GROUP (ORDER BY x, y)";
- // and "STRING_AGG(s, sep)" as "LISTAGG(s, sep)".
- final List operands2;
- if (!operands.isEmpty() && Util.last(operands) instanceof SqlNodeList) {
- orderList = (SqlNodeList) Util.last(operands);
- operands2 = Util.skipLast(operands);
- } else {
- operands2 = operands;
- }
- call2 =
- SqlStdOperatorTable.LISTAGG.createCall(
- call.getFunctionQuantifier(), pos, operands2);
- translateAgg(call2, filter, distinctList, orderList, ignoreNulls, outerCall);
- return;
-
- case GROUP_CONCAT:
- // Translate "GROUP_CONCAT(s ORDER BY x, y SEPARATOR ',')"
- // as if it were "LISTAGG(s, ',') WITHIN GROUP (ORDER BY x, y)".
- // To do this, build a list of operands without ORDER BY with with sep.
- operands2 = new ArrayList<>(operands);
- final SqlNode separator;
- if (!operands2.isEmpty()
- && Util.last(operands2).getKind() == SqlKind.SEPARATOR) {
- final SqlCall sepCall = (SqlCall) operands2.remove(operands.size() - 1);
- separator = sepCall.operand(0);
- } else {
- separator = null;
- }
-
- if (!operands2.isEmpty() && Util.last(operands2) instanceof SqlNodeList) {
- orderList = (SqlNodeList) operands2.remove(operands2.size() - 1);
- }
-
- if (separator != null) {
- operands2.add(separator);
- }
-
- call2 =
- SqlStdOperatorTable.LISTAGG.createCall(
- call.getFunctionQuantifier(), pos, operands2);
- translateAgg(call2, filter, distinctList, orderList, ignoreNulls, outerCall);
- return;
-
- case ARRAY_AGG:
- case ARRAY_CONCAT_AGG:
- // Translate "ARRAY_AGG(s ORDER BY x, y)"
- // as if it were "ARRAY_AGG(s) WITHIN GROUP (ORDER BY x, y)";
- // similarly "ARRAY_CONCAT_AGG".
- if (!operands.isEmpty() && Util.last(operands) instanceof SqlNodeList) {
- orderList = (SqlNodeList) Util.last(operands);
- call2 =
- call.getOperator()
- .createCall(
- call.getFunctionQuantifier(),
- pos,
- Util.skipLast(operands));
- translateAgg(
- call2, filter, distinctList, orderList, ignoreNulls, outerCall);
- return;
- }
- // "ARRAY_AGG" and "ARRAY_CONCAT_AGG" without "ORDER BY"
- // are handled normally; fall through.
-
- default:
- break;
- }
- final List args = new ArrayList<>();
- int filterArg = -1;
- final ImmutableBitSet distinctKeys;
- try {
- // switch out of agg mode
- bb.agg = null;
- // ----- FLINK MODIFICATION BEGIN -----
- FlinkSqlCallBinding binding =
- new FlinkSqlCallBinding(validator(), aggregatingSelectScope, call);
- List sqlNodes = binding.operands();
- for (int i = 0; i < sqlNodes.size(); i++) {
- SqlNode operand = sqlNodes.get(i);
- // special case for COUNT(*): delete the *
- if (operand instanceof SqlIdentifier) {
- SqlIdentifier id = (SqlIdentifier) operand;
- if (id.isStar()) {
- assert call.operandCount() == 1;
- assert args.isEmpty();
- break;
- }
- }
- RexNode convertedExpr = bb.convertExpression(operand);
- args.add(lookupOrCreateGroupExpr(convertedExpr));
- }
- // ----- FLINK MODIFICATION END -----
-
- if (filter != null) {
- RexNode convertedExpr = bb.convertExpression(filter);
- if (convertedExpr.getType().isNullable()) {
- convertedExpr =
- rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE, convertedExpr);
- }
- filterArg = lookupOrCreateGroupExpr(convertedExpr);
- }
-
- if (distinctList == null) {
- distinctKeys = null;
- } else {
- final ImmutableBitSet.Builder distinctBuilder = ImmutableBitSet.builder();
- for (SqlNode distinct : distinctList) {
- RexNode e = bb.convertExpression(distinct);
- assert e != null;
- distinctBuilder.set(lookupOrCreateGroupExpr(e));
- }
- distinctKeys = distinctBuilder.build();
- }
- } finally {
- // switch back into agg mode
- bb.agg = this;
- }
-
- SqlAggFunction aggFunction = (SqlAggFunction) call.getOperator();
- final RelDataType type = validator().deriveType(bb.scope(), call);
- boolean distinct = false;
- SqlLiteral quantifier = call.getFunctionQuantifier();
- if ((null != quantifier) && (quantifier.getValue() == SqlSelectKeyword.DISTINCT)) {
- distinct = true;
- }
- boolean approximate = false;
- if (aggFunction == SqlStdOperatorTable.APPROX_COUNT_DISTINCT) {
- aggFunction = SqlStdOperatorTable.COUNT;
- distinct = true;
- approximate = true;
- }
- final RelCollation collation;
- if (orderList == null || orderList.size() == 0) {
- collation = RelCollations.EMPTY;
- } else {
- try {
- // switch out of agg mode
- bb.agg = null;
- collation =
- RelCollations.of(
- orderList.stream()
- .map(
- order ->
- bb.convertSortExpression(
- order,
- RelFieldCollation.Direction
- .ASCENDING,
- RelFieldCollation.NullDirection
- .UNSPECIFIED,
- this::sortToFieldCollation))
- .collect(Collectors.toList()));
- } finally {
- // switch back into agg mode
- bb.agg = this;
- }
- }
- final AggregateCall aggCall =
- AggregateCall.create(
- aggFunction,
- distinct,
- approximate,
- ignoreNulls,
- args,
- filterArg,
- distinctKeys,
- collation,
- type,
- nameMap.get(outerCall.toString()));
- RexNode rex =
- rexBuilder.addAggCall(
- aggCall,
- groupExprs.size(),
- aggCalls,
- aggCallMapping,
- i -> convertedInputExprs.get(i).left.getType().isNullable());
- aggMapping.put(outerCall, rex);
- }
-
- private RelFieldCollation sortToFieldCollation(
- SqlNode expr,
- RelFieldCollation.Direction direction,
- RelFieldCollation.NullDirection nullDirection) {
- final RexNode node = bb.convertExpression(expr);
- final int fieldIndex = lookupOrCreateGroupExpr(node);
- if (nullDirection == RelFieldCollation.NullDirection.UNSPECIFIED) {
- nullDirection = direction.defaultNullDirection();
- }
- return new RelFieldCollation(fieldIndex, direction, nullDirection);
- }
-
- private int lookupOrCreateGroupExpr(RexNode expr) {
- int index = 0;
- for (RexNode convertedInputExpr : Pair.left(convertedInputExprs)) {
- if (expr.equals(convertedInputExpr)) {
- return index;
- }
- ++index;
- }
-
- // not found -- add it
- addExpr(expr, null);
- return index;
- }
-
- /**
- * If an expression is structurally identical to one of the group-by expressions, returns a
- * reference to the expression, otherwise returns null.
- */
- public int lookupGroupExpr(SqlNode expr) {
- for (int i = 0; i < groupExprs.size(); i++) {
- SqlNode groupExpr = groupExprs.get(i);
- if (expr.equalsDeep(groupExpr, Litmus.IGNORE)) {
- return i;
- }
- }
- return -1;
- }
-
- public @Nullable RexNode lookupAggregates(SqlCall call) {
- // assert call.getOperator().isAggregator();
- assert bb.agg == this;
-
- for (Map.Entry> e : auxiliaryGroupExprs.entrySet()) {
- if (call.equalsDeep(e.getKey(), Litmus.IGNORE)) {
- AuxiliaryConverter converter = e.getValue().e;
- final int groupOrdinal = e.getValue().i;
- return converter.convert(
- rexBuilder,
- convertedInputExprs.get(groupOrdinal).left,
- rexBuilder.makeInputRef(castNonNull(bb.root), groupOrdinal));
- }
- }
-
- return aggMapping.get(call);
- }
-
- public List> getPreExprs() {
- return convertedInputExprs;
- }
-
- public List getAggCalls() {
- return aggCalls;
- }
-
- public RelDataTypeFactory getTypeFactory() {
- return typeFactory;
- }
- }
-
/** Context to find a relational expression to a field offset. */
private static class LookupContext {
- private final List> relOffsetList = new ArrayList<>();
+ private final PairList relOffsetList = PairList.of();
/**
* Creates a LookupContext with multiple input relational expressions.
@@ -6470,7 +5895,7 @@ private static class LookupContext {
* @param systemFieldCount Number of system fields
*/
LookupContext(Blackboard bb, List rels, int systemFieldCount) {
- bb.flatten(rels, systemFieldCount, new int[] {0}, relOffsetList);
+ bb.flatten(rels, systemFieldCount, new int[] {0}, relOffsetList::add);
}
/**
@@ -6484,7 +5909,7 @@ private static class LookupContext {
* @param offset Offset of relational expression in FROM clause
* @return Relational expression and the ordinal of its first field
*/
- Pair findRel(int offset) {
+ Map.Entry findRel(int offset) {
return relOffsetList.get(offset);
}
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
index a7363bd5a99fa..29064715339c2 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
@@ -58,6 +58,7 @@
import org.apache.calcite.sql.fun.SqlArrayValueConstructor;
import org.apache.calcite.sql.fun.SqlBetweenOperator;
import org.apache.calcite.sql.fun.SqlCase;
+import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator;
import org.apache.calcite.sql.fun.SqlExtractFunction;
import org.apache.calcite.sql.fun.SqlInternalOperators;
@@ -81,7 +82,6 @@
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlValidator;
-import org.apache.calcite.sql.validate.SqlValidatorImpl;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.checkerframework.checker.nullness.qual.Nullable;
@@ -96,7 +96,9 @@
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
+import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;
+import static org.apache.calcite.sql.fun.SqlStdOperatorTable.QUANTIFY_OPERATORS;
import static org.apache.calcite.sql.type.NonNullableAccessors.getComponentTypeOrThrow;
import static org.apache.calcite.util.Util.first;
@@ -106,7 +108,7 @@
* FLINK modifications are at lines
*
*
- * - Added in Flink-35216: Lines 731 ~ 776
+ *
- Added in Flink-35216: Lines 832 ~ 878
*
*/
public class StandardConvertletTable extends ReflectiveConvertletTable {
@@ -131,6 +133,8 @@ private StandardConvertletTable() {
// Register convertlets for specific objects.
registerOp(SqlStdOperatorTable.CAST, this::convertCast);
+ registerOp(SqlLibraryOperators.SAFE_CAST, this::convertCast);
+ registerOp(SqlLibraryOperators.TRY_CAST, this::convertCast);
registerOp(SqlLibraryOperators.INFIX_CAST, this::convertCast);
registerOp(
SqlStdOperatorTable.IS_DISTINCT_FROM,
@@ -168,6 +172,9 @@ private StandardConvertletTable() {
return e;
});
+ registerOp(SqlLibraryOperators.DATETIME_TRUNC, new TruncConvertlet());
+ registerOp(SqlLibraryOperators.TIMESTAMP_TRUNC, new TruncConvertlet());
+
registerOp(SqlLibraryOperators.LTRIM, new TrimConvertlet(SqlTrimFunction.Flag.LEADING));
registerOp(SqlLibraryOperators.RTRIM, new TrimConvertlet(SqlTrimFunction.Flag.TRAILING));
@@ -193,6 +200,9 @@ private StandardConvertletTable() {
registerOp(SqlLibraryOperators.TIMESTAMP_DIFF3, new TimestampDiffConvertlet());
registerOp(SqlLibraryOperators.TIMESTAMP_SUB, new TimestampSubConvertlet());
+ QUANTIFY_OPERATORS.forEach(
+ operator -> registerOp(operator, StandardConvertletTable::convertQuantifyOperator));
+
registerOp(SqlLibraryOperators.NVL, StandardConvertletTable::convertNvl);
registerOp(SqlLibraryOperators.DECODE, StandardConvertletTable::convertDecode);
registerOp(SqlLibraryOperators.IF, StandardConvertletTable::convertIf);
@@ -255,6 +265,8 @@ private StandardConvertletTable() {
registerOp(SqlStdOperatorTable.ITEM, this::convertItem);
// "AS" has no effect, so expand "x AS id" into "x".
registerOp(SqlStdOperatorTable.AS, (cx, call) -> cx.convertExpression(call.operand(0)));
+ registerOp(SqlStdOperatorTable.CONVERT, this::convertCharset);
+ registerOp(SqlStdOperatorTable.TRANSLATE, this::translateCharset);
// "SQRT(x)" is equivalent to "POWER(x, .5)"
registerOp(
SqlStdOperatorTable.SQRT,
@@ -265,6 +277,19 @@ private StandardConvertletTable() {
call.operand(0),
SqlLiteral.createExactNumeric("0.5", SqlParserPos.ZERO))));
+ // "STRPOS(string, substring) is equivalent to
+ // "POSITION(substring IN string)"
+ registerOp(
+ SqlLibraryOperators.STRPOS,
+ (cx, call) ->
+ cx.convertExpression(
+ SqlStdOperatorTable.POSITION.createCall(
+ SqlParserPos.ZERO, call.operand(1), call.operand(0))));
+
+ // "INSTR(string, substring, position, occurrence) is equivalent to
+ // "POSITION(substring, string, position, occurrence)"
+ registerOp(SqlLibraryOperators.INSTR, StandardConvertletTable::convertInstr);
+
// REVIEW jvs 24-Apr-2006: This only seems to be working from within a
// windowed agg. I have added an optimizer rule
// org.apache.calcite.rel.rules.AggregateReduceFunctionsRule which handles
@@ -338,6 +363,24 @@ private StandardConvertletTable() {
}
}
+ /** Converts ALL or SOME operators. */
+ private static RexNode convertQuantifyOperator(SqlRexContext cx, SqlCall call) {
+ final RexBuilder rexBuilder = cx.getRexBuilder();
+ final RexNode left = cx.convertExpression(call.getOperandList().get(0));
+ assert call.getOperandList().get(1) instanceof SqlNodeList;
+ final RexNode right =
+ cx.convertExpression(((SqlNodeList) call.getOperandList().get(1)).get(0));
+ final RelDataType rightComponentType = requireNonNull(right.getType().getComponentType());
+ final RelDataType returnType =
+ cx.getTypeFactory()
+ .createTypeWithNullability(
+ cx.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN),
+ right.getType().isNullable()
+ || left.getType().isNullable()
+ || rightComponentType.isNullable());
+ return rexBuilder.makeCall(returnType, call.getOperator(), ImmutableList.of(left, right));
+ }
+
/** Converts a call to the {@code NVL} function (and also its synonym, {@code IFNULL}). */
private static RexNode convertNvl(SqlRexContext cx, SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
@@ -362,6 +405,40 @@ private static RexNode convertNvl(SqlRexContext cx, SqlCall call) {
operand1)));
}
+ /**
+ * Converts a call to the INSTR function. INSTR(string, substring, position, occurrence) is
+ * equivalent to POSITION(substring, string, position, occurrence)
+ */
+ private static RexNode convertInstr(SqlRexContext cx, SqlCall call) {
+ final RexBuilder rexBuilder = cx.getRexBuilder();
+ final List operands =
+ convertOperands(cx, call, SqlOperandTypeChecker.Consistency.NONE);
+ final RelDataType type = cx.getValidator().getValidatedNodeType(call);
+ final List exprs = new ArrayList<>();
+ switch (call.operandCount()) {
+ // Must reverse order of first 2 operands.
+ case 2:
+ exprs.add(operands.get(1)); // Substring
+ exprs.add(operands.get(0)); // String
+ break;
+ case 3:
+ exprs.add(operands.get(1)); // Substring
+ exprs.add(operands.get(0)); // String
+ exprs.add(operands.get(2)); // Position
+ break;
+ case 4:
+ exprs.add(operands.get(1)); // Substring
+ exprs.add(operands.get(0)); // String
+ exprs.add(operands.get(2)); // Position
+ exprs.add(operands.get(3)); // Occurrence
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ "Position does not accept " + call.operandCount() + " operands");
+ }
+ return rexBuilder.makeCall(type, SqlStdOperatorTable.POSITION, exprs);
+ }
+
/** Converts a call to the DECODE function. */
private static RexNode convertDecode(SqlRexContext cx, SqlCall call) {
final RexBuilder rexBuilder = cx.getRexBuilder();
@@ -552,39 +629,41 @@ public RexNode convertJdbc(SqlRexContext cx, SqlJdbcFunctionCall op, SqlCall cal
protected RexNode convertCast(SqlRexContext cx, final SqlCall call) {
RelDataTypeFactory typeFactory = cx.getTypeFactory();
- assert call.getKind() == SqlKind.CAST;
+ final SqlValidator validator = cx.getValidator();
+ final SqlKind kind = call.getKind();
+ checkArgument(kind == SqlKind.CAST || kind == SqlKind.SAFE_CAST, kind);
+ final boolean safe = kind == SqlKind.SAFE_CAST;
final SqlNode left = call.operand(0);
final SqlNode right = call.operand(1);
+ final RexBuilder rexBuilder = cx.getRexBuilder();
if (right instanceof SqlIntervalQualifier) {
final SqlIntervalQualifier intervalQualifier = (SqlIntervalQualifier) right;
if (left instanceof SqlIntervalLiteral) {
RexLiteral sourceInterval = (RexLiteral) cx.convertExpression(left);
BigDecimal sourceValue = (BigDecimal) sourceInterval.getValue();
RexLiteral castedInterval =
- cx.getRexBuilder().makeIntervalLiteral(sourceValue, intervalQualifier);
- return castToValidatedType(cx, call, castedInterval);
+ rexBuilder.makeIntervalLiteral(sourceValue, intervalQualifier);
+ return castToValidatedType(call, castedInterval, validator, rexBuilder, safe);
} else if (left instanceof SqlNumericLiteral) {
RexLiteral sourceInterval = (RexLiteral) cx.convertExpression(left);
- BigDecimal sourceValue = (BigDecimal) sourceInterval.getValue();
+ BigDecimal sourceValue =
+ requireNonNull(sourceInterval.getValueAs(BigDecimal.class), "sourceValue");
final BigDecimal multiplier = intervalQualifier.getUnit().multiplier;
- sourceValue = SqlFunctions.multiply(sourceValue, multiplier);
RexLiteral castedInterval =
- cx.getRexBuilder().makeIntervalLiteral(sourceValue, intervalQualifier);
- return castToValidatedType(cx, call, castedInterval);
+ rexBuilder.makeIntervalLiteral(
+ SqlFunctions.multiply(sourceValue, multiplier), intervalQualifier);
+ return castToValidatedType(call, castedInterval, validator, rexBuilder, safe);
}
- return castToValidatedType(cx, call, cx.convertExpression(left));
- }
- SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
- RelDataType type = dataType.deriveType(cx.getValidator());
- if (type == null) {
- type = cx.getValidator().getValidatedNodeType(dataType.getTypeName());
- }
- RexNode arg = cx.convertExpression(left);
- if (arg.getType().isNullable()) {
- type = typeFactory.createTypeWithNullability(type, true);
+ RexNode value = cx.convertExpression(left);
+ return castToValidatedType(call, value, validator, rexBuilder, safe);
}
+
+ final RexNode arg = cx.convertExpression(left);
+ final SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
+ RelDataType type =
+ SqlCastFunction.deriveType(
+ cx.getTypeFactory(), arg.getType(), dataType.deriveType(validator), safe);
if (SqlUtil.isNullLiteral(left, false)) {
- final SqlValidatorImpl validator = (SqlValidatorImpl) cx.getValidator();
validator.setValidatedNodeType(left, type);
return cx.convertExpression(left);
}
@@ -593,7 +672,7 @@ protected RexNode convertCast(SqlRexContext cx, final SqlCall call) {
// arg.getType() may be ANY
if (argComponentType == null) {
- argComponentType = dataType.getComponentTypeSpec().deriveType(cx.getValidator());
+ argComponentType = dataType.getComponentTypeSpec().deriveType(validator);
}
requireNonNull(argComponentType, () -> "componentType of " + arg);
@@ -615,7 +694,7 @@ protected RexNode convertCast(SqlRexContext cx, final SqlCall call) {
type = typeFactory.createTypeWithNullability(type, isn);
}
}
- return cx.getRexBuilder().makeCast(type, arg);
+ return rexBuilder.makeCast(type, arg, safe, safe);
}
protected RexNode convertFloorCeil(SqlRexContext cx, SqlCall call) {
@@ -655,6 +734,28 @@ protected RexNode convertFloorCeil(SqlRexContext cx, SqlCall call) {
return convertFunction(cx, (SqlFunction) call.getOperator(), call);
}
+ protected RexNode convertCharset(SqlRexContext cx, SqlCall call) {
+ final SqlNode expr = call.operand(0);
+ final String srcCharset = call.operand(1).toString();
+ final String destCharset = call.operand(2).toString();
+ final RexBuilder rexBuilder = cx.getRexBuilder();
+ return rexBuilder.makeCall(
+ SqlStdOperatorTable.CONVERT,
+ cx.convertExpression(expr),
+ rexBuilder.makeLiteral(srcCharset),
+ rexBuilder.makeLiteral(destCharset));
+ }
+
+ protected RexNode translateCharset(SqlRexContext cx, SqlCall call) {
+ final SqlNode expr = call.operand(0);
+ final String transcodingName = call.operand(1).toString();
+ final RexBuilder rexBuilder = cx.getRexBuilder();
+ return rexBuilder.makeCall(
+ SqlStdOperatorTable.TRANSLATE,
+ cx.convertExpression(expr),
+ rexBuilder.makeLiteral(transcodingName));
+ }
+
/**
* Converts a call to the {@code EXTRACT} function.
*
@@ -1248,12 +1349,15 @@ private static Pair convertOverlapsOperand(
return Pair.of(r0, r1);
}
- /**
- * Casts a RexNode value to the validated type of a SqlCall. If the value was already of the
- * validated type, then the value is returned without an additional cast.
- */
+ @Deprecated // to be removed before 2.0
public RexNode castToValidatedType(SqlRexContext cx, SqlCall call, RexNode value) {
- return castToValidatedType(call, value, cx.getValidator(), cx.getRexBuilder());
+ return castToValidatedType(call, value, cx.getValidator(), cx.getRexBuilder(), false);
+ }
+
+ @Deprecated // to be removed before 2.0
+ public static RexNode castToValidatedType(
+ SqlNode node, RexNode e, SqlValidator validator, RexBuilder rexBuilder) {
+ return castToValidatedType(node, e, validator, rexBuilder, false);
}
/**
@@ -1261,12 +1365,12 @@ public RexNode castToValidatedType(SqlRexContext cx, SqlCall call, RexNode value
* validated type, then the value is returned without an additional cast.
*/
public static RexNode castToValidatedType(
- SqlNode node, RexNode e, SqlValidator validator, RexBuilder rexBuilder) {
+ SqlNode node, RexNode e, SqlValidator validator, RexBuilder rexBuilder, boolean safe) {
final RelDataType type = validator.getValidatedNodeType(node);
if (e.getType() == type) {
return e;
}
- return rexBuilder.makeCast(type, e);
+ return rexBuilder.makeCast(type, e, safe, safe);
}
/**
@@ -1842,7 +1946,7 @@ public RexNode convertCall(SqlRexContext cx, SqlCall call) {
// If the TIMESTAMPADD call has type TIMESTAMP and op2 has type DATE
// (which can happen for sub-day time frames such as HOUR), cast op2 to
// TIMESTAMP.
- final RexNode op2b = rexBuilder.makeCast(type, op2, false);
+ final RexNode op2b = rexBuilder.makeCast(type, op2);
return rexBuilder.makeCall(
type,
SqlStdOperatorTable.TIMESTAMP_ADD,
@@ -1880,6 +1984,25 @@ public RexNode convertCall(SqlRexContext cx, SqlCall call) {
}
}
+ /**
+ * Convertlet that handles the BigQuery {@code DATETIME_TRUNC} and {@code TIMESTAMP_TRUNC}
+ * functions. Ensures that DATE operands are cast to TIMESTAMPs to match the expected return
+ * type for BigQuery.
+ */
+ private static class TruncConvertlet implements SqlRexConvertlet {
+ @Override
+ public RexNode convertCall(SqlRexContext cx, SqlCall call) {
+ final RexBuilder rexBuilder = cx.getRexBuilder();
+ RexNode op1 = cx.convertExpression(call.operand(0));
+ RexNode op2 = cx.convertExpression(call.operand(1));
+ if (op1.getType().getSqlTypeName() == SqlTypeName.DATE) {
+ RelDataType type = cx.getValidator().getValidatedNodeType(call);
+ op1 = cx.getRexBuilder().makeCast(type, op1);
+ }
+ return rexBuilder.makeCall(call.getOperator(), op1, op2);
+ }
+ }
+
/** Convertlet that handles the BigQuery {@code TIMESTAMP_SUB} function. */
private static class TimestampSubConvertlet implements SqlRexConvertlet {
@Override
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/tools/RelBuilder.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/tools/RelBuilder.java
deleted file mode 100644
index a2df489a9b633..0000000000000
--- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ /dev/null
@@ -1,5258 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to you under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.calcite.tools;
-
-import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.ImmutableSortedMultiset;
-import com.google.common.collect.Iterables;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Multiset;
-import com.google.common.collect.Sets;
-import org.apache.calcite.linq4j.Ord;
-import org.apache.calcite.linq4j.function.Experimental;
-import org.apache.calcite.plan.Context;
-import org.apache.calcite.plan.Contexts;
-import org.apache.calcite.plan.Convention;
-import org.apache.calcite.plan.RelOptCluster;
-import org.apache.calcite.plan.RelOptPredicateList;
-import org.apache.calcite.plan.RelOptSchema;
-import org.apache.calcite.plan.RelOptTable;
-import org.apache.calcite.plan.RelOptUtil;
-import org.apache.calcite.plan.ViewExpanders;
-import org.apache.calcite.prepare.RelOptTableImpl;
-import org.apache.calcite.rel.RelCollation;
-import org.apache.calcite.rel.RelCollations;
-import org.apache.calcite.rel.RelDistribution;
-import org.apache.calcite.rel.RelFieldCollation;
-import org.apache.calcite.rel.RelHomogeneousShuttle;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.core.Aggregate;
-import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.core.Correlate;
-import org.apache.calcite.rel.core.CorrelationId;
-import org.apache.calcite.rel.core.Filter;
-import org.apache.calcite.rel.core.Intersect;
-import org.apache.calcite.rel.core.Join;
-import org.apache.calcite.rel.core.JoinRelType;
-import org.apache.calcite.rel.core.Match;
-import org.apache.calcite.rel.core.Minus;
-import org.apache.calcite.rel.core.Project;
-import org.apache.calcite.rel.core.RelFactories;
-import org.apache.calcite.rel.core.RepeatUnion;
-import org.apache.calcite.rel.core.Snapshot;
-import org.apache.calcite.rel.core.Sort;
-import org.apache.calcite.rel.core.Spool;
-import org.apache.calcite.rel.core.TableFunctionScan;
-import org.apache.calcite.rel.core.TableScan;
-import org.apache.calcite.rel.core.TableSpool;
-import org.apache.calcite.rel.core.Uncollect;
-import org.apache.calcite.rel.core.Union;
-import org.apache.calcite.rel.core.Values;
-import org.apache.calcite.rel.hint.Hintable;
-import org.apache.calcite.rel.hint.RelHint;
-import org.apache.calcite.rel.metadata.RelColumnMapping;
-import org.apache.calcite.rel.metadata.RelMetadataQuery;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeFactory;
-import org.apache.calcite.rel.type.RelDataTypeField;
-import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexCall;
-import org.apache.calcite.rex.RexCallBinding;
-import org.apache.calcite.rex.RexCorrelVariable;
-import org.apache.calcite.rex.RexDynamicParam;
-import org.apache.calcite.rex.RexExecutor;
-import org.apache.calcite.rex.RexFieldCollation;
-import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexLiteral;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.rex.RexShuttle;
-import org.apache.calcite.rex.RexSimplify;
-import org.apache.calcite.rex.RexSubQuery;
-import org.apache.calcite.rex.RexUtil;
-import org.apache.calcite.rex.RexWindowBound;
-import org.apache.calcite.rex.RexWindowBounds;
-import org.apache.calcite.runtime.Hook;
-import org.apache.calcite.schema.TransientTable;
-import org.apache.calcite.schema.impl.ListTransientTable;
-import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.SqlKind;
-import org.apache.calcite.sql.SqlOperator;
-import org.apache.calcite.sql.SqlUtil;
-import org.apache.calcite.sql.SqlWindow;
-import org.apache.calcite.sql.fun.SqlCountAggFunction;
-import org.apache.calcite.sql.fun.SqlLikeOperator;
-import org.apache.calcite.sql.fun.SqlQuantifyOperator;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
-import org.apache.calcite.sql.type.SqlReturnTypeInference;
-import org.apache.calcite.sql.type.SqlTypeName;
-import org.apache.calcite.sql.type.TableFunctionReturnTypeInference;
-import org.apache.calcite.sql.validate.SqlValidatorUtil;
-import org.apache.calcite.sql2rel.SqlToRelConverter;
-import org.apache.calcite.util.DateString;
-import org.apache.calcite.util.Holder;
-import org.apache.calcite.util.ImmutableBitSet;
-import org.apache.calcite.util.ImmutableIntList;
-import org.apache.calcite.util.ImmutableNullableList;
-import org.apache.calcite.util.Litmus;
-import org.apache.calcite.util.NlsString;
-import org.apache.calcite.util.Optionality;
-import org.apache.calcite.util.Pair;
-import org.apache.calcite.util.Util;
-import org.apache.calcite.util.mapping.Mapping;
-import org.apache.calcite.util.mapping.Mappings;
-import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
-import org.checkerframework.checker.nullness.qual.Nullable;
-import org.immutables.value.Value;
-
-import java.math.BigDecimal;
-import java.util.AbstractList;
-import java.util.ArrayDeque;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.BitSet;
-import java.util.Collections;
-import java.util.Deque;
-import java.util.EnumSet;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Locale;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
-import java.util.SortedSet;
-import java.util.TreeSet;
-import java.util.function.BiFunction;
-import java.util.function.Function;
-import java.util.function.UnaryOperator;
-import java.util.stream.Collectors;
-import java.util.stream.StreamSupport;
-
-import static java.util.Objects.requireNonNull;
-import static org.apache.calcite.linq4j.Nullness.castNonNull;
-import static org.apache.calcite.sql.SqlKind.UNION;
-import static org.apache.calcite.util.Static.RESOURCE;
-
-/**
- * Copied from calcite to workaround CALCITE-4668
- *
- *