Skip to content

[SPARK-52232][SQL] Fix non-deterministic queries to produce different results at every step #50957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,9 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends U

override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed))

override def withShiftedSeed(shift: Long): Shuffle =
copy(randomSeed = randomSeed.map(_ + shift))

override lazy val resolved: Boolean =
childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non

override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed))

override def withShiftedSeed(shift: Long): Uuid = Uuid(randomSeed.map(_ + shift))

override lazy val resolved: Boolean = randomSeed.isDefined

override def nullable: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ trait ExpressionWithRandomSeed extends Expression {

def seedExpression: Expression
def withNewSeed(seed: Long): Expression
def withShiftedSeed(shift: Long): Expression
}

private[catalyst] object ExpressionWithRandomSeed {
Expand Down Expand Up @@ -114,6 +115,9 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends Nondetermi

override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), hideSeed)

override def withShiftedSeed(shift: Long): Rand =
Rand(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed)

override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -165,6 +169,9 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends Nondeterm

override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), hideSeed)

override def withShiftedSeed(shift: Long): Randn =
Randn(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed)

override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -268,6 +275,9 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression,
override def withNewSeed(newSeed: Long): Expression =
Uniform(min, max, Literal(newSeed, LongType), hideSeed)

override def withShiftedSeed(shift: Long): Expression =
Uniform(min, max, Literal(seed + shift, LongType), hideSeed)

override def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
Uniform(newFirst, newSecond, newThird, hideSeed)
Expand Down Expand Up @@ -348,6 +358,10 @@ case class RandStr(

override def withNewSeed(newSeed: Long): Expression =
RandStr(length, Literal(newSeed, LongType), hideSeed)

override def withShiftedSeed(shift: Long): Expression =
RandStr(length, Literal(seed + shift, LongType), hideSeed)

override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression =
RandStr(newFirst, newSecond, hideSeed)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable
import org.apache.spark.SparkException
import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, InterpretedMutableProjection, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpressionWithRandomSeed, InterpretedMutableProjection, Literal}
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation.hasUnevaluableExpr
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, LocalRelation, LogicalPlan, OneRowRelation, Project, Union, UnionLoopRef}
Expand Down Expand Up @@ -183,11 +183,24 @@ case class UnionLoopExec(
// Main loop for obtaining the result of the recursive query.
while (prevCount > 0 && !limitReached) {
var prevPlan: LogicalPlan = null

// If the recursive part contains non-deterministic expressions that depends on a seed, we
// need to create a new seed since the seed for this expression is set in the analysis, and
// we avoid re-triggering the analysis for every iterative step.
val recursionReseeded = if (recursion.deterministic) {
recursion
} else {
recursion.transformAllExpressionsWithSubqueries {
case e: ExpressionWithRandomSeed =>
e.withShiftedSeed(currentLevel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the currentLevel start with 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It starts with 1.

}
}

// the current plan is created by substituting UnionLoopRef node with the project node of
// the previous plan.
// This way we support only UNION ALL case. Additional case should be added for UNION case.
// One way of supporting UNION case can be seen at SPARK-24497 PR from Peter Toth.
val newRecursion = recursion.transformWithSubqueries {
val newRecursion = recursionReseeded.transformWithSubqueries {
case r: UnionLoopRef if r.loopId == loopId =>
prevDF.queryExecution.optimizedPlan match {
case l: LocalRelation =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1629,3 +1629,75 @@ WithCTE
+- Project [n#x]
+- SubqueryAlias t1
+- CTERelationRef xxxx, true, [n#x], false, false


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
UNION ALL
SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
UNION ALL
SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
UNION ALL
SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT randstr(10, 82374)
UNION ALL
SELECT randstr(10, 237685)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT UUID(82374)
UNION ALL
SELECT UUID(237685)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT ARRAY(1,2,3,4,5)
UNION ALL
SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query analysis
[Analyzer test output redacted due to nondeterminism]
56 changes: 55 additions & 1 deletion sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
Original file line number Diff line number Diff line change
Expand Up @@ -586,4 +586,58 @@ WITH RECURSIVE t1 AS (
SELECT 1 AS n
UNION ALL
SELECT n+1 FROM t2 WHERE n < 5)
SELECT * FROM t1;
SELECT * FROM t1;

-- Non-deterministic query with rand with seed
WITH RECURSIVE randoms(val) AS (
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
UNION ALL
SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
FROM randoms
)
SELECT val FROM randoms LIMIT 5;

-- Non-deterministic query with uniform with seed
WITH RECURSIVE randoms(val) AS (
SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
UNION ALL
SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
FROM randoms
)
SELECT val FROM randoms LIMIT 5;

-- Non-deterministic query with randn with seed
WITH RECURSIVE randoms(val) AS (
SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
UNION ALL
SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
FROM randoms
)
SELECT val FROM randoms LIMIT 5;

-- Non-deterministic query with randstr
WITH RECURSIVE randoms(val) AS (
SELECT randstr(10, 82374)
UNION ALL
SELECT randstr(10, 237685)
FROM randoms
)
SELECT val FROM randoms LIMIT 5;

-- Non-deterministic query with UUID
WITH RECURSIVE randoms(val) AS (
SELECT UUID(82374)
UNION ALL
SELECT UUID(237685)
FROM randoms
)
SELECT val FROM randoms LIMIT 5;

-- Non-deterministic query with shuffle
WITH RECURSIVE randoms(val) AS (
SELECT ARRAY(1,2,3,4,5)
UNION ALL
SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
FROM randoms
)
SELECT val FROM randoms LIMIT 5;
108 changes: 108 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -1475,3 +1475,111 @@ struct<n:int>
3
4
5


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
UNION ALL
SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query schema
struct<val:int>
-- !query output
1
4
4
4
5


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
UNION ALL
SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query schema
struct<val:int>
-- !query output
1
4
4
4
5


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
UNION ALL
SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query schema
struct<val:int>
-- !query output
-2
1
2
2
6


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT randstr(10, 82374)
UNION ALL
SELECT randstr(10, 237685)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query schema
struct<val:string>
-- !query output
IpXzdTW03I
UxLgwhvH5j
dBlWnfo7rO
fmfDBMf60f
kFeBV7dQWi


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT UUID(82374)
UNION ALL
SELECT UUID(237685)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query schema
struct<val:string>
-- !query output
4ea190e3-c088-4ddd-a545-fb431059ae3c
8b88900e-f862-468c-8d3b-828188116155
9e75cc4e-6d6e-4235-9788-5efd4d0fd0cc
be4f5346-1c7f-4697-8a2c-1343347872c5
d0032efe-ae60-461b-8582-f6a7c649f238


-- !query
WITH RECURSIVE randoms(val) AS (
SELECT ARRAY(1,2,3,4,5)
UNION ALL
SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
FROM randoms
)
SELECT val FROM randoms LIMIT 5
-- !query schema
struct<val:array<int>>
-- !query output
[1,2,3,4,5]
[1,2,3,5,4]
[3,4,2,1,5]
[4,3,2,5,1]
[4,5,1,2,3]
Loading