diff --git a/daft/functions/__init__.py b/daft/functions/__init__.py index 51958792ff1..40f5c533d48 100644 --- a/daft/functions/__init__.py +++ b/daft/functions/__init__.py @@ -213,6 +213,7 @@ is_inf, not_nan, fill_nan, + width_bucket, ) from .partition import ( partition_days, @@ -612,5 +613,6 @@ "week_of_year", "weekofyear", "when", + "width_bucket", "year", ] diff --git a/daft/functions/numeric.py b/daft/functions/numeric.py index eeb04c4b881..83a776fde81 100644 --- a/daft/functions/numeric.py +++ b/daft/functions/numeric.py @@ -306,6 +306,23 @@ def conv(expr: Expression, from_base: int, to_base: int) -> Expression: return Expression._call_builtin_scalar_fn("conv", expr, from_base, to_base) +def width_bucket( + value: Expression, + min: Expression, + max: Expression, + num_bucket: Expression, +) -> Expression: + """Returns the 1-indexed bucket of ``value`` in an equiwidth histogram over ``[min, max]``. + + Returns ``0`` below the range and ``num_bucket + 1`` at or above; descending bounds + (``min > max``) flip the orientation. Non-integer ``num_bucket`` truncates toward zero. + Examples: ``width_bucket(5.3, 0.2, 10.6, 5) == 3``, ``width_bucket(-2.1, 1.3, 3.4, 3) == 0``. + Returns NULL when ``num_bucket <= 0``, ``min == max``, ``value`` is NaN, or + ``min``/``max`` is NaN/Infinite. + """ + return Expression._call_builtin_scalar_fn("width_bucket", value, min, max, num_bucket) + + def is_nan(expr: Expression) -> Expression: """Checks if values are NaN (a special float value indicating not-a-number). diff --git a/src/daft-functions/src/numeric/mod.rs b/src/daft-functions/src/numeric/mod.rs index 3bd8447e848..6ba75411f0e 100644 --- a/src/daft-functions/src/numeric/mod.rs +++ b/src/daft-functions/src/numeric/mod.rs @@ -18,6 +18,7 @@ pub mod round; pub mod sign; pub mod sqrt; pub mod trigonometry; +pub mod width_bucket; use abs::Abs; use bin::Bin; @@ -44,6 +45,7 @@ use power::Power; use round::Round; use sign::{Negate, Sign}; use sqrt::Sqrt; +use width_bucket::WidthBucket; fn to_field_numeric(f: &dyn ScalarUDF, input: &Expr, schema: &Schema) -> DaftResult { let field = input.to_field(schema)?; @@ -91,6 +93,7 @@ impl FunctionModule for NumericFunctions { parent.add_fn(Sign); parent.add_fn(Negate); parent.add_fn(Sqrt); + parent.add_fn(WidthBucket); // trig functions use trigonometry::*; diff --git a/src/daft-functions/src/numeric/width_bucket.rs b/src/daft-functions/src/numeric/width_bucket.rs new file mode 100644 index 00000000000..38a3e6617ff --- /dev/null +++ b/src/daft-functions/src/numeric/width_bucket.rs @@ -0,0 +1,176 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + datatypes::Int64Array, + prelude::{DataType, Field, Schema}, + series::{IntoSeries, Series}, +}; +use daft_dsl::{ + ExprRef, + functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn}, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct WidthBucket; + +#[derive(FunctionArgs)] +struct WidthBucketArgs { + value: T, + min: T, + max: T, + num_bucket: T, +} + +#[typetag::serde] +impl ScalarUDF for WidthBucket { + fn call( + &self, + inputs: FunctionArgs, + _ctx: &daft_dsl::functions::scalar::EvalContext, + ) -> DaftResult { + let WidthBucketArgs { + value, + min, + max, + num_bucket, + } = inputs.try_into()?; + width_bucket_impl(value, min, max, num_bucket) + } + + fn name(&self) -> &'static str { + "width_bucket" + } + + fn get_return_field( + &self, + inputs: FunctionArgs, + schema: &Schema, + ) -> DaftResult { + let WidthBucketArgs { + value, + min, + max, + num_bucket, + } = inputs.try_into()?; + let value_field = value.to_field(schema)?; + let min_field = min.to_field(schema)?; + let max_field = max.to_field(schema)?; + let num_bucket_field = num_bucket.to_field(schema)?; + for (name, field) in [ + ("value", &value_field), + ("min", &min_field), + ("max", &max_field), + ("num_bucket", &num_bucket_field), + ] { + if !field.dtype.is_numeric() { + return Err(DaftError::TypeError(format!( + "Expected `{name}` of width_bucket to be numeric, got {}", + field.dtype + ))); + } + } + Ok(Field::new(value_field.name, DataType::Int64)) + } + + fn docstring(&self) -> &'static str { + "Returns the bucket number for `value` in an equiwidth histogram with \ + `num_bucket` buckets in the range [min, max]. Returns 0 below the range and \ + num_bucket+1 at or above. Supports descending bounds (min > max). Returns NULL \ + if num_bucket <= 0, min == max, value/min/max is NaN, or min/max is infinite." + } +} + +fn width_bucket_impl( + value: Series, + min: Series, + max: Series, + num_bucket: Series, +) -> DaftResult { + let value = value.cast(&DataType::Float64)?; + let min = min.cast(&DataType::Float64)?; + let max = max.cast(&DataType::Float64)?; + let num_bucket = num_bucket.cast(&DataType::Int64)?; + let (value, min, max, num_bucket) = align_lengths(value, min, max, num_bucket)?; + + let v_arr = value.f64().unwrap(); + let mn_arr = min.f64().unwrap(); + let mx_arr = max.f64().unwrap(); + let nb_arr = num_bucket.i64().unwrap(); + + let iter = v_arr + .iter() + .zip(mn_arr.iter()) + .zip(mx_arr.iter()) + .zip(nb_arr.iter()) + .map(|(((v, mn), mx), nb)| match (v, mn, mx, nb) { + (Some(v), Some(mn), Some(mx), Some(nb)) => compute_bucket(v, mn, mx, nb), + _ => None, + }); + Ok(Int64Array::from_iter(Field::new(v_arr.name(), DataType::Int64), iter).into_series()) +} + +fn align_lengths( + a: Series, + b: Series, + c: Series, + d: Series, +) -> DaftResult<(Series, Series, Series, Series)> { + let lens = [a.len(), b.len(), c.len(), d.len()]; + let max_len = *lens.iter().max().unwrap(); + for &l in &lens { + if l != 1 && l != max_len { + return Err(DaftError::ValueError(format!( + "Cannot apply width_bucket to arrays of different lengths: {} vs {} vs {} vs {}", + lens[0], lens[1], lens[2], lens[3] + ))); + } + } + let bcast = |s: Series| -> DaftResult { + if s.len() == max_len { + Ok(s) + } else { + s.broadcast(max_len) + } + }; + Ok((bcast(a)?, bcast(b)?, bcast(c)?, bcast(d)?)) +} + +/// Mirrors Spark's `WidthBucket` NULL and bucketing semantics for parity. +fn compute_bucket(v: f64, mn: f64, mx: f64, nb: i64) -> Option { + if nb <= 0 + || nb == i64::MAX + || v.is_nan() + || mn == mx + || mn.is_nan() + || mn.is_infinite() + || mx.is_nan() + || mx.is_infinite() + { + return None; + } + // f64 cast can round nb just below i64::MAX up to i64::MAX, so guard +1 with saturating_add. + let bucket = if mn < mx { + if v < mn { + 0 + } else if v >= mx { + nb.saturating_add(1) + } else { + (((nb as f64) * (v - mn) / (mx - mn)) as i64).saturating_add(1) + } + } else { + // Descending: roles of below/above flip so a smaller value sorts higher. + if v > mn { + 0 + } else if v <= mx { + nb.saturating_add(1) + } else { + (((nb as f64) * (mn - v) / (mn - mx)) as i64).saturating_add(1) + } + }; + Some(bucket) +} + +#[must_use] +pub fn width_bucket(value: ExprRef, min: ExprRef, max: ExprRef, num_bucket: ExprRef) -> ExprRef { + ScalarFn::builtin(WidthBucket, vec![value, min, max, num_bucket]).into() +} diff --git a/tests/recordbatch/numeric/test_numeric.py b/tests/recordbatch/numeric/test_numeric.py index 0f6bf83e0e3..50acfc2520c 100644 --- a/tests/recordbatch/numeric/test_numeric.py +++ b/tests/recordbatch/numeric/test_numeric.py @@ -9,6 +9,7 @@ import pytest from daft import col, lit +from daft.functions import width_bucket from daft.recordbatch import MicroPartition from tests.recordbatch import daft_numeric_types @@ -1186,3 +1187,94 @@ def test_pmod_bad_input() -> None: table = MicroPartition.from_pydict({"a": ["x", "y"], "b": [1, 2]}) with pytest.raises(ValueError, match="Expected inputs to pmod to be numeric"): table.eval_expression_list([pmod(col("a"), col("b"))]) + + +def test_width_bucket_basic() -> None: + table = MicroPartition.from_pydict( + { + "v": [5.3, -2.1, 8.1, -0.9], + "mn": [0.2, 1.3, 0.0, 5.2], + "mx": [10.6, 3.4, 5.7, 0.5], + "nb": [5, 3, 4, 2], + } + ) + result = table.eval_expression_list([width_bucket(col("v"), col("mn"), col("mx"), col("nb")).alias("bucket")]) + assert result.get_column_by_name("bucket").to_pylist() == [3, 0, 5, 3] + + +def test_width_bucket_descending() -> None: + """Min > max flips orientation; v > min underflows, v <= max overflows.""" + table = MicroPartition.from_pydict({"v": [11.0, 5.0, 0.0, -1.0]}) + result = table.eval_expression_list([width_bucket(col("v"), lit(10.0), lit(0.0), lit(5)).alias("bucket")]) + # min=10, max=0, nb=5: v=11 underflows -> 0; v=5 mid -> floor(5*(10-5)/10)+1=3; + # v=0 == max -> nb+1=6; v=-1 < max -> nb+1=6. + assert result.get_column_by_name("bucket").to_pylist() == [0, 3, 6, 6] + + +def test_width_bucket_integer_inputs() -> None: + table = MicroPartition.from_pydict({"v": pa.array([3, 5, 7], type=pa.int32())}) + result = table.eval_expression_list([width_bucket(col("v"), lit(0), lit(10), lit(5)).alias("bucket")]) + assert result.get_column_by_name("bucket").to_pylist() == [2, 3, 4] + + +def test_width_bucket_below_and_above_range() -> None: + table = MicroPartition.from_pydict({"v": [-0.1, 0.0, 5.0, 10.0, 10.1]}) + result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")]) + # v=0 is at min so bucket 1; v=10 hits the v>=max branch so nb+1=6. + assert result.get_column_by_name("bucket").to_pylist() == [0, 1, 3, 6, 6] + + +def test_width_bucket_null_returning() -> None: + """Each invalid row exercises a distinct NULL-returning guard in compute_bucket.""" + i64_max = 2**63 - 1 + nan, inf = math.nan, math.inf + table = MicroPartition.from_pydict( + { + # nb<=0 nb==i64::MAX mn==mx v=NaN mn=NaN mn=Inf mx=NaN mx=Inf valid + "v": [1.0, 1.0, 1.0, nan, 1.0, 1.0, 1.0, 1.0, 1.0], + "mn": [0.0, 0.0, 5.0, 0.0, nan, inf, 0.0, 0.0, 0.0], + "mx": [10.0, 10.0, 5.0, 10.0, 10.0, 10.0, nan, inf, 10.0], + "nb": [0, i64_max, 4, 4, 4, 4, 4, 4, 4], + } + ) + result = table.eval_expression_list([width_bucket(col("v"), col("mn"), col("mx"), col("nb")).alias("bucket")]) + values = result.get_column_by_name("bucket").to_pylist() + assert values[:-1] == [None] * 8 + assert values[-1] is not None + + +def test_width_bucket_null_propagation() -> None: + table = MicroPartition.from_pydict({"v": [1.0, None, 5.0]}) + result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")]) + assert result.get_column_by_name("bucket").to_pylist() == [1, None, 3] + + +def test_width_bucket_scalar_broadcast() -> None: + table = MicroPartition.from_pydict({"v": [-1.0, 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 11.0]}) + result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")]) + assert result.get_column_by_name("bucket").to_pylist() == [0, 1, 2, 3, 4, 5, 6, 6] + + +def test_width_bucket_fractional_num_bucket_truncates() -> None: + """Daft promotes any numeric num_bucket to Int64 by truncation.""" + table = MicroPartition.from_pydict({"v": [3.0]}) + result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5.7)).alias("bucket")]) + # 5.7 truncates to 5; floor(5 * 3 / 10) + 1 = 2. + assert result.get_column_by_name("bucket").to_pylist() == [2] + + +def test_width_bucket_bad_input() -> None: + table = MicroPartition.from_pydict({"v": ["a", "b"]}) + with pytest.raises(ValueError, match="Expected `value` of width_bucket to be numeric"): + table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5))]) + + +def test_width_bucket_huge_num_bucket() -> None: + """Nb just below i64::MAX must saturate, not wrap to i64::MIN.""" + nb = 2**63 - 2 # one below the i64::MAX None-guard; (nb as f64) rounds up to 2^63 + table = MicroPartition.from_pydict({"v": [-1.0, 5.0, 11.0]}) + result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(nb)).alias("bucket")]) + values = result.get_column_by_name("bucket").to_pylist() + assert values[0] == 0 # below range + assert values[1] == 2**62 + 1 # 2^63 * 0.5 truncated to i64, then saturating_add(1) + assert values[2] == 2**63 - 1 # at/above range: nb.saturating_add(1) -> i64::MAX