Skip to content
188 changes: 152 additions & 36 deletions src/daft-local-execution/src/sinks/grouped_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use daft_dsl::expr::{
};
use daft_micropartition::MicroPartition;
use itertools::Itertools;
use tracing::{Span, instrument};
use tracing::{Instrument, Span, instrument};

use super::blocking_sink::{
BlockingSink, BlockingSinkFinalizeResult, BlockingSinkOutput, BlockingSinkSinkResult,
Expand All @@ -23,72 +23,184 @@ use crate::{
pipeline::{InputId, NodeName},
};

/// Minimum input rows before the sharded `AggThenPartition` / `PartitionThenAgg`
/// strategies fan a single morsel out across multiple shard tasks. Smaller inputs
/// run the existing single-threaded path so the per-task overhead doesn't dominate.
const SHARD_THRESHOLD: usize = 32_768;

/// Number of shard tasks spawned per morsel when the input crosses
/// `SHARD_THRESHOLD`. Fixed rather than tied to `max_concurrency` because the
/// framework already runs `max_concurrency` morsels concurrently; fanning out
/// further per morsel would oversubscribe.
const NUM_SHARDS_PER_MORSEL: usize = 4;

#[derive(Clone, Debug)]
pub(crate) enum AggStrategy {
// TODO: This would probably benefit from doing sharded aggs.
AggThenPartition,
PartitionThenAgg(usize),
PartitionOnly,
}

impl AggStrategy {
fn execute_strategy(
async fn execute_strategy(
&self,
inner_states: &mut [Option<SinglePartitionAggregateState>],
input: MicroPartition,
params: &GroupedAggregateParams,
) -> DaftResult<()> {
match self {
Self::AggThenPartition => Self::execute_agg_then_partition(inner_states, input, params),
Self::AggThenPartition => {
Self::execute_agg_then_partition(inner_states, input, params).await
}
Self::PartitionThenAgg(threshold) => {
Self::execute_partition_then_agg(inner_states, input, params, *threshold)
Self::execute_partition_then_agg(inner_states, input, params, *threshold).await
}
Self::PartitionOnly => Self::execute_partition_only(inner_states, input, params),
}
}

fn execute_agg_then_partition(
async fn execute_agg_then_partition(
inner_states: &mut [Option<SinglePartitionAggregateState>],
input: MicroPartition,
params: &GroupedAggregateParams,
) -> DaftResult<()> {
let agged = input.agg(
params.partial_agg_exprs.as_slice(),
params.group_by.as_slice(),
)?;
let partitioned =
agged.partition_by_hash(params.final_group_by.as_slice(), inner_states.len())?;
for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) {
let state = state.get_or_insert_default();
state.partially_aggregated.push(p);
let num_slots = inner_states.len();

// Small inputs: the existing single-threaded path. Shard overhead would
// dominate the K-way fan-out otherwise.
if input.len() < SHARD_THRESHOLD {
let agged = input.agg(
params.partial_agg_exprs.as_slice(),
params.group_by.as_slice(),
)?;
let partitioned =
agged.partition_by_hash(params.final_group_by.as_slice(), num_slots)?;
for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) {
let state = state.get_or_insert_default();
state.partially_aggregated.push(p);
}
return Ok(());
}

// Large inputs: row-range slice into K shards, run `agg + partition_by_hash`
// per shard concurrently, then append each shard's slot-i output to slot i.
let total_rows = input.len();
let shard_size = total_rows.div_ceil(NUM_SHARDS_PER_MORSEL);
let mut tasks = tokio::task::JoinSet::new();
for shard_idx in 0..NUM_SHARDS_PER_MORSEL {
let start = shard_idx * shard_size;
if start >= total_rows {
break;
}
let end = (start + shard_size).min(total_rows);
let shard = input.slice(start, end)?;
let partial_agg_exprs = params.partial_agg_exprs.clone();
let group_by = params.group_by.clone();
let final_group_by = params.final_group_by.clone();
tasks.spawn(
async move {
let agged = shard.agg(&partial_agg_exprs, &group_by)?;
let partitioned = agged.partition_by_hash(&final_group_by, num_slots)?;
DaftResult::Ok(partitioned)
}
.instrument(Span::current()),
);
}

let shard_results: Vec<Vec<MicroPartition>> = tasks
.join_all()
.await
.into_iter()
.collect::<DaftResult<Vec<_>>>()?;

for shard_partitioned in shard_results {
for (p, state) in shard_partitioned.into_iter().zip(inner_states.iter_mut()) {
let state = state.get_or_insert_default();
state.partially_aggregated.push(p);
}
}

Ok(())
}

fn execute_partition_then_agg(
async fn execute_partition_then_agg(
inner_states: &mut [Option<SinglePartitionAggregateState>],
input: MicroPartition,
params: &GroupedAggregateParams,
partial_agg_threshold: usize,
) -> DaftResult<()> {
let partitioned =
input.partition_by_hash(params.group_by.as_slice(), inner_states.len())?;
for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) {
let state = state.get_or_insert_default();
if state.unaggregated_size + p.len() >= partial_agg_threshold {
let mut unaggregated = std::mem::take(&mut state.unaggregated);
unaggregated.push(p);
let aggregated = MicroPartition::concat(unaggregated)?.agg(
params.partial_agg_exprs.as_slice(),
params.group_by.as_slice(),
)?;
state.partially_aggregated.push(aggregated);
state.unaggregated_size = 0;
} else {
state.unaggregated_size += p.len();
state.unaggregated.push(p);
let num_slots = inner_states.len();

// Small inputs: the existing single-threaded path. Shard overhead would
// dominate the K-way fan-out otherwise.
if input.len() < SHARD_THRESHOLD {
let partitioned = input.partition_by_hash(params.group_by.as_slice(), num_slots)?;
for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) {
let state = state.get_or_insert_default();
if state.unaggregated_size + p.len() >= partial_agg_threshold {
let mut unaggregated = std::mem::take(&mut state.unaggregated);
unaggregated.push(p);
let aggregated = MicroPartition::concat(unaggregated)?.agg(
params.partial_agg_exprs.as_slice(),
params.group_by.as_slice(),
)?;
state.partially_aggregated.push(aggregated);
state.unaggregated_size = 0;
} else {
state.unaggregated_size += p.len();
state.unaggregated.push(p);
}
}
return Ok(());
}

// Large inputs: row-range slice into K shards, run `partition_by_hash`
// per shard concurrently, then apply the existing flush-on-threshold
// logic in slot order over each shard's slot-i output. The flush logic
// stays sequential because it reads and mutates `state.unaggregated_size`,
// which is per-slot shared state.
let total_rows = input.len();
let shard_size = total_rows.div_ceil(NUM_SHARDS_PER_MORSEL);
let mut tasks = tokio::task::JoinSet::new();
for shard_idx in 0..NUM_SHARDS_PER_MORSEL {
let start = shard_idx * shard_size;
if start >= total_rows {
break;
}
let end = (start + shard_size).min(total_rows);
let shard = input.slice(start, end)?;
let group_by = params.group_by.clone();
tasks.spawn(
async move { shard.partition_by_hash(&group_by, num_slots) }
.instrument(Span::current()),
);
Comment on lines +173 to +176

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 CPU-bound work (partition_by_hash, agg) is spawned with tokio::task::JoinSet::spawn, which places tasks on tokio's cooperative async thread pool. Because these closures contain no .await points they run to completion without yielding, potentially blocking tokio workers and starving concurrent I/O tasks. For genuinely CPU-bound work, tokio::task::spawn_blocking (which dispatches to a dedicated blocking pool) is the idiomatic choice. The same pattern exists in the AggThenPartition sharding above.

Suggested change
tasks.spawn(
async move { shard.partition_by_hash(&group_by, num_slots) }
.instrument(Span::current()),
);
let span = Span::current();
tasks.spawn(tokio::task::spawn_blocking(move || {
let _enter = span.enter();
shard.partition_by_hash(&group_by, num_slots)
}).map(|r| r.expect("shard task panicked")));

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matches the existing pattern in finalize() in the same file, which also dispatches CPU-bound concat+agg work via tokio::task::JoinSet::spawn rather than spawn_blocking. If this is worth changing it would be a sink-wide refactor across all spawn sites, not just this PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. The finalize() method already uses tokio::task::JoinSet::spawn for CPU-bound concat + agg + eval_expression_list work with no .await points — the same pattern this PR follows. Changing to spawn_blocking would be a sink-wide refactor touching finalize() and all the sharding sites, not something scoped to this PR. Fair point, acknowledged.

}

let shard_results: Vec<Vec<MicroPartition>> = tasks
.join_all()
.await
.into_iter()
.collect::<DaftResult<Vec<_>>>()?;
Comment on lines +179 to +183

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 JoinSet::join_all returns results in completion order, not spawn order

JoinSet::join_all() is implemented via repeated join_next(), which yields tasks in the order they finish, not the order they were spawned. For PartitionThenAgg this means the flush-on-threshold logic in the merge loop runs over shards in a non-deterministic order. The PR author correctly argues correctness is preserved because partial-agg combinators are commutative, but the non-deterministic flush ordering makes memory-usage patterns harder to reason about and can make test assertions around intermediate state order-sensitive. Consider using an ordered collection (e.g. collecting into a Vec indexed by shard_idx) so processing order is deterministic.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping join_all. Correctness is preserved (commutative combinators as noted) and forcing spawn-order via indexed collection adds code without changing observable output. finalize() in the same file uses the same join_all pattern.


for shard_partitioned in shard_results {
for (p, state) in shard_partitioned.into_iter().zip(inner_states.iter_mut()) {
let state = state.get_or_insert_default();
if state.unaggregated_size + p.len() >= partial_agg_threshold {
let mut unaggregated = std::mem::take(&mut state.unaggregated);
unaggregated.push(p);
let aggregated = MicroPartition::concat(unaggregated)?.agg(
params.partial_agg_exprs.as_slice(),
params.group_by.as_slice(),
)?;
state.partially_aggregated.push(aggregated);
state.unaggregated_size = 0;
} else {
state.unaggregated_size += p.len();
state.unaggregated.push(p);
}
}
}

Ok(())
}

Expand Down Expand Up @@ -140,7 +252,7 @@ impl GroupedAggregateState {
}
}

fn push(
async fn push(
&mut self,
input: MicroPartition,
params: &GroupedAggregateParams,
Expand All @@ -158,7 +270,9 @@ impl GroupedAggregateState {

// If we have determined a strategy, execute it.
if let Some(strategy) = strategy {
strategy.execute_strategy(inner_states, input, params)?;
strategy
.execute_strategy(inner_states, input, params)
.await?;
} else {
// Otherwise, determine the strategy and execute
let decided_strategy = Self::determine_agg_strategy(
Expand All @@ -169,7 +283,9 @@ impl GroupedAggregateState {
strategy,
global_strategy_lock,
)?;
decided_strategy.execute_strategy(inner_states, input, params)?;
decided_strategy
.execute_strategy(inner_states, input, params)
.await?;
}
Ok(())
}
Expand Down Expand Up @@ -336,7 +452,7 @@ impl BlockingSink for GroupedAggregateSink {
spawner
.spawn(
async move {
state.push(input, &params, &strategy_lock)?;
state.push(input, &params, &strategy_lock).await?;
Ok(state)
},
Span::current(),
Expand Down
Loading