Skip to content

Commit

Permalink
IGNITE-22716 SQL Calcite: Fix implicit conversion to DECIMAL for some…
Browse files Browse the repository at this point in the history
… functions
  • Loading branch information
alex-plekhanov committed Jul 17, 2024
1 parent 2448ffa commit ec8f27c
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.stream.Collectors;
import org.apache.calcite.plan.Context;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.util.CancelFlag;
import org.apache.ignite.IgniteCheckedException;
Expand Down Expand Up @@ -139,10 +140,15 @@ public RootQuery(

Context parent = Commons.convert(qryCtx);

FrameworkConfig frameworkCfg = qryCtx != null ? qryCtx.unwrap(FrameworkConfig.class) : null;

if (frameworkCfg == null)
frameworkCfg = FRAMEWORK_CONFIG;

ctx = BaseQueryContext.builder()
.parentContext(parent)
.frameworkConfig(
Frameworks.newConfigBuilder(FRAMEWORK_CONFIG)
Frameworks.newConfigBuilder(frameworkCfg)
.defaultSchema(schema)
.build()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.SqlFunctions;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.Util;
import org.apache.ignite.internal.processors.query.calcite.util.Commons;

/** */
public class ConverterUtils {
Expand Down Expand Up @@ -166,6 +168,21 @@ static List<Type> internalTypes(List<? extends RexNode> operandList) {
return Util.transform(operandList, node -> toInternal(node.getType()));
}

/**
* Convert {@code operand} to {@code targetType}.
*
* @param operand The expression to convert
* @param targetType Target type
* @return A new expression with java type corresponding to {@code targetType}
* or original expression if there is no need to convert.
*/
public static Expression convert(Expression operand, RelDataType targetType) {
if (SqlTypeUtil.isDecimal(targetType))
return convertToDecimal(operand, targetType);
else
return convert(operand, Commons.typeFactory().getJavaClass(targetType));
}

/**
* Convert {@code operand} to target type {@code toType}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1963,7 +1963,7 @@ private ParameterExpression genValueStatement(
final Expression convertedCallVal =
noConvert
? callVal
: ConverterUtils.convert(callVal, returnType);
: ConverterUtils.convert(callVal, call.getType());

final Expression valExpression =
Expressions.condition(condition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlConformance;
import org.apache.calcite.util.BuiltInMethod;
Expand Down Expand Up @@ -552,11 +551,9 @@ Expression translateCast(
}
break;
}
if (targetType.getSqlTypeName() == SqlTypeName.DECIMAL)
convert = ConverterUtils.convertToDecimal(operand, targetType);

if (convert == null)
convert = ConverterUtils.convert(operand, typeFactory.getJavaClass(targetType));
convert = ConverterUtils.convert(operand, targetType);

// Going from anything to CHAR(n) or VARCHAR(n), make sure value is no
// longer than n.
Expand Down Expand Up @@ -1073,7 +1070,7 @@ private Result implementCaseWhen(RexCall call) {
list.newName("case_when_value"));
list.add(Expressions.declare(0, valVariable, null));
final List<RexNode> operandList = call.getOperands();
implementRecursively(this, operandList, valVariable, 0);
implementRecursively(this, operandList, valVariable, call.getType(), 0);
final Expression isNullExpression = checkNull(valVariable);
final ParameterExpression isNullVariable =
Expressions.parameter(
Expand Down Expand Up @@ -1108,8 +1105,13 @@ private Result implementCaseWhen(RexCall call) {
* }
* </pre></blockquote>
*/
private void implementRecursively(final RexToLixTranslator currentTranslator,
final List<RexNode> operandList, final ParameterExpression valueVariable, int pos) {
private void implementRecursively(
final RexToLixTranslator currentTranslator,
final List<RexNode> operandList,
final ParameterExpression valueVariable,
final RelDataType valueType,
int pos
) {
final BlockBuilder curBlockBuilder = currentTranslator.getBlockBuilder();
final List<Type> storageTypes = ConverterUtils.internalTypes(operandList);
// [ELSE] clause
Expand All @@ -1119,7 +1121,7 @@ private void implementRecursively(final RexToLixTranslator currentTranslator,
curBlockBuilder.add(
Expressions.statement(
Expressions.assign(valueVariable,
ConverterUtils.convert(res, valueVariable.getType()))));
ConverterUtils.convert(res, valueType))));
return;
}
// Condition code: !a_isNull && a_value
Expand All @@ -1141,7 +1143,7 @@ private void implementRecursively(final RexToLixTranslator currentTranslator,
ifTrueBlockBuilder.add(
Expressions.statement(
Expressions.assign(valueVariable,
ConverterUtils.convert(ifTrueRes, valueVariable.getType()))));
ConverterUtils.convert(ifTrueRes, valueType))));
final BlockStatement ifTrue = ifTrueBlockBuilder.toBlock();
// There is no [ELSE] clause
if (pos + 1 == operandList.size() - 1) {
Expand All @@ -1154,7 +1156,7 @@ private void implementRecursively(final RexToLixTranslator currentTranslator,
new BlockBuilder(true, curBlockBuilder);
final RexToLixTranslator ifFalseTranslator =
currentTranslator.setBlock(ifFalseBlockBuilder);
implementRecursively(ifFalseTranslator, operandList, valueVariable, pos + 2);
implementRecursively(ifFalseTranslator, operandList, valueVariable, valueType, pos + 2);
final BlockStatement ifFalse = ifFalseBlockBuilder.toBlock();
curBlockBuilder.add(
Expressions.ifThenElse(tester, ifTrue, ifFalse));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.calcite.tools.FrameworkConfig;
import org.apache.ignite.Ignite;
import org.apache.ignite.cache.query.FieldsQueryCursor;
import org.apache.ignite.internal.IgniteEx;
import org.apache.ignite.internal.IgniteInterruptedCheckedException;
import org.apache.ignite.internal.processors.cache.distributed.dht.atomic.GridDhtAtomicCache;
import org.apache.ignite.internal.processors.cache.distributed.dht.topology.GridDhtLocalPartition;
import org.apache.ignite.internal.processors.query.QueryContext;
import org.apache.ignite.internal.processors.query.QueryEngine;
import org.apache.ignite.internal.processors.query.schema.management.SchemaManager;
import org.apache.ignite.internal.util.typedef.F;
Expand Down Expand Up @@ -296,6 +298,9 @@ public static Matcher<String> containsAnyScan(final String schema, final String
/** */
private String exactPlan;

/** */
private FrameworkConfig frameworkCfg;

/** */
public QueryChecker(String qry) {
this.qry = qry;
Expand All @@ -322,6 +327,13 @@ public QueryChecker withParams(Object... params) {
return this;
}

/** */
public QueryChecker withFrameworkConfig(FrameworkConfig frameworkCfg) {
this.frameworkCfg = frameworkCfg;

return this;
}

/** */
public QueryChecker returns(Object... res) {
if (expectedResult == null)
Expand Down Expand Up @@ -370,8 +382,10 @@ public void check() {
// Check plan.
QueryEngine engine = getEngine();

QueryContext ctx = frameworkCfg != null ? QueryContext.of(frameworkCfg) : null;

List<FieldsQueryCursor<List<?>>> explainCursors =
engine.query(null, "PUBLIC", "EXPLAIN PLAN FOR " + qry, params);
engine.query(ctx, "PUBLIC", "EXPLAIN PLAN FOR " + qry, params);

FieldsQueryCursor<List<?>> explainCursor = explainCursors.get(0);
List<List<?>> explainRes = explainCursor.getAll();
Expand All @@ -387,7 +401,7 @@ public void check() {

// Check result.
List<FieldsQueryCursor<List<?>>> cursors =
engine.query(null, "PUBLIC", qry, params);
engine.query(ctx, "PUBLIC", qry, params);

FieldsQueryCursor<List<?>> cur = cursors.get(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
import java.util.stream.Collectors;
import com.google.common.collect.ImmutableSet;
import org.apache.calcite.runtime.CalciteException;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.Frameworks;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.cache.QueryEntity;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.processors.query.IgniteSQLException;
import org.apache.ignite.internal.util.typedef.F;
import org.junit.Test;

import static org.apache.ignite.internal.processors.query.calcite.CalciteQueryProcessor.FRAMEWORK_CONFIG;

/**
* Test SQL data types.
*/
Expand Down Expand Up @@ -467,6 +471,24 @@ public void testNumericConversion() {
.check();
}

/** */
@Test
public void testFunctionArgsToNumericImplicitConversion() {
assertQuery("select decode(?, 0, 0, 1, 1.0)").withParams(0).returns(new BigDecimal("0.0")).check();
assertQuery("select decode(?, 0, 0, 1, 1.0)").withParams(1).returns(new BigDecimal("1.0")).check();
assertQuery("select decode(?, 0, 0, 1, 1.000)").withParams(0).returns(new BigDecimal("0.000")).check();
assertQuery("select decode(?, 0, 0, 1, 1.000)").withParams(1).returns(new BigDecimal("1.000")).check();
assertQuery("select decode(?, 0, 0.0, 1, 1.000)").withParams(0).returns(new BigDecimal("0.000")).check();
assertQuery("select decode(?, 0, 0.000, 1, 1.0)").withParams(1).returns(new BigDecimal("1.000")).check();

FrameworkConfig frameworkCfg = Frameworks.newConfigBuilder(FRAMEWORK_CONFIG)
.sqlValidatorConfig(FRAMEWORK_CONFIG.getSqlValidatorConfig().withCallRewrite(false))
.build();

assertQuery("select coalesce(?, 1.000)").withParams(0).withFrameworkConfig(frameworkCfg)
.returns(new BigDecimal("0.000")).check();
}

/** */
@Test
public void testArithmeticOverflow() {
Expand Down

0 comments on commit ec8f27c

Please sign in to comment.