diff --git a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala index b8c825fde..44cc48d51 100644 --- a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala +++ b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala @@ -109,6 +109,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, StringType} import org.apache.spark.status.ElementTrackingStore import org.apache.spark.storage.BlockManagerId import org.apache.spark.storage.FileSegment +import org.apache.spark.unsafe.types.UTF8String import org.apache.auron.{protobuf => pb, sparkver} import org.apache.auron.common.AuronBuildInfo @@ -578,17 +579,8 @@ class ShimsImpl extends Shims with Logging { .setMonotonicIncreasingIdExpr(pb.MonotonicIncreasingIdExprNode.newBuilder()) .build()) - case StringSplit(str, pat @ Literal(_, StringType), Literal(-1, IntegerType)) - // native StringSplit implementation does not support regex, so only most frequently - // used cases without regex are supported - if Seq(",", ", ", ":", ";", "#", "@", "_", "-", "\\|", "\\.").contains( - pat.value.toString) => - val nativePat = pat.value.toString match { - case "\\|" => "|" - case "\\." => "." - case other => other - } - Some( + case StringSplit(str, Literal(pattern: UTF8String, StringType), Literal(-1, IntegerType)) => + literalStringSplitPattern(pattern.toString).map { nativePat => pb.PhysicalExprNode .newBuilder() .setScalarFunction( @@ -600,7 +592,8 @@ class ShimsImpl extends Shims with Logging { .addArgs(NativeConverters .convertExprWithFallback(Literal(nativePat), isPruningExpr, fallback)) .setReturnType(NativeConverters.convertDataType(ArrayType(StringType)))) - .build()) + .build() + } case e: TaggingExpression => Some(NativeConverters.convertExprWithFallback(e.child, isPruningExpr, fallback)) @@ -617,6 +610,42 @@ class ShimsImpl extends Shims with Logging { } } + private def literalStringSplitPattern(pattern: String): Option[String] = { + if (pattern.isEmpty) { + return None + } + + val regexMetaCharacters = Set('.', '^', '$', '|', '?', '*', '+', '(', ')', '[', ']', '{', '}') + val literalPattern = new StringBuilder + var index = 0 + while (index < pattern.length) { + val ch = pattern.charAt(index) + if (ch == '\\') { + if (index + 1 >= pattern.length) { + return None + } + val escaped = pattern.charAt(index + 1) + if (regexMetaCharacters.contains(escaped) || escaped == '\\') { + literalPattern.append(escaped) + index += 2 + } else { + return None + } + } else if (regexMetaCharacters.contains(ch)) { + return None + } else { + literalPattern.append(ch) + index += 1 + } + } + + if (literalPattern.nonEmpty) { + Some(literalPattern.toString) + } else { + None + } + } + override def getLikeEscapeChar(expr: Expression): Char = { expr.asInstanceOf[Like].escapeChar } diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala index 9f8825087..735abb71d 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala @@ -20,6 +20,7 @@ import java.sql.Date import java.text.SimpleDateFormat import org.apache.spark.sql.{AuronQueryTest, Row} +import org.apache.spark.sql.functions.{col, split} import org.apache.spark.sql.internal.SQLConf import org.apache.auron.util.AuronTestUtils @@ -189,6 +190,46 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite { } } + test("split function with literal regex patterns") { + withTable("t1") { + sql("create table t1(c1 string, c2 string, c3 string) using parquet") + sql("insert into t1 values('a/b/c', 'a+b+c', 'a::b::c'), (null, null, null)") + checkSparkAnswerAndOperator { () => + sql("select c1, c2, c3 from t1").select( + split(col("c1"), "/"), + split(col("c2"), "\\+"), + split(col("c3"), "::")) + } + } + } + + test("split function with regex patterns falls back to Spark") { + withTable("t1") { + sql("create table t1(c1 string, c2 string, c3 string) using parquet") + sql("insert into t1 values('abc', 'a+b+c', 'a.b.c'), (null, null, null)") + val df = checkSparkAnswerAndOperator( + () => + sql("select c1, c2, c3 from t1") + .select(split(col("c1"), ".+"), split(col("c2"), "[+]"), split(col("c3"), ".")), + requireNative = false) + val plan = stripAQEPlan(df.queryExecution.executedPlan) + assert( + plan.collectFirst { case op if !isNativeOrPassThrough(op) => op }.nonEmpty, + "regex split patterns should fall back to Spark") + + val escapedBackslashDf = checkSparkAnswerAndOperator( + () => sql("select c2 from t1").select(split(col("c2"), "\\\\+")), + requireNative = false) + val escapedBackslashPlan = stripAQEPlan(escapedBackslashDf.queryExecution.executedPlan) + val escapedBackslashFallback = escapedBackslashPlan.collectFirst { + case op if !isNativeOrPassThrough(op) => op + }.nonEmpty + assert( + escapedBackslashFallback, + "split on repeated backslash regex should fall back to Spark") + } + } + test("weekofyear function") { withSQLConf("spark.sql.session.timeZone" -> "America/Los_Angeles") { withTable("t1") {