Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions daft/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,21 @@
replace,
regexp_replace,
find,
find_in_set,
hamming_distance_str,
levenshtein_distance,
jaro_similarity,
jaro_winkler_similarity,
damerau_levenshtein_distance,
overlay,
translate,
substring_index,
soundex,
ascii_func,
chr_func,
space,
url_encode,
url_decode,
)
from .struct import unnest, to_struct
from .url import download, upload, parse_url
Expand Down Expand Up @@ -415,6 +419,7 @@
"fill_nan",
"fill_null",
"find",
"find_in_set",
"first_value",
"floor",
"format",
Expand Down Expand Up @@ -507,6 +512,7 @@
"not_nan",
"not_null",
"over",
"overlay",
"parse_url",
"partition_days",
"partition_hours",
Expand Down Expand Up @@ -608,6 +614,8 @@
"unnest",
"upload",
"upper",
"url_decode",
"url_encode",
"uuid",
"value_counts",
"var",
Expand Down
161 changes: 161 additions & 0 deletions daft/functions/str.py
Original file line number Diff line number Diff line change
Expand Up @@ -1995,3 +1995,164 @@ def space(expr: Expression) -> Expression:

"""
return Expression._call_builtin_scalar_fn("space", expr)


def find_in_set(needle: Expression, str_array: Expression) -> Expression:
"""Returns the 1-based index of `needle` in the comma-separated `str_array`.

Returns 0 when `needle` is not found in `str_array`. Returns 0 when `needle`
contains a comma (since the search uses ``,`` as the only separator).
Returns null if either input is null. This is compatible with Spark's
``find_in_set`` function.

Args:
needle: The string expression to search for.
str_array: A string expression of values separated by commas.

Returns:
Expression: an Int32 expression with the 1-based index, or 0 if not found.

Examples:
>>> import daft
>>> from daft.functions import find_in_set
>>> df = daft.from_pydict({"x": ["ab", "d", "a,b"], "y": ["abc,b,ab,c", "a,b,c", "a,b,c"]})
>>> df = df.with_column("idx", find_in_set(df["x"], df["y"]))
>>> df.collect()
╭────────┬────────────┬───────╮
│ x ┆ y ┆ idx │
│ --- ┆ --- ┆ --- │
│ String ┆ String ┆ Int32 │
╞════════╪════════════╪═══════╡
│ ab ┆ abc,b,ab,c ┆ 3 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ d ┆ a,b,c ┆ 0 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ a,b ┆ a,b,c ┆ 0 │
╰────────┴────────────┴───────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
"""
return Expression._call_builtin_scalar_fn("find_in_set", needle, str_array)


def overlay(
input: Expression,
replace: Expression,
pos: Expression | int,
length: Expression | int | None = None,
) -> Expression:
"""Replaces the substring of ``input`` starting at ``pos`` (1-based) with ``replace``.

``length`` controls how many characters of the original string are removed.
When ``length`` is omitted or negative, the character length of ``replace`` is
used. Positions less than 1 are treated as 1. This is compatible with Spark's
``overlay`` function.

Args:
input: The string expression to modify.
replace: The replacement string expression.
pos: 1-based starting position (integer expression or literal).
length: Number of characters to overwrite. If None or negative, uses the
length of ``replace``.

Returns:
Expression: a String expression with the substring overlaid.

Examples:
>>> import daft
>>> from daft.functions import overlay
>>> df = daft.from_pydict({"x": ["Spark SQL", "Spark SQL", "Spark SQL"]})
>>> df = df.with_column("a", overlay(df["x"], lit("_"), 6))
>>> df = df.with_column("b", overlay(df["x"], lit("CORE"), 7))
>>> df = df.with_column("c", overlay(df["x"], lit("ANSI "), 7, 0))
>>> df.collect()
╭───────────┬───────────┬────────────┬────────────────╮
│ x ┆ a ┆ b ┆ c │
│ --- ┆ --- ┆ --- ┆ --- │
│ String ┆ String ┆ String ┆ String │
╞═══════════╪═══════════╪════════════╪════════════════╡
│ Spark SQL ┆ Spark_SQL ┆ Spark CORE ┆ Spark ANSI SQL │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ Spark SQL ┆ Spark_SQL ┆ Spark CORE ┆ Spark ANSI SQL │
├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ Spark SQL ┆ Spark_SQL ┆ Spark CORE ┆ Spark ANSI SQL │
╰───────────┴───────────┴────────────┴────────────────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
"""
pos_expr = pos if isinstance(pos, Expression) else lit(pos)
if length is None:
return Expression._call_builtin_scalar_fn("overlay", input, replace, pos_expr)
length_expr = length if isinstance(length, Expression) else lit(length)
return Expression._call_builtin_scalar_fn("overlay", input, replace, pos_expr, length_expr)


def url_encode(expr: Expression) -> Expression:
"""Translates a string into ``application/x-www-form-urlencoded`` format using UTF-8.

Spaces become ``+`` and unsafe characters are percent-encoded. This is
compatible with Spark's ``url_encode`` function.

Args:
expr: The string expression to encode.

Returns:
Expression: a String expression with the URL-encoded result.

Examples:
>>> import daft
>>> from daft.functions import url_encode
>>> df = daft.from_pydict({"x": ["Spark SQL", "https://daft.ai", "中文"]})
>>> df = df.with_column("encoded", url_encode(df["x"]))
>>> df.collect()
╭─────────────────┬───────────────────────╮
│ x ┆ encoded │
│ --- ┆ --- │
│ String ┆ String │
╞═════════════════╪═══════════════════════╡
│ Spark SQL ┆ Spark+SQL │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ https://daft.ai ┆ https%3A%2F%2Fdaft.ai │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 中文 ┆ %E4%B8%AD%E6%96%87 │
╰─────────────────┴───────────────────────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
"""
return Expression._call_builtin_scalar_fn("url_encode", expr)


def url_decode(expr: Expression) -> Expression:
"""Decodes a string in ``application/x-www-form-urlencoded`` format using UTF-8.

``+`` is converted to a space and ``%XX`` escape sequences are converted back
to bytes. This is compatible with Spark's ``url_decode`` function. Raises a
ValueError when the input contains invalid percent-encodings.

Args:
expr: The URL-encoded string expression to decode.

Returns:
Expression: a String expression with the decoded result.

Examples:
>>> import daft
>>> from daft.functions import url_decode
>>> df = daft.from_pydict({"x": ["Spark+SQL", "https%3A%2F%2Fdaft.ai", "%E4%B8%AD"]})
>>> df = df.with_column("decoded", url_decode(df["x"]))
>>> df.collect()
╭───────────────────────┬─────────────────╮
│ x ┆ decoded │
│ --- ┆ --- │
│ String ┆ String │
╞═══════════════════════╪═════════════════╡
│ Spark+SQL ┆ Spark SQL │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ https%3A%2F%2Fdaft.ai ┆ https://daft.ai │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ %E4%B8%AD ┆ 中 │
╰───────────────────────┴─────────────────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
"""
return Expression._call_builtin_scalar_fn("url_decode", expr)
151 changes: 151 additions & 0 deletions src/daft-functions-utf8/src/find_in_set.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use common_error::{DaftError, DaftResult, ensure};
use daft_core::{
array::DataArray,
datatypes::Int32Type,
prelude::{DataType, Field, FullNull, Schema},
series::{IntoSeries, Series},
};
use daft_dsl::{
ExprRef,
functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn},
};
use serde::{Deserialize, Serialize};

use crate::utils::{create_broadcasted_str_iter, parse_inputs};

/// Compute Spark-compatible `find_in_set(str, str_array)`:
/// Returns the 1-based index of `str` in the comma-separated `str_array`,
/// or 0 if not found, or 0 if `str` contains a comma.
fn compute_find_in_set(needle: &str, haystack: &str) -> i32 {
if needle.contains(',') {
return 0;
}
for (idx, part) in haystack.split(',').enumerate() {
if part == needle {
return (idx as i32) + 1;
}
}
0
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct FindInSet;

#[typetag::serde]
impl ScalarUDF for FindInSet {
fn name(&self) -> &'static str {
"find_in_set"
}

fn call(
&self,
inputs: FunctionArgs<Series>,
_ctx: &daft_dsl::functions::scalar::EvalContext,
) -> DaftResult<Series> {
let needle = inputs.required((0, "str"))?.cast(&DataType::Utf8)?;
let haystack = inputs.required((1, "str_array"))?.cast(&DataType::Utf8)?;
let name = needle.name();

needle.with_utf8_array(|needle_arr| {
haystack.with_utf8_array(|haystack_arr| {
let (is_full_null, expected_size) = parse_inputs(needle_arr, &[haystack_arr])
.map_err(|e| DaftError::ValueError(format!("Error in find_in_set: {e}")))?;

if is_full_null {
return Ok(DataArray::<Int32Type>::full_null(
name,
&DataType::Int32,
expected_size,
)
.into_series());
}
if expected_size == 0 {
return Ok(DataArray::<Int32Type>::empty(name, &DataType::Int32).into_series());
}

let needle_iter = create_broadcasted_str_iter(needle_arr, expected_size);
let haystack_iter = create_broadcasted_str_iter(haystack_arr, expected_size);

let result: DataArray<Int32Type> = needle_iter
.zip(haystack_iter)
.map(|(n, h)| match (n, h) {
(Some(n), Some(h)) => Some(compute_find_in_set(n, h)),
_ => None,
})
.collect();

Ok(result.rename(name).into_series())
})
})
}

fn get_return_field(
&self,
inputs: FunctionArgs<ExprRef>,
schema: &Schema,
) -> DaftResult<Field> {
ensure!(
inputs.len() == 2,
SchemaMismatch: "Expected 2 inputs, but received {}",
inputs.len()
);

let needle = inputs.required((0, "str"))?.to_field(schema)?;
let haystack = inputs.required((1, "str_array"))?.to_field(schema)?;

ensure!(
needle.dtype.is_string() || needle.dtype == DataType::Null,
TypeError: "First argument to 'find_in_set' must be a string, got {}",
needle.dtype
);
ensure!(
haystack.dtype.is_string() || haystack.dtype == DataType::Null,
TypeError: "Second argument to 'find_in_set' must be a string, got {}",
haystack.dtype
);

Ok(Field::new(needle.name, DataType::Int32))
}

fn docstring(&self) -> &'static str {
"Returns the 1-based index of the first argument in the comma-separated second argument. \
Returns 0 if not found, or 0 if the first argument contains a comma. \
Returns null if either input is null."
}
}

#[must_use]
pub fn find_in_set(needle: ExprRef, haystack: ExprRef) -> ExprRef {
ScalarFn::builtin(FindInSet, vec![needle, haystack]).into()
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_find_in_set_basic() {
assert_eq!(compute_find_in_set("ab", "abc,b,ab,c,def"), 3);
assert_eq!(compute_find_in_set("a", "a,b,c"), 1);
assert_eq!(compute_find_in_set("c", "a,b,c"), 3);
assert_eq!(compute_find_in_set("d", "a,b,c"), 0);
}

#[test]
fn test_find_in_set_empty() {
assert_eq!(compute_find_in_set("", ""), 1);
assert_eq!(compute_find_in_set("", "a,,b"), 2);
assert_eq!(compute_find_in_set("a", ""), 0);
}

#[test]
fn test_find_in_set_with_comma_in_needle() {
// Needles containing comma always return 0 per Spark spec.
assert_eq!(compute_find_in_set("a,b", "a,b,c"), 0);
}

#[test]
fn test_find_in_set_unicode() {
assert_eq!(compute_find_in_set("β", "α,β,γ"), 2);
}
}
4 changes: 4 additions & 0 deletions src/daft-functions-utf8/src/levenshtein.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ impl ScalarUDF for LevenshteinDistance {
"levenshtein_distance"
}

fn aliases(&self) -> &'static [&'static str] {
&["levenshtein"]
}

fn call(
&self,
inputs: FunctionArgs<Series>,
Expand Down
Loading
Loading