-
Notifications
You must be signed in to change notification settings - Fork 496
feat: add width_bucket function for PySpark parity #7146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
48b13ee
e3b0af4
60909b7
b5b3088
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| use common_error::{DaftError, DaftResult}; | ||
| use daft_core::{ | ||
| datatypes::Int64Array, | ||
| prelude::{DataType, Field, Schema}, | ||
| series::{IntoSeries, Series}, | ||
| }; | ||
| use daft_dsl::{ | ||
| ExprRef, | ||
| functions::{FunctionArgs, ScalarUDF, scalar::ScalarFn}, | ||
| }; | ||
| use serde::{Deserialize, Serialize}; | ||
|
|
||
| #[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] | ||
| pub struct WidthBucket; | ||
|
|
||
| #[derive(FunctionArgs)] | ||
| struct WidthBucketArgs<T> { | ||
| value: T, | ||
| min: T, | ||
| max: T, | ||
| num_bucket: T, | ||
| } | ||
|
|
||
| #[typetag::serde] | ||
| impl ScalarUDF for WidthBucket { | ||
| fn call( | ||
| &self, | ||
| inputs: FunctionArgs<Series>, | ||
| _ctx: &daft_dsl::functions::scalar::EvalContext, | ||
| ) -> DaftResult<Series> { | ||
| let WidthBucketArgs { | ||
| value, | ||
| min, | ||
| max, | ||
| num_bucket, | ||
| } = inputs.try_into()?; | ||
| width_bucket_impl(value, min, max, num_bucket) | ||
| } | ||
|
|
||
| fn name(&self) -> &'static str { | ||
| "width_bucket" | ||
| } | ||
|
|
||
| fn get_return_field( | ||
| &self, | ||
| inputs: FunctionArgs<ExprRef>, | ||
| schema: &Schema, | ||
| ) -> DaftResult<Field> { | ||
| let WidthBucketArgs { | ||
| value, | ||
| min, | ||
| max, | ||
| num_bucket, | ||
| } = inputs.try_into()?; | ||
| let value_field = value.to_field(schema)?; | ||
| let min_field = min.to_field(schema)?; | ||
| let max_field = max.to_field(schema)?; | ||
| let num_bucket_field = num_bucket.to_field(schema)?; | ||
| for (name, field) in [ | ||
| ("value", &value_field), | ||
| ("min", &min_field), | ||
| ("max", &max_field), | ||
| ("num_bucket", &num_bucket_field), | ||
| ] { | ||
| if !field.dtype.is_numeric() { | ||
| return Err(DaftError::TypeError(format!( | ||
| "Expected `{name}` of width_bucket to be numeric, got {}", | ||
| field.dtype | ||
| ))); | ||
| } | ||
| } | ||
| Ok(Field::new(value_field.name, DataType::Int64)) | ||
| } | ||
|
|
||
| fn docstring(&self) -> &'static str { | ||
| "Returns the bucket number for `value` in an equiwidth histogram with \ | ||
| `num_bucket` buckets in the range [min, max]. Returns 0 below the range and \ | ||
| num_bucket+1 at or above. Supports descending bounds (min > max). Returns NULL \ | ||
| if num_bucket <= 0, min == max, value/min/max is NaN, or min/max is infinite." | ||
| } | ||
| } | ||
|
|
||
| fn width_bucket_impl( | ||
| value: Series, | ||
| min: Series, | ||
| max: Series, | ||
| num_bucket: Series, | ||
| ) -> DaftResult<Series> { | ||
| let value = value.cast(&DataType::Float64)?; | ||
| let min = min.cast(&DataType::Float64)?; | ||
| let max = max.cast(&DataType::Float64)?; | ||
| let num_bucket = num_bucket.cast(&DataType::Int64)?; | ||
| let (value, min, max, num_bucket) = align_lengths(value, min, max, num_bucket)?; | ||
|
|
||
| let v_arr = value.f64().unwrap(); | ||
| let mn_arr = min.f64().unwrap(); | ||
| let mx_arr = max.f64().unwrap(); | ||
| let nb_arr = num_bucket.i64().unwrap(); | ||
|
|
||
| let iter = v_arr | ||
| .iter() | ||
| .zip(mn_arr.iter()) | ||
| .zip(mx_arr.iter()) | ||
| .zip(nb_arr.iter()) | ||
| .map(|(((v, mn), mx), nb)| match (v, mn, mx, nb) { | ||
| (Some(v), Some(mn), Some(mx), Some(nb)) => compute_bucket(v, mn, mx, nb), | ||
| _ => None, | ||
| }); | ||
| Ok(Int64Array::from_iter(Field::new(v_arr.name(), DataType::Int64), iter).into_series()) | ||
| } | ||
|
|
||
| fn align_lengths( | ||
| a: Series, | ||
| b: Series, | ||
| c: Series, | ||
| d: Series, | ||
| ) -> DaftResult<(Series, Series, Series, Series)> { | ||
| let lens = [a.len(), b.len(), c.len(), d.len()]; | ||
| let max_len = *lens.iter().max().unwrap(); | ||
| for &l in &lens { | ||
| if l != 1 && l != max_len { | ||
| return Err(DaftError::ValueError(format!( | ||
| "Cannot apply width_bucket to arrays of different lengths: {} vs {} vs {} vs {}", | ||
| lens[0], lens[1], lens[2], lens[3] | ||
| ))); | ||
| } | ||
| } | ||
| let bcast = |s: Series| -> DaftResult<Series> { | ||
| if s.len() == max_len { | ||
| Ok(s) | ||
| } else { | ||
| s.broadcast(max_len) | ||
| } | ||
| }; | ||
| Ok((bcast(a)?, bcast(b)?, bcast(c)?, bcast(d)?)) | ||
| } | ||
|
|
||
| /// Mirrors Spark's `WidthBucket` NULL and bucketing semantics for parity. | ||
| fn compute_bucket(v: f64, mn: f64, mx: f64, nb: i64) -> Option<i64> { | ||
| if nb <= 0 | ||
| || nb == i64::MAX | ||
| || v.is_nan() | ||
| || mn == mx | ||
| || mn.is_nan() | ||
| || mn.is_infinite() | ||
| || mx.is_nan() | ||
| || mx.is_infinite() | ||
| { | ||
| return None; | ||
| } | ||
| let bucket = if mn < mx { | ||
| if v < mn { | ||
| 0 | ||
| } else if v >= mx { | ||
| nb + 1 | ||
| } else { | ||
| ((nb as f64) * (v - mn) / (mx - mn)) as i64 + 1 | ||
| } | ||
| } else { | ||
| // Descending: roles of below/above flip so a smaller value sorts higher. | ||
| if v > mn { | ||
| 0 | ||
| } else if v <= mx { | ||
| nb + 1 | ||
| } else { | ||
| ((nb as f64) * (mn - v) / (mn - mx)) as i64 + 1 | ||
| } | ||
| }; | ||
| Some(bucket) | ||
| } | ||
|
|
||
| #[must_use] | ||
| pub fn width_bucket(value: ExprRef, min: ExprRef, max: ExprRef, num_bucket: ExprRef) -> ExprRef { | ||
| ScalarFn::builtin(WidthBucket, vec![value, min, max, num_bucket]).into() | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1186,3 +1186,99 @@ def test_pmod_bad_input() -> None: | |
| table = MicroPartition.from_pydict({"a": ["x", "y"], "b": [1, 2]}) | ||
| with pytest.raises(ValueError, match="Expected inputs to pmod to be numeric"): | ||
| table.eval_expression_list([pmod(col("a"), col("b"))]) | ||
|
|
||
|
|
||
| def test_width_bucket_basic() -> None: | ||
| from daft.functions import width_bucket | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Rule Used: Import statements should be placed at the top of t... (source) Learned From Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems like a valid concern
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, updated! |
||
|
|
||
| table = MicroPartition.from_pydict( | ||
| { | ||
| "v": [5.3, -2.1, 8.1, -0.9], | ||
| "mn": [0.2, 1.3, 0.0, 5.2], | ||
| "mx": [10.6, 3.4, 5.7, 0.5], | ||
| "nb": [5, 3, 4, 2], | ||
| } | ||
| ) | ||
| result = table.eval_expression_list([width_bucket(col("v"), col("mn"), col("mx"), col("nb")).alias("bucket")]) | ||
| assert result.get_column_by_name("bucket").to_pylist() == [3, 0, 5, 3] | ||
|
|
||
|
|
||
| def test_width_bucket_descending() -> None: | ||
| """Min > max flips orientation; v > min underflows, v <= max overflows.""" | ||
| from daft.functions import width_bucket | ||
|
|
||
| table = MicroPartition.from_pydict({"v": [11.0, 5.0, 0.0, -1.0]}) | ||
| result = table.eval_expression_list([width_bucket(col("v"), lit(10.0), lit(0.0), lit(5)).alias("bucket")]) | ||
| # min=10, max=0, nb=5: v=11 underflows -> 0; v=5 mid -> floor(5*(10-5)/10)+1=3; | ||
| # v=0 == max -> nb+1=6; v=-1 < max -> nb+1=6. | ||
| assert result.get_column_by_name("bucket").to_pylist() == [0, 3, 6, 6] | ||
|
|
||
|
|
||
| def test_width_bucket_integer_inputs() -> None: | ||
| from daft.functions import width_bucket | ||
|
|
||
| table = MicroPartition.from_pydict({"v": pa.array([3, 5, 7], type=pa.int32())}) | ||
| result = table.eval_expression_list([width_bucket(col("v"), lit(0), lit(10), lit(5)).alias("bucket")]) | ||
| assert result.get_column_by_name("bucket").to_pylist() == [2, 3, 4] | ||
|
|
||
|
|
||
| def test_width_bucket_below_and_above_range() -> None: | ||
| from daft.functions import width_bucket | ||
|
|
||
| table = MicroPartition.from_pydict({"v": [-0.1, 0.0, 5.0, 10.0, 10.1]}) | ||
| result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")]) | ||
| # v=0 is at min so bucket 1; v=10 hits the v>=max branch so nb+1=6. | ||
| assert result.get_column_by_name("bucket").to_pylist() == [0, 1, 3, 6, 6] | ||
|
|
||
|
|
||
| def test_width_bucket_null_returning() -> None: | ||
| from daft.functions import width_bucket | ||
|
|
||
| i64_max = 2**63 - 1 | ||
| table = MicroPartition.from_pydict( | ||
| { | ||
| "v": [1.0, 1.0, 1.0, 1.0, math.nan, 1.0, 1.0, 1.0], | ||
| "mn": [0.0, 0.0, 0.0, 5.0, 0.0, math.nan, 0.0, 0.0], | ||
| "mx": [10.0, 10.0, 10.0, 5.0, 10.0, 10.0, math.inf, 10.0], | ||
| "nb": [0, -3, i64_max, 4, 4, 4, 4, 4], | ||
| } | ||
| ) | ||
| result = table.eval_expression_list([width_bucket(col("v"), col("mn"), col("mx"), col("nb")).alias("bucket")]) | ||
| values = result.get_column_by_name("bucket").to_pylist() | ||
| # Only the last row (all valid) should be non-NULL. | ||
| assert values[:-1] == [None] * 7 | ||
| assert values[-1] is not None | ||
|
|
||
|
|
||
| def test_width_bucket_null_propagation() -> None: | ||
| from daft.functions import width_bucket | ||
|
|
||
| table = MicroPartition.from_pydict({"v": [1.0, None, 5.0]}) | ||
| result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")]) | ||
| assert result.get_column_by_name("bucket").to_pylist() == [1, None, 3] | ||
|
|
||
|
|
||
| def test_width_bucket_scalar_broadcast() -> None: | ||
| from daft.functions import width_bucket | ||
|
|
||
| table = MicroPartition.from_pydict({"v": [-1.0, 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 11.0]}) | ||
| result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5)).alias("bucket")]) | ||
| assert result.get_column_by_name("bucket").to_pylist() == [0, 1, 2, 3, 4, 5, 6, 6] | ||
|
|
||
|
|
||
| def test_width_bucket_fractional_num_bucket_truncates() -> None: | ||
| """Daft promotes any numeric num_bucket to Int64 by truncation.""" | ||
| from daft.functions import width_bucket | ||
|
|
||
| table = MicroPartition.from_pydict({"v": [3.0]}) | ||
| result = table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5.7)).alias("bucket")]) | ||
| # 5.7 truncates to 5; floor(5 * 3 / 10) + 1 = 2. | ||
| assert result.get_column_by_name("bucket").to_pylist() == [2] | ||
|
|
||
|
|
||
| def test_width_bucket_bad_input() -> None: | ||
| from daft.functions import width_bucket | ||
|
|
||
| table = MicroPartition.from_pydict({"v": ["a", "b"]}) | ||
| with pytest.raises(ValueError, match="Expected `value` of width_bucket to be numeric"): | ||
| table.eval_expression_list([width_bucket(col("v"), lit(0.0), lit(10.0), lit(5))]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nbvalues just belowi64::MAXThe guard
nb == i64::MAXdoes not cover the full risky range. Fornbin approximately[i64::MAX − 1023, i64::MAX − 1], the castnb as f64rounds up to2^63(the nearest representable f64). Whenvis close tomx, the formula((nb as f64) * fraction) as i64saturates toi64::MAX(Rust's saturating f64-to-i64 cast), and the subsequent+ 1wraps toi64::MINin release mode (or panics in debug). Usingsaturating_add(1)instead of+ 1for both branch returns makes this safe without changing observed behaviour for all practical inputs.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated