From 27ff365f474f26c89bd283ee842a58aad75e359d Mon Sep 17 00:00:00 2001 From: pavle-martinovic_data Date: Tue, 20 May 2025 15:08:09 +0200 Subject: [PATCH 1/7] [SPARK-52232][SQL] Fix non-deterministic queries to produce different results at every step --- .../expressions/collectionOperations.scala | 2 ++ .../spark/sql/catalyst/expressions/misc.scala | 2 ++ .../expressions/randomExpressions.scala | 12 ++++++++++++ .../spark/sql/execution/UnionLoopExec.scala | 19 +++++++++++++++++-- .../analyzer-results/cte-recursion.sql.out | 12 ++++++++++++ .../sql-tests/inputs/cte-recursion.sql | 11 ++++++++++- .../sql-tests/results/cte-recursion.sql.out | 18 ++++++++++++++++++ 7 files changed, 73 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 84e52282b632f..f8b34e9b79c66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1267,6 +1267,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends U override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed)) + override def withNextSeed(): Shuffle = copy(randomSeed = Some(randomSeed.get + 1)) + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index e7d3701544c54..239ad461008c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -260,6 +260,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed)) + override def withNextSeed(): Uuid = Uuid(Some(randomSeed.get + 1)) + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index fa6eb2c111895..78787c1103284 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -76,6 +76,7 @@ trait ExpressionWithRandomSeed extends Expression { def seedExpression: Expression def withNewSeed(seed: Long): Expression + def withNextSeed(): Expression } private[catalyst] object ExpressionWithRandomSeed { @@ -114,6 +115,8 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends Nondetermi override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), hideSeed) + override def withNextSeed(): Rand = Rand(Add(child, Literal(1)), hideSeed) + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -165,6 +168,8 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends Nondeterm override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), hideSeed) + override def withNextSeed(): Randn = Randn(Literal(seed, LongType), hideSeed) + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -268,6 +273,9 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression, override def withNewSeed(newSeed: Long): Expression = Uniform(min, max, Literal(newSeed, LongType), hideSeed) + override def withNextSeed(): Expression = + Uniform(min, max, Literal(seed + 1, LongType), hideSeed) + override def withNewChildrenInternal( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = Uniform(newFirst, newSecond, newThird, hideSeed) @@ -348,6 +356,10 @@ case class RandStr( override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType), hideSeed) + + override def withNextSeed(): Expression = + RandStr(length, Literal(seed + 1, LongType), hideSeed) + override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = RandStr(newFirst, newSecond, hideSeed) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala index 33188db5d23b0..087953a65e700 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, InterpretedMutableProjection, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpressionWithRandomSeed, InterpretedMutableProjection, Literal} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation.hasUnevaluableExpr import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, LocalRelation, LogicalPlan, OneRowRelation, Project, Union, UnionLoopRef} @@ -180,14 +180,29 @@ case class UnionLoopExec( val numPartitions = prevDF.queryExecution.toRdd.partitions.length + var recursionReseeded = recursion + // Main loop for obtaining the result of the recursive query. while (prevCount > 0 && !limitReached) { var prevPlan: LogicalPlan = null + + // If the recursive part contains non-deterministic expressions that depends on a seed, we + // need to create a new seed since the seed for this expression is set in the analysis, and + // we avoid re-triggering the analysis for every iterative step. + recursionReseeded = if (recursion.deterministic) { + recursionReseeded + } else { + recursionReseeded.transformExpressionsDown { + case e: ExpressionWithRandomSeed => + e.withNextSeed() + } + } + // the current plan is created by substituting UnionLoopRef node with the project node of // the previous plan. // This way we support only UNION ALL case. Additional case should be added for UNION case. // One way of supporting UNION case can be seen at SPARK-24497 PR from Peter Toth. - val newRecursion = recursion.transformWithSubqueries { + val newRecursion = recursionReseeded.transformWithSubqueries { case r: UnionLoopRef if r.loopId == loopId => prevDF.queryExecution.optimizedPlan match { case l: LocalRelation => diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out index 4aff038838654..fd3311333ec5a 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out @@ -1629,3 +1629,15 @@ WithCTE +- Project [n#x] +- SubqueryAlias t1 +- CTERelationRef xxxx, true, [n#x], false, false + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(rand(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(rand(237685) * 5 + 1) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql index fba8861083be4..d5654d827e28f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql @@ -586,4 +586,13 @@ WITH RECURSIVE t1 AS ( SELECT 1 AS n UNION ALL SELECT n+1 FROM t2 WHERE n < 5) -SELECT * FROM t1; \ No newline at end of file +SELECT * FROM t1; + +-- Non-deterministic query with rand with seed +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(rand(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(rand(237685) * 5 + 1) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out index 06f440a3f6335..ebbc8214864ef 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out @@ -1475,3 +1475,21 @@ struct 3 4 5 + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(rand(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(rand(237685) * 5 + 1) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct +-- !query output +1 +4 +4 +4 +5 From 317e0194016e7b9dfa4487443dc63499860dc837 Mon Sep 17 00:00:00 2001 From: pavle-martinovic_data Date: Tue, 20 May 2025 18:59:09 +0200 Subject: [PATCH 2/7] change new function to work for every shift --- .../expressions/collectionOperations.scala | 3 ++- .../spark/sql/catalyst/expressions/misc.scala | 2 +- .../expressions/randomExpressions.scala | 14 +++++++------- .../spark/sql/execution/UnionLoopExec.scala | 10 ++++------ .../analyzer-results/cte-recursion.sql.out | 12 ++++++++++++ .../sql-tests/inputs/cte-recursion.sql | 9 +++++++++ .../sql-tests/results/cte-recursion.sql.out | 18 ++++++++++++++++++ 7 files changed, 53 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f8b34e9b79c66..75063670b270d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1267,7 +1267,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends U override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed)) - override def withNextSeed(): Shuffle = copy(randomSeed = Some(randomSeed.get + 1)) + override def withShiftedSeed(shift: Long): Shuffle = + copy(randomSeed = Some(randomSeed.get + shift)) override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 239ad461008c9..a44f49f684d38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -260,7 +260,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed)) - override def withNextSeed(): Uuid = Uuid(Some(randomSeed.get + 1)) + override def withShiftedSeed(shift: Long): Uuid = Uuid(Some(randomSeed.get + shift)) override lazy val resolved: Boolean = randomSeed.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 78787c1103284..1b0484051e427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -76,7 +76,7 @@ trait ExpressionWithRandomSeed extends Expression { def seedExpression: Expression def withNewSeed(seed: Long): Expression - def withNextSeed(): Expression + def withShiftedSeed(shift: Long): Expression } private[catalyst] object ExpressionWithRandomSeed { @@ -115,7 +115,7 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends Nondetermi override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), hideSeed) - override def withNextSeed(): Rand = Rand(Add(child, Literal(1)), hideSeed) + override def withShiftedSeed(shift: Long): Rand = Rand(Add(child, Literal(shift)), hideSeed) override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() @@ -168,7 +168,7 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends Nondeterm override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), hideSeed) - override def withNextSeed(): Randn = Randn(Literal(seed, LongType), hideSeed) + override def withShiftedSeed(shift: Long): Randn = Randn(Add(child, Literal(shift)), hideSeed) override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() @@ -273,8 +273,8 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression, override def withNewSeed(newSeed: Long): Expression = Uniform(min, max, Literal(newSeed, LongType), hideSeed) - override def withNextSeed(): Expression = - Uniform(min, max, Literal(seed + 1, LongType), hideSeed) + override def withShiftedSeed(shift: Long): Expression = + Uniform(min, max, Literal(seed + shift, LongType), hideSeed) override def withNewChildrenInternal( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = @@ -357,8 +357,8 @@ case class RandStr( override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType), hideSeed) - override def withNextSeed(): Expression = - RandStr(length, Literal(seed + 1, LongType), hideSeed) + override def withShiftedSeed(shift: Long): Expression = + RandStr(length, Literal(seed + shift, LongType), hideSeed) override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = RandStr(newFirst, newSecond, hideSeed) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala index 087953a65e700..debf5dc05dd4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala @@ -180,8 +180,6 @@ case class UnionLoopExec( val numPartitions = prevDF.queryExecution.toRdd.partitions.length - var recursionReseeded = recursion - // Main loop for obtaining the result of the recursive query. while (prevCount > 0 && !limitReached) { var prevPlan: LogicalPlan = null @@ -189,12 +187,12 @@ case class UnionLoopExec( // If the recursive part contains non-deterministic expressions that depends on a seed, we // need to create a new seed since the seed for this expression is set in the analysis, and // we avoid re-triggering the analysis for every iterative step. - recursionReseeded = if (recursion.deterministic) { - recursionReseeded + val recursionReseeded = if (recursion.deterministic) { + recursion } else { - recursionReseeded.transformExpressionsDown { + recursion.transformExpressionsDown { case e: ExpressionWithRandomSeed => - e.withNextSeed() + e.withShiftedSeed(currentLevel) } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out index fd3311333ec5a..b271ebe4f8bcd 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out @@ -1641,3 +1641,15 @@ WITH RECURSIVE randoms(val) AS ( SELECT val FROM randoms LIMIT 5 -- !query analysis [Analyzer test output redacted due to nondeterminism] + + +-- !query +WITH RECURSIVE randoms(val, step) AS ( + SELECT CAST(UNIFORM(1, 6, 82374) AS INT), 1 AS step + UNION ALL + SELECT CAST(UNIFORM(1, 6, 237685) AS INT), step + 1 + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql index d5654d827e28f..427f4fa7d64bc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql @@ -595,4 +595,13 @@ WITH RECURSIVE randoms(val) AS ( SELECT CAST(floor(rand(237685) * 5 + 1) AS INT) FROM randoms ) +SELECT val FROM randoms LIMIT 5; + +-- Non-deterministic query with uniform with seed +WITH RECURSIVE randoms(val, step) AS ( + SELECT CAST(UNIFORM(1, 6, 82374) AS INT), 1 AS step + UNION ALL + SELECT CAST(UNIFORM(1, 6, 237685) AS INT), step + 1 + FROM randoms +) SELECT val FROM randoms LIMIT 5; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out index ebbc8214864ef..a5b5f4a13e4aa 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out @@ -1493,3 +1493,21 @@ struct 4 4 5 + + +-- !query +WITH RECURSIVE randoms(val, step) AS ( + SELECT CAST(UNIFORM(1, 6, 82374) AS INT), 1 AS step + UNION ALL + SELECT CAST(UNIFORM(1, 6, 237685) AS INT), step + 1 + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct +-- !query output +1 +4 +4 +4 +5 From fb941371cebe5619ada9d3ef727c4844b316c576 Mon Sep 17 00:00:00 2001 From: pavle-martinovic_data Date: Wed, 21 May 2025 14:05:57 +0200 Subject: [PATCH 3/7] Add aditional tests covering all random functions --- .../analyzer-results/cte-recursion.sql.out | 31 +++++- .../sql-tests/inputs/cte-recursion.sql | 13 ++- .../sql-tests/results/cte-recursion.sql.out | 41 ++++++- .../sql/execution/RecursiveCTESuite.scala | 103 ++++++++++++++++++ 4 files changed, 178 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out index b271ebe4f8bcd..230b476cfa51e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out @@ -1645,9 +1645,36 @@ SELECT val FROM randoms LIMIT 5 -- !query WITH RECURSIVE randoms(val, step) AS ( - SELECT CAST(UNIFORM(1, 6, 82374) AS INT), 1 AS step + SELECT CAST(UNIFORM(1, 6, 82374) AS INT) UNION ALL - SELECT CAST(UNIFORM(1, 6, 237685) AS INT), step + 1 + SELECT CAST(UNIFORM(1, 6, 237685) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", + "sqlState" : "42802", + "messageParameters" : { + "numExpr" : "2", + "numTarget" : "1" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 154, + "stopIndex" : 160, + "fragment" : "randoms" + } ] +} + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(randn(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(randn(237685) * 5 + 1) AS INT) FROM randoms ) SELECT val FROM randoms LIMIT 5 diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql index 427f4fa7d64bc..1859bad1c32d3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql @@ -599,9 +599,18 @@ SELECT val FROM randoms LIMIT 5; -- Non-deterministic query with uniform with seed WITH RECURSIVE randoms(val, step) AS ( - SELECT CAST(UNIFORM(1, 6, 82374) AS INT), 1 AS step + SELECT CAST(UNIFORM(1, 6, 82374) AS INT) UNION ALL - SELECT CAST(UNIFORM(1, 6, 237685) AS INT), step + 1 + SELECT CAST(UNIFORM(1, 6, 237685) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5; + +-- Non-deterministic query with randn with seed +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(randn(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(randn(237685) * 5 + 1) AS INT) FROM randoms ) SELECT val FROM randoms LIMIT 5; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out index a5b5f4a13e4aa..6242cff3cb1d0 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out @@ -1497,17 +1497,46 @@ struct -- !query WITH RECURSIVE randoms(val, step) AS ( - SELECT CAST(UNIFORM(1, 6, 82374) AS INT), 1 AS step + SELECT CAST(UNIFORM(1, 6, 82374) AS INT) UNION ALL - SELECT CAST(UNIFORM(1, 6, 237685) AS INT), step + 1 + SELECT CAST(UNIFORM(1, 6, 237685) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", + "sqlState" : "42802", + "messageParameters" : { + "numExpr" : "2", + "numTarget" : "1" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 154, + "stopIndex" : 160, + "fragment" : "randoms" + } ] +} + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(randn(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(randn(237685) * 5 + 1) AS INT) FROM randoms ) SELECT val FROM randoms LIMIT 5 -- !query schema struct -- !query output +-2 1 -4 -4 -4 -5 +2 +2 +6 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala new file mode 100644 index 0000000000000..c2aaa2629de6c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.countDistinct +import org.apache.spark.sql.test.SharedSparkSession + +class RecursiveCTESuite extends QueryTest with SharedSparkSession { + + test("Random rCTEs produce different results in different iterations - RAND") { + val df = sql("""WITH RECURSIVE randoms(val) AS ( + | SELECT CAST(floor(rand() * 50 + 1) AS INT) + | UNION ALL + | SELECT CAST(floor(rand() * 50 + 1) AS INT) + | FROM randoms + |) + |SELECT val FROM randoms LIMIT 10;""".stripMargin) + + val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) + assert(distinctCount > 2) + } + + test("Random rCTEs produce different results in different iterations - UNIFORM") { + val df = sql("""WITH RECURSIVE randoms(val) AS ( + | SELECT CAST(UNIFORM(1,51) AS INT) + | UNION ALL + | SELECT CAST(UNIFORM(1,51) AS INT) + | FROM randoms + |) + |SELECT val FROM randoms LIMIT 10;""".stripMargin) + + val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) + assert(distinctCount > 2) + } + + test("Random rCTEs produce different results in different iterations - RANDN") { + val df = sql("""WITH RECURSIVE randoms(val) AS ( + | SELECT CAST(floor(randn() * 50) AS INT) + | UNION ALL + | SELECT CAST(floor(randn() * 50) AS INT) + | FROM randoms + |) + |SELECT val FROM randoms LIMIT 10;""".stripMargin) + + val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) + assert(distinctCount > 2) + } + + test("Random rCTEs produce different results in different iterations - RANDSTR") { + val df = sql("""WITH RECURSIVE randoms(val) AS ( + | SELECT randstr(10) + | UNION ALL + | SELECT randstr(10) + | FROM randoms + |) + |SELECT val FROM randoms LIMIT 10;""".stripMargin) + + val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) + assert(distinctCount > 2) + } + + test("Random rCTEs produce different results in different iterations - UUID") { + val df = sql("""WITH RECURSIVE randoms(val) AS ( + | SELECT UUID() + | UNION ALL + | SELECT UUID() + | FROM randoms + |) + |SELECT val FROM randoms LIMIT 10;""".stripMargin) + + val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) + assert(distinctCount > 2) + } + + test("Random rCTEs produce different results in different iterations - SHUFFLE") { + val df = sql("""WITH RECURSIVE randoms(val) AS ( + | SELECT ARRAY(1,2,3,4,5) + | UNION ALL + | SELECT SHUFFLE(ARRAY(1,2,3,4,5)) + | FROM randoms + | ) + |SELECT val FROM randoms LIMIT 10;""".stripMargin) + + val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) + assert(distinctCount > 2) + } +} From 94a1ec6600e69d851adb386e8b8b637dc64b559b Mon Sep 17 00:00:00 2001 From: pavle-martinovic_data Date: Thu, 22 May 2025 10:15:18 +0200 Subject: [PATCH 4/7] Make shift and tests more stable; change transformation --- .../expressions/collectionOperations.scala | 2 +- .../spark/sql/catalyst/expressions/misc.scala | 2 +- .../expressions/randomExpressions.scala | 6 ++- .../spark/sql/execution/UnionLoopExec.scala | 2 +- .../analyzer-results/cte-recursion.sql.out | 31 ++++++------- .../sql-tests/inputs/cte-recursion.sql | 10 ++++- .../sql-tests/results/cte-recursion.sql.out | 43 +++++++++++-------- .../sql/execution/RecursiveCTESuite.scala | 12 +++--- 8 files changed, 61 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 75063670b270d..b4978fbe1f70a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1268,7 +1268,7 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends U override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed)) override def withShiftedSeed(shift: Long): Shuffle = - copy(randomSeed = Some(randomSeed.get + shift)) + copy(randomSeed = randomSeed.map(_ + shift)) override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a44f49f684d38..dcbca34b240b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -260,7 +260,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed)) - override def withShiftedSeed(shift: Long): Uuid = Uuid(Some(randomSeed.get + shift)) + override def withShiftedSeed(shift: Long): Uuid = Uuid(randomSeed.map(_ + shift)) override lazy val resolved: Boolean = randomSeed.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 1b0484051e427..06cc6e55c8ec4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -115,7 +115,8 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends Nondetermi override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), hideSeed) - override def withShiftedSeed(shift: Long): Rand = Rand(Add(child, Literal(shift)), hideSeed) + override def withShiftedSeed(shift: Long): Rand = + Rand(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed) override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() @@ -168,7 +169,8 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends Nondeterm override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), hideSeed) - override def withShiftedSeed(shift: Long): Randn = Randn(Add(child, Literal(shift)), hideSeed) + override def withShiftedSeed(shift: Long): Randn = + Randn(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed) override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala index debf5dc05dd4b..e836681c513db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala @@ -190,7 +190,7 @@ case class UnionLoopExec( val recursionReseeded = if (recursion.deterministic) { recursion } else { - recursion.transformExpressionsDown { + recursion.transformAllExpressionsWithSubqueries { case e: ExpressionWithRandomSeed => e.withShiftedSeed(currentLevel) } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out index 230b476cfa51e..5187c59e25683 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out @@ -1644,7 +1644,7 @@ SELECT val FROM randoms LIMIT 5 -- !query -WITH RECURSIVE randoms(val, step) AS ( +WITH RECURSIVE randoms(val) AS ( SELECT CAST(UNIFORM(1, 6, 82374) AS INT) UNION ALL SELECT CAST(UNIFORM(1, 6, 237685) AS INT) @@ -1652,22 +1652,7 @@ WITH RECURSIVE randoms(val, step) AS ( ) SELECT val FROM randoms LIMIT 5 -- !query analysis -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", - "sqlState" : "42802", - "messageParameters" : { - "numExpr" : "2", - "numTarget" : "1" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 154, - "stopIndex" : 160, - "fragment" : "randoms" - } ] -} +[Analyzer test output redacted due to nondeterminism] -- !query @@ -1680,3 +1665,15 @@ WITH RECURSIVE randoms(val) AS ( SELECT val FROM randoms LIMIT 5 -- !query analysis [Analyzer test output redacted due to nondeterminism] + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT randstr(10, 82374) + UNION ALL + SELECT randstr(10, 237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql index 1859bad1c32d3..57cb95f03bb0f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql @@ -598,7 +598,7 @@ WITH RECURSIVE randoms(val) AS ( SELECT val FROM randoms LIMIT 5; -- Non-deterministic query with uniform with seed -WITH RECURSIVE randoms(val, step) AS ( +WITH RECURSIVE randoms(val) AS ( SELECT CAST(UNIFORM(1, 6, 82374) AS INT) UNION ALL SELECT CAST(UNIFORM(1, 6, 237685) AS INT) @@ -613,4 +613,12 @@ WITH RECURSIVE randoms(val) AS ( SELECT CAST(floor(randn(237685) * 5 + 1) AS INT) FROM randoms ) +SELECT val FROM randoms LIMIT 5; + +WITH RECURSIVE randoms(val) AS ( + SELECT randstr(10, 82374) + UNION ALL + SELECT randstr(10, 237685) + FROM randoms +) SELECT val FROM randoms LIMIT 5; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out index 6242cff3cb1d0..20733670e3dfa 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out @@ -1496,7 +1496,7 @@ struct -- !query -WITH RECURSIVE randoms(val, step) AS ( +WITH RECURSIVE randoms(val) AS ( SELECT CAST(UNIFORM(1, 6, 82374) AS INT) UNION ALL SELECT CAST(UNIFORM(1, 6, 237685) AS INT) @@ -1504,24 +1504,13 @@ WITH RECURSIVE randoms(val, step) AS ( ) SELECT val FROM randoms LIMIT 5 -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", - "sqlState" : "42802", - "messageParameters" : { - "numExpr" : "2", - "numTarget" : "1" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 154, - "stopIndex" : 160, - "fragment" : "randoms" - } ] -} +1 +4 +4 +4 +5 -- !query @@ -1540,3 +1529,21 @@ struct 2 2 6 + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT randstr(10, 82374) + UNION ALL + SELECT randstr(10, 237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct +-- !query output +IpXzdTW03I +UxLgwhvH5j +dBlWnfo7rO +fmfDBMf60f +kFeBV7dQWi diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala index c2aaa2629de6c..05db3ac669063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala @@ -30,7 +30,7 @@ class RecursiveCTESuite extends QueryTest with SharedSparkSession { | SELECT CAST(floor(rand() * 50 + 1) AS INT) | FROM randoms |) - |SELECT val FROM randoms LIMIT 10;""".stripMargin) + |SELECT val FROM randoms LIMIT 30;""".stripMargin) val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) assert(distinctCount > 2) @@ -43,7 +43,7 @@ class RecursiveCTESuite extends QueryTest with SharedSparkSession { | SELECT CAST(UNIFORM(1,51) AS INT) | FROM randoms |) - |SELECT val FROM randoms LIMIT 10;""".stripMargin) + |SELECT val FROM randoms LIMIT 30;""".stripMargin) val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) assert(distinctCount > 2) @@ -56,7 +56,7 @@ class RecursiveCTESuite extends QueryTest with SharedSparkSession { | SELECT CAST(floor(randn() * 50) AS INT) | FROM randoms |) - |SELECT val FROM randoms LIMIT 10;""".stripMargin) + |SELECT val FROM randoms LIMIT 30;""".stripMargin) val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) assert(distinctCount > 2) @@ -69,7 +69,7 @@ class RecursiveCTESuite extends QueryTest with SharedSparkSession { | SELECT randstr(10) | FROM randoms |) - |SELECT val FROM randoms LIMIT 10;""".stripMargin) + |SELECT val FROM randoms LIMIT 30;""".stripMargin) val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) assert(distinctCount > 2) @@ -82,7 +82,7 @@ class RecursiveCTESuite extends QueryTest with SharedSparkSession { | SELECT UUID() | FROM randoms |) - |SELECT val FROM randoms LIMIT 10;""".stripMargin) + |SELECT val FROM randoms LIMIT 30;""".stripMargin) val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) assert(distinctCount > 2) @@ -95,7 +95,7 @@ class RecursiveCTESuite extends QueryTest with SharedSparkSession { | SELECT SHUFFLE(ARRAY(1,2,3,4,5)) | FROM randoms | ) - |SELECT val FROM randoms LIMIT 10;""".stripMargin) + |SELECT val FROM randoms LIMIT 30;""".stripMargin) val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) assert(distinctCount > 2) From fa2ca557359f19185e1da055704365c1cd795b78 Mon Sep 17 00:00:00 2001 From: pavle-martinovic_data Date: Thu, 22 May 2025 10:32:53 +0200 Subject: [PATCH 5/7] add golden file tests for rest of functions --- .../analyzer-results/cte-recursion.sql.out | 24 +++++++++++++ .../sql-tests/inputs/cte-recursion.sql | 19 ++++++++++ .../sql-tests/results/cte-recursion.sql.out | 36 +++++++++++++++++++ 3 files changed, 79 insertions(+) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out index 5187c59e25683..dc2b5a20fde51 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out @@ -1677,3 +1677,27 @@ WITH RECURSIVE randoms(val) AS ( SELECT val FROM randoms LIMIT 5 -- !query analysis [Analyzer test output redacted due to nondeterminism] + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT UUID(82374) + UNION ALL + SELECT UUID(237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT ARRAY(1,2,3,4,5) + UNION ALL + SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql index 57cb95f03bb0f..8ef0c391a3fc5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql @@ -615,10 +615,29 @@ WITH RECURSIVE randoms(val) AS ( ) SELECT val FROM randoms LIMIT 5; +-- Non-deterministic query with randstr WITH RECURSIVE randoms(val) AS ( SELECT randstr(10, 82374) UNION ALL SELECT randstr(10, 237685) FROM randoms ) +SELECT val FROM randoms LIMIT 5; + +-- Non-deterministic query with UUID +WITH RECURSIVE randoms(val) AS ( + SELECT UUID(82374) + UNION ALL + SELECT UUID(237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5; + +-- Non-deterministic query with shuffle +WITH RECURSIVE randoms(val) AS ( + SELECT ARRAY(1,2,3,4,5) + UNION ALL + SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685) + FROM randoms +) SELECT val FROM randoms LIMIT 5; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out index 20733670e3dfa..fab186730d0b8 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out @@ -1547,3 +1547,39 @@ UxLgwhvH5j dBlWnfo7rO fmfDBMf60f kFeBV7dQWi + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT UUID(82374) + UNION ALL + SELECT UUID(237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct +-- !query output +4ea190e3-c088-4ddd-a545-fb431059ae3c +8b88900e-f862-468c-8d3b-828188116155 +9e75cc4e-6d6e-4235-9788-5efd4d0fd0cc +be4f5346-1c7f-4697-8a2c-1343347872c5 +d0032efe-ae60-461b-8582-f6a7c649f238 + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT ARRAY(1,2,3,4,5) + UNION ALL + SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct> +-- !query output +[1,2,3,4,5] +[1,2,3,5,4] +[3,4,2,1,5] +[4,3,2,5,1] +[4,5,1,2,3] From 142e9708ce5a0ce2f77720ebd171df2666344670 Mon Sep 17 00:00:00 2001 From: pavle-martinovic_data Date: Thu, 22 May 2025 16:20:08 +0200 Subject: [PATCH 6/7] remove suite and start seeds from +0 --- .../spark/sql/execution/UnionLoopExec.scala | 2 +- .../sql-tests/results/cte-recursion.sql.out | 12 +- .../sql/execution/RecursiveCTESuite.scala | 103 ------------------ 3 files changed, 7 insertions(+), 110 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala index e836681c513db..af956462fe7a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala @@ -192,7 +192,7 @@ case class UnionLoopExec( } else { recursion.transformAllExpressionsWithSubqueries { case e: ExpressionWithRandomSeed => - e.withShiftedSeed(currentLevel) + e.withShiftedSeed(currentLevel - 1) } } diff --git a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out index fab186730d0b8..d6939ab84b57c 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out @@ -1489,7 +1489,7 @@ SELECT val FROM randoms LIMIT 5 struct -- !query output 1 -4 +3 4 4 5 @@ -1507,7 +1507,7 @@ SELECT val FROM randoms LIMIT 5 struct -- !query output 1 -4 +3 4 4 5 @@ -1525,9 +1525,9 @@ SELECT val FROM randoms LIMIT 5 struct -- !query output -2 -1 2 2 +5 6 @@ -1543,7 +1543,7 @@ SELECT val FROM randoms LIMIT 5 struct -- !query output IpXzdTW03I -UxLgwhvH5j +Zj7uI2Ex6e dBlWnfo7rO fmfDBMf60f kFeBV7dQWi @@ -1560,9 +1560,9 @@ SELECT val FROM randoms LIMIT 5 -- !query schema struct -- !query output +19974dca-21f6-47ef-b58c-73908ab52aa0 4ea190e3-c088-4ddd-a545-fb431059ae3c 8b88900e-f862-468c-8d3b-828188116155 -9e75cc4e-6d6e-4235-9788-5efd4d0fd0cc be4f5346-1c7f-4697-8a2c-1343347872c5 d0032efe-ae60-461b-8582-f6a7c649f238 @@ -1580,6 +1580,6 @@ struct> -- !query output [1,2,3,4,5] [1,2,3,5,4] -[3,4,2,1,5] +[2,1,5,3,4] [4,3,2,5,1] [4,5,1,2,3] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala deleted file mode 100644 index 05db3ac669063..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RecursiveCTESuite.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.functions.countDistinct -import org.apache.spark.sql.test.SharedSparkSession - -class RecursiveCTESuite extends QueryTest with SharedSparkSession { - - test("Random rCTEs produce different results in different iterations - RAND") { - val df = sql("""WITH RECURSIVE randoms(val) AS ( - | SELECT CAST(floor(rand() * 50 + 1) AS INT) - | UNION ALL - | SELECT CAST(floor(rand() * 50 + 1) AS INT) - | FROM randoms - |) - |SELECT val FROM randoms LIMIT 30;""".stripMargin) - - val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) - assert(distinctCount > 2) - } - - test("Random rCTEs produce different results in different iterations - UNIFORM") { - val df = sql("""WITH RECURSIVE randoms(val) AS ( - | SELECT CAST(UNIFORM(1,51) AS INT) - | UNION ALL - | SELECT CAST(UNIFORM(1,51) AS INT) - | FROM randoms - |) - |SELECT val FROM randoms LIMIT 30;""".stripMargin) - - val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) - assert(distinctCount > 2) - } - - test("Random rCTEs produce different results in different iterations - RANDN") { - val df = sql("""WITH RECURSIVE randoms(val) AS ( - | SELECT CAST(floor(randn() * 50) AS INT) - | UNION ALL - | SELECT CAST(floor(randn() * 50) AS INT) - | FROM randoms - |) - |SELECT val FROM randoms LIMIT 30;""".stripMargin) - - val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) - assert(distinctCount > 2) - } - - test("Random rCTEs produce different results in different iterations - RANDSTR") { - val df = sql("""WITH RECURSIVE randoms(val) AS ( - | SELECT randstr(10) - | UNION ALL - | SELECT randstr(10) - | FROM randoms - |) - |SELECT val FROM randoms LIMIT 30;""".stripMargin) - - val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) - assert(distinctCount > 2) - } - - test("Random rCTEs produce different results in different iterations - UUID") { - val df = sql("""WITH RECURSIVE randoms(val) AS ( - | SELECT UUID() - | UNION ALL - | SELECT UUID() - | FROM randoms - |) - |SELECT val FROM randoms LIMIT 30;""".stripMargin) - - val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) - assert(distinctCount > 2) - } - - test("Random rCTEs produce different results in different iterations - SHUFFLE") { - val df = sql("""WITH RECURSIVE randoms(val) AS ( - | SELECT ARRAY(1,2,3,4,5) - | UNION ALL - | SELECT SHUFFLE(ARRAY(1,2,3,4,5)) - | FROM randoms - | ) - |SELECT val FROM randoms LIMIT 30;""".stripMargin) - - val distinctCount = df.select(countDistinct("val")).collect()(0).getLong(0) - assert(distinctCount > 2) - } -} From eac0ee579a092b49c771bf0b882185be8161655e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 22 May 2025 23:11:11 +0800 Subject: [PATCH 7/7] Update sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala --- .../scala/org/apache/spark/sql/execution/UnionLoopExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala index af956462fe7a6..d44d3b0b6ef0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala @@ -187,7 +187,7 @@ case class UnionLoopExec( // If the recursive part contains non-deterministic expressions that depends on a seed, we // need to create a new seed since the seed for this expression is set in the analysis, and // we avoid re-triggering the analysis for every iterative step. - val recursionReseeded = if (recursion.deterministic) { + val recursionReseeded = if (currentLevel == 1 || recursion.deterministic) { recursion } else { recursion.transformAllExpressionsWithSubqueries {