Skip to content

Commit 832ff87

Browse files
belieferueshin
authored andcommitted
[SPARK-28077][SQL] Support ANSI SQL OVERLAY function.
## What changes were proposed in this pull request? The `OVERLAY` function is a `ANSI` `SQL`. For example: ``` SELECT OVERLAY('abcdef' PLACING '45' FROM 4); SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5); SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5 FOR 0); SELECT OVERLAY('babosa' PLACING 'ubb' FROM 2 FOR 4); ``` The results of the above four `SQL` are: ``` abc45f yabadaba yabadabadoo bubba ``` Note: If the input string is null, then the result is null too. There are some mainstream database support the syntax. **PostgreSQL:** https://www.postgresql.org/docs/11/functions-string.html **Vertica:** https://www.vertica.com/docs/9.2.x/HTML/Content/Authoring/SQLReferenceManual/Functions/String/OVERLAY.htm?zoom_highlight=overlay **Oracle:** https://docs.oracle.com/en/database/oracle/oracle-database/19/arpls/UTL_RAW.html#GUID-342E37E7-FE43-4CE1-A0E9-7DAABD000369 **DB2:** https://www.ibm.com/support/knowledgecenter/SSGMCP_5.3.0/com.ibm.cics.rexx.doc/rexx/overlay.html There are some show of the PR on my production environment. ``` spark-sql> SELECT OVERLAY('abcdef' PLACING '45' FROM 4); abc45f Time taken: 6.385 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5); yabadaba Time taken: 0.191 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5 FOR 0); yabadabadoo Time taken: 0.186 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY('babosa' PLACING 'ubb' FROM 2 FOR 4); bubba Time taken: 0.151 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY(null PLACING '45' FROM 4); NULL Time taken: 0.22 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY(null PLACING 'daba' FROM 5); NULL Time taken: 0.157 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY(null PLACING 'daba' FROM 5 FOR 0); NULL Time taken: 0.254 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY(null PLACING 'ubb' FROM 2 FOR 4); NULL Time taken: 0.159 seconds, Fetched 1 row(s) ``` ## How was this patch tested? Exists UT and new UT. Closes apache#24918 from beliefer/ansi-sql-overlay. Lead-authored-by: gengjiaan <[email protected]> Co-authored-by: Jiaan Geng <[email protected]> Signed-off-by: Takuya UESHIN <[email protected]>
1 parent 31e7c37 commit 832ff87

File tree

10 files changed

+281
-0
lines changed

10 files changed

+281
-0
lines changed

docs/sql-keywords.md

+2
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,14 @@ Below is a list of all the keywords in Spark SQL.
194194
<tr><td>OUTPUTFORMAT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
195195
<tr><td>OVER</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
196196
<tr><td>OVERLAPS</td><td>reserved</td><td>non-reserved</td><td>reserved</td></tr>
197+
<tr><td>OVERLAY</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
197198
<tr><td>OVERWRITE</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
198199
<tr><td>PARTITION</td><td>non-reserved</td><td>non-reserved</td><td>reserved</td></tr>
199200
<tr><td>PARTITIONED</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
200201
<tr><td>PARTITIONS</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
201202
<tr><td>PERCENT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
202203
<tr><td>PIVOT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
204+
<tr><td>PLACING</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
203205
<tr><td>POSITION</td><td>non-reserved</td><td>non-reserved</td><td>reserved</td></tr>
204206
<tr><td>PRECEDING</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
205207
<tr><td>PRIMARY</td><td>reserved</td><td>non-reserved</td><td>reserved</td></tr>

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

+7
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,8 @@ primaryExpression
705705
((FOR | ',') len=valueExpression)? ')' #substring
706706
| TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)?
707707
FROM srcStr=valueExpression ')' #trim
708+
| OVERLAY '(' input=valueExpression PLACING replace=valueExpression
709+
FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay
708710
;
709711

710712
constant
@@ -1002,6 +1004,7 @@ ansiNonReserved
10021004
| OUT
10031005
| OUTPUTFORMAT
10041006
| OVER
1007+
| OVERLAY
10051008
| OVERWRITE
10061009
| PARTITION
10071010
| PARTITIONED
@@ -1253,12 +1256,14 @@ nonReserved
12531256
| OUTPUTFORMAT
12541257
| OVER
12551258
| OVERLAPS
1259+
| OVERLAY
12561260
| OVERWRITE
12571261
| PARTITION
12581262
| PARTITIONED
12591263
| PARTITIONS
12601264
| PERCENTLIT
12611265
| PIVOT
1266+
| PLACING
12621267
| POSITION
12631268
| PRECEDING
12641269
| PRIMARY
@@ -1509,12 +1514,14 @@ OUTER: 'OUTER';
15091514
OUTPUTFORMAT: 'OUTPUTFORMAT';
15101515
OVER: 'OVER';
15111516
OVERLAPS: 'OVERLAPS';
1517+
OVERLAY: 'OVERLAY';
15121518
OVERWRITE: 'OVERWRITE';
15131519
PARTITION: 'PARTITION';
15141520
PARTITIONED: 'PARTITIONED';
15151521
PARTITIONS: 'PARTITIONS';
15161522
PERCENTLIT: 'PERCENT';
15171523
PIVOT: 'PIVOT';
1524+
PLACING: 'PLACING';
15181525
POSITION: 'POSITION';
15191526
PRECEDING: 'PRECEDING';
15201527
PRIMARY: 'PRIMARY';

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

+1
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ object FunctionRegistry {
348348
expression[RegExpReplace]("regexp_replace"),
349349
expression[StringRepeat]("repeat"),
350350
expression[StringReplace]("replace"),
351+
expression[Overlay]("overlay"),
351352
expression[RLike]("rlike"),
352353
expression[StringRPad]("rpad"),
353354
expression[StringTrimRight]("rtrim"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

+106
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ import org.apache.spark.sql.types._
6868
* - [[UnaryExpression]]: an expression that has one child.
6969
* - [[BinaryExpression]]: an expression that has two children.
7070
* - [[TernaryExpression]]: an expression that has three children.
71+
* - [[QuaternaryExpression]]: an expression that has four children.
7172
* - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have
7273
* the same output data type.
7374
*
@@ -757,6 +758,111 @@ abstract class TernaryExpression extends Expression {
757758
}
758759
}
759760

761+
/**
762+
* An expression with four inputs and one output. The output is by default evaluated to null
763+
* if any input is evaluated to null.
764+
*/
765+
abstract class QuaternaryExpression extends Expression {
766+
767+
override def foldable: Boolean = children.forall(_.foldable)
768+
769+
override def nullable: Boolean = children.exists(_.nullable)
770+
771+
/**
772+
* Default behavior of evaluation according to the default nullability of QuaternaryExpression.
773+
* If subclass of QuaternaryExpression override nullable, probably should also override this.
774+
*/
775+
override def eval(input: InternalRow): Any = {
776+
val exprs = children
777+
val value1 = exprs(0).eval(input)
778+
if (value1 != null) {
779+
val value2 = exprs(1).eval(input)
780+
if (value2 != null) {
781+
val value3 = exprs(2).eval(input)
782+
if (value3 != null) {
783+
val value4 = exprs(3).eval(input)
784+
if (value4 != null) {
785+
return nullSafeEval(value1, value2, value3, value4)
786+
}
787+
}
788+
}
789+
}
790+
null
791+
}
792+
793+
/**
794+
* Called by default [[eval]] implementation. If subclass of QuaternaryExpression keep the
795+
* default nullability, they can override this method to save null-check code. If we need
796+
* full control of evaluation process, we should override [[eval]].
797+
*/
798+
protected def nullSafeEval(input1: Any, input2: Any, input3: Any, input4: Any): Any =
799+
sys.error(s"QuaternaryExpressions must override either eval or nullSafeEval")
800+
801+
/**
802+
* Short hand for generating quaternary evaluation code.
803+
* If either of the sub-expressions is null, the result of this computation
804+
* is assumed to be null.
805+
*
806+
* @param f accepts four variable names and returns Java code to compute the output.
807+
*/
808+
protected def defineCodeGen(
809+
ctx: CodegenContext,
810+
ev: ExprCode,
811+
f: (String, String, String, String) => String): ExprCode = {
812+
nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3, eval4) => {
813+
s"${ev.value} = ${f(eval1, eval2, eval3, eval4)};"
814+
})
815+
}
816+
817+
/**
818+
* Short hand for generating quaternary evaluation code.
819+
* If either of the sub-expressions is null, the result of this computation
820+
* is assumed to be null.
821+
*
822+
* @param f function that accepts the 4 non-null evaluation result names of children
823+
* and returns Java code to compute the output.
824+
*/
825+
protected def nullSafeCodeGen(
826+
ctx: CodegenContext,
827+
ev: ExprCode,
828+
f: (String, String, String, String) => String): ExprCode = {
829+
val firstGen = children(0).genCode(ctx)
830+
val secondGen = children(1).genCode(ctx)
831+
val thridGen = children(2).genCode(ctx)
832+
val fourthGen = children(3).genCode(ctx)
833+
val resultCode = f(firstGen.value, secondGen.value, thridGen.value, fourthGen.value)
834+
835+
if (nullable) {
836+
val nullSafeEval =
837+
firstGen.code + ctx.nullSafeExec(children(0).nullable, firstGen.isNull) {
838+
secondGen.code + ctx.nullSafeExec(children(1).nullable, secondGen.isNull) {
839+
thridGen.code + ctx.nullSafeExec(children(2).nullable, thridGen.isNull) {
840+
fourthGen.code + ctx.nullSafeExec(children(3).nullable, fourthGen.isNull) {
841+
s"""
842+
${ev.isNull} = false; // resultCode could change nullability.
843+
$resultCode
844+
"""
845+
}
846+
}
847+
}
848+
}
849+
850+
ev.copy(code = code"""
851+
boolean ${ev.isNull} = true;
852+
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
853+
$nullSafeEval""")
854+
} else {
855+
ev.copy(code = code"""
856+
${firstGen.code}
857+
${secondGen.code}
858+
${thridGen.code}
859+
${fourthGen.code}
860+
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
861+
$resultCode""", isNull = FalseLiteral)
862+
}
863+
}
864+
}
865+
760866
/**
761867
* A trait used for resolving nullable flags, including `nullable`, `containsNull` of [[ArrayType]]
762868
* and `valueContainsNull` of [[MapType]], containsNull, valueContainsNull flags of the output date

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

+64
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
3232
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
3333
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
3434
import org.apache.spark.sql.types._
35+
import org.apache.spark.unsafe.UTF8StringBuilder
3536
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
3637

3738
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -454,6 +455,69 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp
454455
override def prettyName: String = "replace"
455456
}
456457

458+
object Overlay {
459+
460+
def calculate(input: UTF8String, replace: UTF8String, pos: Int, len: Int): UTF8String = {
461+
val builder = new UTF8StringBuilder
462+
builder.append(input.substringSQL(1, pos - 1))
463+
builder.append(replace)
464+
// If you specify length, it must be a positive whole number or zero.
465+
// Otherwise it will be ignored.
466+
// The default value for length is the length of replace.
467+
val length = if (len >= 0) {
468+
len
469+
} else {
470+
replace.numChars
471+
}
472+
builder.append(input.substringSQL(pos + length, Int.MaxValue))
473+
builder.build()
474+
}
475+
}
476+
477+
// scalastyle:off line.size.limit
478+
@ExpressionDescription(
479+
usage = "_FUNC_(input, replace, pos[, len]) - Replace `input` with `replace` that starts at `pos` and is of length `len`.",
480+
examples = """
481+
Examples:
482+
> SELECT _FUNC_('Spark SQL' PLACING '_' FROM 6);
483+
Spark_SQL
484+
> SELECT _FUNC_('Spark SQL' PLACING 'CORE' FROM 7);
485+
Spark CORE
486+
> SELECT _FUNC_('Spark SQL' PLACING 'ANSI ' FROM 7 FOR 0);
487+
Spark ANSI SQL
488+
> SELECT _FUNC_('Spark SQL' PLACING 'tructured' FROM 2 FOR 4);
489+
Structured SQL
490+
""")
491+
// scalastyle:on line.size.limit
492+
case class Overlay(input: Expression, replace: Expression, pos: Expression, len: Expression)
493+
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
494+
495+
def this(str: Expression, replace: Expression, pos: Expression) = {
496+
this(str, replace, pos, Literal.create(-1, IntegerType))
497+
}
498+
499+
override def dataType: DataType = StringType
500+
501+
override def inputTypes: Seq[AbstractDataType] =
502+
Seq(StringType, StringType, IntegerType, IntegerType)
503+
504+
override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil
505+
506+
override def nullSafeEval(inputEval: Any, replaceEval: Any, posEval: Any, lenEval: Any): Any = {
507+
val inputStr = inputEval.asInstanceOf[UTF8String]
508+
val replaceStr = replaceEval.asInstanceOf[UTF8String]
509+
val position = posEval.asInstanceOf[Int]
510+
val length = lenEval.asInstanceOf[Int]
511+
Overlay.calculate(inputStr, replaceStr, position, length)
512+
}
513+
514+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
515+
defineCodeGen(ctx, ev, (input, replace, pos, len) =>
516+
"org.apache.spark.sql.catalyst.expressions.Overlay" +
517+
s".calculate($input, $replace, $pos, $len);")
518+
}
519+
}
520+
457521
object StringTranslate {
458522

459523
def buildDict(matchingString: UTF8String, replaceString: UTF8String)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

+14
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
14211421
}
14221422
}
14231423

1424+
/**
1425+
* Create a Overlay expression.
1426+
*/
1427+
override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) {
1428+
val input = expression(ctx.input)
1429+
val replace = expression(ctx.replace)
1430+
val position = expression(ctx.position)
1431+
val lengthOpt = Option(ctx.length).map(expression)
1432+
lengthOpt match {
1433+
case Some(length) => Overlay(input, replace, position, length)
1434+
case None => new Overlay(input, replace, position)
1435+
}
1436+
}
1437+
14241438
/**
14251439
* Create a (windowed) Function expression.
14261440
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

+24
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
428428
// scalastyle:on
429429
}
430430

431+
test("overlay") {
432+
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("_"),
433+
Literal.create(6, IntegerType)), "Spark_SQL")
434+
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("CORE"),
435+
Literal.create(7, IntegerType)), "Spark CORE")
436+
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("ANSI "),
437+
Literal.create(7, IntegerType), Literal.create(0, IntegerType)), "Spark ANSI SQL")
438+
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("tructured"),
439+
Literal.create(2, IntegerType), Literal.create(4, IntegerType)), "Structured SQL")
440+
checkEvaluation(new Overlay(Literal.create(null, StringType), Literal("_"),
441+
Literal.create(6, IntegerType)), null)
442+
checkEvaluation(new Overlay(Literal.create(null, StringType), Literal("CORE"),
443+
Literal.create(7, IntegerType)), null)
444+
checkEvaluation(Overlay(Literal.create(null, StringType), Literal("ANSI "),
445+
Literal.create(7, IntegerType), Literal.create(0, IntegerType)), null)
446+
checkEvaluation(Overlay(Literal.create(null, StringType), Literal("tructured"),
447+
Literal.create(2, IntegerType), Literal.create(4, IntegerType)), null)
448+
// scalastyle:off
449+
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
450+
checkEvaluation(new Overlay(Literal("Spark的SQL"), Literal("_"),
451+
Literal.create(6, IntegerType)), "Spark_SQL")
452+
// scalastyle:on
453+
}
454+
431455
test("translate") {
432456
checkEvaluation(
433457
StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

+29
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,35 @@ class PlanParserSuite extends AnalysisTest {
744744
)
745745
}
746746

747+
test("OVERLAY function") {
748+
def assertOverlayPlans(inputSQL: String, expectedExpression: Expression): Unit = {
749+
comparePlans(
750+
parsePlan(inputSQL),
751+
Project(Seq(UnresolvedAlias(expectedExpression)), OneRowRelation())
752+
)
753+
}
754+
755+
assertOverlayPlans(
756+
"SELECT OVERLAY('Spark SQL' PLACING '_' FROM 6)",
757+
new Overlay(Literal("Spark SQL"), Literal("_"), Literal(6))
758+
)
759+
760+
assertOverlayPlans(
761+
"SELECT OVERLAY('Spark SQL' PLACING 'CORE' FROM 7)",
762+
new Overlay(Literal("Spark SQL"), Literal("CORE"), Literal(7))
763+
)
764+
765+
assertOverlayPlans(
766+
"SELECT OVERLAY('Spark SQL' PLACING 'ANSI ' FROM 7 FOR 0)",
767+
Overlay(Literal("Spark SQL"), Literal("ANSI "), Literal(7), Literal(0))
768+
)
769+
770+
assertOverlayPlans(
771+
"SELECT OVERLAY('Spark SQL' PLACING 'tructured' FROM 2 FOR 4)",
772+
Overlay(Literal("Spark SQL"), Literal("tructured"), Literal(2), Literal(4))
773+
)
774+
}
775+
747776
test("precedence of set operations") {
748777
val a = table("a").select(star())
749778
val b = table("b").select(star())

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

+22
Original file line numberDiff line numberDiff line change
@@ -2516,6 +2516,28 @@ object functions {
25162516
SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
25172517
}
25182518

2519+
/**
2520+
* Overlay the specified portion of `src` with `replaceString`,
2521+
* starting from byte position `pos` of `inputString` and proceeding for `len` bytes.
2522+
*
2523+
* @group string_funcs
2524+
* @since 3.0.0
2525+
*/
2526+
def overlay(src: Column, replaceString: String, pos: Int, len: Int): Column = withExpr {
2527+
Overlay(src.expr, lit(replaceString).expr, lit(pos).expr, lit(len).expr)
2528+
}
2529+
2530+
/**
2531+
* Overlay the specified portion of `src` with `replaceString`,
2532+
* starting from byte position `pos` of `inputString`.
2533+
*
2534+
* @group string_funcs
2535+
* @since 3.0.0
2536+
*/
2537+
def overlay(src: Column, replaceString: String, pos: Int): Column = withExpr {
2538+
new Overlay(src.expr, lit(replaceString).expr, lit(pos).expr)
2539+
}
2540+
25192541
/**
25202542
* Translate any character in the src by a character in replaceString.
25212543
* The characters in replaceString correspond to the characters in matchingString.

0 commit comments

Comments
 (0)