Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
between,
bin,
conv,
width_bucket,
is_nan,
is_inf,
not_nan,
Expand Down Expand Up @@ -612,5 +613,6 @@
"week_of_year",
"weekofyear",
"when",
"width_bucket",
"year",
]
17 changes: 17 additions & 0 deletions daft/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,23 @@ def conv(expr: Expression, from_base: int, to_base: int) -> Expression:
return Expression._call_builtin_scalar_fn("conv", expr, from_base, to_base)


def width_bucket(
value: Expression,
min: Expression,
max: Expression,
num_bucket: Expression,
) -> Expression:
"""Returns the 1-indexed bucket of ``value`` in an equiwidth histogram over ``[min, max]``.

Returns ``0`` below the range and ``num_bucket + 1`` at or above; descending bounds
(``min > max``) flip the orientation. Non-integer ``num_bucket`` truncates toward zero.
Examples: ``width_bucket(5.3, 0.2, 10.6, 5) == 3``, ``width_bucket(-2.1, 1.3, 3.4, 3) == 0``.
Returns NULL when ``num_bucket <= 0``, ``min == max``, ``value`` is NaN, or
``min``/``max`` is NaN/Infinite.
"""
return Expression._call_builtin_scalar_fn("width_bucket", value, min, max, num_bucket)


def is_nan(expr: Expression) -> Expression:
"""Checks if values are NaN (a special float value indicating not-a-number).

Expand Down
3 changes: 3 additions & 0 deletions src/daft-functions/src/numeric/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod round;
pub mod sign;
pub mod sqrt;
pub mod trigonometry;
pub mod width_bucket;

use abs::Abs;
use bin::Bin;
Expand All @@ -44,6 +45,7 @@ use power::Power;
use round::Round;
use sign::{Negate, Sign};
use sqrt::Sqrt;
use width_bucket::WidthBucket;

fn to_field_numeric(f: &dyn ScalarUDF, input: &Expr, schema: &Schema) -> DaftResult<Field> {
let field = input.to_field(schema)?;
Expand Down Expand Up @@ -91,6 +93,7 @@ impl FunctionModule for NumericFunctions {
parent.add_fn(Sign);
parent.add_fn(Negate);
parent.add_fn(Sqrt);
parent.add_fn(WidthBucket);

// trig functions
use trigonometry::*;
Expand Down
175 changes: 175 additions & 0 deletions src/daft-functions/src/numeric/width_bucket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
use common_error::{DaftError, DaftResult};
use daft_core::{
datatypes::Int64Array,
prelude::{DataType, Field, Schema},
series::{IntoSeries, Series},
};
use daft_dsl::{
ExprRef,
functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn},
};
use serde::{Deserialize, Serialize};

#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct WidthBucket;

#[derive(FunctionArgs)]
struct WidthBucketArgs<T> {
value: T,
min: T,
max: T,
num_bucket: T,
}

#[typetag::serde]
impl ScalarUDF for WidthBucket {
fn call(
&self,
inputs: FunctionArgs<Series>,
_ctx: &daft_dsl::functions::scalar::EvalContext,
) -> DaftResult<Series> {
let WidthBucketArgs {
value,
min,
max,
num_bucket,
} = inputs.try_into()?;
width_bucket_impl(value, min, max, num_bucket)
}

fn name(&self) -> &'static str {
"width_bucket"
}

fn get_return_field(
&self,
inputs: FunctionArgs<ExprRef>,
schema: &Schema,
) -> DaftResult<Field> {
let WidthBucketArgs {
value,
min,
max,
num_bucket,
} = inputs.try_into()?;
let value_field = value.to_field(schema)?;
let min_field = min.to_field(schema)?;
let max_field = max.to_field(schema)?;
let num_bucket_field = num_bucket.to_field(schema)?;
for (name, field) in [
("value", &value_field),
("min", &min_field),
("max", &max_field),
("num_bucket", &num_bucket_field),
] {
if !field.dtype.is_numeric() {
return Err(DaftError::TypeError(format!(
"Expected `{name}` of width_bucket to be numeric, got {}",
field.dtype
)));
}
}
Ok(Field::new(value_field.name, DataType::Int64))
}

fn docstring(&self) -> &'static str {
"Returns the bucket number for `value` in an equiwidth histogram with \
`num_bucket` buckets in the range [min, max]. Returns 0 below the range and \
num_bucket+1 at or above. Supports descending bounds (min > max). Returns NULL \
if num_bucket <= 0, min == max, value/min/max is NaN, or min/max is infinite."
}
}

fn width_bucket_impl(
value: Series,
min: Series,
max: Series,
num_bucket: Series,
) -> DaftResult<Series> {
let value = value.cast(&DataType::Float64)?;
let min = min.cast(&DataType::Float64)?;
let max = max.cast(&DataType::Float64)?;
let num_bucket = num_bucket.cast(&DataType::Int64)?;
let (value, min, max, num_bucket) = align_lengths(value, min, max, num_bucket)?;

let v_arr = value.f64().unwrap();
let mn_arr = min.f64().unwrap();
let mx_arr = max.f64().unwrap();
let nb_arr = num_bucket.i64().unwrap();

let iter = v_arr
.iter()
.zip(mn_arr.iter())
.zip(mx_arr.iter())
.zip(nb_arr.iter())
.map(|(((v, mn), mx), nb)| match (v, mn, mx, nb) {
(Some(v), Some(mn), Some(mx), Some(nb)) => compute_bucket(v, mn, mx, nb),
_ => None,
});
Ok(Int64Array::from_iter(Field::new(v_arr.name(), DataType::Int64), iter).into_series())
}

fn align_lengths(
a: Series,
b: Series,
c: Series,
d: Series,
) -> DaftResult<(Series, Series, Series, Series)> {
let lens = [a.len(), b.len(), c.len(), d.len()];
let max_len = *lens.iter().max().unwrap();
for &l in &lens {
if l != 1 && l != max_len {
return Err(DaftError::ValueError(format!(
"Cannot apply width_bucket to arrays of different lengths: {} vs {} vs {} vs {}",
lens[0], lens[1], lens[2], lens[3]
)));
}
}
let bcast = |s: Series| -> DaftResult<Series> {
if s.len() == max_len {
Ok(s)
} else {
s.broadcast(max_len)
}
};
Ok((bcast(a)?, bcast(b)?, bcast(c)?, bcast(d)?))
}

/// Mirrors Spark's `WidthBucket` NULL and bucketing semantics for parity.
fn compute_bucket(v: f64, mn: f64, mx: f64, nb: i64) -> Option<i64> {
if nb <= 0
|| nb == i64::MAX
|| v.is_nan()
|| mn == mx
|| mn.is_nan()
|| mn.is_infinite()
|| mx.is_nan()
|| mx.is_infinite()
{
return None;
}
let bucket = if mn < mx {
if v < mn {
0
} else if v >= mx {
nb + 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Integer overflow for nb values just below i64::MAX

The guard nb == i64::MAX does not cover the full risky range. For nb in approximately [i64::MAX − 1023, i64::MAX − 1], the cast nb as f64 rounds up to 2^63 (the nearest representable f64). When v is close to mx, the formula ((nb as f64) * fraction) as i64 saturates to i64::MAX (Rust's saturating f64-to-i64 cast), and the subsequent + 1 wraps to i64::MIN in release mode (or panics in debug). Using saturating_add(1) instead of + 1 for both branch returns makes this safe without changing observed behaviour for all practical inputs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

} else {
((nb as f64) * (v - mn) / (mx - mn)) as i64 + 1
}
} else {
// Descending: roles of below/above flip so a smaller value sorts higher.
if v > mn {
0
} else if v <= mx {
nb + 1
} else {
((nb as f64) * (mn - v) / (mn - mx)) as i64 + 1
}
};
Some(bucket)
}

#[must_use]
pub fn width_bucket(value: ExprRef, min: ExprRef, max: ExprRef, num_bucket: ExprRef) -> ExprRef {
ScalarFn::builtin(WidthBucket, vec![value, min, max, num_bucket]).into()
}
96 changes: 96 additions & 0 deletions tests/recordbatch/numeric/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,3 +1186,99 @@ def test_pmod_bad_input() -> None:
table = MicroPartition.from_pydict({"a": ["x", "y"], "b": [1, 2]})
with pytest.raises(ValueError, match="Expected inputs to pmod to be numeric"):
table.eval_expression_list([pmod(col("a"), col("b"))])


def test_width_bucket_basic() -> None:
from daft.functions import width_bucket

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Inline imports inside test functions

from daft.functions import width_bucket is repeated at the start of each of the nine new test functions. Per the project's style rule, import statements should be at the top of the file rather than inside function bodies. A single top-level import (alongside from daft import col, lit) covers all tests.

Rule Used: Import statements should be placed at the top of t... (source)

Learned From
Eventual-Inc/Daft#5078

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like a valid concern

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, updated!


table = MicroPartition.from_pydict(
{
"v": [5.3, -2.1, 8.1, -0.9],
"mn": [0.2, 1.3, 0.0, 5.2],
"mx": [10.6, 3.4, 5.7, 0.5],
"nb": [5, 3, 4, 2],
}
)
result = table.eval_expression_list([width_bucket(col("v"), col("mn"), col("mx"), col("nb")).alias("bucket")])
assert result.get_column_by_name("bucket").to_pylist() == [3, 0, 5, 3]


def test_width_bucket_descending() -> None:
"""Min > max flips orientation; v > min underflows, v <= max overflows."""
from daft.functions import width_bucket

table = MicroPartition.from_pydict({"v": [11.0, 5.0, 0.0, -1.0]})
result = table.eval_expression_list([width_bucket(col("v"), lit(10.0), lit(0.0), lit(5)).alias("bucket")])
# min=10, max=0, nb=5: v=11 underflows -> 0; v=5 mid -> floor(5*(10-5)/10)+1=3;
# v=0 == max -> nb+1=6; v=-1 < max -> nb+1=6.
assert result.get_column_by_name("bucket").to_pylist() == [0, 3, 6, 6]


def test_width_bucket_integer_inputs() -> None:
from daft.functions import width_bucket

table = MicroPartition.from_pydict({"v": pa.array([3, 5, 7], type=pa.int32())})
result = table.eval_expression_list([width_bucket(col("v"), lit(0), lit(10), lit(5)).alias("bucket")])
assert result.get_column_by_name("bucket").to_pylist() == [2, 3, 4]


def test_width_bucket_below_and_above_range() -> None:
from daft.functions import width_bucket

table = MicroPartition.from_pydict({"v": [-0.1, 0.0, 5.0, 10.0, 10.1]})
result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")])
# v=0 is at min so bucket 1; v=10 hits the v>=max branch so nb+1=6.
assert result.get_column_by_name("bucket").to_pylist() == [0, 1, 3, 6, 6]


def test_width_bucket_null_returning() -> None:
from daft.functions import width_bucket

i64_max = 2**63 - 1
table = MicroPartition.from_pydict(
{
"v": [1.0, 1.0, 1.0, 1.0, math.nan, 1.0, 1.0, 1.0],
"mn": [0.0, 0.0, 0.0, 5.0, 0.0, math.nan, 0.0, 0.0],
"mx": [10.0, 10.0, 10.0, 5.0, 10.0, 10.0, math.inf, 10.0],
"nb": [0, -3, i64_max, 4, 4, 4, 4, 4],
}
)
result = table.eval_expression_list([width_bucket(col("v"), col("mn"), col("mx"), col("nb")).alias("bucket")])
values = result.get_column_by_name("bucket").to_pylist()
# Only the last row (all valid) should be non-NULL.
assert values[:-1] == [None] * 7
assert values[-1] is not None


def test_width_bucket_null_propagation() -> None:
from daft.functions import width_bucket

table = MicroPartition.from_pydict({"v": [1.0, None, 5.0]})
result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")])
assert result.get_column_by_name("bucket").to_pylist() == [1, None, 3]


def test_width_bucket_scalar_broadcast() -> None:
from daft.functions import width_bucket

table = MicroPartition.from_pydict({"v": [-1.0, 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 11.0]})
result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")])
assert result.get_column_by_name("bucket").to_pylist() == [0, 1, 2, 3, 4, 5, 6, 6]


def test_width_bucket_fractional_num_bucket_truncates() -> None:
"""Daft promotes any numeric num_bucket to Int64 by truncation."""
from daft.functions import width_bucket

table = MicroPartition.from_pydict({"v": [3.0]})
result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5.7)).alias("bucket")])
# 5.7 truncates to 5; floor(5 * 3 / 10) + 1 = 2.
assert result.get_column_by_name("bucket").to_pylist() == [2]


def test_width_bucket_bad_input() -> None:
from daft.functions import width_bucket

table = MicroPartition.from_pydict({"v": ["a", "b"]})
with pytest.raises(ValueError, match="Expected `value` of width_bucket to be numeric"):
table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5))])
Loading