Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
is_inf,
not_nan,
fill_nan,
width_bucket,
)
from .partition import (
partition_days,
Expand Down Expand Up @@ -612,5 +613,6 @@
"week_of_year",
"weekofyear",
"when",
"width_bucket",
"year",
]
17 changes: 17 additions & 0 deletions daft/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,23 @@ def conv(expr: Expression, from_base: int, to_base: int) -> Expression:
return Expression._call_builtin_scalar_fn("conv", expr, from_base, to_base)


def width_bucket(
value: Expression,
min: Expression,
max: Expression,
num_bucket: Expression,
) -> Expression:
"""Returns the 1-indexed bucket of ``value`` in an equiwidth histogram over ``[min, max]``.

Returns ``0`` below the range and ``num_bucket + 1`` at or above; descending bounds
(``min > max``) flip the orientation. Non-integer ``num_bucket`` truncates toward zero.
Examples: ``width_bucket(5.3, 0.2, 10.6, 5) == 3``, ``width_bucket(-2.1, 1.3, 3.4, 3) == 0``.
Returns NULL when ``num_bucket <= 0``, ``min == max``, ``value`` is NaN, or
``min``/``max`` is NaN/Infinite.
"""
return Expression._call_builtin_scalar_fn("width_bucket", value, min, max, num_bucket)


def is_nan(expr: Expression) -> Expression:
"""Checks if values are NaN (a special float value indicating not-a-number).

Expand Down
3 changes: 3 additions & 0 deletions src/daft-functions/src/numeric/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod round;
pub mod sign;
pub mod sqrt;
pub mod trigonometry;
pub mod width_bucket;

use abs::Abs;
use bin::Bin;
Expand All @@ -44,6 +45,7 @@ use power::Power;
use round::Round;
use sign::{Negate, Sign};
use sqrt::Sqrt;
use width_bucket::WidthBucket;

fn to_field_numeric(f: &dyn ScalarUDF, input: &Expr, schema: &Schema) -> DaftResult<Field> {
let field = input.to_field(schema)?;
Expand Down Expand Up @@ -91,6 +93,7 @@ impl FunctionModule for NumericFunctions {
parent.add_fn(Sign);
parent.add_fn(Negate);
parent.add_fn(Sqrt);
parent.add_fn(WidthBucket);

// trig functions
use trigonometry::*;
Expand Down
176 changes: 176 additions & 0 deletions src/daft-functions/src/numeric/width_bucket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use common_error::{DaftError, DaftResult};
use daft_core::{
datatypes::Int64Array,
prelude::{DataType, Field, Schema},
series::{IntoSeries, Series},
};
use daft_dsl::{
ExprRef,
functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn},
};
use serde::{Deserialize, Serialize};

#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct WidthBucket;

#[derive(FunctionArgs)]
struct WidthBucketArgs<T> {
value: T,
min: T,
max: T,
num_bucket: T,
}

#[typetag::serde]
impl ScalarUDF for WidthBucket {
fn call(
&self,
inputs: FunctionArgs<Series>,
_ctx: &daft_dsl::functions::scalar::EvalContext,
) -> DaftResult<Series> {
let WidthBucketArgs {
value,
min,
max,
num_bucket,
} = inputs.try_into()?;
width_bucket_impl(value, min, max, num_bucket)
}

fn name(&self) -> &'static str {
"width_bucket"
}

fn get_return_field(
&self,
inputs: FunctionArgs<ExprRef>,
schema: &Schema,
) -> DaftResult<Field> {
let WidthBucketArgs {
value,
min,
max,
num_bucket,
} = inputs.try_into()?;
let value_field = value.to_field(schema)?;
let min_field = min.to_field(schema)?;
let max_field = max.to_field(schema)?;
let num_bucket_field = num_bucket.to_field(schema)?;
for (name, field) in [
("value", &value_field),
("min", &min_field),
("max", &max_field),
("num_bucket", &num_bucket_field),
] {
if !field.dtype.is_numeric() {
return Err(DaftError::TypeError(format!(
"Expected `{name}` of width_bucket to be numeric, got {}",
field.dtype
)));
}
}
Ok(Field::new(value_field.name, DataType::Int64))
}

fn docstring(&self) -> &'static str {
"Returns the bucket number for `value` in an equiwidth histogram with \
`num_bucket` buckets in the range [min, max]. Returns 0 below the range and \
num_bucket+1 at or above. Supports descending bounds (min > max). Returns NULL \
if num_bucket <= 0, min == max, value/min/max is NaN, or min/max is infinite."
}
}

fn width_bucket_impl(
value: Series,
min: Series,
max: Series,
num_bucket: Series,
) -> DaftResult<Series> {
let value = value.cast(&DataType::Float64)?;
let min = min.cast(&DataType::Float64)?;
let max = max.cast(&DataType::Float64)?;
let num_bucket = num_bucket.cast(&DataType::Int64)?;
let (value, min, max, num_bucket) = align_lengths(value, min, max, num_bucket)?;

let v_arr = value.f64().unwrap();
let mn_arr = min.f64().unwrap();
let mx_arr = max.f64().unwrap();
let nb_arr = num_bucket.i64().unwrap();

let iter = v_arr
.iter()
.zip(mn_arr.iter())
.zip(mx_arr.iter())
.zip(nb_arr.iter())
.map(|(((v, mn), mx), nb)| match (v, mn, mx, nb) {
(Some(v), Some(mn), Some(mx), Some(nb)) => compute_bucket(v, mn, mx, nb),
_ => None,
});
Ok(Int64Array::from_iter(Field::new(v_arr.name(), DataType::Int64), iter).into_series())
}

fn align_lengths(
a: Series,
b: Series,
c: Series,
d: Series,
) -> DaftResult<(Series, Series, Series, Series)> {
let lens = [a.len(), b.len(), c.len(), d.len()];
let max_len = *lens.iter().max().unwrap();
for &l in &lens {
if l != 1 && l != max_len {
return Err(DaftError::ValueError(format!(
"Cannot apply width_bucket to arrays of different lengths: {} vs {} vs {} vs {}",
lens[0], lens[1], lens[2], lens[3]
)));
}
}
let bcast = |s: Series| -> DaftResult<Series> {
if s.len() == max_len {
Ok(s)
} else {
s.broadcast(max_len)
}
};
Ok((bcast(a)?, bcast(b)?, bcast(c)?, bcast(d)?))
}

/// Mirrors Spark's `WidthBucket` NULL and bucketing semantics for parity.
fn compute_bucket(v: f64, mn: f64, mx: f64, nb: i64) -> Option<i64> {
if nb <= 0
|| nb == i64::MAX
|| v.is_nan()
|| mn == mx
|| mn.is_nan()
|| mn.is_infinite()
|| mx.is_nan()
|| mx.is_infinite()
{
return None;
}
// f64 cast can round nb just below i64::MAX up to i64::MAX, so guard +1 with saturating_add.
let bucket = if mn < mx {
if v < mn {
0
} else if v >= mx {
nb.saturating_add(1)
} else {
(((nb as f64) * (v - mn) / (mx - mn)) as i64).saturating_add(1)
}
} else {
// Descending: roles of below/above flip so a smaller value sorts higher.
if v > mn {
0
} else if v <= mx {
nb.saturating_add(1)
} else {
(((nb as f64) * (mn - v) / (mn - mx)) as i64).saturating_add(1)
}
};
Some(bucket)
}

#[must_use]
pub fn width_bucket(value: ExprRef, min: ExprRef, max: ExprRef, num_bucket: ExprRef) -> ExprRef {
ScalarFn::builtin(WidthBucket, vec![value, min, max, num_bucket]).into()
}
92 changes: 92 additions & 0 deletions tests/recordbatch/numeric/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

from daft import col, lit
from daft.functions import width_bucket
from daft.recordbatch import MicroPartition
from tests.recordbatch import daft_numeric_types

Expand Down Expand Up @@ -1186,3 +1187,94 @@ def test_pmod_bad_input() -> None:
table = MicroPartition.from_pydict({"a": ["x", "y"], "b": [1, 2]})
with pytest.raises(ValueError, match="Expected inputs to pmod to be numeric"):
table.eval_expression_list([pmod(col("a"), col("b"))])


def test_width_bucket_basic() -> None:
table = MicroPartition.from_pydict(
{
"v": [5.3, -2.1, 8.1, -0.9],
"mn": [0.2, 1.3, 0.0, 5.2],
"mx": [10.6, 3.4, 5.7, 0.5],
"nb": [5, 3, 4, 2],
}
)
result = table.eval_expression_list([width_bucket(col("v"), col("mn"), col("mx"), col("nb")).alias("bucket")])
assert result.get_column_by_name("bucket").to_pylist() == [3, 0, 5, 3]


def test_width_bucket_descending() -> None:
"""Min > max flips orientation; v > min underflows, v <= max overflows."""
table = MicroPartition.from_pydict({"v": [11.0, 5.0, 0.0, -1.0]})
result = table.eval_expression_list([width_bucket(col("v"), lit(10.0), lit(0.0), lit(5)).alias("bucket")])
# min=10, max=0, nb=5: v=11 underflows -> 0; v=5 mid -> floor(5*(10-5)/10)+1=3;
# v=0 == max -> nb+1=6; v=-1 < max -> nb+1=6.
assert result.get_column_by_name("bucket").to_pylist() == [0, 3, 6, 6]


def test_width_bucket_integer_inputs() -> None:
table = MicroPartition.from_pydict({"v": pa.array([3, 5, 7], type=pa.int32())})
result = table.eval_expression_list([width_bucket(col("v"), lit(0), lit(10), lit(5)).alias("bucket")])
assert result.get_column_by_name("bucket").to_pylist() == [2, 3, 4]


def test_width_bucket_below_and_above_range() -> None:
table = MicroPartition.from_pydict({"v": [-0.1, 0.0, 5.0, 10.0, 10.1]})
result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")])
# v=0 is at min so bucket 1; v=10 hits the v>=max branch so nb+1=6.
assert result.get_column_by_name("bucket").to_pylist() == [0, 1, 3, 6, 6]


def test_width_bucket_null_returning() -> None:
"""Each invalid row exercises a distinct NULL-returning guard in compute_bucket."""
i64_max = 2**63 - 1
nan, inf = math.nan, math.inf
table = MicroPartition.from_pydict(
{
# nb<=0 nb==i64::MAX mn==mx v=NaN mn=NaN mn=Inf mx=NaN mx=Inf valid
"v": [1.0, 1.0, 1.0, nan, 1.0, 1.0, 1.0, 1.0, 1.0],
"mn": [0.0, 0.0, 5.0, 0.0, nan, inf, 0.0, 0.0, 0.0],
"mx": [10.0, 10.0, 5.0, 10.0, 10.0, 10.0, nan, inf, 10.0],
"nb": [0, i64_max, 4, 4, 4, 4, 4, 4, 4],
}
)
result = table.eval_expression_list([width_bucket(col("v"), col("mn"), col("mx"), col("nb")).alias("bucket")])
values = result.get_column_by_name("bucket").to_pylist()
assert values[:-1] == [None] * 8
assert values[-1] is not None


def test_width_bucket_null_propagation() -> None:
table = MicroPartition.from_pydict({"v": [1.0, None, 5.0]})
result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")])
assert result.get_column_by_name("bucket").to_pylist() == [1, None, 3]


def test_width_bucket_scalar_broadcast() -> None:
table = MicroPartition.from_pydict({"v": [-1.0, 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 11.0]})
result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")])
assert result.get_column_by_name("bucket").to_pylist() == [0, 1, 2, 3, 4, 5, 6, 6]


def test_width_bucket_fractional_num_bucket_truncates() -> None:
"""Daft promotes any numeric num_bucket to Int64 by truncation."""
table = MicroPartition.from_pydict({"v": [3.0]})
result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5.7)).alias("bucket")])
# 5.7 truncates to 5; floor(5 * 3 / 10) + 1 = 2.
assert result.get_column_by_name("bucket").to_pylist() == [2]


def test_width_bucket_bad_input() -> None:
table = MicroPartition.from_pydict({"v": ["a", "b"]})
with pytest.raises(ValueError, match="Expected `value` of width_bucket to be numeric"):
table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5))])


def test_width_bucket_huge_num_bucket() -> None:
"""Nb just below i64::MAX must saturate, not wrap to i64::MIN."""
nb = 2**63 - 2 # one below the i64::MAX None-guard; (nb as f64) rounds up to 2^63
table = MicroPartition.from_pydict({"v": [-1.0, 5.0, 11.0]})
result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(nb)).alias("bucket")])
values = result.get_column_by_name("bucket").to_pylist()
assert values[0] == 0 # below range
assert values[1] == 2**62 + 1 # 2^63 * 0.5 truncated to i64, then saturating_add(1)
assert values[2] == 2**63 - 1 # at/above range: nb.saturating_add(1) -> i64::MAX
Loading