diff --git a/src/daft-local-execution/src/sinks/grouped_aggregate.rs b/src/daft-local-execution/src/sinks/grouped_aggregate.rs index eda94bd42e7..025c7b4c7a6 100644 --- a/src/daft-local-execution/src/sinks/grouped_aggregate.rs +++ b/src/daft-local-execution/src/sinks/grouped_aggregate.rs @@ -13,7 +13,7 @@ use daft_dsl::expr::{ }; use daft_micropartition::MicroPartition; use itertools::Itertools; -use tracing::{Span, instrument}; +use tracing::{Instrument, Span, instrument}; use super::blocking_sink::{ BlockingSink, BlockingSinkFinalizeResult, BlockingSinkOutput, BlockingSinkSinkResult, @@ -23,72 +23,184 @@ use crate::{ pipeline::{InputId, NodeName}, }; +/// Minimum input rows before the sharded `AggThenPartition` / `PartitionThenAgg` +/// strategies fan a single morsel out across multiple shard tasks. Smaller inputs +/// run the existing single-threaded path so the per-task overhead doesn't dominate. +const SHARD_THRESHOLD: usize = 32_768; + +/// Number of shard tasks spawned per morsel when the input crosses +/// `SHARD_THRESHOLD`. Fixed rather than tied to `max_concurrency` because the +/// framework already runs `max_concurrency` morsels concurrently; fanning out +/// further per morsel would oversubscribe. +const NUM_SHARDS_PER_MORSEL: usize = 4; + #[derive(Clone, Debug)] pub(crate) enum AggStrategy { - // TODO: This would probably benefit from doing sharded aggs. AggThenPartition, PartitionThenAgg(usize), PartitionOnly, } impl AggStrategy { - fn execute_strategy( + async fn execute_strategy( &self, inner_states: &mut [Option], input: MicroPartition, params: &GroupedAggregateParams, ) -> DaftResult<()> { match self { - Self::AggThenPartition => Self::execute_agg_then_partition(inner_states, input, params), + Self::AggThenPartition => { + Self::execute_agg_then_partition(inner_states, input, params).await + } Self::PartitionThenAgg(threshold) => { - Self::execute_partition_then_agg(inner_states, input, params, *threshold) + Self::execute_partition_then_agg(inner_states, input, params, *threshold).await } Self::PartitionOnly => Self::execute_partition_only(inner_states, input, params), } } - fn execute_agg_then_partition( + async fn execute_agg_then_partition( inner_states: &mut [Option], input: MicroPartition, params: &GroupedAggregateParams, ) -> DaftResult<()> { - let agged = input.agg( - params.partial_agg_exprs.as_slice(), - params.group_by.as_slice(), - )?; - let partitioned = - agged.partition_by_hash(params.final_group_by.as_slice(), inner_states.len())?; - for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { - let state = state.get_or_insert_default(); - state.partially_aggregated.push(p); + let num_slots = inner_states.len(); + + // Small inputs: the existing single-threaded path. Shard overhead would + // dominate the K-way fan-out otherwise. + if input.len() < SHARD_THRESHOLD { + let agged = input.agg( + params.partial_agg_exprs.as_slice(), + params.group_by.as_slice(), + )?; + let partitioned = + agged.partition_by_hash(params.final_group_by.as_slice(), num_slots)?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + state.partially_aggregated.push(p); + } + return Ok(()); } + + // Large inputs: row-range slice into K shards, run `agg + partition_by_hash` + // per shard concurrently, then append each shard's slot-i output to slot i. + let total_rows = input.len(); + let shard_size = total_rows.div_ceil(NUM_SHARDS_PER_MORSEL); + let mut tasks = tokio::task::JoinSet::new(); + for shard_idx in 0..NUM_SHARDS_PER_MORSEL { + let start = shard_idx * shard_size; + if start >= total_rows { + break; + } + let end = (start + shard_size).min(total_rows); + let shard = input.slice(start, end)?; + let partial_agg_exprs = params.partial_agg_exprs.clone(); + let group_by = params.group_by.clone(); + let final_group_by = params.final_group_by.clone(); + tasks.spawn( + async move { + let agged = shard.agg(&partial_agg_exprs, &group_by)?; + let partitioned = agged.partition_by_hash(&final_group_by, num_slots)?; + DaftResult::Ok(partitioned) + } + .instrument(Span::current()), + ); + } + + let shard_results: Vec> = tasks + .join_all() + .await + .into_iter() + .collect::>>()?; + + for shard_partitioned in shard_results { + for (p, state) in shard_partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + state.partially_aggregated.push(p); + } + } + Ok(()) } - fn execute_partition_then_agg( + async fn execute_partition_then_agg( inner_states: &mut [Option], input: MicroPartition, params: &GroupedAggregateParams, partial_agg_threshold: usize, ) -> DaftResult<()> { - let partitioned = - input.partition_by_hash(params.group_by.as_slice(), inner_states.len())?; - for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { - let state = state.get_or_insert_default(); - if state.unaggregated_size + p.len() >= partial_agg_threshold { - let mut unaggregated = std::mem::take(&mut state.unaggregated); - unaggregated.push(p); - let aggregated = MicroPartition::concat(unaggregated)?.agg( - params.partial_agg_exprs.as_slice(), - params.group_by.as_slice(), - )?; - state.partially_aggregated.push(aggregated); - state.unaggregated_size = 0; - } else { - state.unaggregated_size += p.len(); - state.unaggregated.push(p); + let num_slots = inner_states.len(); + + // Small inputs: the existing single-threaded path. Shard overhead would + // dominate the K-way fan-out otherwise. + if input.len() < SHARD_THRESHOLD { + let partitioned = input.partition_by_hash(params.group_by.as_slice(), num_slots)?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + if state.unaggregated_size + p.len() >= partial_agg_threshold { + let mut unaggregated = std::mem::take(&mut state.unaggregated); + unaggregated.push(p); + let aggregated = MicroPartition::concat(unaggregated)?.agg( + params.partial_agg_exprs.as_slice(), + params.group_by.as_slice(), + )?; + state.partially_aggregated.push(aggregated); + state.unaggregated_size = 0; + } else { + state.unaggregated_size += p.len(); + state.unaggregated.push(p); + } } + return Ok(()); } + + // Large inputs: row-range slice into K shards, run `partition_by_hash` + // per shard concurrently, then apply the existing flush-on-threshold + // logic in slot order over each shard's slot-i output. The flush logic + // stays sequential because it reads and mutates `state.unaggregated_size`, + // which is per-slot shared state. + let total_rows = input.len(); + let shard_size = total_rows.div_ceil(NUM_SHARDS_PER_MORSEL); + let mut tasks = tokio::task::JoinSet::new(); + for shard_idx in 0..NUM_SHARDS_PER_MORSEL { + let start = shard_idx * shard_size; + if start >= total_rows { + break; + } + let end = (start + shard_size).min(total_rows); + let shard = input.slice(start, end)?; + let group_by = params.group_by.clone(); + tasks.spawn( + async move { shard.partition_by_hash(&group_by, num_slots) } + .instrument(Span::current()), + ); + } + + let shard_results: Vec> = tasks + .join_all() + .await + .into_iter() + .collect::>>()?; + + for shard_partitioned in shard_results { + for (p, state) in shard_partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + if state.unaggregated_size + p.len() >= partial_agg_threshold { + let mut unaggregated = std::mem::take(&mut state.unaggregated); + unaggregated.push(p); + let aggregated = MicroPartition::concat(unaggregated)?.agg( + params.partial_agg_exprs.as_slice(), + params.group_by.as_slice(), + )?; + state.partially_aggregated.push(aggregated); + state.unaggregated_size = 0; + } else { + state.unaggregated_size += p.len(); + state.unaggregated.push(p); + } + } + } + Ok(()) } @@ -140,7 +252,7 @@ impl GroupedAggregateState { } } - fn push( + async fn push( &mut self, input: MicroPartition, params: &GroupedAggregateParams, @@ -158,7 +270,9 @@ impl GroupedAggregateState { // If we have determined a strategy, execute it. if let Some(strategy) = strategy { - strategy.execute_strategy(inner_states, input, params)?; + strategy + .execute_strategy(inner_states, input, params) + .await?; } else { // Otherwise, determine the strategy and execute let decided_strategy = Self::determine_agg_strategy( @@ -169,7 +283,9 @@ impl GroupedAggregateState { strategy, global_strategy_lock, )?; - decided_strategy.execute_strategy(inner_states, input, params)?; + decided_strategy + .execute_strategy(inner_states, input, params) + .await?; } Ok(()) } @@ -336,7 +452,7 @@ impl BlockingSink for GroupedAggregateSink { spawner .spawn( async move { - state.push(input, ¶ms, &strategy_lock)?; + state.push(input, ¶ms, &strategy_lock).await?; Ok(state) }, Span::current(),