diff --git a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala index f70e09126cb..a5865311307 100644 --- a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala +++ b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala @@ -25,8 +25,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PersistedView, ViewType} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression} -import org.apache.spark.sql.catalyst.expressions.ScalarSubquery +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression, ScalarSubquery} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical._ @@ -128,7 +127,7 @@ trait LineageParser { exp.toAttribute, if (!containsCountAll(exp.child)) references else references + exp.toAttribute.withName(AGGREGATE_COUNT_COLUMN_IDENTIFIER)) - case a: Attribute => a -> a.references + case a: Attribute => a -> AttributeSet(a) } ListMap(exps: _*) } @@ -149,6 +148,9 @@ trait LineageParser { attr.withQualifier(attr.qualifier.init) case attr if attr.name.equalsIgnoreCase(AGGREGATE_COUNT_COLUMN_IDENTIFIER) => attr.withQualifier(qualifier) + case attr if isNameWithQualifier(attr, qualifier) => + val newName = attr.name.split('.').last.stripPrefix("`").stripSuffix("`") + attr.withName(newName).withQualifier(qualifier) }) } } else { @@ -160,6 +162,12 @@ trait LineageParser { } } + private def isNameWithQualifier(attr: Attribute, qualifier: Seq[String]): Boolean = { + val nameTokens = attr.name.split('.') + val namespace = nameTokens.init.mkString(".") + nameTokens.length > 1 && namespace.endsWith(qualifier.mkString(".")) + } + private def mergeRelationColumnLineage( parentColumnsLineage: AttributeMap[AttributeSet], relationOutput: Seq[Attribute], @@ -327,6 +335,31 @@ trait LineageParser { joinColumnsLineage(parentColumnsLineage, getSelectColumnLineage(p.aggregateExpressions)) p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage) + case p: Expand => + val references = + p.projections.transpose.map(_.flatMap(x => x.references)).map(AttributeSet(_)) + + val childColumnsLineage = ListMap(p.output.zip(references): _*) + val nextColumnsLineage = + joinColumnsLineage(parentColumnsLineage, childColumnsLineage) + p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage) + + case p: Window => + val windowColumnsLineage = + ListMap(p.windowExpressions.map(exp => (exp.toAttribute, exp.references)): _*) + + val nextColumnsLineage = if (parentColumnsLineage.isEmpty) { + ListMap(p.child.output.map(attr => (attr, attr.references)): _*) ++ windowColumnsLineage + } else { + parentColumnsLineage.map { + case (k, _) if windowColumnsLineage.contains(k) => + k -> windowColumnsLineage(k) + case (k, attrs) => + k -> AttributeSet(attrs.flatten(attr => + windowColumnsLineage.getOrElse(attr, AttributeSet(attr)))) + } + } + p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage) case p: Join => p.joinType match { case LeftSemi | LeftAnti => diff --git a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala index 6652be9ea15..050f3ddc9f0 100644 --- a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala +++ b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala @@ -1094,6 +1094,125 @@ class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite } } + test("test group by") { + withTable("t1", "t2", "v2_catalog.db.t1", "v2_catalog.db.t2") { _ => + spark.sql("CREATE TABLE t1 (a string, b string, c string) USING hive") + spark.sql("CREATE TABLE t2 (a string, b string, c string) USING hive") + spark.sql("CREATE TABLE v2_catalog.db.t1 (a string, b string, c string)") + spark.sql("CREATE TABLE v2_catalog.db.t2 (a string, b string, c string)") + val ret0 = + exectractLineage( + s"insert into table t1 select a," + + s"concat_ws('/', collect_set(b))," + + s"count(distinct(b)) * count(distinct(c))" + + s"from t2 group by a") + assert(ret0 == Lineage( + List("default.t2"), + List("default.t1"), + List( + ("default.t1.a", Set("default.t2.a")), + ("default.t1.b", Set("default.t2.b")), + ("default.t1.c", Set("default.t2.b", "default.t2.c"))))) + + val ret1 = + exectractLineage( + s"insert into table v2_catalog.db.t1 select a," + + s"concat_ws('/', collect_set(b))," + + s"count(distinct(b)) * count(distinct(c))" + + s"from v2_catalog.db.t2 group by a") + assert(ret1 == Lineage( + List("v2_catalog.db.t2"), + List("v2_catalog.db.t1"), + List( + ("v2_catalog.db.t1.a", Set("v2_catalog.db.t2.a")), + ("v2_catalog.db.t1.b", Set("v2_catalog.db.t2.b")), + ("v2_catalog.db.t1.c", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c"))))) + + val ret2 = + exectractLineage( + s"insert into table v2_catalog.db.t1 select a," + + s"count(distinct(b+c))," + + s"count(distinct(b)) * count(distinct(c))" + + s"from v2_catalog.db.t2 group by a") + assert(ret2 == Lineage( + List("v2_catalog.db.t2"), + List("v2_catalog.db.t1"), + List( + ("v2_catalog.db.t1.a", Set("v2_catalog.db.t2.a")), + ("v2_catalog.db.t1.b", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")), + ("v2_catalog.db.t1.c", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c"))))) + } + } + + test("test grouping sets") { + withTable("t1", "t2") { _ => + spark.sql("CREATE TABLE t1 (a string, b string, c string) USING hive") + spark.sql("CREATE TABLE t2 (a string, b string, c string, d string) USING hive") + val ret0 = + exectractLineage( + s"insert into table t1 select a,b,GROUPING__ID " + + s"from t2 group by a,b,c,d grouping sets ((a,b,c), (a,b,d))") + assert(ret0 == Lineage( + List("default.t2"), + List("default.t1"), + List( + ("default.t1.a", Set("default.t2.a")), + ("default.t1.b", Set("default.t2.b")), + ("default.t1.c", Set())))) + } + } + + test("test catch table with window function") { + withTable("t1", "t2") { _ => + spark.sql("CREATE TABLE t1 (a string, b string) USING hive") + spark.sql("CREATE TABLE t2 (a string, b string) USING hive") + + spark.sql( + s"cache table c1 select * from (" + + s"select a, b, row_number() over (partition by a order by b asc ) rank from t2)" + + s" where rank=1") + val ret0 = exectractLineage("insert overwrite table t1 select a, b from c1") + assert(ret0 == Lineage( + List("default.t2"), + List("default.t1"), + List( + ("default.t1.a", Set("default.t2.a")), + ("default.t1.b", Set("default.t2.b"))))) + + val ret1 = exectractLineage("insert overwrite table t1 select a, rank from c1") + assert(ret1 == Lineage( + List("default.t2"), + List("default.t1"), + List( + ("default.t1.a", Set("default.t2.a")), + ("default.t1.b", Set("default.t2.a", "default.t2.b"))))) + + spark.sql( + s"cache table c2 select * from (" + + s"select b, a, row_number() over (partition by a order by b asc ) rank from t2)" + + s" where rank=1") + val ret2 = exectractLineage("insert overwrite table t1 select a, b from c2") + assert(ret2 == Lineage( + List("default.t2"), + List("default.t1"), + List( + ("default.t1.a", Set("default.t2.a")), + ("default.t1.b", Set("default.t2.b"))))) + + spark.sql( + s"cache table c3 select * from (" + + s"select a as aa, b as bb, row_number() over (partition by a order by b asc ) rank" + + s" from t2) where rank=1") + val ret3 = exectractLineage("insert overwrite table t1 select aa, bb from c3") + assert(ret3 == Lineage( + List("default.t2"), + List("default.t1"), + List( + ("default.t1.a", Set("default.t2.a")), + ("default.t1.b", Set("default.t2.b"))))) + } + } + private def exectractLineageWithoutExecuting(sql: String): Lineage = { val parsed = spark.sessionState.sqlParser.parsePlan(sql) val analyzed = spark.sessionState.analyzer.execute(parsed)