From 2d4b9b4800401e5d1079ddc006a88affc0672738 Mon Sep 17 00:00:00 2001 From: Nishaanth Reddy Date: Wed, 3 Jun 2026 18:10:53 -0700 Subject: [PATCH 1/5] feat(functions): add string distance/similarity functions - add levenshtein_distance, jaro_similarity, jaro_winkler_similarity, damerau_levenshtein_distance - pure Rust implementations with no external dependencies, following hamming_distance_str pattern - expose as top-level daft.functions API and Expression methods - handle null inputs (return null) and null-typed columns (DataType::Null) - include 24 pytest test cases covering correctness, edge cases, and null handling --- daft/expressions/expressions.py | 40 ++++ daft/functions/__init__.py | 8 + daft/functions/str.py | 151 +++++++++++++++ .../src/damerau_levenshtein.rs | 145 ++++++++++++++ src/daft-functions-utf8/src/jaro.rs | 157 +++++++++++++++ src/daft-functions-utf8/src/jaro_winkler.rs | 114 +++++++++++ src/daft-functions-utf8/src/levenshtein.rs | 132 +++++++++++++ src/daft-functions-utf8/src/lib.rs | 12 ++ tests/functions/test_string_distance.py | 178 ++++++++++++++++++ 9 files changed, 937 insertions(+) create mode 100644 src/daft-functions-utf8/src/damerau_levenshtein.rs create mode 100644 src/daft-functions-utf8/src/jaro.rs create mode 100644 src/daft-functions-utf8/src/jaro_winkler.rs create mode 100644 src/daft-functions-utf8/src/levenshtein.rs create mode 100644 tests/functions/test_string_distance.py diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 2221a26c625..6afdbbd2407 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -2361,6 +2361,46 @@ def hamming_distance_str(self, other: Expression) -> Expression: return hamming_distance_str(self, other) + def levenshtein_distance(self, other: Expression) -> Expression: + """Compute the Levenshtein edit distance between two strings. + + Tip: See Also + [`daft.functions.levenshtein_distance`](https://docs.daft.ai/en/stable/api/functions/levenshtein_distance/) + """ + from daft.functions import levenshtein_distance + + return levenshtein_distance(self, other) + + def jaro_similarity(self, other: Expression) -> Expression: + """Compute the Jaro similarity between two strings. + + Tip: See Also + [`daft.functions.jaro_similarity`](https://docs.daft.ai/en/stable/api/functions/jaro_similarity/) + """ + from daft.functions import jaro_similarity + + return jaro_similarity(self, other) + + def jaro_winkler_similarity(self, other: Expression) -> Expression: + """Compute the Jaro-Winkler similarity between two strings. + + Tip: See Also + [`daft.functions.jaro_winkler_similarity`](https://docs.daft.ai/en/stable/api/functions/jaro_winkler_similarity/) + """ + from daft.functions import jaro_winkler_similarity + + return jaro_winkler_similarity(self, other) + + def damerau_levenshtein_distance(self, other: Expression) -> Expression: + """Compute the Damerau-Levenshtein distance between two strings. + + Tip: See Also + [`daft.functions.damerau_levenshtein_distance`](https://docs.daft.ai/en/stable/api/functions/damerau_levenshtein_distance/) + """ + from daft.functions import damerau_levenshtein_distance + + return damerau_levenshtein_distance(self, other) + def value_counts(self) -> Expression: """Counts the occurrences of each distinct value in the list. diff --git a/daft/functions/__init__.py b/daft/functions/__init__.py index 6613248ba36..e297726987c 100644 --- a/daft/functions/__init__.py +++ b/daft/functions/__init__.py @@ -276,6 +276,10 @@ regexp_replace, find, hamming_distance_str, + levenshtein_distance, + jaro_similarity, + jaro_winkler_similarity, + damerau_levenshtein_distance, ) from .struct import unnest, to_struct from .url import download, upload, parse_url @@ -349,6 +353,7 @@ "current_date", "current_timestamp", "current_timezone", + "damerau_levenshtein_distance", "date", "date_add", "date_diff", @@ -423,6 +428,8 @@ "is_nan", "is_null", "jaccard_similarity", + "jaro_similarity", + "jaro_winkler_similarity", "jq", "json_array_length", "json_object_keys", @@ -434,6 +441,7 @@ "left", "length", "length_bytes", + "levenshtein_distance", "like", "list_agg", "list_agg_distinct", diff --git a/daft/functions/str.py b/daft/functions/str.py index 6677769fddb..a831310cc80 100644 --- a/daft/functions/str.py +++ b/daft/functions/str.py @@ -1612,3 +1612,154 @@ def hamming_distance_str(left: Expression, right: Expression) -> Expression: (Showing first 3 of 3 rows) """ return Expression._call_builtin_scalar_fn("hamming_distance_str", left, right) + + +def levenshtein_distance(left: Expression, right: Expression) -> Expression: + """Compute the Levenshtein edit distance between two strings. + + The Levenshtein distance is the minimum number of single-character insertions, + deletions, or substitutions required to transform one string into the other. + + Args: + left: The left string expression to compare. + right: The right string expression to compare against. + + Returns: + The Levenshtein distance for each pair of strings. Returns null when either + input is null. + + Examples: + >>> import daft + >>> from daft.functions import levenshtein_distance + >>> df = daft.from_pydict({"x": ["kitten", "saturday", ""], "y": ["sitting", "sunday", "abc"]}) + >>> df = df.with_column("distance", levenshtein_distance(df["x"], df["y"])) + >>> df.collect() + ╭──────────┬─────────┬──────────╮ + │ x ┆ y ┆ distance │ + │ --- ┆ --- ┆ --- │ + │ String ┆ String ┆ Int64 │ + ╞══════════╪═════════╪══════════╡ + │ kitten ┆ sitting ┆ 3 │ + ├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ + │ saturday ┆ sunday ┆ 3 │ + ├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ + │ ┆ abc ┆ 3 │ + ╰──────────┴─────────┴──────────╯ + + (Showing first 3 of 3 rows) + """ + return Expression._call_builtin_scalar_fn("levenshtein_distance", left, right) + + +def jaro_similarity(left: Expression, right: Expression) -> Expression: + """Compute the Jaro similarity between two strings. + + The Jaro similarity is a measure of similarity between two strings, based on + matching characters and transpositions. Returns a value between 0.0 (no similarity) + and 1.0 (identical strings). + + Args: + left: The left string expression to compare. + right: The right string expression to compare against. + + Returns: + The Jaro similarity (0.0 to 1.0) for each pair of strings. Returns null when + either input is null. + + Examples: + >>> import daft + >>> from daft.functions import jaro_similarity + >>> df = daft.from_pydict({"x": ["martha", "dwayne", "dixon"], "y": ["marhta", "duane", "dicksonx"]}) + >>> df = df.with_column("similarity", jaro_similarity(df["x"], df["y"])) + >>> df.collect() + ╭────────┬──────────┬────────────╮ + │ x ┆ y ┆ similarity │ + │ --- ┆ --- ┆ --- │ + │ String ┆ String ┆ Float64 │ + ╞════════╪══════════╪════════════╡ + │ martha ┆ marhta ┆ 0.944444 │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ dwayne ┆ duane ┆ 0.822222 │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ dixon ┆ dicksonx ┆ 0.766667 │ + ╰────────┴──────────┴────────────╯ + + (Showing first 3 of 3 rows) + """ + return Expression._call_builtin_scalar_fn("jaro_similarity", left, right) + + +def jaro_winkler_similarity(left: Expression, right: Expression) -> Expression: + """Compute the Jaro-Winkler similarity between two strings. + + This is the Jaro similarity with a prefix bonus for strings sharing a common + prefix (up to 4 characters). Returns a value between 0.0 (no similarity) and + 1.0 (identical strings). + + Args: + left: The left string expression to compare. + right: The right string expression to compare against. + + Returns: + The Jaro-Winkler similarity (0.0 to 1.0) for each pair of strings. Returns + null when either input is null. + + Examples: + >>> import daft + >>> from daft.functions import jaro_winkler_similarity + >>> df = daft.from_pydict({"x": ["martha", "dwayne", "dixon"], "y": ["marhta", "duane", "dicksonx"]}) + >>> df = df.with_column("similarity", jaro_winkler_similarity(df["x"], df["y"])) + >>> df.collect() + ╭────────┬──────────┬────────────╮ + │ x ┆ y ┆ similarity │ + │ --- ┆ --- ┆ --- │ + │ String ┆ String ┆ Float64 │ + ╞════════╪══════════╪════════════╡ + │ martha ┆ marhta ┆ 0.961111 │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ dwayne ┆ duane ┆ 0.840000 │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ dixon ┆ dicksonx ┆ 0.813333 │ + ╰────────┴──────────┴────────────╯ + + (Showing first 3 of 3 rows) + """ + return Expression._call_builtin_scalar_fn("jaro_winkler_similarity", left, right) + + +def damerau_levenshtein_distance(left: Expression, right: Expression) -> Expression: + """Compute the Damerau-Levenshtein distance between two strings. + + This extends the Levenshtein distance by also counting transpositions of two + adjacent characters as a single edit operation (in addition to insertions, + deletions, and substitutions). + + Args: + left: The left string expression to compare. + right: The right string expression to compare against. + + Returns: + The Damerau-Levenshtein distance for each pair of strings. Returns null when + either input is null. + + Examples: + >>> import daft + >>> from daft.functions import damerau_levenshtein_distance + >>> df = daft.from_pydict({"x": ["abc", "abc", ""], "y": ["bac", "abd", "abc"]}) + >>> df = df.with_column("distance", damerau_levenshtein_distance(df["x"], df["y"])) + >>> df.collect() + ╭────────┬────────┬──────────╮ + │ x ┆ y ┆ distance │ + │ --- ┆ --- ┆ --- │ + │ String ┆ String ┆ Int64 │ + ╞════════╪════════╪══════════╡ + │ abc ┆ bac ┆ 1 │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ + │ abc ┆ abd ┆ 1 │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ + │ ┆ abc ┆ 3 │ + ╰────────┴────────┴──────────╯ + + (Showing first 3 of 3 rows) + """ + return Expression._call_builtin_scalar_fn("damerau_levenshtein_distance", left, right) diff --git a/src/daft-functions-utf8/src/damerau_levenshtein.rs b/src/daft-functions-utf8/src/damerau_levenshtein.rs new file mode 100644 index 00000000000..8477c464791 --- /dev/null +++ b/src/daft-functions-utf8/src/damerau_levenshtein.rs @@ -0,0 +1,145 @@ +use std::sync::Arc; + +use arrow_buffer::NullBufferBuilder; +use daft_core::prelude::{Int64Array, IntoSeries}; +use daft_dsl::functions::{prelude::*, scalar::ScalarFn}; +use serde::{Deserialize, Serialize}; + +const NULL_SENTINEL: i64 = 0; + +/// Compute the Damerau-Levenshtein distance (optimal string alignment variant). +/// This extends Levenshtein by also allowing transposition of two adjacent characters +/// as a single edit operation. +fn compute_damerau_levenshtein_distance(left: &str, right: &str) -> i64 { + let left_chars: Vec = left.chars().collect(); + let right_chars: Vec = right.chars().collect(); + + let n = left_chars.len(); + let m = right_chars.len(); + + if n == 0 { + return m as i64; + } + if m == 0 { + return n as i64; + } + + // Full matrix needed for transposition lookback + let mut matrix = vec![vec![0i64; m + 1]; n + 1]; + + for i in 0..=n { + matrix[i][0] = i as i64; + } + for j in 0..=m { + matrix[0][j] = j as i64; + } + + for i in 1..=n { + for j in 1..=m { + let cost = if left_chars[i - 1] == right_chars[j - 1] { + 0 + } else { + 1 + }; + + matrix[i][j] = (matrix[i - 1][j] + 1) // deletion + .min(matrix[i][j - 1] + 1) // insertion + .min(matrix[i - 1][j - 1] + cost); // substitution + + // Transposition + if i > 1 + && j > 1 + && left_chars[i - 1] == right_chars[j - 2] + && left_chars[i - 2] == right_chars[j - 1] + { + matrix[i][j] = matrix[i][j].min(matrix[i - 2][j - 2] + 1); + } + } + } + + matrix[n][m] +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct DamerauLevenshteinDistance; + +#[typetag::serde] +impl ScalarUDF for DamerauLevenshteinDistance { + fn name(&self) -> &'static str { + "damerau_levenshtein_distance" + } + + fn call( + &self, + inputs: FunctionArgs, + _ctx: &daft_dsl::functions::scalar::EvalContext, + ) -> DaftResult { + let left = inputs.required(0)?.cast(&DataType::Utf8)?; + let right = inputs.required(1)?.cast(&DataType::Utf8)?; + + left.with_utf8_array(|left| { + right.with_utf8_array(|right| { + let len = left.len(); + let mut values = Vec::with_capacity(len); + let mut validity = NullBufferBuilder::new(len); + + for i in 0..len { + match (left.get(i), right.get(i)) { + (Some(l), Some(r)) => { + values.push(compute_damerau_levenshtein_distance(l, r)); + validity.append_non_null(); + } + _ => { + values.push(NULL_SENTINEL); + validity.append_null(); + } + } + } + + let field = Arc::new(Field::new(self.name(), DataType::Int64)); + let result = + Int64Array::from_field_and_values(field, values).with_nulls(validity.finish())?; + Ok(result.into_series()) + }) + }) + } + + fn get_return_field( + &self, + inputs: FunctionArgs, + schema: &Schema, + ) -> DaftResult { + ensure!( + inputs.len() == 2, + SchemaMismatch: "Expected 2 inputs, but received {}", + inputs.len() + ); + + let left = inputs.required(0)?.to_field(schema)?; + let right = inputs.required(1)?.to_field(schema)?; + + ensure!( + left.dtype.is_string() || left.dtype == DataType::Null, + TypeError: "First argument must be a string, got {}", + left.dtype + ); + ensure!( + right.dtype.is_string() || right.dtype == DataType::Null, + TypeError: "Second argument must be a string, got {}", + right.dtype + ); + + Ok(Field::new(self.name(), DataType::Int64)) + } + + fn docstring(&self) -> &'static str { + "Compute the Damerau-Levenshtein distance between two strings. This extends the \ + Levenshtein distance by also counting transpositions of two adjacent characters \ + as a single edit operation. Returns null when either input is null." + } +} + +#[must_use] +pub fn damerau_levenshtein_distance(left: ExprRef, right: ExprRef) -> ExprRef { + ScalarFn::builtin(DamerauLevenshteinDistance, vec![left, right]).into() +} diff --git a/src/daft-functions-utf8/src/jaro.rs b/src/daft-functions-utf8/src/jaro.rs new file mode 100644 index 00000000000..30d1126b227 --- /dev/null +++ b/src/daft-functions-utf8/src/jaro.rs @@ -0,0 +1,157 @@ +use std::sync::Arc; + +use arrow_buffer::NullBufferBuilder; +use daft_core::prelude::{Float64Array, IntoSeries}; +use daft_dsl::functions::{prelude::*, scalar::ScalarFn}; +use serde::{Deserialize, Serialize}; + +const NULL_SENTINEL: f64 = 0.0; + +/// Compute Jaro similarity between two strings. +/// Returns a value between 0.0 (no similarity) and 1.0 (identical). +pub(crate) fn compute_jaro_similarity(left: &str, right: &str) -> f64 { + let left_chars: Vec = left.chars().collect(); + let right_chars: Vec = right.chars().collect(); + + let s1_len = left_chars.len(); + let s2_len = right_chars.len(); + + if s1_len == 0 && s2_len == 0 { + return 1.0; + } + if s1_len == 0 || s2_len == 0 { + return 0.0; + } + + // Maximum distance for matching characters + let match_distance = (s1_len.max(s2_len) / 2).saturating_sub(1); + + let mut s1_matches = vec![false; s1_len]; + let mut s2_matches = vec![false; s2_len]; + + let mut matches: f64 = 0.0; + let mut transpositions: f64 = 0.0; + + // Find matching characters + for i in 0..s1_len { + let start = i.saturating_sub(match_distance); + let end = (i + match_distance + 1).min(s2_len); + + for j in start..end { + if s2_matches[j] || left_chars[i] != right_chars[j] { + continue; + } + s1_matches[i] = true; + s2_matches[j] = true; + matches += 1.0; + break; + } + } + + if matches == 0.0 { + return 0.0; + } + + // Count transpositions + let mut k = 0; + for i in 0..s1_len { + if !s1_matches[i] { + continue; + } + while !s2_matches[k] { + k += 1; + } + if left_chars[i] != right_chars[k] { + transpositions += 1.0; + } + k += 1; + } + + (matches / s1_len as f64 + + matches / s2_len as f64 + + (matches - transpositions / 2.0) / matches) + / 3.0 +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct JaroSimilarity; + +#[typetag::serde] +impl ScalarUDF for JaroSimilarity { + fn name(&self) -> &'static str { + "jaro_similarity" + } + + fn call( + &self, + inputs: FunctionArgs, + _ctx: &daft_dsl::functions::scalar::EvalContext, + ) -> DaftResult { + let left = inputs.required(0)?.cast(&DataType::Utf8)?; + let right = inputs.required(1)?.cast(&DataType::Utf8)?; + + left.with_utf8_array(|left| { + right.with_utf8_array(|right| { + let len = left.len(); + let mut values = Vec::with_capacity(len); + let mut validity = NullBufferBuilder::new(len); + + for i in 0..len { + match (left.get(i), right.get(i)) { + (Some(l), Some(r)) => { + values.push(compute_jaro_similarity(l, r)); + validity.append_non_null(); + } + _ => { + values.push(NULL_SENTINEL); + validity.append_null(); + } + } + } + + let field = Arc::new(Field::new(self.name(), DataType::Float64)); + let result = Float64Array::from_field_and_values(field, values) + .with_nulls(validity.finish())?; + Ok(result.into_series()) + }) + }) + } + + fn get_return_field( + &self, + inputs: FunctionArgs, + schema: &Schema, + ) -> DaftResult { + ensure!( + inputs.len() == 2, + SchemaMismatch: "Expected 2 inputs, but received {}", + inputs.len() + ); + + let left = inputs.required(0)?.to_field(schema)?; + let right = inputs.required(1)?.to_field(schema)?; + + ensure!( + left.dtype.is_string() || left.dtype == DataType::Null, + TypeError: "First argument must be a string, got {}", + left.dtype + ); + ensure!( + right.dtype.is_string() || right.dtype == DataType::Null, + TypeError: "Second argument must be a string, got {}", + right.dtype + ); + + Ok(Field::new(self.name(), DataType::Float64)) + } + + fn docstring(&self) -> &'static str { + "Compute the Jaro similarity between two strings. Returns a value between 0.0 \ + (no similarity) and 1.0 (identical strings). Returns null when either input is null." + } +} + +#[must_use] +pub fn jaro_similarity(left: ExprRef, right: ExprRef) -> ExprRef { + ScalarFn::builtin(JaroSimilarity, vec![left, right]).into() +} diff --git a/src/daft-functions-utf8/src/jaro_winkler.rs b/src/daft-functions-utf8/src/jaro_winkler.rs new file mode 100644 index 00000000000..54d92873493 --- /dev/null +++ b/src/daft-functions-utf8/src/jaro_winkler.rs @@ -0,0 +1,114 @@ +use std::sync::Arc; + +use arrow_buffer::NullBufferBuilder; +use daft_core::prelude::{Float64Array, IntoSeries}; +use daft_dsl::functions::{prelude::*, scalar::ScalarFn}; +use serde::{Deserialize, Serialize}; + +use crate::jaro::compute_jaro_similarity; + +const NULL_SENTINEL: f64 = 0.0; + +/// Compute Jaro-Winkler similarity. Applies a prefix bonus to the Jaro similarity +/// for strings that share a common prefix (up to 4 characters). +/// The scaling factor p is fixed at 0.1 (standard value). +fn compute_jaro_winkler_similarity(left: &str, right: &str) -> f64 { + let jaro = compute_jaro_similarity(left, right); + + // Find common prefix length (max 4 characters) + let prefix_len = left + .chars() + .zip(right.chars()) + .take(4) + .take_while(|(a, b)| a == b) + .count(); + + let p = 0.1; // Standard Winkler scaling factor + + jaro + (prefix_len as f64 * p * (1.0 - jaro)) +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct JaroWinklerSimilarity; + +#[typetag::serde] +impl ScalarUDF for JaroWinklerSimilarity { + fn name(&self) -> &'static str { + "jaro_winkler_similarity" + } + + fn call( + &self, + inputs: FunctionArgs, + _ctx: &daft_dsl::functions::scalar::EvalContext, + ) -> DaftResult { + let left = inputs.required(0)?.cast(&DataType::Utf8)?; + let right = inputs.required(1)?.cast(&DataType::Utf8)?; + + left.with_utf8_array(|left| { + right.with_utf8_array(|right| { + let len = left.len(); + let mut values = Vec::with_capacity(len); + let mut validity = NullBufferBuilder::new(len); + + for i in 0..len { + match (left.get(i), right.get(i)) { + (Some(l), Some(r)) => { + values.push(compute_jaro_winkler_similarity(l, r)); + validity.append_non_null(); + } + _ => { + values.push(NULL_SENTINEL); + validity.append_null(); + } + } + } + + let field = Arc::new(Field::new(self.name(), DataType::Float64)); + let result = Float64Array::from_field_and_values(field, values) + .with_nulls(validity.finish())?; + Ok(result.into_series()) + }) + }) + } + + fn get_return_field( + &self, + inputs: FunctionArgs, + schema: &Schema, + ) -> DaftResult { + ensure!( + inputs.len() == 2, + SchemaMismatch: "Expected 2 inputs, but received {}", + inputs.len() + ); + + let left = inputs.required(0)?.to_field(schema)?; + let right = inputs.required(1)?.to_field(schema)?; + + ensure!( + left.dtype.is_string() || left.dtype == DataType::Null, + TypeError: "First argument must be a string, got {}", + left.dtype + ); + ensure!( + right.dtype.is_string() || right.dtype == DataType::Null, + TypeError: "Second argument must be a string, got {}", + right.dtype + ); + + Ok(Field::new(self.name(), DataType::Float64)) + } + + fn docstring(&self) -> &'static str { + "Compute the Jaro-Winkler similarity between two strings. This is the Jaro \ + similarity with a prefix bonus for strings sharing a common prefix (up to 4 chars). \ + Returns a value between 0.0 (no similarity) and 1.0 (identical). Returns null when \ + either input is null." + } +} + +#[must_use] +pub fn jaro_winkler_similarity(left: ExprRef, right: ExprRef) -> ExprRef { + ScalarFn::builtin(JaroWinklerSimilarity, vec![left, right]).into() +} diff --git a/src/daft-functions-utf8/src/levenshtein.rs b/src/daft-functions-utf8/src/levenshtein.rs new file mode 100644 index 00000000000..ac7c5f92c91 --- /dev/null +++ b/src/daft-functions-utf8/src/levenshtein.rs @@ -0,0 +1,132 @@ +use std::sync::Arc; + +use arrow_buffer::NullBufferBuilder; +use daft_core::prelude::{Int64Array, IntoSeries}; +use daft_dsl::functions::{prelude::*, scalar::ScalarFn}; +use serde::{Deserialize, Serialize}; + +const NULL_SENTINEL: i64 = 0; + +/// Compute Levenshtein edit distance using Wagner-Fischer algorithm. +/// Uses O(min(n,m)) space by only keeping two rows of the DP matrix. +fn compute_levenshtein_distance(left: &str, right: &str) -> i64 { + let left_chars: Vec = left.chars().collect(); + let right_chars: Vec = right.chars().collect(); + + let n = left_chars.len(); + let m = right_chars.len(); + + if n == 0 { + return m as i64; + } + if m == 0 { + return n as i64; + } + + // Ensure we iterate over the shorter string for the inner loop + let (shorter, longer, short_len, long_len) = if n <= m { + (&left_chars, &right_chars, n, m) + } else { + (&right_chars, &left_chars, m, n) + }; + + let mut prev_row: Vec = (0..=(short_len as i64)).collect(); + let mut curr_row: Vec = vec![0; short_len + 1]; + + for i in 1..=long_len { + curr_row[0] = i as i64; + for j in 1..=short_len { + let cost = if longer[i - 1] == shorter[j - 1] { 0 } else { 1 }; + curr_row[j] = (prev_row[j] + 1) // deletion + .min(curr_row[j - 1] + 1) // insertion + .min(prev_row[j - 1] + cost); // substitution + } + std::mem::swap(&mut prev_row, &mut curr_row); + } + + prev_row[short_len] +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct LevenshteinDistance; + +#[typetag::serde] +impl ScalarUDF for LevenshteinDistance { + fn name(&self) -> &'static str { + "levenshtein_distance" + } + + fn call( + &self, + inputs: FunctionArgs, + _ctx: &daft_dsl::functions::scalar::EvalContext, + ) -> DaftResult { + let left = inputs.required(0)?.cast(&DataType::Utf8)?; + let right = inputs.required(1)?.cast(&DataType::Utf8)?; + + left.with_utf8_array(|left| { + right.with_utf8_array(|right| { + let len = left.len(); + let mut values = Vec::with_capacity(len); + let mut validity = NullBufferBuilder::new(len); + + for i in 0..len { + match (left.get(i), right.get(i)) { + (Some(l), Some(r)) => { + values.push(compute_levenshtein_distance(l, r)); + validity.append_non_null(); + } + _ => { + values.push(NULL_SENTINEL); + validity.append_null(); + } + } + } + + let field = Arc::new(Field::new(self.name(), DataType::Int64)); + let result = + Int64Array::from_field_and_values(field, values).with_nulls(validity.finish())?; + Ok(result.into_series()) + }) + }) + } + + fn get_return_field( + &self, + inputs: FunctionArgs, + schema: &Schema, + ) -> DaftResult { + ensure!( + inputs.len() == 2, + SchemaMismatch: "Expected 2 inputs, but received {}", + inputs.len() + ); + + let left = inputs.required(0)?.to_field(schema)?; + let right = inputs.required(1)?.to_field(schema)?; + + ensure!( + left.dtype.is_string() || left.dtype == DataType::Null, + TypeError: "First argument must be a string, got {}", + left.dtype + ); + ensure!( + right.dtype.is_string() || right.dtype == DataType::Null, + TypeError: "Second argument must be a string, got {}", + right.dtype + ); + + Ok(Field::new(self.name(), DataType::Int64)) + } + + fn docstring(&self) -> &'static str { + "Compute the Levenshtein edit distance between two strings. The Levenshtein distance \ + is the minimum number of single-character insertions, deletions, or substitutions \ + required to transform one string into the other. Returns null when either input is null." + } +} + +#[must_use] +pub fn levenshtein_distance(left: ExprRef, right: ExprRef) -> ExprRef { + ScalarFn::builtin(LevenshteinDistance, vec![left, right]).into() +} diff --git a/src/daft-functions-utf8/src/lib.rs b/src/daft-functions-utf8/src/lib.rs index 4cbf100d8cd..815ac2f86a7 100644 --- a/src/daft-functions-utf8/src/lib.rs +++ b/src/daft-functions-utf8/src/lib.rs @@ -2,12 +2,16 @@ mod capitalize; mod case; mod contains; mod count_matches; +mod damerau_levenshtein; mod endswith; mod find; mod hamming; mod ilike; +mod jaro; +mod jaro_winkler; mod left; mod length_bytes; +mod levenshtein; mod like; mod lower; mod lpad; @@ -37,12 +41,16 @@ pub use capitalize::*; pub use case::*; pub use contains::*; pub use count_matches::*; +pub use damerau_levenshtein::*; pub use endswith::*; pub use find::*; pub use hamming::*; pub use ilike::*; +pub use jaro::*; +pub use jaro_winkler::*; pub use left::*; pub use length_bytes::*; +pub use levenshtein::*; pub use like::*; pub use lower::*; pub use lpad::*; @@ -74,13 +82,17 @@ impl daft_dsl::functions::FunctionModule for Utf8Functions { parent.add_fn(Capitalize); parent.add_fn(Contains); parent.add_fn(CountMatches); + parent.add_fn(DamerauLevenshteinDistance); parent.add_fn(EndsWith); parent.add_fn(Find); parent.add_fn(HammingDistance); parent.add_fn(ILike); + parent.add_fn(JaroSimilarity); + parent.add_fn(JaroWinklerSimilarity); parent.add_fn(KebabCase); parent.add_fn(Left); parent.add_fn(LengthBytes); + parent.add_fn(LevenshteinDistance); parent.add_fn(Like); parent.add_fn(Lower); parent.add_fn(LPad); diff --git a/tests/functions/test_string_distance.py b/tests/functions/test_string_distance.py new file mode 100644 index 00000000000..b896294a153 --- /dev/null +++ b/tests/functions/test_string_distance.py @@ -0,0 +1,178 @@ +"""Tests for string distance/similarity functions (issue #6794).""" +from __future__ import annotations + +import pytest + +import daft +from daft import col +from daft.functions import ( + damerau_levenshtein_distance, + jaro_similarity, + jaro_winkler_similarity, + levenshtein_distance, +) + + +class TestLevenshteinDistance: + def test_basic(self): + df = daft.from_pydict({"a": ["kitten", "saturday", "abc"], "b": ["sitting", "sunday", "abc"]}) + result = df.with_column("dist", levenshtein_distance(col("a"), col("b"))).collect() + distances = result.to_pydict()["dist"] + assert distances == [3, 3, 0] + + def test_empty_strings(self): + df = daft.from_pydict({"a": ["", "abc", ""], "b": ["abc", "", ""]}) + result = df.with_column("dist", levenshtein_distance(col("a"), col("b"))).collect() + distances = result.to_pydict()["dist"] + assert distances == [3, 3, 0] + + def test_null_handling(self): + df = daft.from_pydict({"a": ["hello", None, "world"], "b": ["hallo", "test", None]}) + result = df.with_column("dist", levenshtein_distance(col("a"), col("b"))).collect() + distances = result.to_pydict()["dist"] + assert distances[0] == 1 + assert distances[1] is None + assert distances[2] is None + + def test_identical_strings(self): + df = daft.from_pydict({"a": ["foo", "bar", "baz"], "b": ["foo", "bar", "baz"]}) + result = df.with_column("dist", levenshtein_distance(col("a"), col("b"))).collect() + distances = result.to_pydict()["dist"] + assert distances == [0, 0, 0] + + def test_single_char_edits(self): + df = daft.from_pydict({ + "a": ["cat", "cat", "cat"], + "b": ["hat", "cats", "at"], + }) + result = df.with_column("dist", levenshtein_distance(col("a"), col("b"))).collect() + distances = result.to_pydict()["dist"] + # substitution, insertion, deletion + assert distances == [1, 1, 1] + + def test_expression_method(self): + df = daft.from_pydict({"a": ["kitten"], "b": ["sitting"]}) + result = df.with_column("dist", col("a").levenshtein_distance(col("b"))).collect() + assert result.to_pydict()["dist"] == [3] + + +class TestJaroSimilarity: + def test_identical(self): + df = daft.from_pydict({"a": ["hello", ""], "b": ["hello", ""]}) + result = df.with_column("sim", jaro_similarity(col("a"), col("b"))).collect() + sims = result.to_pydict()["sim"] + assert sims[0] == pytest.approx(1.0) + assert sims[1] == pytest.approx(1.0) + + def test_completely_different(self): + df = daft.from_pydict({"a": ["abc"], "b": ["xyz"]}) + result = df.with_column("sim", jaro_similarity(col("a"), col("b"))).collect() + assert result.to_pydict()["sim"][0] == pytest.approx(0.0) + + def test_known_values(self): + # Well-known test case: martha vs marhta -> 0.944444 + df = daft.from_pydict({"a": ["martha"], "b": ["marhta"]}) + result = df.with_column("sim", jaro_similarity(col("a"), col("b"))).collect() + assert result.to_pydict()["sim"][0] == pytest.approx(0.944444, rel=1e-4) + + def test_null_handling(self): + df = daft.from_pydict({"a": ["hello", None], "b": [None, "world"]}) + result = df.with_column("sim", jaro_similarity(col("a"), col("b"))).collect() + sims = result.to_pydict()["sim"] + assert sims[0] is None + assert sims[1] is None + + def test_empty_vs_nonempty(self): + df = daft.from_pydict({"a": ["", "abc"], "b": ["abc", ""]}) + result = df.with_column("sim", jaro_similarity(col("a"), col("b"))).collect() + sims = result.to_pydict()["sim"] + assert sims[0] == pytest.approx(0.0) + assert sims[1] == pytest.approx(0.0) + + def test_expression_method(self): + df = daft.from_pydict({"a": ["martha"], "b": ["marhta"]}) + result = df.with_column("sim", col("a").jaro_similarity(col("b"))).collect() + assert result.to_pydict()["sim"][0] == pytest.approx(0.944444, rel=1e-4) + + +class TestJaroWinklerSimilarity: + def test_identical(self): + df = daft.from_pydict({"a": ["hello"], "b": ["hello"]}) + result = df.with_column("sim", jaro_winkler_similarity(col("a"), col("b"))).collect() + assert result.to_pydict()["sim"][0] == pytest.approx(1.0) + + def test_prefix_bonus(self): + # Jaro-Winkler should be >= Jaro for strings sharing a prefix + df = daft.from_pydict({"a": ["martha"], "b": ["marhta"]}) + jaro_result = df.with_column("sim", jaro_similarity(col("a"), col("b"))).collect() + jw_result = df.with_column("sim", jaro_winkler_similarity(col("a"), col("b"))).collect() + jaro_val = jaro_result.to_pydict()["sim"][0] + jw_val = jw_result.to_pydict()["sim"][0] + assert jw_val >= jaro_val + + def test_known_values(self): + # martha vs marhta: Jaro = 0.944444, prefix "mar" (len=3) + # JW = 0.944444 + (3 * 0.1 * (1 - 0.944444)) = 0.961111 + df = daft.from_pydict({"a": ["martha"], "b": ["marhta"]}) + result = df.with_column("sim", jaro_winkler_similarity(col("a"), col("b"))).collect() + assert result.to_pydict()["sim"][0] == pytest.approx(0.961111, rel=1e-4) + + def test_no_common_prefix(self): + # No common prefix means JW == Jaro + df = daft.from_pydict({"a": ["abc"], "b": ["xyz"]}) + jaro_result = df.with_column("sim", jaro_similarity(col("a"), col("b"))).collect() + jw_result = df.with_column("sim", jaro_winkler_similarity(col("a"), col("b"))).collect() + assert jw_result.to_pydict()["sim"][0] == pytest.approx(jaro_result.to_pydict()["sim"][0]) + + def test_null_handling(self): + df = daft.from_pydict({"a": ["hello", None], "b": [None, "world"]}) + result = df.with_column("sim", jaro_winkler_similarity(col("a"), col("b"))).collect() + sims = result.to_pydict()["sim"] + assert sims[0] is None + assert sims[1] is None + + def test_expression_method(self): + df = daft.from_pydict({"a": ["martha"], "b": ["marhta"]}) + result = df.with_column("sim", col("a").jaro_winkler_similarity(col("b"))).collect() + assert result.to_pydict()["sim"][0] == pytest.approx(0.961111, rel=1e-4) + + +class TestDamerauLevenshteinDistance: + def test_basic(self): + df = daft.from_pydict({"a": ["abc", "abc"], "b": ["bac", "abc"]}) + result = df.with_column("dist", damerau_levenshtein_distance(col("a"), col("b"))).collect() + distances = result.to_pydict()["dist"] + assert distances[0] == 1 # abc -> bac: single transposition + assert distances[1] == 0 # identical + + def test_transposition_vs_levenshtein(self): + # "ab" -> "ba" should be 1 for Damerau-Levenshtein (transposition) + # but 2 for standard Levenshtein (two substitutions) + df = daft.from_pydict({"a": ["ab"], "b": ["ba"]}) + dl_result = df.with_column("dist", damerau_levenshtein_distance(col("a"), col("b"))).collect() + lev_result = df.with_column("dist", levenshtein_distance(col("a"), col("b"))).collect() + assert dl_result.to_pydict()["dist"][0] == 1 + assert lev_result.to_pydict()["dist"][0] == 2 + + def test_empty_strings(self): + df = daft.from_pydict({"a": ["", "abc", ""], "b": ["abc", "", ""]}) + result = df.with_column("dist", damerau_levenshtein_distance(col("a"), col("b"))).collect() + distances = result.to_pydict()["dist"] + assert distances == [3, 3, 0] + + def test_identical(self): + df = daft.from_pydict({"a": ["hello"], "b": ["hello"]}) + result = df.with_column("dist", damerau_levenshtein_distance(col("a"), col("b"))).collect() + assert result.to_pydict()["dist"][0] == 0 + + def test_null_handling(self): + df = daft.from_pydict({"a": ["hello", None], "b": [None, "world"]}) + result = df.with_column("dist", damerau_levenshtein_distance(col("a"), col("b"))).collect() + distances = result.to_pydict()["dist"] + assert distances[0] is None + assert distances[1] is None + + def test_expression_method(self): + df = daft.from_pydict({"a": ["abc"], "b": ["bac"]}) + result = df.with_column("dist", col("a").damerau_levenshtein_distance(col("b"))).collect() + assert result.to_pydict()["dist"] == [1] From 51add79405990af5ad715edf3f3d1831ee248a12 Mon Sep 17 00:00:00 2001 From: Nishaanth Reddy Date: Wed, 3 Jun 2026 18:34:54 -0700 Subject: [PATCH 2/5] fix: resolve style and doctest CI failures - apply rustfmt to levenshtein.rs, jaro.rs, damerau_levenshtein.rs - apply ruff format to test_string_distance.py - fix jaro_similarity and jaro_winkler_similarity docstring examples to use full-precision Float64 output --- daft/functions/str.py | 44 +++++++++---------- .../src/damerau_levenshtein.rs | 4 +- src/daft-functions-utf8/src/jaro.rs | 4 +- src/daft-functions-utf8/src/levenshtein.rs | 10 +++-- tests/functions/test_string_distance.py | 11 +++-- 5 files changed, 39 insertions(+), 34 deletions(-) diff --git a/daft/functions/str.py b/daft/functions/str.py index a831310cc80..3057a9da18b 100644 --- a/daft/functions/str.py +++ b/daft/functions/str.py @@ -1672,17 +1672,17 @@ def jaro_similarity(left: Expression, right: Expression) -> Expression: >>> df = daft.from_pydict({"x": ["martha", "dwayne", "dixon"], "y": ["marhta", "duane", "dicksonx"]}) >>> df = df.with_column("similarity", jaro_similarity(df["x"], df["y"])) >>> df.collect() - ╭────────┬──────────┬────────────╮ - │ x ┆ y ┆ similarity │ - │ --- ┆ --- ┆ --- │ - │ String ┆ String ┆ Float64 │ - ╞════════╪══════════╪════════════╡ - │ martha ┆ marhta ┆ 0.944444 │ - ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ - │ dwayne ┆ duane ┆ 0.822222 │ - ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ - │ dixon ┆ dicksonx ┆ 0.766667 │ - ╰────────┴──────────┴────────────╯ + ╭────────┬──────────┬────────────────────╮ + │ x ┆ y ┆ similarity │ + │ --- ┆ --- ┆ --- │ + │ String ┆ String ┆ Float64 │ + ╞════════╪══════════╪════════════════════╡ + │ martha ┆ marhta ┆ 0.9444444444444445 │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ dwayne ┆ duane ┆ 0.8222222222222223 │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ dixon ┆ dicksonx ┆ 0.7666666666666666 │ + ╰────────┴──────────┴────────────────────╯ (Showing first 3 of 3 rows) """ @@ -1710,17 +1710,17 @@ def jaro_winkler_similarity(left: Expression, right: Expression) -> Expression: >>> df = daft.from_pydict({"x": ["martha", "dwayne", "dixon"], "y": ["marhta", "duane", "dicksonx"]}) >>> df = df.with_column("similarity", jaro_winkler_similarity(df["x"], df["y"])) >>> df.collect() - ╭────────┬──────────┬────────────╮ - │ x ┆ y ┆ similarity │ - │ --- ┆ --- ┆ --- │ - │ String ┆ String ┆ Float64 │ - ╞════════╪══════════╪════════════╡ - │ martha ┆ marhta ┆ 0.961111 │ - ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ - │ dwayne ┆ duane ┆ 0.840000 │ - ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤ - │ dixon ┆ dicksonx ┆ 0.813333 │ - ╰────────┴──────────┴────────────╯ + ╭────────┬──────────┬────────────────────╮ + │ x ┆ y ┆ similarity │ + │ --- ┆ --- ┆ --- │ + │ String ┆ String ┆ Float64 │ + ╞════════╪══════════╪════════════════════╡ + │ martha ┆ marhta ┆ 0.9611111111111111 │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ dwayne ┆ duane ┆ 0.8400000000000001 │ + ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ dixon ┆ dicksonx ┆ 0.8133333333333332 │ + ╰────────┴──────────┴────────────────────╯ (Showing first 3 of 3 rows) """ diff --git a/src/daft-functions-utf8/src/damerau_levenshtein.rs b/src/daft-functions-utf8/src/damerau_levenshtein.rs index 8477c464791..9c9b95fb389 100644 --- a/src/daft-functions-utf8/src/damerau_levenshtein.rs +++ b/src/daft-functions-utf8/src/damerau_levenshtein.rs @@ -97,8 +97,8 @@ impl ScalarUDF for DamerauLevenshteinDistance { } let field = Arc::new(Field::new(self.name(), DataType::Int64)); - let result = - Int64Array::from_field_and_values(field, values).with_nulls(validity.finish())?; + let result = Int64Array::from_field_and_values(field, values) + .with_nulls(validity.finish())?; Ok(result.into_series()) }) }) diff --git a/src/daft-functions-utf8/src/jaro.rs b/src/daft-functions-utf8/src/jaro.rs index 30d1126b227..febc7ff89c8 100644 --- a/src/daft-functions-utf8/src/jaro.rs +++ b/src/daft-functions-utf8/src/jaro.rs @@ -67,9 +67,7 @@ pub(crate) fn compute_jaro_similarity(left: &str, right: &str) -> f64 { k += 1; } - (matches / s1_len as f64 - + matches / s2_len as f64 - + (matches - transpositions / 2.0) / matches) + (matches / s1_len as f64 + matches / s2_len as f64 + (matches - transpositions / 2.0) / matches) / 3.0 } diff --git a/src/daft-functions-utf8/src/levenshtein.rs b/src/daft-functions-utf8/src/levenshtein.rs index ac7c5f92c91..e160cd0948f 100644 --- a/src/daft-functions-utf8/src/levenshtein.rs +++ b/src/daft-functions-utf8/src/levenshtein.rs @@ -36,7 +36,11 @@ fn compute_levenshtein_distance(left: &str, right: &str) -> i64 { for i in 1..=long_len { curr_row[0] = i as i64; for j in 1..=short_len { - let cost = if longer[i - 1] == shorter[j - 1] { 0 } else { 1 }; + let cost = if longer[i - 1] == shorter[j - 1] { + 0 + } else { + 1 + }; curr_row[j] = (prev_row[j] + 1) // deletion .min(curr_row[j - 1] + 1) // insertion .min(prev_row[j - 1] + cost); // substitution @@ -84,8 +88,8 @@ impl ScalarUDF for LevenshteinDistance { } let field = Arc::new(Field::new(self.name(), DataType::Int64)); - let result = - Int64Array::from_field_and_values(field, values).with_nulls(validity.finish())?; + let result = Int64Array::from_field_and_values(field, values) + .with_nulls(validity.finish())?; Ok(result.into_series()) }) }) diff --git a/tests/functions/test_string_distance.py b/tests/functions/test_string_distance.py index b896294a153..9d3a23465ff 100644 --- a/tests/functions/test_string_distance.py +++ b/tests/functions/test_string_distance.py @@ -1,4 +1,5 @@ """Tests for string distance/similarity functions (issue #6794).""" + from __future__ import annotations import pytest @@ -41,10 +42,12 @@ def test_identical_strings(self): assert distances == [0, 0, 0] def test_single_char_edits(self): - df = daft.from_pydict({ - "a": ["cat", "cat", "cat"], - "b": ["hat", "cats", "at"], - }) + df = daft.from_pydict( + { + "a": ["cat", "cat", "cat"], + "b": ["hat", "cats", "at"], + } + ) result = df.with_column("dist", levenshtein_distance(col("a"), col("b"))).collect() distances = result.to_pydict()["dist"] # substitution, insertion, deletion From b5e383228459e5e1444ccd2321f4c5b3ee55703d Mon Sep 17 00:00:00 2001 From: Nishaanth Reddy Date: Wed, 3 Jun 2026 19:05:19 -0700 Subject: [PATCH 3/5] fix: resolve clippy and spellcheck CI failures - use i64::from(bool) instead of if/else for boolean-to-int conversion - use iter_mut().enumerate() instead of indexing loop (needless_range_loop) - use mul_add for jaro-winkler formula (suboptimal_flops) - replace "abd" with "acb" in docstring example (spellcheck flagged "abd") --- daft/functions/str.py | 4 ++-- src/daft-functions-utf8/src/damerau_levenshtein.rs | 10 +++------- src/daft-functions-utf8/src/jaro_winkler.rs | 2 +- src/daft-functions-utf8/src/levenshtein.rs | 6 +----- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/daft/functions/str.py b/daft/functions/str.py index 3057a9da18b..dd25850f6a7 100644 --- a/daft/functions/str.py +++ b/daft/functions/str.py @@ -1745,7 +1745,7 @@ def damerau_levenshtein_distance(left: Expression, right: Expression) -> Express Examples: >>> import daft >>> from daft.functions import damerau_levenshtein_distance - >>> df = daft.from_pydict({"x": ["abc", "abc", ""], "y": ["bac", "abd", "abc"]}) + >>> df = daft.from_pydict({"x": ["abc", "abc", ""], "y": ["bac", "acb", "abc"]}) >>> df = df.with_column("distance", damerau_levenshtein_distance(df["x"], df["y"])) >>> df.collect() ╭────────┬────────┬──────────╮ @@ -1755,7 +1755,7 @@ def damerau_levenshtein_distance(left: Expression, right: Expression) -> Express ╞════════╪════════╪══════════╡ │ abc ┆ bac ┆ 1 │ ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ - │ abc ┆ abd ┆ 1 │ + │ abc ┆ acb ┆ 1 │ ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ │ ┆ abc ┆ 3 │ ╰────────┴────────┴──────────╯ diff --git a/src/daft-functions-utf8/src/damerau_levenshtein.rs b/src/daft-functions-utf8/src/damerau_levenshtein.rs index 9c9b95fb389..5ec5c7ee869 100644 --- a/src/daft-functions-utf8/src/damerau_levenshtein.rs +++ b/src/daft-functions-utf8/src/damerau_levenshtein.rs @@ -27,8 +27,8 @@ fn compute_damerau_levenshtein_distance(left: &str, right: &str) -> i64 { // Full matrix needed for transposition lookback let mut matrix = vec![vec![0i64; m + 1]; n + 1]; - for i in 0..=n { - matrix[i][0] = i as i64; + for (i, row) in matrix.iter_mut().enumerate() { + row[0] = i as i64; } for j in 0..=m { matrix[0][j] = j as i64; @@ -36,11 +36,7 @@ fn compute_damerau_levenshtein_distance(left: &str, right: &str) -> i64 { for i in 1..=n { for j in 1..=m { - let cost = if left_chars[i - 1] == right_chars[j - 1] { - 0 - } else { - 1 - }; + let cost = i64::from(left_chars[i - 1] != right_chars[j - 1]); matrix[i][j] = (matrix[i - 1][j] + 1) // deletion .min(matrix[i][j - 1] + 1) // insertion diff --git a/src/daft-functions-utf8/src/jaro_winkler.rs b/src/daft-functions-utf8/src/jaro_winkler.rs index 54d92873493..299cddac292 100644 --- a/src/daft-functions-utf8/src/jaro_winkler.rs +++ b/src/daft-functions-utf8/src/jaro_winkler.rs @@ -25,7 +25,7 @@ fn compute_jaro_winkler_similarity(left: &str, right: &str) -> f64 { let p = 0.1; // Standard Winkler scaling factor - jaro + (prefix_len as f64 * p * (1.0 - jaro)) + (prefix_len as f64 * p).mul_add(1.0 - jaro, jaro) } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] diff --git a/src/daft-functions-utf8/src/levenshtein.rs b/src/daft-functions-utf8/src/levenshtein.rs index e160cd0948f..2b2075c8726 100644 --- a/src/daft-functions-utf8/src/levenshtein.rs +++ b/src/daft-functions-utf8/src/levenshtein.rs @@ -36,11 +36,7 @@ fn compute_levenshtein_distance(left: &str, right: &str) -> i64 { for i in 1..=long_len { curr_row[0] = i as i64; for j in 1..=short_len { - let cost = if longer[i - 1] == shorter[j - 1] { - 0 - } else { - 1 - }; + let cost = i64::from(longer[i - 1] != shorter[j - 1]); curr_row[j] = (prev_row[j] + 1) // deletion .min(curr_row[j - 1] + 1) // insertion .min(prev_row[j - 1] + cost); // substitution From fc6adaf85daef299c8ffaceca3e6902dfa4b7d3d Mon Sep 17 00:00:00 2001 From: Nishaanth Reddy Date: Wed, 3 Jun 2026 21:14:04 -0700 Subject: [PATCH 4/5] refactor(functions): address review feedback on string distance functions - document damerau_levenshtein_distance computes OSA variant, noting it differs from true Damerau-Levenshtein for overlapping transpositions - extract shared binary_str_distance and binary_str_distance_to_field helpers in utils.rs - collapse identical call/get_return_field boilerplate across the 4 UDFs into the generic helpers (-91 net lines) - update Rust docstring for damerau_levenshtein to note OSA variant --- daft/functions/str.py | 11 ++- .../src/damerau_levenshtein.rs | 66 +++------------ src/daft-functions-utf8/src/jaro.rs | 62 ++------------ src/daft-functions-utf8/src/jaro_winkler.rs | 67 +++------------ src/daft-functions-utf8/src/levenshtein.rs | 62 ++------------ src/daft-functions-utf8/src/utils.rs | 83 +++++++++++++++++++ 6 files changed, 130 insertions(+), 221 deletions(-) diff --git a/daft/functions/str.py b/daft/functions/str.py index dd25850f6a7..6fce16ec98e 100644 --- a/daft/functions/str.py +++ b/daft/functions/str.py @@ -1734,13 +1734,20 @@ def damerau_levenshtein_distance(left: Expression, right: Expression) -> Express adjacent characters as a single edit operation (in addition to insertions, deletions, and substitutions). + Note: + This computes the Optimal String Alignment (OSA) variant, which does not + allow a substring to be edited more than once. Results may differ from the + true Damerau-Levenshtein distance for inputs with overlapping transpositions + (e.g., ``"CA"`` to ``"ABC"`` is 3 under OSA but 2 under true + Damerau-Levenshtein). OSA does not satisfy the triangle inequality. + Args: left: The left string expression to compare. right: The right string expression to compare against. Returns: - The Damerau-Levenshtein distance for each pair of strings. Returns null when - either input is null. + The Damerau-Levenshtein (OSA) distance for each pair of strings. Returns null + when either input is null. Examples: >>> import daft diff --git a/src/daft-functions-utf8/src/damerau_levenshtein.rs b/src/daft-functions-utf8/src/damerau_levenshtein.rs index 5ec5c7ee869..b837e58ddd4 100644 --- a/src/daft-functions-utf8/src/damerau_levenshtein.rs +++ b/src/daft-functions-utf8/src/damerau_levenshtein.rs @@ -1,11 +1,7 @@ -use std::sync::Arc; - -use arrow_buffer::NullBufferBuilder; -use daft_core::prelude::{Int64Array, IntoSeries}; use daft_dsl::functions::{prelude::*, scalar::ScalarFn}; use serde::{Deserialize, Serialize}; -const NULL_SENTINEL: i64 = 0; +use crate::utils::{binary_str_distance, binary_str_distance_to_field}; /// Compute the Damerau-Levenshtein distance (optimal string alignment variant). /// This extends Levenshtein by also allowing transposition of two adjacent characters @@ -70,34 +66,12 @@ impl ScalarUDF for DamerauLevenshteinDistance { inputs: FunctionArgs, _ctx: &daft_dsl::functions::scalar::EvalContext, ) -> DaftResult { - let left = inputs.required(0)?.cast(&DataType::Utf8)?; - let right = inputs.required(1)?.cast(&DataType::Utf8)?; - - left.with_utf8_array(|left| { - right.with_utf8_array(|right| { - let len = left.len(); - let mut values = Vec::with_capacity(len); - let mut validity = NullBufferBuilder::new(len); - - for i in 0..len { - match (left.get(i), right.get(i)) { - (Some(l), Some(r)) => { - values.push(compute_damerau_levenshtein_distance(l, r)); - validity.append_non_null(); - } - _ => { - values.push(NULL_SENTINEL); - validity.append_null(); - } - } - } - - let field = Arc::new(Field::new(self.name(), DataType::Int64)); - let result = Int64Array::from_field_and_values(field, values) - .with_nulls(validity.finish())?; - Ok(result.into_series()) - }) - }) + binary_str_distance::( + inputs, + self.name(), + DataType::Int64, + compute_damerau_levenshtein_distance, + ) } fn get_return_field( @@ -105,33 +79,15 @@ impl ScalarUDF for DamerauLevenshteinDistance { inputs: FunctionArgs, schema: &Schema, ) -> DaftResult { - ensure!( - inputs.len() == 2, - SchemaMismatch: "Expected 2 inputs, but received {}", - inputs.len() - ); - - let left = inputs.required(0)?.to_field(schema)?; - let right = inputs.required(1)?.to_field(schema)?; - - ensure!( - left.dtype.is_string() || left.dtype == DataType::Null, - TypeError: "First argument must be a string, got {}", - left.dtype - ); - ensure!( - right.dtype.is_string() || right.dtype == DataType::Null, - TypeError: "Second argument must be a string, got {}", - right.dtype - ); - - Ok(Field::new(self.name(), DataType::Int64)) + binary_str_distance_to_field(inputs, schema, self.name(), DataType::Int64) } fn docstring(&self) -> &'static str { "Compute the Damerau-Levenshtein distance between two strings. This extends the \ Levenshtein distance by also counting transpositions of two adjacent characters \ - as a single edit operation. Returns null when either input is null." + as a single edit operation. This computes the Optimal String Alignment (OSA) \ + variant, which may differ from true Damerau-Levenshtein for inputs with \ + overlapping transpositions. Returns null when either input is null." } } diff --git a/src/daft-functions-utf8/src/jaro.rs b/src/daft-functions-utf8/src/jaro.rs index febc7ff89c8..1da2d261d47 100644 --- a/src/daft-functions-utf8/src/jaro.rs +++ b/src/daft-functions-utf8/src/jaro.rs @@ -1,11 +1,7 @@ -use std::sync::Arc; - -use arrow_buffer::NullBufferBuilder; -use daft_core::prelude::{Float64Array, IntoSeries}; use daft_dsl::functions::{prelude::*, scalar::ScalarFn}; use serde::{Deserialize, Serialize}; -const NULL_SENTINEL: f64 = 0.0; +use crate::utils::{binary_str_distance, binary_str_distance_to_field}; /// Compute Jaro similarity between two strings. /// Returns a value between 0.0 (no similarity) and 1.0 (identical). @@ -85,34 +81,12 @@ impl ScalarUDF for JaroSimilarity { inputs: FunctionArgs, _ctx: &daft_dsl::functions::scalar::EvalContext, ) -> DaftResult { - let left = inputs.required(0)?.cast(&DataType::Utf8)?; - let right = inputs.required(1)?.cast(&DataType::Utf8)?; - - left.with_utf8_array(|left| { - right.with_utf8_array(|right| { - let len = left.len(); - let mut values = Vec::with_capacity(len); - let mut validity = NullBufferBuilder::new(len); - - for i in 0..len { - match (left.get(i), right.get(i)) { - (Some(l), Some(r)) => { - values.push(compute_jaro_similarity(l, r)); - validity.append_non_null(); - } - _ => { - values.push(NULL_SENTINEL); - validity.append_null(); - } - } - } - - let field = Arc::new(Field::new(self.name(), DataType::Float64)); - let result = Float64Array::from_field_and_values(field, values) - .with_nulls(validity.finish())?; - Ok(result.into_series()) - }) - }) + binary_str_distance::( + inputs, + self.name(), + DataType::Float64, + compute_jaro_similarity, + ) } fn get_return_field( @@ -120,27 +94,7 @@ impl ScalarUDF for JaroSimilarity { inputs: FunctionArgs, schema: &Schema, ) -> DaftResult { - ensure!( - inputs.len() == 2, - SchemaMismatch: "Expected 2 inputs, but received {}", - inputs.len() - ); - - let left = inputs.required(0)?.to_field(schema)?; - let right = inputs.required(1)?.to_field(schema)?; - - ensure!( - left.dtype.is_string() || left.dtype == DataType::Null, - TypeError: "First argument must be a string, got {}", - left.dtype - ); - ensure!( - right.dtype.is_string() || right.dtype == DataType::Null, - TypeError: "Second argument must be a string, got {}", - right.dtype - ); - - Ok(Field::new(self.name(), DataType::Float64)) + binary_str_distance_to_field(inputs, schema, self.name(), DataType::Float64) } fn docstring(&self) -> &'static str { diff --git a/src/daft-functions-utf8/src/jaro_winkler.rs b/src/daft-functions-utf8/src/jaro_winkler.rs index 299cddac292..1bc12182f82 100644 --- a/src/daft-functions-utf8/src/jaro_winkler.rs +++ b/src/daft-functions-utf8/src/jaro_winkler.rs @@ -1,13 +1,10 @@ -use std::sync::Arc; - -use arrow_buffer::NullBufferBuilder; -use daft_core::prelude::{Float64Array, IntoSeries}; use daft_dsl::functions::{prelude::*, scalar::ScalarFn}; use serde::{Deserialize, Serialize}; -use crate::jaro::compute_jaro_similarity; - -const NULL_SENTINEL: f64 = 0.0; +use crate::{ + jaro::compute_jaro_similarity, + utils::{binary_str_distance, binary_str_distance_to_field}, +}; /// Compute Jaro-Winkler similarity. Applies a prefix bonus to the Jaro similarity /// for strings that share a common prefix (up to 4 characters). @@ -42,34 +39,12 @@ impl ScalarUDF for JaroWinklerSimilarity { inputs: FunctionArgs, _ctx: &daft_dsl::functions::scalar::EvalContext, ) -> DaftResult { - let left = inputs.required(0)?.cast(&DataType::Utf8)?; - let right = inputs.required(1)?.cast(&DataType::Utf8)?; - - left.with_utf8_array(|left| { - right.with_utf8_array(|right| { - let len = left.len(); - let mut values = Vec::with_capacity(len); - let mut validity = NullBufferBuilder::new(len); - - for i in 0..len { - match (left.get(i), right.get(i)) { - (Some(l), Some(r)) => { - values.push(compute_jaro_winkler_similarity(l, r)); - validity.append_non_null(); - } - _ => { - values.push(NULL_SENTINEL); - validity.append_null(); - } - } - } - - let field = Arc::new(Field::new(self.name(), DataType::Float64)); - let result = Float64Array::from_field_and_values(field, values) - .with_nulls(validity.finish())?; - Ok(result.into_series()) - }) - }) + binary_str_distance::( + inputs, + self.name(), + DataType::Float64, + compute_jaro_winkler_similarity, + ) } fn get_return_field( @@ -77,27 +52,7 @@ impl ScalarUDF for JaroWinklerSimilarity { inputs: FunctionArgs, schema: &Schema, ) -> DaftResult { - ensure!( - inputs.len() == 2, - SchemaMismatch: "Expected 2 inputs, but received {}", - inputs.len() - ); - - let left = inputs.required(0)?.to_field(schema)?; - let right = inputs.required(1)?.to_field(schema)?; - - ensure!( - left.dtype.is_string() || left.dtype == DataType::Null, - TypeError: "First argument must be a string, got {}", - left.dtype - ); - ensure!( - right.dtype.is_string() || right.dtype == DataType::Null, - TypeError: "Second argument must be a string, got {}", - right.dtype - ); - - Ok(Field::new(self.name(), DataType::Float64)) + binary_str_distance_to_field(inputs, schema, self.name(), DataType::Float64) } fn docstring(&self) -> &'static str { diff --git a/src/daft-functions-utf8/src/levenshtein.rs b/src/daft-functions-utf8/src/levenshtein.rs index 2b2075c8726..6deb66656b3 100644 --- a/src/daft-functions-utf8/src/levenshtein.rs +++ b/src/daft-functions-utf8/src/levenshtein.rs @@ -1,11 +1,7 @@ -use std::sync::Arc; - -use arrow_buffer::NullBufferBuilder; -use daft_core::prelude::{Int64Array, IntoSeries}; use daft_dsl::functions::{prelude::*, scalar::ScalarFn}; use serde::{Deserialize, Serialize}; -const NULL_SENTINEL: i64 = 0; +use crate::utils::{binary_str_distance, binary_str_distance_to_field}; /// Compute Levenshtein edit distance using Wagner-Fischer algorithm. /// Uses O(min(n,m)) space by only keeping two rows of the DP matrix. @@ -61,34 +57,12 @@ impl ScalarUDF for LevenshteinDistance { inputs: FunctionArgs, _ctx: &daft_dsl::functions::scalar::EvalContext, ) -> DaftResult { - let left = inputs.required(0)?.cast(&DataType::Utf8)?; - let right = inputs.required(1)?.cast(&DataType::Utf8)?; - - left.with_utf8_array(|left| { - right.with_utf8_array(|right| { - let len = left.len(); - let mut values = Vec::with_capacity(len); - let mut validity = NullBufferBuilder::new(len); - - for i in 0..len { - match (left.get(i), right.get(i)) { - (Some(l), Some(r)) => { - values.push(compute_levenshtein_distance(l, r)); - validity.append_non_null(); - } - _ => { - values.push(NULL_SENTINEL); - validity.append_null(); - } - } - } - - let field = Arc::new(Field::new(self.name(), DataType::Int64)); - let result = Int64Array::from_field_and_values(field, values) - .with_nulls(validity.finish())?; - Ok(result.into_series()) - }) - }) + binary_str_distance::( + inputs, + self.name(), + DataType::Int64, + compute_levenshtein_distance, + ) } fn get_return_field( @@ -96,27 +70,7 @@ impl ScalarUDF for LevenshteinDistance { inputs: FunctionArgs, schema: &Schema, ) -> DaftResult { - ensure!( - inputs.len() == 2, - SchemaMismatch: "Expected 2 inputs, but received {}", - inputs.len() - ); - - let left = inputs.required(0)?.to_field(schema)?; - let right = inputs.required(1)?.to_field(schema)?; - - ensure!( - left.dtype.is_string() || left.dtype == DataType::Null, - TypeError: "First argument must be a string, got {}", - left.dtype - ); - ensure!( - right.dtype.is_string() || right.dtype == DataType::Null, - TypeError: "Second argument must be a string, got {}", - right.dtype - ); - - Ok(Field::new(self.name(), DataType::Int64)) + binary_str_distance_to_field(inputs, schema, self.name(), DataType::Int64) } fn docstring(&self) -> &'static str { diff --git a/src/daft-functions-utf8/src/utils.rs b/src/daft-functions-utf8/src/utils.rs index 359ff78b68f..22b18a61b9b 100644 --- a/src/daft-functions-utf8/src/utils.rs +++ b/src/daft-functions-utf8/src/utils.rs @@ -1,9 +1,11 @@ use std::{borrow::Cow, sync::Arc}; use arrow::array::{Datum, Scalar}; +use arrow_buffer::NullBufferBuilder; use common_error::{DaftError, DaftResult, ensure}; use daft_core::{ array::{DataArray, iterator::Utf8Iter}, + datatypes::DaftPrimitiveType, prelude::{BooleanArray, DaftPhysicalType, DataType, Field, FullNull, Schema, Utf8Array}, series::{IntoSeries, Series}, }; @@ -247,3 +249,84 @@ pub(crate) fn binary_utf8_to_field( ); Ok(Field::new(input.name, return_dtype)) } + +/// Evaluate a pairwise string distance/similarity function over two string inputs. +/// +/// Casts both inputs to Utf8, iterates row-by-row tracking nulls, and produces a +/// numeric output array of type `T`. Returns null for a row when either input is null. +/// Shared by the string distance/similarity UDFs (levenshtein, jaro, jaro-winkler, +/// damerau-levenshtein) to avoid duplicating the cast/iterate/null-track/build logic. +pub(crate) fn binary_str_distance( + inputs: FunctionArgs, + name: &'static str, + return_dtype: DataType, + compute: F, +) -> DaftResult +where + T: DaftPrimitiveType, + T::Native: Default, + F: Fn(&str, &str) -> T::Native, + DataArray: IntoSeries, +{ + let left = inputs.required(0)?.cast(&DataType::Utf8)?; + let right = inputs.required(1)?.cast(&DataType::Utf8)?; + let field = Arc::new(Field::new(name, return_dtype)); + + left.with_utf8_array(|left| { + right.with_utf8_array(|right| { + let len = left.len(); + let mut values = Vec::with_capacity(len); + let mut validity = NullBufferBuilder::new(len); + + for i in 0..len { + match (left.get(i), right.get(i)) { + (Some(l), Some(r)) => { + values.push(compute(l, r)); + validity.append_non_null(); + } + _ => { + values.push(T::Native::default()); + validity.append_null(); + } + } + } + + let result = DataArray::::from_field_and_values(field.clone(), values) + .with_nulls(validity.finish())?; + Ok(result.into_series()) + }) + }) +} + +/// Compute the return field for a pairwise string distance/similarity function. +/// +/// Validates that both arguments are string-typed (or Null) and returns a field with +/// the given `return_dtype`. Shared by the string distance/similarity UDFs. +pub(crate) fn binary_str_distance_to_field( + inputs: FunctionArgs, + schema: &Schema, + name: &'static str, + return_dtype: DataType, +) -> DaftResult { + ensure!( + inputs.len() == 2, + SchemaMismatch: "Expected 2 inputs, but received {}", + inputs.len() + ); + + let left = inputs.required(0)?.to_field(schema)?; + let right = inputs.required(1)?.to_field(schema)?; + + ensure!( + left.dtype.is_string() || left.dtype == DataType::Null, + TypeError: "First argument must be a string, got {}", + left.dtype + ); + ensure!( + right.dtype.is_string() || right.dtype == DataType::Null, + TypeError: "Second argument must be a string, got {}", + right.dtype + ); + + Ok(Field::new(name, return_dtype)) +} From 63771041a9b69b26be30e530d715f311a7e3771f Mon Sep 17 00:00:00 2001 From: Nishaanth Reddy Date: Thu, 4 Jun 2026 17:03:46 -0700 Subject: [PATCH 5/5] fix(functions): support scalar broadcasting in string distance functions - rewrite binary_str_distance to use parse_inputs + create_broadcasted_str_iter, matching the broadcast-aware pattern of other utf8 helpers - fixes out-of-bounds panic for col-scalar (e.g. levenshtein_distance(col("a"), "kitten")) and wrong-length output for scalar-col inputs - handle full-null and empty-input cases explicitly - add TestScalarBroadcast regression tests covering col-scalar, scalar-col, and null-scalar - addresses maintainer review feedback (PR #7068) --- src/daft-functions-utf8/src/utils.rs | 32 ++++++++++++++++++------- tests/functions/test_string_distance.py | 28 +++++++++++++++++++++- 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/daft-functions-utf8/src/utils.rs b/src/daft-functions-utf8/src/utils.rs index 22b18a61b9b..b77d87a8154 100644 --- a/src/daft-functions-utf8/src/utils.rs +++ b/src/daft-functions-utf8/src/utils.rs @@ -252,8 +252,10 @@ pub(crate) fn binary_utf8_to_field( /// Evaluate a pairwise string distance/similarity function over two string inputs. /// -/// Casts both inputs to Utf8, iterates row-by-row tracking nulls, and produces a -/// numeric output array of type `T`. Returns null for a row when either input is null. +/// Casts both inputs to Utf8, then iterates row-by-row tracking nulls and produces a +/// numeric output array of type `T`. Supports scalar broadcasting on either side +/// (column-column, column-scalar, scalar-column) via `parse_inputs` and +/// `create_broadcasted_str_iter`. Returns null for a row when either input is null. /// Shared by the string distance/similarity UDFs (levenshtein, jaro, jaro-winkler, /// damerau-levenshtein) to avoid duplicating the cast/iterate/null-track/build logic. pub(crate) fn binary_str_distance( @@ -270,16 +272,30 @@ where { let left = inputs.required(0)?.cast(&DataType::Utf8)?; let right = inputs.required(1)?.cast(&DataType::Utf8)?; - let field = Arc::new(Field::new(name, return_dtype)); + let field = Arc::new(Field::new(name, return_dtype.clone())); left.with_utf8_array(|left| { right.with_utf8_array(|right| { - let len = left.len(); - let mut values = Vec::with_capacity(len); - let mut validity = NullBufferBuilder::new(len); + let (is_full_null, expected_size) = parse_inputs(left, &[right]) + .map_err(|e| DaftError::ValueError(format!("Error in {name}: {e}")))?; - for i in 0..len { - match (left.get(i), right.get(i)) { + if is_full_null { + return Ok( + DataArray::::full_null(name, &return_dtype, expected_size).into_series(), + ); + } + if expected_size == 0 { + return Ok(DataArray::::empty(name, &return_dtype).into_series()); + } + + let left_iter = create_broadcasted_str_iter(left, expected_size); + let right_iter = create_broadcasted_str_iter(right, expected_size); + + let mut values = Vec::with_capacity(expected_size); + let mut validity = NullBufferBuilder::new(expected_size); + + for (l, r) in left_iter.zip(right_iter) { + match (l, r) { (Some(l), Some(r)) => { values.push(compute(l, r)); validity.append_non_null(); diff --git a/tests/functions/test_string_distance.py b/tests/functions/test_string_distance.py index 9d3a23465ff..cd25cf4ae94 100644 --- a/tests/functions/test_string_distance.py +++ b/tests/functions/test_string_distance.py @@ -5,7 +5,7 @@ import pytest import daft -from daft import col +from daft import col, lit from daft.functions import ( damerau_levenshtein_distance, jaro_similarity, @@ -59,6 +59,32 @@ def test_expression_method(self): assert result.to_pydict()["dist"] == [3] +class TestScalarBroadcast: + """Scalar broadcasting on either side (column-scalar, scalar-column).""" + + def test_column_scalar(self): + df = daft.from_pydict({"a": ["kitten", "sitting", "kitten"]}) + result = df.with_column("dist", levenshtein_distance(col("a"), lit("kitten"))).collect() + assert result.to_pydict()["dist"] == [0, 3, 0] + + def test_scalar_column(self): + df = daft.from_pydict({"a": ["kitten", "sitting", "kitten"]}) + result = df.with_column("dist", levenshtein_distance(lit("kitten"), col("a"))).collect() + assert result.to_pydict()["dist"] == [0, 3, 0] + + def test_scalar_column_similarity(self): + df = daft.from_pydict({"a": ["martha", "martha"]}) + result = df.with_column("sim", jaro_similarity(lit("marhta"), col("a"))).collect() + sims = result.to_pydict()["sim"] + assert sims[0] == pytest.approx(0.944444, rel=1e-4) + assert sims[1] == pytest.approx(0.944444, rel=1e-4) + + def test_column_scalar_null_scalar(self): + df = daft.from_pydict({"a": ["kitten", "sitting"]}) + result = df.with_column("dist", levenshtein_distance(col("a"), lit(None))).collect() + assert result.to_pydict()["dist"] == [None, None] + + class TestJaroSimilarity: def test_identical(self): df = daft.from_pydict({"a": ["hello", ""], "b": ["hello", ""]})