diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0a5e46d98728f..dc6a1ab2db9cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -449,14 +449,16 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveEncodersInUDF), // The rewrite rules might move resolved query plan into subquery. Once the resolved plan // contains ScalaUDF, their encoders won't be resolved if `ResolveEncodersInUDF` is not - // applied before the rewrite rules. So we need to apply `ResolveEncodersInUDF` before the - // rewrite rules. + // applied before the rewrite rules. So we need to apply the rewrite rules after + // `ResolveEncodersInUDF` Batch("DML rewrite", fixedPoint, RewriteDeleteFromTable, RewriteUpdateTable, RewriteMergeIntoTable, // Ensures columns of an output table are correctly resolved from the data in a logical plan. - ResolveOutputRelation), + ResolveOutputRelation, + // Apply table check constraints to validate data during write operations. + new ResolveTableConstraints(catalogManager)), Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, @@ -1437,6 +1439,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor new ResolveReferencesInUpdate(catalogManager) private val resolveReferencesInSort = new ResolveReferencesInSort(catalogManager) + private val resolveReferencesInFilter = + new ResolveReferencesInFilter(catalogManager) /** * Return true if there're conflicting attributes among children's outputs of a plan @@ -1483,9 +1487,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor checkTrailingCommaInSelect(expanded, starRemoved = true) } expanded - // If the filter list contains Stars, expand it. - case p: Filter if containsStar(Seq(p.condition)) => - p.copy(expandStarExpression(p.condition, p.child)) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) { @@ -1711,23 +1712,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Project(child.output, r.copy(resolvedFinal, newChild)) } - // Filter can host both grouping expressions/aggregate functions and missing attributes. - // The grouping expressions/aggregate functions resolution takes precedence over missing - // attributes. See the classdoc of `ResolveReferences` for details. - case f @ Filter(cond, child) if !cond.resolved || f.missingInput.nonEmpty => - val resolvedBasic = resolveExpressionByPlanChildren(cond, f) - val resolvedWithAgg = resolveColWithAgg(resolvedBasic, child) - val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(resolvedWithAgg), child) - // Missing columns should be resolved right after basic column resolution. - // See the doc of `ResolveReferences`. - val resolvedFinal = resolveColsLastResort(newCond.head) - if (child.output == newChild.output) { - f.copy(condition = resolvedFinal) + case f: Filter => + // If the filter list contains Stars, expand it. + val afterStarExpansion = if (containsStar(Seq(f.condition))) { + f.copy(expandStarExpression(f.condition, f.child)) } else { - // Add missing attributes and then project them away. - val newFilter = Filter(resolvedFinal, newChild) - Project(child.output, newFilter) + f } + resolveReferencesInFilter.apply(afterStarExpansion) case s: Sort if !s.resolved || s.missingInput.nonEmpty => resolveReferencesInSort(s) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInFilter.scala new file mode 100644 index 0000000000000..44349324916bb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInFilter.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.connector.catalog.CatalogManager + + +/** + * A virtual rule to resolve [[UnresolvedAttribute]] in [[Filter]]. It's only used by the real + * rules `ResolveReferences` and `ResolveTableConstraints`. Filters containing unresolved stars + * should have been expanded before applying this rule. + * Filter can host both grouping expressions/aggregate functions and missing attributes. + * The grouping expressions/aggregate functions resolution takes precedence over missing + * attributes. See the classdoc of `ResolveReferences` for details. + */ + class ResolveReferencesInFilter(val catalogManager: CatalogManager) + extends SQLConfHelper with ColumnResolutionHelper { + def apply(f: Filter): LogicalPlan = { + if (f.condition.resolved && f.missingInput.isEmpty) { + return f + } + val resolvedBasic = resolveExpressionByPlanChildren(f.condition, f) + val resolvedWithAgg = resolveColWithAgg(resolvedBasic, f.child) + val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(resolvedWithAgg), f.child) + // Missing columns should be resolved right after basic column resolution. + // See the doc of `ResolveReferences`. + val resolvedFinal = resolveColsLastResort(newCond.head) + if (f.child.output == newChild.output) { + f.copy(condition = resolvedFinal) + } else { + // Add missing attributes and then project them away. + val newFilter = Filter(resolvedFinal, newChild) + Project(f.child.output, newFilter) + } + } + + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraints.scala index 3b86b9580ae19..a2dde6669be13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableConstraints.scala @@ -22,38 +22,58 @@ import org.apache.spark.sql.catalyst.expressions.{And, CheckInvariant, Expressio import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, V2WriteCommand} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND -import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.connector.catalog.{CatalogManager, Table} import org.apache.spark.sql.connector.catalog.constraints.Check import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation class ResolveTableConstraints(val catalogManager: CatalogManager) extends Rule[LogicalPlan] { + private val resolveReferencesInFilter = new ResolveReferencesInFilter(catalogManager) + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( _.containsPattern(COMMAND), ruleId) { case v2Write: V2WriteCommand if v2Write.table.resolved && v2Write.query.resolved && !containsCheckInvariant(v2Write.query) && v2Write.outputResolved => v2Write.table match { - case r: DataSourceV2Relation - if r.table.constraints != null && r.table.constraints.nonEmpty => - // Check constraint is the only enforced constraint for DSV2 tables. - val checkInvariants = r.table.constraints.collect { - case c: Check => - val unresolvedExpr = buildCatalystExpression(c) - val columnExtractors = mutable.Map[String, Expression]() - buildColumnExtractors(unresolvedExpr, columnExtractors) - CheckInvariant(unresolvedExpr, columnExtractors.toSeq, c.name, c.predicateSql) - } - // Combine the check invariants into a single expression using conjunctive AND. - checkInvariants.reduceOption(And).fold(v2Write)( - condition => v2Write.withNewQuery(Filter(condition, v2Write.query))) + case r: DataSourceV2Relation => + buildCheckCondition(r.table).map { condition => + val filter = Filter(condition, v2Write.query) + // Resolve attribute references in the filter condition only, not the entire query. + // We use a targeted resolver (ResolveReferencesInFilter) instead of the full + // `ResolveReferences` rule to avoid the creation of `TempResolvedColumn` nodes that + // would interfere with the analyzer's ability to correctly identify unresolved + // attributes. + val resolvedFilter = resolveReferencesInFilter(filter) + v2Write.withNewQuery(resolvedFilter) + }.getOrElse(v2Write) case _ => v2Write } } + // Constructs an optional check condition based on the table's check constraints. + // This condition validates data during write operations. + // Returns None if no check constraints exist; otherwise, combines all constraints using + // logical AND. + private def buildCheckCondition(table: Table): Option[Expression] = { + if (table.constraints == null || table.constraints.isEmpty) { + None + } else { + val checkInvariants = table.constraints.collect { + // Check constraint is the only enforced constraint for DSV2 tables. + case c: Check => + val unresolvedExpr = buildCatalystExpression(c) + val columnExtractors = mutable.Map[String, Expression]() + buildColumnExtractors(unresolvedExpr, columnExtractors) + CheckInvariant(unresolvedExpr, columnExtractors.toSeq, c.name, c.predicateSql) + } + checkInvariants.reduceOption(And) + } + } + private def containsCheckInvariant(plan: LogicalPlan): Boolean = { - plan match { + plan exists { case Filter(condition, _) => condition.exists(_.isInstanceOf[CheckInvariant]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationTable.scala index 07acacd9a35d3..5980be8635d2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/RowLevelOperationTable.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connector.write import java.util import org.apache.spark.sql.connector.catalog.{Column, SupportsRead, SupportsRowLevelOperations, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -40,6 +41,7 @@ private[sql] case class RowLevelOperationTable( override def schema: StructType = table.schema override def columns: Array[Column] = table.columns() override def capabilities: util.Set[TableCapability] = table.capabilities + override def constraints(): Array[Constraint] = table.constraints() override def toString: String = table.toString override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 29be5b19acd3c..af651e8f886cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -17,37 +17,27 @@ package org.apache.spark.sql.connector.catalog -import java.util - import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException -import org.apache.spark.sql.connector.expressions.Transform class InMemoryRowLevelOperationTableCatalog extends InMemoryTableCatalog { import CatalogV2Implicits._ - override def createTable( - ident: Identifier, - columns: Array[Column], - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { if (tables.containsKey(ident)) { throw new TableAlreadyExistsException(ident.asMultipartIdentifier) } - InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) + InMemoryTableCatalog.maybeSimulateFailedTableCreation(tableInfo.properties) val tableName = s"$name.${ident.quoted}" - val schema = CatalogV2Util.v2ColumnsToStructType(columns) - val table = new InMemoryRowLevelOperationTable(tableName, schema, partitions, properties) + val schema = CatalogV2Util.v2ColumnsToStructType(tableInfo.columns) + val table = new InMemoryRowLevelOperationTable( + tableName, schema, tableInfo.partitions, tableInfo.properties, tableInfo.constraints()) tables.put(ident, table) namespaces.putIfAbsent(ident.namespace.toList, Map()) table } - override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { - createTable(ident, tableInfo.columns(), tableInfo.partitions(), tableInfo.properties) - } - override def alterTable(ident: Identifier, changes: TableChange*): Table = { val table = loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index 585480ace7255..608d701af5783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -63,6 +63,31 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { Row(6, "hr", "new-text"))) } + test("delete from table with table constraints") { + sql( + s""" + |CREATE TABLE $tableNameAsString ( + | pk INT NOT NULL PRIMARY KEY, + | id INT UNIQUE, + | dep STRING, + | CONSTRAINT pk_check CHECK (pk > 0)) + | PARTITIONED BY (dep) + |""".stripMargin) + append("pk INT NOT NULL, id INT, dep STRING", + """{ "pk": 1, "id": 2, "dep": "hr" } + |{ "pk": 2, "id": 4, "dep": "eng" } + |{ "pk": 3, "id": 6, "dep": "eng" } + |""".stripMargin) + sql(s"DELETE FROM $tableNameAsString WHERE pk < 2") + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(2, 4, "eng"), Row(3, 6, "eng"))) + sql(s"DELETE FROM $tableNameAsString WHERE pk >=3") + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq(Row(2, 4, "eng"))) + } + test("delete from table containing struct column with default value") { sql( s"""CREATE TABLE $tableNameAsString ( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 580638230218b..f7aec678292ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, GenericRowWithSchema} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, Update, Write} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Delete, Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog, Insert, MetadataColumn, Operation, Reinsert, TableInfo, Update, Write} import org.apache.spark.sql.connector.expressions.LogicalExpressions.{identity, reference} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} @@ -102,7 +102,12 @@ abstract class RowLevelOperationSuiteBase protected def createTable(columns: Array[Column]): Unit = { val transforms = Array[Transform](identity(reference(Seq("dep")))) - catalog.createTable(ident, columns, transforms, extraTableProps) + val tableInfo = new TableInfo.Builder() + .withColumns(columns) + .withPartitions(transforms) + .withProperties(extraTableProps) + .build() + catalog.createTable(ident, tableInfo) } protected def createAndInitTable(schemaString: String, jsonData: String): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala index 39624a33d8614..a97be14536177 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala @@ -62,7 +62,7 @@ trait DDLCommandTestUtils extends SQLTestUtils { (f: String => Unit): Unit = { val nsCat = s"$cat.$ns" withNamespace(nsCat) { - sql(s"CREATE NAMESPACE $nsCat") + sql(s"CREATE NAMESPACE IF NOT EXISTS $nsCat") val t = s"$nsCat.$tableName" withTable(t) { f(t) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index 0dbda6a1cf6b5..efc315ecacb54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -422,6 +422,221 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma } } + test("Check constraint violation on table update - top level column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { t => + sql(s"CREATE TABLE $t (id INT, value INT," + + s" CONSTRAINT positive_id CHECK (id > 0)) $defaultUsing") + sql(s"INSERT INTO $t VALUES (5, 10)") + val error = intercept[SparkRuntimeException] { + sql(s"UPDATE $t SET id = -1 WHERE value = 10") + } + checkError( + exception = error, + condition = "CHECK_CONSTRAINT_VIOLATION", + sqlState = "23001", + parameters = + Map("constraintName" -> "positive_id", "expression" -> "id > 0", "values" -> " - id : -1") + ) + } + } + + test("Check constraint violation on table update - nested column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { t => + sql(s"CREATE TABLE $t (id INT, value INT," + + s" s STRUCT CONSTRAINT positive_num CHECK (s.num > 0)) $defaultUsing") + sql(s"INSERT INTO $t VALUES (5, 10, struct(5, 'test'))") + val error = intercept[SparkRuntimeException] { + sql(s"UPDATE $t SET s = named_struct('num', -1, 'str', 'test') WHERE value = 10") + } + checkError( + exception = error, + condition = "CHECK_CONSTRAINT_VIOLATION", + sqlState = "23001", + parameters = + Map("constraintName" -> "positive_num", + "expression" -> "s.num > 0", "values" -> " - s.num : -1") + ) + } + } + + test("Check constraint violation on table update - map type column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { t => + sql(s"CREATE TABLE $t (id INT, value INT," + + s" m MAP CONSTRAINT positive_num CHECK (m['a'] > 0)) $defaultUsing") + sql(s"INSERT INTO $t VALUES (5, 10, map('a', 5))") + val error = intercept[SparkRuntimeException] { + sql(s"UPDATE $t SET m = map('a', -1) WHERE value = 10") + } + checkError( + exception = error, + condition = "CHECK_CONSTRAINT_VIOLATION", + sqlState = "23001", + parameters = + Map("constraintName" -> "positive_num", + "expression" -> "m['a'] > 0", "values" -> " - m['a'] : -1") + ) + } + } + + test("Check constraint violation on table update - array type column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { t => + sql(s"CREATE TABLE $t (id INT, value INT," + + s" a ARRAY, CONSTRAINT positive_array CHECK (a[1] > 0)) $defaultUsing") + sql(s"INSERT INTO $t VALUES (5, 10, array(5, 6))") + val error = intercept[SparkRuntimeException] { + sql(s"UPDATE $t SET a = array(1, -2) WHERE value = 10") + } + checkError( + exception = error, + condition = "CHECK_CONSTRAINT_VIOLATION", + sqlState = "23001", + parameters = + Map("constraintName" -> "positive_array", + "expression" -> "a[1] > 0", "values" -> " - a[1] : -2") + ) + } + } + + test("Check constraint violation on table merge - top level column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { target => + withNamespaceAndTable("ns", "tbl2", rowLevelOPCatalog) { source => + sql(s"CREATE TABLE $target (id INT, value INT," + + s" CONSTRAINT positive_id CHECK (id > 0)) $defaultUsing") + sql(s"CREATE TABLE $source (id INT, value INT) $defaultUsing") + sql(s"INSERT INTO $target VALUES (5, 10)") + sql(s"INSERT INTO $source VALUES (-1, 20)") + + var error = intercept[SparkRuntimeException] { + sql( + s""" + |MERGE INTO $target t + |USING $source s + |ON t.value = s.value + |WHEN NOT MATCHED THEN INSERT(id, value) VALUES (s.id, s.value) + |""".stripMargin) + } + checkError( + exception = error, + condition = "CHECK_CONSTRAINT_VIOLATION", + sqlState = "23001", + parameters = + Map("constraintName" -> "positive_id", "expression" -> "id > 0", + "values" -> " - id : -1") + ) + + error = intercept[SparkRuntimeException] { + sql( + s""" + |MERGE INTO $target t + |USING $source s + |ON t.value = s.value + |WHEN MATCHED THEN UPDATE SET id = s.id + |WHEN NOT MATCHED THEN INSERT(id, value) VALUES (s.id, s.value) + |""".stripMargin) + } + checkError( + exception = error, + condition = "CHECK_CONSTRAINT_VIOLATION", + sqlState = "23001", + parameters = + Map("constraintName" -> "positive_id", "expression" -> "id > 0", + "values" -> " - id : -1") + ) + } + } + } + + test("Check constraint violation on table merge - nested column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { target => + withNamespaceAndTable("ns", "tbl2", rowLevelOPCatalog) { source => + sql(s"CREATE TABLE $target (id INT, value INT," + + s" s STRUCT " + + s"CONSTRAINT positive_num CHECK (s.num > 0)) $defaultUsing") + sql(s"CREATE TABLE $source (id INT, value INT) $defaultUsing") + sql(s"INSERT INTO $target VALUES (5, 10, struct(5, 'test'))") + sql(s"INSERT INTO $source VALUES (-1, 20)") + + val error = intercept[SparkRuntimeException] { + sql( + s""" + |MERGE INTO $target t + |USING $source s + |ON t.value = s.value + |WHEN NOT MATCHED THEN INSERT(id, value, s) VALUES + | (s.id, s.value, named_struct('num', -1, 'str', 'test')) + |""".stripMargin) + } + checkError( + exception = error, + condition = "CHECK_CONSTRAINT_VIOLATION", + sqlState = "23001", + parameters = + Map("constraintName" -> "positive_num", "expression" -> "s.num > 0", + "values" -> " - s.num : -1") + ) + } + } + } + + test("Check constraint violation on table merge - map type column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { target => + withNamespaceAndTable("ns", "tbl2", rowLevelOPCatalog) { source => + sql(s"CREATE TABLE $target (id INT, value INT," + + s" m MAP CONSTRAINT positive_num CHECK (m['a'] > 0)) $defaultUsing") + sql(s"CREATE TABLE $source (id INT, value INT) $defaultUsing") + sql(s"INSERT INTO $target VALUES (5, 10, map('a', 5))") + sql(s"INSERT INTO $source VALUES (-1, 20)") + + val error = intercept[SparkRuntimeException] { + sql( + s""" + |MERGE INTO $target t + |USING $source s + |ON t.value = s.value + |WHEN NOT MATCHED THEN INSERT(id, value, m) VALUES (s.id, s.value, map('a', -1)) + |""".stripMargin) + } + checkError( + exception = error, + condition = "CHECK_CONSTRAINT_VIOLATION", + sqlState = "23001", + parameters = + Map("constraintName" -> "positive_num", "expression" -> "m['a'] > 0", + "values" -> " - m['a'] : -1") + ) + } + } + } + + test("Check constraint violation on table merge - array type column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { target => + withNamespaceAndTable("ns", "tbl2", rowLevelOPCatalog) { source => + sql(s"CREATE TABLE $target (id INT, value INT," + + s" a ARRAY, CONSTRAINT positive_array CHECK (a[1] > 0)) $defaultUsing") + sql(s"CREATE TABLE $source (id INT, value INT) $defaultUsing") + sql(s"INSERT INTO $target VALUES (5, 10, array(5, 6))") + sql(s"INSERT INTO $source VALUES (-1, 20)") + + val error = intercept[SparkRuntimeException] { + sql( + s""" + |MERGE INTO $target t + |USING $source s + |ON t.value = s.value + |WHEN NOT MATCHED THEN INSERT(id, value, a) VALUES (s.id, s.value, array(1, -2)) + |""".stripMargin) + } + checkError( + exception = error, + condition = "CHECK_CONSTRAINT_VIOLATION", + sqlState = "23001", + parameters = + Map("constraintName" -> "positive_array", "expression" -> "a[1] > 0", + "values" -> " - a[1] : -2") + ) + } + } + } test("Check constraint violation on insert overwrite by position") { withNamespaceAndTable("ns", "tbl", nonPartitionCatalog) { t => @@ -507,4 +722,135 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma checkAnswer(spark.table(t), Seq(Row(1, Seq(5, 6, 7)), Row(2, Seq(8, null)))) } } + + test("Check constraint validation succeeds on table update - top level column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { t => + sql(s"CREATE TABLE $t (id INT, value INT," + + s" CONSTRAINT positive_id CHECK (id > 0)) $defaultUsing") + sql(s"INSERT INTO $t VALUES (1, 10), (2, 20)") + sql(s"UPDATE $t SET id = null WHERE value = 10") + checkAnswer(spark.table(t), Seq(Row(null, 10), Row(2, 20))) + } + } + + test("Check constraint validation succeeds on table update - nested column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { t => + sql(s"CREATE TABLE $t (id INT, value INT," + + s" s STRUCT CONSTRAINT positive_num CHECK (s.num > 0)) $defaultUsing") + sql(s"INSERT INTO $t VALUES (1, 10, struct(5, 'test')), (2, 20, struct(10, 'test'))") + sql(s"UPDATE $t SET s = named_struct('num', null, 'str', 'test') WHERE value = 10") + checkAnswer(spark.table(t), Seq(Row(1, 10, Row(null, "test")), Row(2, 20, Row(10, "test")))) + } + } + + test("Check constraint validation succeeds on table update - map type column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { t => + sql(s"CREATE TABLE $t (id INT, value INT," + + s" m MAP CONSTRAINT positive_num CHECK (m['a'] > 0)) $defaultUsing") + sql(s"INSERT INTO $t VALUES (1, 10, map('a', 5)), (2, 20, map('b', 10))") + sql(s"UPDATE $t SET m = map('a', null) WHERE value = 10") + checkAnswer(spark.table(t), Seq(Row(1, 10, Map("a" -> null)), Row(2, 20, Map("b" -> 10)))) + } + } + + test("Check constraint validation succeeds on table update - array type column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { t => + sql(s"CREATE TABLE $t (id INT, value INT," + + s" a ARRAY, CONSTRAINT positive_array CHECK (a[1] > 0)) $defaultUsing") + sql(s"INSERT INTO $t VALUES (1, 10, array(5, 6)), (2, 20, array(7, 8))") + sql(s"UPDATE $t SET a = array(null, 1) WHERE value = 10") + checkAnswer(spark.table(t), Seq(Row(1, 10, Seq(null, 1)), Row(2, 20, Seq(7, 8)))) + } + } + + test("Check constraint validation succeeds on table merge - top level column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { target => + withNamespaceAndTable("ns", "tbl2", rowLevelOPCatalog) { source => + sql(s"CREATE TABLE $target (id INT, value INT," + + s" CONSTRAINT positive_id CHECK (id > 0)) $defaultUsing") + sql(s"CREATE TABLE $source (id INT, value INT) $defaultUsing") + sql(s"INSERT INTO $target VALUES (1, 10), (2, 20)") + sql(s"INSERT INTO $source VALUES (3, 30), (4, 40)") + + sql( + s""" + |MERGE INTO $target t + |USING $source s + |ON t.value = s.value + |WHEN NOT MATCHED THEN INSERT(id, value) VALUES (s.id, s.value) + |""".stripMargin) + checkAnswer(spark.table(target), Seq(Row(1, 10), Row(2, 20), Row(3, 30), Row(4, 40))) + } + } + } + + test("Check constraint validation succeeds on table merge - nested column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { target => + withNamespaceAndTable("ns", "tbl2", rowLevelOPCatalog) { source => + sql(s"CREATE TABLE $target (id INT, value INT," + + s" s STRUCT " + + s"CONSTRAINT positive_num CHECK (s.num > 0)) $defaultUsing") + sql(s"CREATE TABLE $source (id INT, value INT) $defaultUsing") + sql(s"INSERT INTO $target VALUES (1, 10, struct(5, 'test')), (2, 20, struct(10, 'test'))") + sql(s"INSERT INTO $source VALUES (3, 30), (4, 40)") + + sql( + s""" + |MERGE INTO $target t + |USING $source s + |ON t.value = s.value + |WHEN NOT MATCHED THEN INSERT(id, value) VALUES (s.id, s.value) + |""".stripMargin) + checkAnswer(spark.table(target), + Seq(Row(1, 10, Row(5, "test")), Row(2, 20, Row(10, "test")), + Row(3, 30, null), Row(4, 40, null))) + } + } + } + + test("Check constraint validation succeeds on table merge - map type column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { target => + withNamespaceAndTable("ns", "tbl2", rowLevelOPCatalog) { source => + sql(s"CREATE TABLE $target (id INT, value INT," + + s" m MAP CONSTRAINT positive_num CHECK (m['a'] > 0)) $defaultUsing") + sql(s"CREATE TABLE $source (id INT, value INT) $defaultUsing") + sql(s"INSERT INTO $target VALUES (1, 10, map('a', 5)), (2, 20, map('b', 10))") + sql(s"INSERT INTO $source VALUES (3, 30), (4, 40)") + + sql( + s""" + |MERGE INTO $target t + |USING $source s + |ON t.value = s.value + |WHEN NOT MATCHED THEN INSERT(id, value) VALUES (s.id, s.value) + |""".stripMargin) + checkAnswer(spark.table(target), + Seq(Row(1, 10, Map("a" -> 5)), Row(2, 20, Map("b" -> 10)), + Row(3, 30, null), Row(4, 40, null))) + } + } + } + + test("Check constraint validation succeeds on table merge - array type column") { + withNamespaceAndTable("ns", "tbl", rowLevelOPCatalog) { target => + withNamespaceAndTable("ns", "tbl2", rowLevelOPCatalog) { source => + sql(s"CREATE TABLE $target (id INT, value INT," + + s" a ARRAY, CONSTRAINT positive_array CHECK (a[1] > 0)) $defaultUsing") + sql(s"CREATE TABLE $source (id INT, value INT) $defaultUsing") + sql(s"INSERT INTO $target VALUES (1, 10, array(5, 6)), (2, 20, array(7, 8))") + sql(s"INSERT INTO $source VALUES (3, 30), (4, 40)") + + sql( + s""" + |MERGE INTO $target t + |USING $source s + |ON t.value = s.value + |WHEN NOT MATCHED THEN INSERT(id, value) VALUES (s.id, s.value) + |""".stripMargin) + checkAnswer(spark.table(target), + Seq(Row(1, 10, Seq(5, 6)), Row(2, 20, Seq(7, 8)), + Row(3, 30, null), Row(4, 40, null))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala index 76928faec3189..24bc4483d31c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala @@ -35,6 +35,7 @@ trait CommandSuiteBase extends SharedSparkSession { def commandVersion: String = "V2" // The command version is added to test names def catalog: String = "test_catalog" // The default V2 catalog for testing def nonPartitionCatalog: String = "non_part_test_catalog" // Catalog for non-partitioned tables + def rowLevelOPCatalog: String = "row_level_op_catalog" def defaultUsing: String = "USING _" // The clause is used in creating v2 tables under testing // V2 catalogs created and used especially for testing @@ -42,6 +43,8 @@ trait CommandSuiteBase extends SharedSparkSession { .set(s"spark.sql.catalog.$catalog", classOf[InMemoryPartitionTableCatalog].getName) .set(s"spark.sql.catalog.$nonPartitionCatalog", classOf[InMemoryTableCatalog].getName) .set(s"spark.sql.catalog.fun_$catalog", classOf[InMemoryCatalog].getName) + .set(s"spark.sql.catalog.$rowLevelOPCatalog", + classOf[InMemoryRowLevelOperationTableCatalog].getName) def checkLocation( t: String,