Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions daft/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,12 @@
expm1,
between,
bin,
bround,
conv,
greatest,
hex,
least,
unhex,
is_nan,
is_inf,
not_nan,
Expand Down Expand Up @@ -330,6 +335,7 @@
"bitwise_xor",
"bool_and",
"bool_or",
"bround",
"capitalize",
"cast",
"cbrt",
Expand Down Expand Up @@ -422,10 +428,12 @@
"from_utc_timestamp",
"get",
"great_circle_distance",
"greatest",
"guess_mime_type",
"hamming_distance",
"hamming_distance_str",
"hash",
"hex",
"hour",
"hypot",
"ilike",
Expand Down Expand Up @@ -453,6 +461,7 @@
"last_day",
"last_value",
"lead",
"least",
"left",
"length",
"length_bytes",
Expand Down Expand Up @@ -604,6 +613,7 @@
"try_decompress",
"try_deserialize",
"try_encode",
"unhex",
"unix_date",
"unnest",
"upload",
Expand Down
15 changes: 13 additions & 2 deletions daft/functions/columnar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from daft.expressions import Expression, col
from daft.functions.list import to_list
from daft.functions.numeric import greatest, least


def columns_sum(*exprs: Expression | str) -> Expression:
Expand Down Expand Up @@ -105,6 +106,11 @@ def columns_avg(*exprs: Expression | str) -> Expression:
def columns_min(*exprs: Expression | str) -> Expression:
"""Find the minimum value across columns.

This is an alias for :func:`daft.functions.least`. Unlike a list-based
aggregation, this works on any comparable dtype (numeric, boolean, string,
temporal, etc.) and skips NULLs row-wise: the result is NULL only when all
inputs in that row are NULL.

Args:
exprs: The columns to find the minimum of.

Expand All @@ -131,12 +137,17 @@ def columns_min(*exprs: Expression | str) -> Expression:
if not exprs:
raise ValueError("columns_min requires at least one expression")
exprs_list = [col(e) if isinstance(e, str) else e for e in exprs]
return to_list(*exprs_list).list_min().alias("columns_min")
return least(*exprs_list).alias("columns_min")


def columns_max(*exprs: Expression | str) -> Expression:
"""Find the maximum value across columns.

This is an alias for :func:`daft.functions.greatest`. Unlike a list-based
aggregation, this works on any comparable dtype (numeric, boolean, string,
temporal, etc.) and skips NULLs row-wise: the result is NULL only when all
inputs in that row are NULL.

Args:
exprs: The columns to find the maximum of.

Expand All @@ -163,4 +174,4 @@ def columns_max(*exprs: Expression | str) -> Expression:
if not exprs:
raise ValueError("columns_max requires at least one expression")
exprs_list = [col(e) if isinstance(e, str) else e for e in exprs]
return to_list(*exprs_list).list_max().alias("columns_max")
return greatest(*exprs_list).alias("columns_max")
90 changes: 90 additions & 0 deletions daft/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,96 @@ def conv(expr: Expression, from_base: int, to_base: int) -> Expression:
return Expression._call_builtin_scalar_fn("conv", expr, from_base, to_base)


def bround(expr: Expression, decimals: Expression | int = 0) -> Expression:
"""Rounds a numeric expression to ``decimals`` places using HALF_EVEN (banker's) rounding.

Negative ``decimals`` rounds to powers of 10 above the decimal point
(e.g. ``bround(125, -1) == 120`` because 12.5 rounds to even -> 12).

Args:
expr: The expression to round.
decimals: Number of decimal places to round to. Defaults to 0.
"""
return Expression._call_builtin_scalar_fn("bround", expr, decimals)


def hex(expr: Expression) -> Expression:
"""Converts an integer/string/binary expression to its uppercase hexadecimal string.

For integer inputs, negatives are encoded as 64-bit two's complement
(``hex(-1) == 'FFFFFFFFFFFFFFFF'``). For string and binary inputs, returns
the uppercase hex of the underlying bytes (``hex('Spark') == '537061726B'``).
"""
return Expression._call_builtin_scalar_fn("hex", expr)


def unhex(expr: Expression) -> Expression:
r"""Inverse of :func:`hex`: decodes a hexadecimal string into binary bytes.

Odd-length inputs are left-padded with ``'0'`` (``unhex('F') == b'\x0f'``).
Returns NULL when the input contains characters outside ``[0-9a-fA-F]``.
"""
return Expression._call_builtin_scalar_fn("unhex", expr)


def greatest(*exprs: Expression) -> Expression:
"""Returns the largest value among the inputs, skipping NULLs row-wise.

Returns NULL only when all inputs in a row are NULL. Inputs are promoted
to a common supertype before comparison. Requires at least one argument.

Examples:
>>> import daft
>>> from daft.functions import greatest
>>> df = daft.from_pydict({"a": [1, None, 3], "b": [2, 5, 1], "c": [None, 4, 6]})
>>> df = df.with_column("g", greatest(df["a"], df["b"], df["c"]))
>>> df.show()
╭───────┬───────┬───────┬───────╮
│ a ┆ b ┆ c ┆ g │
│ --- ┆ --- ┆ --- ┆ --- │
│ Int64 ┆ Int64 ┆ Int64 ┆ Int64 │
╞═══════╪═══════╪═══════╪═══════╡
│ 1 ┆ 2 ┆ None ┆ 2 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ None ┆ 5 ┆ 4 ┆ 5 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 1 ┆ 6 ┆ 6 │
╰───────┴───────┴───────┴───────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
"""
return Expression._call_builtin_scalar_fn("greatest", *exprs)


def least(*exprs: Expression) -> Expression:
"""Returns the smallest value among the inputs, skipping NULLs row-wise.

Returns NULL only when all inputs in a row are NULL. Inputs are promoted
to a common supertype before comparison. Requires at least one argument.

Examples:
>>> import daft
>>> from daft.functions import least
>>> df = daft.from_pydict({"a": [1, None, 3], "b": [2, 5, 1], "c": [None, 4, 6]})
>>> df = df.with_column("l", least(df["a"], df["b"], df["c"]))
>>> df.show()
╭───────┬───────┬───────┬───────╮
│ a ┆ b ┆ c ┆ l │
│ --- ┆ --- ┆ --- ┆ --- │
│ Int64 ┆ Int64 ┆ Int64 ┆ Int64 │
╞═══════╪═══════╪═══════╪═══════╡
│ 1 ┆ 2 ┆ None ┆ 1 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ None ┆ 5 ┆ 4 ┆ 4 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 1 ┆ 6 ┆ 1 │
╰───────┴───────┴───────┴───────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
"""
return Expression._call_builtin_scalar_fn("least", *exprs)


def is_nan(expr: Expression) -> Expression:
"""Checks if values are NaN (a special float value indicating not-a-number).

Expand Down
Loading
Loading