Skip to content

Commit

Permalink
[KYUUBI apache#4393] [Kyuubi apache#4332] Fix some bugs with `Groupby…
Browse files Browse the repository at this point in the history
…` and `CacheTable`

close apache#4332
### _Why are the changes needed?_

For the case where the table name has been resolved and an `Expand` logical plan exists
```
InsertIntoHiveTable `default`.`t1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, false, false, [a, b]
+- Aggregate [a#0], [a#0, ansi_cast((count(if ((gid#9 = 1)) spark_catalog.default.t2.`b`#10 else null) * count(if ((gid#9 = 2)) spark_catalog.default.t2.`c`#11 else null)) as string) AS b#8]
   +- Aggregate [a#0, spark_catalog.default.t2.`b`#10, spark_catalog.default.t2.`c`#11, gid#9], [a#0, spark_catalog.default.t2.`b`#10, spark_catalog.default.t2.`c`#11, gid#9]
      +- Expand [ArrayBuffer(a#0, b#1, null, 1), ArrayBuffer(a#0, null, c#2, 2)], [a#0, spark_catalog.default.t2.`b`#10, spark_catalog.default.t2.`c`#11, gid#9]
         +- HiveTableRelation [`default`.`t2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Data Cols: [a#0, b#1, c#2], Partition Cols: []]
```
For the case `CacheTable` with `window` function
```
InsertIntoHiveTable `default`.`t1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, true, false, [a, b]
+- Project [a#98, b#99]
   +- InMemoryRelation [a#98, b#99, rank#100], StorageLevel(disk, memory, deserialized, 1 replicas)
         +- *(2) Filter (isnotnull(rank#4) AND (rank#4 = 1))
            +- Window [row_number() windowspecdefinition(a#9, b#10 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rank#4], [a#9], [b#10 ASC NULLS FIRST]
               +- *(1) Sort [a#9 ASC NULLS FIRST, b#10 ASC NULLS FIRST], false, 0
                  +- Exchange hashpartitioning(a#9, 200), ENSURE_REQUIREMENTS, [id=apache#38]
                     +- Scan hive default.t2 [a#9, b#10], HiveTableRelation [`default`.`t2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Data Cols: [a#9, b#10], Partition Cols: []]

```

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes apache#4393 from iodone/kyuubi-4332.

Closes apache#4393

d2afdab [odone] fix cache table bug
443af79 [odone] fix some bugs with groupby

Authored-by: odone <[email protected]>
Signed-off-by: ulyssesyou <[email protected]>
  • Loading branch information
iodone authored and yanghua committed Apr 25, 2023
1 parent e4be464 commit cf4f7cc
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PersistedView, ViewType}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, HiveTableRelation}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -128,7 +127,7 @@ trait LineageParser {
exp.toAttribute,
if (!containsCountAll(exp.child)) references
else references + exp.toAttribute.withName(AGGREGATE_COUNT_COLUMN_IDENTIFIER))
case a: Attribute => a -> a.references
case a: Attribute => a -> AttributeSet(a)
}
ListMap(exps: _*)
}
Expand All @@ -149,6 +148,9 @@ trait LineageParser {
attr.withQualifier(attr.qualifier.init)
case attr if attr.name.equalsIgnoreCase(AGGREGATE_COUNT_COLUMN_IDENTIFIER) =>
attr.withQualifier(qualifier)
case attr if isNameWithQualifier(attr, qualifier) =>
val newName = attr.name.split('.').last.stripPrefix("`").stripSuffix("`")
attr.withName(newName).withQualifier(qualifier)
})
}
} else {
Expand All @@ -160,6 +162,12 @@ trait LineageParser {
}
}

private def isNameWithQualifier(attr: Attribute, qualifier: Seq[String]): Boolean = {
val nameTokens = attr.name.split('.')
val namespace = nameTokens.init.mkString(".")
nameTokens.length > 1 && namespace.endsWith(qualifier.mkString("."))
}

private def mergeRelationColumnLineage(
parentColumnsLineage: AttributeMap[AttributeSet],
relationOutput: Seq[Attribute],
Expand Down Expand Up @@ -327,6 +335,31 @@ trait LineageParser {
joinColumnsLineage(parentColumnsLineage, getSelectColumnLineage(p.aggregateExpressions))
p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)

case p: Expand =>
val references =
p.projections.transpose.map(_.flatMap(x => x.references)).map(AttributeSet(_))

val childColumnsLineage = ListMap(p.output.zip(references): _*)
val nextColumnsLineage =
joinColumnsLineage(parentColumnsLineage, childColumnsLineage)
p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)

case p: Window =>
val windowColumnsLineage =
ListMap(p.windowExpressions.map(exp => (exp.toAttribute, exp.references)): _*)

val nextColumnsLineage = if (parentColumnsLineage.isEmpty) {
ListMap(p.child.output.map(attr => (attr, attr.references)): _*) ++ windowColumnsLineage
} else {
parentColumnsLineage.map {
case (k, _) if windowColumnsLineage.contains(k) =>
k -> windowColumnsLineage(k)
case (k, attrs) =>
k -> AttributeSet(attrs.flatten(attr =>
windowColumnsLineage.getOrElse(attr, AttributeSet(attr))))
}
}
p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)
case p: Join =>
p.joinType match {
case LeftSemi | LeftAnti =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,125 @@ class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite
}
}

test("test group by") {
withTable("t1", "t2", "v2_catalog.db.t1", "v2_catalog.db.t2") { _ =>
spark.sql("CREATE TABLE t1 (a string, b string, c string) USING hive")
spark.sql("CREATE TABLE t2 (a string, b string, c string) USING hive")
spark.sql("CREATE TABLE v2_catalog.db.t1 (a string, b string, c string)")
spark.sql("CREATE TABLE v2_catalog.db.t2 (a string, b string, c string)")
val ret0 =
exectractLineage(
s"insert into table t1 select a," +
s"concat_ws('/', collect_set(b))," +
s"count(distinct(b)) * count(distinct(c))" +
s"from t2 group by a")
assert(ret0 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.b")),
("default.t1.c", Set("default.t2.b", "default.t2.c")))))

val ret1 =
exectractLineage(
s"insert into table v2_catalog.db.t1 select a," +
s"concat_ws('/', collect_set(b))," +
s"count(distinct(b)) * count(distinct(c))" +
s"from v2_catalog.db.t2 group by a")
assert(ret1 == Lineage(
List("v2_catalog.db.t2"),
List("v2_catalog.db.t1"),
List(
("v2_catalog.db.t1.a", Set("v2_catalog.db.t2.a")),
("v2_catalog.db.t1.b", Set("v2_catalog.db.t2.b")),
("v2_catalog.db.t1.c", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")))))

val ret2 =
exectractLineage(
s"insert into table v2_catalog.db.t1 select a," +
s"count(distinct(b+c))," +
s"count(distinct(b)) * count(distinct(c))" +
s"from v2_catalog.db.t2 group by a")
assert(ret2 == Lineage(
List("v2_catalog.db.t2"),
List("v2_catalog.db.t1"),
List(
("v2_catalog.db.t1.a", Set("v2_catalog.db.t2.a")),
("v2_catalog.db.t1.b", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")),
("v2_catalog.db.t1.c", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")))))
}
}

test("test grouping sets") {
withTable("t1", "t2") { _ =>
spark.sql("CREATE TABLE t1 (a string, b string, c string) USING hive")
spark.sql("CREATE TABLE t2 (a string, b string, c string, d string) USING hive")
val ret0 =
exectractLineage(
s"insert into table t1 select a,b,GROUPING__ID " +
s"from t2 group by a,b,c,d grouping sets ((a,b,c), (a,b,d))")
assert(ret0 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.b")),
("default.t1.c", Set()))))
}
}

test("test catch table with window function") {
withTable("t1", "t2") { _ =>
spark.sql("CREATE TABLE t1 (a string, b string) USING hive")
spark.sql("CREATE TABLE t2 (a string, b string) USING hive")

spark.sql(
s"cache table c1 select * from (" +
s"select a, b, row_number() over (partition by a order by b asc ) rank from t2)" +
s" where rank=1")
val ret0 = exectractLineage("insert overwrite table t1 select a, b from c1")
assert(ret0 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.b")))))

val ret1 = exectractLineage("insert overwrite table t1 select a, rank from c1")
assert(ret1 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.a", "default.t2.b")))))

spark.sql(
s"cache table c2 select * from (" +
s"select b, a, row_number() over (partition by a order by b asc ) rank from t2)" +
s" where rank=1")
val ret2 = exectractLineage("insert overwrite table t1 select a, b from c2")
assert(ret2 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.b")))))

spark.sql(
s"cache table c3 select * from (" +
s"select a as aa, b as bb, row_number() over (partition by a order by b asc ) rank" +
s" from t2) where rank=1")
val ret3 = exectractLineage("insert overwrite table t1 select aa, bb from c3")
assert(ret3 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.b")))))
}
}

private def exectractLineageWithoutExecuting(sql: String): Lineage = {
val parsed = spark.sessionState.sqlParser.parsePlan(sql)
val analyzed = spark.sessionState.analyzer.execute(parsed)
Expand Down

0 comments on commit cf4f7cc

Please sign in to comment.