Skip to content

Commit 93db7b8

Browse files
mgaido91JoshRosen
authored andcommitted
[SPARK-27684][SQL] Avoid conversion overhead for primitive types
## What changes were proposed in this pull request? As outlined in the JIRA by JoshRosen, our conversion mechanism from catalyst types to scala ones is pretty inefficient for primitive data types. Indeed, in these cases, most of the times we are adding useless calls to `identity` function or anyway to functions which return the same value. Using the information we have when we generate the code, we can avoid most of these overheads. ## How was this patch tested? Here is a simple test which shows the benefit that this PR can bring: ``` test("SPARK-27684: perf evaluation") { val intLongUdf = ScalaUDF( (a: Int, b: Long) => a + b, LongType, Literal(1) :: Literal(1L) :: Nil, true :: true :: Nil, nullable = false) val plan = generateProject( MutableProjection.create(Alias(intLongUdf, s"udf")() :: Nil), intLongUdf) plan.initialize(0) var i = 0 val N = 100000000 val t0 = System.nanoTime() while(i < N) { plan(EmptyRow).get(0, intLongUdf.dataType) plan(EmptyRow).get(0, intLongUdf.dataType) plan(EmptyRow).get(0, intLongUdf.dataType) plan(EmptyRow).get(0, intLongUdf.dataType) plan(EmptyRow).get(0, intLongUdf.dataType) plan(EmptyRow).get(0, intLongUdf.dataType) plan(EmptyRow).get(0, intLongUdf.dataType) plan(EmptyRow).get(0, intLongUdf.dataType) plan(EmptyRow).get(0, intLongUdf.dataType) plan(EmptyRow).get(0, intLongUdf.dataType) i += 1 } val t1 = System.nanoTime() println(s"Avg time: ${(t1 - t0).toDouble / N} ns") } ``` The output before the patch is: ``` Avg time: 51.27083294 ns ``` after, we get: ``` Avg time: 11.85874227 ns ``` which is ~5X faster. Moreover a benchmark has been added for Scala UDF. The output after the patch can be seen in this PR, before the patch, the output was: ``` ================================================================================================ UDF with mixed input types ================================================================================================ Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6 Intel(R) Core(TM) i7-4558U CPU 2.80GHz long/nullable int/string to string: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ long/nullable int/string to string wholestage off 257 287 42 0,4 2569,5 1,0X long/nullable int/string to string wholestage on 158 172 18 0,6 1579,0 1,6X Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6 Intel(R) Core(TM) i7-4558U CPU 2.80GHz long/nullable int/string to option: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ long/nullable int/string to option wholestage off 104 107 5 1,0 1037,9 1,0X long/nullable int/string to option wholestage on 80 92 12 1,2 804,0 1,3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6 Intel(R) Core(TM) i7-4558U CPU 2.80GHz long/nullable int to primitive: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ long/nullable int to primitive wholestage off 71 76 7 1,4 712,1 1,0X long/nullable int to primitive wholestage on 64 71 6 1,6 636,2 1,1X ================================================================================================ UDF with primitive types ================================================================================================ Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6 Intel(R) Core(TM) i7-4558U CPU 2.80GHz long/nullable int to string: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ long/nullable int to string wholestage off 60 60 0 1,7 600,3 1,0X long/nullable int to string wholestage on 55 64 8 1,8 551,2 1,1X Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6 Intel(R) Core(TM) i7-4558U CPU 2.80GHz long/nullable int to option: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ long/nullable int to option wholestage off 66 73 9 1,5 663,0 1,0X long/nullable int to option wholestage on 30 32 2 3,3 300,7 2,2X Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6 Intel(R) Core(TM) i7-4558U CPU 2.80GHz long/nullable int/string to primitive: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ long/nullable int/string to primitive wholestage off 32 35 5 3,2 316,7 1,0X long/nullable int/string to primitive wholestage on 41 68 17 2,4 414,0 0,8X ``` The improvements are particularly visible in the second case, ie. when only primitive types are used as inputs. Closes apache#24636 from mgaido91/SPARK-27684. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Josh Rosen <[email protected]>
1 parent 568512c commit 93db7b8

File tree

4 files changed

+204
-7
lines changed

4 files changed

+204
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ object CatalystTypeConverters {
4242
// Since the map values can be mutable, we explicitly import scala.collection.Map at here.
4343
import scala.collection.Map
4444

45-
private def isPrimitive(dataType: DataType): Boolean = {
45+
private[sql] def isPrimitive(dataType: DataType): Boolean = {
4646
dataType match {
4747
case BooleanType => true
4848
case ByteType => true

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

+22-6
Original file line numberDiff line numberDiff line change
@@ -1003,22 +1003,38 @@ case class ScalaUDF(
10031003
// such as IntegerType, its javaType is `int` and the returned type of user-defined
10041004
// function is Object. Trying to convert an Object to `int` will cause casting exception.
10051005
val evalCode = evals.map(_.code).mkString("\n")
1006-
val (funcArgs, initArgs) = evals.zipWithIndex.map { case (eval, i) =>
1007-
val argTerm = ctx.freshName("arg")
1008-
val convert = s"$convertersTerm[$i].apply(${eval.value})"
1009-
val initArg = s"Object $argTerm = ${eval.isNull} ? null : $convert;"
1010-
(argTerm, initArg)
1006+
val (funcArgs, initArgs) = evals.zipWithIndex.zip(children.map(_.dataType)).map {
1007+
case ((eval, i), dt) =>
1008+
val argTerm = ctx.freshName("arg")
1009+
val initArg = if (CatalystTypeConverters.isPrimitive(dt)) {
1010+
val convertedTerm = ctx.freshName("conv")
1011+
s"""
1012+
|${CodeGenerator.boxedType(dt)} $convertedTerm = ${eval.value};
1013+
|Object $argTerm = ${eval.isNull} ? null : $convertedTerm;
1014+
""".stripMargin
1015+
} else {
1016+
s"Object $argTerm = ${eval.isNull} ? null : $convertersTerm[$i].apply(${eval.value});"
1017+
}
1018+
(argTerm, initArg)
10111019
}.unzip
10121020

10131021
val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}")
10141022
val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})"
10151023
val resultConverter = s"$convertersTerm[${children.length}]"
10161024
val boxedType = CodeGenerator.boxedType(dataType)
1025+
1026+
val funcInvokation = if (CatalystTypeConverters.isPrimitive(dataType)
1027+
// If the output is nullable, the returned value must be unwrapped from the Option
1028+
&& !nullable) {
1029+
s"$resultTerm = ($boxedType)$getFuncResult"
1030+
} else {
1031+
s"$resultTerm = ($boxedType)$resultConverter.apply($getFuncResult)"
1032+
}
10171033
val callFunc =
10181034
s"""
10191035
|$boxedType $resultTerm = null;
10201036
|try {
1021-
| $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult);
1037+
| $funcInvokation;
10221038
|} catch (Exception e) {
10231039
| throw new org.apache.spark.SparkException($errorMsgTerm, e);
10241040
|}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
================================================================================================
2+
UDF with mixed input types
3+
================================================================================================
4+
5+
Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6
6+
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
7+
long/nullable int/string to string: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
8+
------------------------------------------------------------------------------------------------------------------------
9+
long/nullable int/string to string wholestage off 194 248 76 0,5 1941,4 1,0X
10+
long/nullable int/string to string wholestage on 127 136 8 0,8 1269,5 1,5X
11+
12+
Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6
13+
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
14+
long/nullable int/string to option: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
15+
------------------------------------------------------------------------------------------------------------------------
16+
long/nullable int/string to option wholestage off 91 97 8 1,1 910,1 1,0X
17+
long/nullable int/string to option wholestage on 60 79 29 1,7 603,8 1,5X
18+
19+
Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6
20+
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
21+
long/nullable int/string to primitive: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
22+
------------------------------------------------------------------------------------------------------------------------
23+
long/nullable int/string to primitive wholestage off 55 63 12 1,8 547,9 1,0X
24+
long/nullable int/string to primitive wholestage on 43 44 2 2,3 428,0 1,3X
25+
26+
27+
================================================================================================
28+
UDF with primitive types
29+
================================================================================================
30+
31+
Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6
32+
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
33+
long/nullable int to string: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
34+
------------------------------------------------------------------------------------------------------------------------
35+
long/nullable int to string wholestage off 46 48 2 2,2 461,2 1,0X
36+
long/nullable int to string wholestage on 49 56 8 2,0 488,9 0,9X
37+
38+
Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6
39+
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
40+
long/nullable int to option: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
41+
------------------------------------------------------------------------------------------------------------------------
42+
long/nullable int to option wholestage off 41 47 9 2,4 408,2 1,0X
43+
long/nullable int to option wholestage on 26 28 2 3,9 256,7 1,6X
44+
45+
Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6
46+
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
47+
long/nullable int to primitive: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
48+
------------------------------------------------------------------------------------------------------------------------
49+
long/nullable int to primitive wholestage off 26 27 0 3,8 263,7 1,0X
50+
long/nullable int to primitive wholestage on 26 31 5 3,8 262,2 1,0X
51+
52+
Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.6
53+
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
54+
UDF identity overhead: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
55+
------------------------------------------------------------------------------------------------------------------------
56+
Baseline 20 22 1 4,9 204,3 1,0X
57+
With identity UDF 24 26 2 4,1 241,3 0,8X
58+
59+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.benchmark
19+
20+
import org.apache.spark.benchmark.Benchmark
21+
import org.apache.spark.sql.catalyst.expressions.Literal
22+
import org.apache.spark.sql.expressions.UserDefinedFunction
23+
import org.apache.spark.sql.functions._
24+
import org.apache.spark.sql.types.{IntegerType, StringType}
25+
26+
/**
27+
* Synthetic benchmark for Scala User Defined Functions.
28+
* To run this benchmark:
29+
* {{{
30+
* 1. without sbt:
31+
* bin/spark-submit --class <this class> --jars <spark core test jar> <sql core test jar>
32+
* 2. build/sbt "sql/test:runMain <this class>"
33+
* 3. generate result:
34+
* SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>"
35+
* Results will be written to "benchmarks/UDFBenchmark-results.txt".
36+
* }}}
37+
*/
38+
object UDFBenchmark extends SqlBasedBenchmark {
39+
40+
private def doRunBenchmarkWithMixedTypes(udf: UserDefinedFunction, cardinality: Int): Unit = {
41+
val idCol = col("id")
42+
val nullableIntCol = when(
43+
idCol % 2 === 0, idCol.cast(IntegerType)).otherwise(Literal(null, IntegerType))
44+
val stringCol = idCol.cast(StringType)
45+
spark.range(cardinality).select(
46+
udf(idCol, nullableIntCol, stringCol)).write.format("noop").save()
47+
}
48+
49+
private def doRunBenchmarkWithPrimitiveTypes(
50+
udf: UserDefinedFunction, cardinality: Int): Unit = {
51+
val idCol = col("id")
52+
val nullableIntCol = when(
53+
idCol % 2 === 0, idCol.cast(IntegerType)).otherwise(Literal(null, IntegerType))
54+
spark.range(cardinality).select(udf(idCol, nullableIntCol)).write.format("noop").save()
55+
}
56+
57+
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
58+
val cardinality = 100000
59+
runBenchmark("UDF with mixed input types") {
60+
codegenBenchmark("long/nullable int/string to string", cardinality) {
61+
val sampleUDF = udf {(a: Long, b: java.lang.Integer, c: String) =>
62+
s"$a,$b,$c"
63+
}
64+
doRunBenchmarkWithMixedTypes(sampleUDF, cardinality)
65+
}
66+
67+
codegenBenchmark("long/nullable int/string to option", cardinality) {
68+
val sampleUDF = udf {(_: Long, b: java.lang.Integer, _: String) =>
69+
Option(b)
70+
}
71+
doRunBenchmarkWithMixedTypes(sampleUDF, cardinality)
72+
}
73+
74+
codegenBenchmark("long/nullable int/string to primitive", cardinality) {
75+
val sampleUDF = udf {(a: Long, b: java.lang.Integer, _: String) =>
76+
Option(b).map(_.longValue()).getOrElse(a)
77+
}
78+
doRunBenchmarkWithMixedTypes(sampleUDF, cardinality)
79+
}
80+
}
81+
82+
runBenchmark("UDF with primitive types") {
83+
codegenBenchmark("long/nullable int to string", cardinality) {
84+
val sampleUDF = udf {(a: Long, b: java.lang.Integer) =>
85+
s"$a,$b"
86+
}
87+
doRunBenchmarkWithPrimitiveTypes(sampleUDF, cardinality)
88+
}
89+
90+
codegenBenchmark("long/nullable int to option", cardinality) {
91+
val sampleUDF = udf {(_: Long, b: java.lang.Integer) =>
92+
Option(b)
93+
}
94+
doRunBenchmarkWithPrimitiveTypes(sampleUDF, cardinality)
95+
}
96+
97+
codegenBenchmark("long/nullable int to primitive", cardinality) {
98+
val sampleUDF = udf {(a: Long, b: java.lang.Integer) =>
99+
Option(b).map(_.longValue()).getOrElse(a)
100+
}
101+
doRunBenchmarkWithPrimitiveTypes(sampleUDF, cardinality)
102+
}
103+
104+
val benchmark = new Benchmark("UDF identity overhead", cardinality, output = output)
105+
106+
benchmark.addCase(s"Baseline", numIters = 5) { _ =>
107+
spark.range(cardinality).select(
108+
col("id"), col("id") * 2, col("id") * 3).write.format("noop").save()
109+
}
110+
111+
val identityUDF = udf { x: Long => x }
112+
benchmark.addCase(s"With identity UDF", numIters = 5) { _ =>
113+
spark.range(cardinality).select(
114+
identityUDF(col("id")),
115+
identityUDF(col("id") * 2),
116+
identityUDF(col("id") * 3)).write.format("noop").save()
117+
}
118+
119+
benchmark.run()
120+
}
121+
}
122+
}

0 commit comments

Comments
 (0)