Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, StringType}
import org.apache.spark.status.ElementTrackingStore
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.FileSegment
import org.apache.spark.unsafe.types.UTF8String

import org.apache.auron.{protobuf => pb, sparkver}
import org.apache.auron.common.AuronBuildInfo
Expand Down Expand Up @@ -578,17 +579,8 @@ class ShimsImpl extends Shims with Logging {
.setMonotonicIncreasingIdExpr(pb.MonotonicIncreasingIdExprNode.newBuilder())
.build())

case StringSplit(str, pat @ Literal(_, StringType), Literal(-1, IntegerType))
// native StringSplit implementation does not support regex, so only most frequently
// used cases without regex are supported
if Seq(",", ", ", ":", ";", "#", "@", "_", "-", "\\|", "\\.").contains(
pat.value.toString) =>
val nativePat = pat.value.toString match {
case "\\|" => "|"
case "\\." => "."
case other => other
}
Some(
case StringSplit(str, Literal(pattern: UTF8String, StringType), Literal(-1, IntegerType)) =>
literalStringSplitPattern(pattern.toString).map { nativePat =>
pb.PhysicalExprNode
.newBuilder()
.setScalarFunction(
Expand All @@ -600,7 +592,8 @@ class ShimsImpl extends Shims with Logging {
.addArgs(NativeConverters
.convertExprWithFallback(Literal(nativePat), isPruningExpr, fallback))
.setReturnType(NativeConverters.convertDataType(ArrayType(StringType))))
.build())
.build()
}

case e: TaggingExpression =>
Some(NativeConverters.convertExprWithFallback(e.child, isPruningExpr, fallback))
Expand All @@ -617,6 +610,42 @@ class ShimsImpl extends Shims with Logging {
}
}

private def literalStringSplitPattern(pattern: String): Option[String] = {
if (pattern.isEmpty) {
return None
}

val regexMetaCharacters = Set('.', '^', '$', '|', '?', '*', '+', '(', ')', '[', ']', '{', '}')
val literalPattern = new StringBuilder
var index = 0
while (index < pattern.length) {
val ch = pattern.charAt(index)
if (ch == '\\') {
if (index + 1 >= pattern.length) {
return None
}
val escaped = pattern.charAt(index + 1)
if (regexMetaCharacters.contains(escaped) || escaped == '\\') {
Comment thread
weimingdiit marked this conversation as resolved.
literalPattern.append(escaped)
index += 2
} else {
return None
}
} else if (regexMetaCharacters.contains(ch)) {
return None
} else {
literalPattern.append(ch)
index += 1
}
}

if (literalPattern.nonEmpty) {
Some(literalPattern.toString)
} else {
None
}
}

override def getLikeEscapeChar(expr: Expression): Char = {
expr.asInstanceOf[Like].escapeChar
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.sql.Date
import java.text.SimpleDateFormat

import org.apache.spark.sql.{AuronQueryTest, Row}
import org.apache.spark.sql.functions.{col, split}
import org.apache.spark.sql.internal.SQLConf

import org.apache.auron.util.AuronTestUtils
Expand Down Expand Up @@ -189,6 +190,46 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite {
}
}

test("split function with literal regex patterns") {
withTable("t1") {
sql("create table t1(c1 string, c2 string, c3 string) using parquet")
sql("insert into t1 values('a/b/c', 'a+b+c', 'a::b::c'), (null, null, null)")
checkSparkAnswerAndOperator { () =>
sql("select c1, c2, c3 from t1").select(
split(col("c1"), "/"),
split(col("c2"), "\\+"),
split(col("c3"), "::"))
}
}
}

test("split function with regex patterns falls back to Spark") {
withTable("t1") {
sql("create table t1(c1 string, c2 string, c3 string) using parquet")
sql("insert into t1 values('abc', 'a+b+c', 'a.b.c'), (null, null, null)")
val df = checkSparkAnswerAndOperator(
() =>
sql("select c1, c2, c3 from t1")
.select(split(col("c1"), ".+"), split(col("c2"), "[+]"), split(col("c3"), ".")),
requireNative = false)
val plan = stripAQEPlan(df.queryExecution.executedPlan)
assert(
plan.collectFirst { case op if !isNativeOrPassThrough(op) => op }.nonEmpty,
"regex split patterns should fall back to Spark")

val escapedBackslashDf = checkSparkAnswerAndOperator(
() => sql("select c2 from t1").select(split(col("c2"), "\\\\+")),
requireNative = false)
val escapedBackslashPlan = stripAQEPlan(escapedBackslashDf.queryExecution.executedPlan)
val escapedBackslashFallback = escapedBackslashPlan.collectFirst {
case op if !isNativeOrPassThrough(op) => op
}.nonEmpty
assert(
escapedBackslashFallback,
"split on repeated backslash regex should fall back to Spark")
}
}

test("weekofyear function") {
withSQLConf("spark.sql.session.timeZone" -> "America/Los_Angeles") {
withTable("t1") {
Expand Down
Loading