Skip to content

Commit bc4a676

Browse files
mgaido91cloud-fan
authored andcommitted
[SPARK-28201][SQL] Revisit MakeDecimal behavior on overflow
## What changes were proposed in this pull request? In SPARK-23179, it has been introduced a flag to control the behavior in case of overflow on decimals. The behavior is: returning `null` when `spark.sql.decimalOperations.nullOnOverflow` (default and traditional Spark behavior); throwing an `ArithmeticException` if that conf is false (according to SQL standards, other DBs behavior). `MakeDecimal` so far had an ambiguous behavior. In case of codegen mode, it returned `null` as the other operators, but in interpreted mode, it was throwing an `IllegalArgumentException`. The PR aligns `MakeDecimal`'s behavior with the one of other operators as defined in SPARK-23179. So now both modes return `null` or throw `ArithmeticException` according to `spark.sql.decimalOperations.nullOnOverflow`'s value. Credits for this PR to mickjermsurawong-stripe who pointed out the wrong behavior in apache#20350. ## How was this patch tested? improved UTs Closes apache#25010 from mgaido91/SPARK-28201. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 048224c commit bc4a676

File tree

4 files changed

+54
-17
lines changed

4 files changed

+54
-17
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala

+26-6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
22+
import org.apache.spark.sql.internal.SQLConf
2223
import org.apache.spark.sql.types._
2324

2425
/**
@@ -46,19 +47,38 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
4647
*/
4748
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
4849

50+
private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow
51+
4952
override def dataType: DataType = DecimalType(precision, scale)
50-
override def nullable: Boolean = true
53+
override def nullable: Boolean = child.nullable || nullOnOverflow
5154
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
5255

53-
protected override def nullSafeEval(input: Any): Any =
54-
Decimal(input.asInstanceOf[Long], precision, scale)
56+
protected override def nullSafeEval(input: Any): Any = {
57+
val longInput = input.asInstanceOf[Long]
58+
val result = new Decimal()
59+
if (nullOnOverflow) {
60+
result.setOrNull(longInput, precision, scale)
61+
} else {
62+
result.set(longInput, precision, scale)
63+
}
64+
}
5565

5666
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
5767
nullSafeCodeGen(ctx, ev, eval => {
68+
val setMethod = if (nullOnOverflow) {
69+
"setOrNull"
70+
} else {
71+
"set"
72+
}
73+
val setNull = if (nullable) {
74+
s"${ev.isNull} = ${ev.value} == null;"
75+
} else {
76+
""
77+
}
5878
s"""
59-
${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale);
60-
${ev.isNull} = ${ev.value} == null;
61-
"""
79+
|${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale);
80+
|$setNull
81+
|""".stripMargin
6282
})
6383
}
6484
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

+5-4
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
7676
*/
7777
def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
7878
if (setOrNull(unscaled, precision, scale) == null) {
79-
throw new IllegalArgumentException("Unscaled value too large for precision")
79+
throw new ArithmeticException("Unscaled value too large for precision")
8080
}
8181
this
8282
}
@@ -111,9 +111,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
111111
*/
112112
def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
113113
this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP)
114-
require(
115-
decimalVal.precision <= precision,
116-
s"Decimal precision ${decimalVal.precision} exceeds max precision $precision")
114+
if (decimalVal.precision > precision) {
115+
throw new ArithmeticException(
116+
s"Decimal precision ${decimalVal.precision} exceeds max precision $precision")
117+
}
117118
this.longVal = 0L
118119
this._precision = precision
119120
this._scale = scale

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala

+18-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.internal.SQLConf
2122
import org.apache.spark.sql.types.{Decimal, DecimalType, LongType}
2223

2324
class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -31,8 +32,23 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
3132
}
3233

3334
test("MakeDecimal") {
34-
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
35-
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
35+
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
36+
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
37+
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
38+
val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
39+
checkEvaluation(overflowExpr, null)
40+
checkEvaluationWithMutableProjection(overflowExpr, null)
41+
evaluateWithoutCodegen(overflowExpr, null)
42+
checkEvaluationWithUnsafeProjection(overflowExpr, null)
43+
}
44+
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
45+
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
46+
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
47+
val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
48+
intercept[ArithmeticException](checkEvaluationWithMutableProjection(overflowExpr, null))
49+
intercept[ArithmeticException](evaluateWithoutCodegen(overflowExpr, null))
50+
intercept[ArithmeticException](checkEvaluationWithUnsafeProjection(overflowExpr, null))
51+
}
3652
}
3753

3854
test("PromotePrecision") {

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
5656
checkDecimal(Decimal(1000000000000000000L, 20, 2), "10000000000000000.00", 20, 2)
5757
checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0)
5858
checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0)
59-
intercept[IllegalArgumentException](Decimal(170L, 2, 1))
60-
intercept[IllegalArgumentException](Decimal(170L, 2, 0))
61-
intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1))
62-
intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1))
63-
intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
59+
intercept[ArithmeticException](Decimal(170L, 2, 1))
60+
intercept[ArithmeticException](Decimal(170L, 2, 0))
61+
intercept[ArithmeticException](Decimal(BigDecimal("10.030"), 2, 1))
62+
intercept[ArithmeticException](Decimal(BigDecimal("-9.95"), 2, 1))
63+
intercept[ArithmeticException](Decimal(1e17.toLong, 17, 0))
6464
}
6565

6666
test("creating decimals with negative scale") {

0 commit comments

Comments
 (0)