-
Notifications
You must be signed in to change notification settings - Fork 498
feat(grouped-agg): shard PartitionThenAgg execution per morsel #7079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b926b9c
ade6c28
59a03b2
f8a9185
cfbc890
77e9854
3ad1f42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,7 @@ use daft_dsl::expr::{ | |
| }; | ||
| use daft_micropartition::MicroPartition; | ||
| use itertools::Itertools; | ||
| use tracing::{Span, instrument}; | ||
| use tracing::{Instrument, Span, instrument}; | ||
|
|
||
| use super::blocking_sink::{ | ||
| BlockingSink, BlockingSinkFinalizeResult, BlockingSinkOutput, BlockingSinkSinkResult, | ||
|
|
@@ -23,72 +23,184 @@ use crate::{ | |
| pipeline::{InputId, NodeName}, | ||
| }; | ||
|
|
||
| /// Minimum input rows before the sharded `AggThenPartition` / `PartitionThenAgg` | ||
| /// strategies fan a single morsel out across multiple shard tasks. Smaller inputs | ||
| /// run the existing single-threaded path so the per-task overhead doesn't dominate. | ||
| const SHARD_THRESHOLD: usize = 32_768; | ||
|
|
||
| /// Number of shard tasks spawned per morsel when the input crosses | ||
| /// `SHARD_THRESHOLD`. Fixed rather than tied to `max_concurrency` because the | ||
| /// framework already runs `max_concurrency` morsels concurrently; fanning out | ||
| /// further per morsel would oversubscribe. | ||
| const NUM_SHARDS_PER_MORSEL: usize = 4; | ||
|
|
||
| #[derive(Clone, Debug)] | ||
| pub(crate) enum AggStrategy { | ||
| // TODO: This would probably benefit from doing sharded aggs. | ||
| AggThenPartition, | ||
| PartitionThenAgg(usize), | ||
| PartitionOnly, | ||
| } | ||
|
|
||
| impl AggStrategy { | ||
| fn execute_strategy( | ||
| async fn execute_strategy( | ||
| &self, | ||
| inner_states: &mut [Option<SinglePartitionAggregateState>], | ||
| input: MicroPartition, | ||
| params: &GroupedAggregateParams, | ||
| ) -> DaftResult<()> { | ||
| match self { | ||
| Self::AggThenPartition => Self::execute_agg_then_partition(inner_states, input, params), | ||
| Self::AggThenPartition => { | ||
| Self::execute_agg_then_partition(inner_states, input, params).await | ||
| } | ||
| Self::PartitionThenAgg(threshold) => { | ||
| Self::execute_partition_then_agg(inner_states, input, params, *threshold) | ||
| Self::execute_partition_then_agg(inner_states, input, params, *threshold).await | ||
| } | ||
| Self::PartitionOnly => Self::execute_partition_only(inner_states, input, params), | ||
| } | ||
| } | ||
|
|
||
| fn execute_agg_then_partition( | ||
| async fn execute_agg_then_partition( | ||
| inner_states: &mut [Option<SinglePartitionAggregateState>], | ||
| input: MicroPartition, | ||
| params: &GroupedAggregateParams, | ||
| ) -> DaftResult<()> { | ||
| let agged = input.agg( | ||
| params.partial_agg_exprs.as_slice(), | ||
| params.group_by.as_slice(), | ||
| )?; | ||
| let partitioned = | ||
| agged.partition_by_hash(params.final_group_by.as_slice(), inner_states.len())?; | ||
| for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { | ||
| let state = state.get_or_insert_default(); | ||
| state.partially_aggregated.push(p); | ||
| let num_slots = inner_states.len(); | ||
|
|
||
| // Small inputs: the existing single-threaded path. Shard overhead would | ||
| // dominate the K-way fan-out otherwise. | ||
| if input.len() < SHARD_THRESHOLD { | ||
| let agged = input.agg( | ||
| params.partial_agg_exprs.as_slice(), | ||
| params.group_by.as_slice(), | ||
| )?; | ||
| let partitioned = | ||
| agged.partition_by_hash(params.final_group_by.as_slice(), num_slots)?; | ||
| for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { | ||
| let state = state.get_or_insert_default(); | ||
| state.partially_aggregated.push(p); | ||
| } | ||
| return Ok(()); | ||
| } | ||
|
|
||
| // Large inputs: row-range slice into K shards, run `agg + partition_by_hash` | ||
| // per shard concurrently, then append each shard's slot-i output to slot i. | ||
| let total_rows = input.len(); | ||
| let shard_size = total_rows.div_ceil(NUM_SHARDS_PER_MORSEL); | ||
| let mut tasks = tokio::task::JoinSet::new(); | ||
| for shard_idx in 0..NUM_SHARDS_PER_MORSEL { | ||
| let start = shard_idx * shard_size; | ||
| if start >= total_rows { | ||
| break; | ||
| } | ||
| let end = (start + shard_size).min(total_rows); | ||
| let shard = input.slice(start, end)?; | ||
| let partial_agg_exprs = params.partial_agg_exprs.clone(); | ||
| let group_by = params.group_by.clone(); | ||
| let final_group_by = params.final_group_by.clone(); | ||
| tasks.spawn( | ||
| async move { | ||
| let agged = shard.agg(&partial_agg_exprs, &group_by)?; | ||
| let partitioned = agged.partition_by_hash(&final_group_by, num_slots)?; | ||
| DaftResult::Ok(partitioned) | ||
| } | ||
| .instrument(Span::current()), | ||
| ); | ||
| } | ||
|
|
||
| let shard_results: Vec<Vec<MicroPartition>> = tasks | ||
| .join_all() | ||
| .await | ||
| .into_iter() | ||
| .collect::<DaftResult<Vec<_>>>()?; | ||
|
|
||
| for shard_partitioned in shard_results { | ||
| for (p, state) in shard_partitioned.into_iter().zip(inner_states.iter_mut()) { | ||
| let state = state.get_or_insert_default(); | ||
| state.partially_aggregated.push(p); | ||
| } | ||
| } | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| fn execute_partition_then_agg( | ||
| async fn execute_partition_then_agg( | ||
| inner_states: &mut [Option<SinglePartitionAggregateState>], | ||
| input: MicroPartition, | ||
| params: &GroupedAggregateParams, | ||
| partial_agg_threshold: usize, | ||
| ) -> DaftResult<()> { | ||
| let partitioned = | ||
| input.partition_by_hash(params.group_by.as_slice(), inner_states.len())?; | ||
| for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { | ||
| let state = state.get_or_insert_default(); | ||
| if state.unaggregated_size + p.len() >= partial_agg_threshold { | ||
| let mut unaggregated = std::mem::take(&mut state.unaggregated); | ||
| unaggregated.push(p); | ||
| let aggregated = MicroPartition::concat(unaggregated)?.agg( | ||
| params.partial_agg_exprs.as_slice(), | ||
| params.group_by.as_slice(), | ||
| )?; | ||
| state.partially_aggregated.push(aggregated); | ||
| state.unaggregated_size = 0; | ||
| } else { | ||
| state.unaggregated_size += p.len(); | ||
| state.unaggregated.push(p); | ||
| let num_slots = inner_states.len(); | ||
|
|
||
| // Small inputs: the existing single-threaded path. Shard overhead would | ||
| // dominate the K-way fan-out otherwise. | ||
| if input.len() < SHARD_THRESHOLD { | ||
| let partitioned = input.partition_by_hash(params.group_by.as_slice(), num_slots)?; | ||
| for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { | ||
| let state = state.get_or_insert_default(); | ||
| if state.unaggregated_size + p.len() >= partial_agg_threshold { | ||
| let mut unaggregated = std::mem::take(&mut state.unaggregated); | ||
| unaggregated.push(p); | ||
| let aggregated = MicroPartition::concat(unaggregated)?.agg( | ||
| params.partial_agg_exprs.as_slice(), | ||
| params.group_by.as_slice(), | ||
| )?; | ||
| state.partially_aggregated.push(aggregated); | ||
| state.unaggregated_size = 0; | ||
| } else { | ||
| state.unaggregated_size += p.len(); | ||
| state.unaggregated.push(p); | ||
| } | ||
| } | ||
| return Ok(()); | ||
| } | ||
|
|
||
| // Large inputs: row-range slice into K shards, run `partition_by_hash` | ||
| // per shard concurrently, then apply the existing flush-on-threshold | ||
| // logic in slot order over each shard's slot-i output. The flush logic | ||
| // stays sequential because it reads and mutates `state.unaggregated_size`, | ||
| // which is per-slot shared state. | ||
| let total_rows = input.len(); | ||
| let shard_size = total_rows.div_ceil(NUM_SHARDS_PER_MORSEL); | ||
| let mut tasks = tokio::task::JoinSet::new(); | ||
| for shard_idx in 0..NUM_SHARDS_PER_MORSEL { | ||
| let start = shard_idx * shard_size; | ||
| if start >= total_rows { | ||
| break; | ||
| } | ||
| let end = (start + shard_size).min(total_rows); | ||
| let shard = input.slice(start, end)?; | ||
| let group_by = params.group_by.clone(); | ||
| tasks.spawn( | ||
| async move { shard.partition_by_hash(&group_by, num_slots) } | ||
| .instrument(Span::current()), | ||
| ); | ||
| } | ||
|
|
||
| let shard_results: Vec<Vec<MicroPartition>> = tasks | ||
| .join_all() | ||
| .await | ||
| .into_iter() | ||
| .collect::<DaftResult<Vec<_>>>()?; | ||
|
Comment on lines
+179
to
+183
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keeping join_all. Correctness is preserved (commutative combinators as noted) and forcing spawn-order via indexed collection adds code without changing observable output. finalize() in the same file uses the same join_all pattern. |
||
|
|
||
| for shard_partitioned in shard_results { | ||
| for (p, state) in shard_partitioned.into_iter().zip(inner_states.iter_mut()) { | ||
| let state = state.get_or_insert_default(); | ||
| if state.unaggregated_size + p.len() >= partial_agg_threshold { | ||
| let mut unaggregated = std::mem::take(&mut state.unaggregated); | ||
| unaggregated.push(p); | ||
| let aggregated = MicroPartition::concat(unaggregated)?.agg( | ||
| params.partial_agg_exprs.as_slice(), | ||
| params.group_by.as_slice(), | ||
| )?; | ||
| state.partially_aggregated.push(aggregated); | ||
| state.unaggregated_size = 0; | ||
| } else { | ||
| state.unaggregated_size += p.len(); | ||
| state.unaggregated.push(p); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
|
|
@@ -140,7 +252,7 @@ impl GroupedAggregateState { | |
| } | ||
| } | ||
|
|
||
| fn push( | ||
| async fn push( | ||
| &mut self, | ||
| input: MicroPartition, | ||
| params: &GroupedAggregateParams, | ||
|
|
@@ -158,7 +270,9 @@ impl GroupedAggregateState { | |
|
|
||
| // If we have determined a strategy, execute it. | ||
| if let Some(strategy) = strategy { | ||
| strategy.execute_strategy(inner_states, input, params)?; | ||
| strategy | ||
| .execute_strategy(inner_states, input, params) | ||
| .await?; | ||
| } else { | ||
| // Otherwise, determine the strategy and execute | ||
| let decided_strategy = Self::determine_agg_strategy( | ||
|
|
@@ -169,7 +283,9 @@ impl GroupedAggregateState { | |
| strategy, | ||
| global_strategy_lock, | ||
| )?; | ||
| decided_strategy.execute_strategy(inner_states, input, params)?; | ||
| decided_strategy | ||
| .execute_strategy(inner_states, input, params) | ||
| .await?; | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
@@ -336,7 +452,7 @@ impl BlockingSink for GroupedAggregateSink { | |
| spawner | ||
| .spawn( | ||
| async move { | ||
| state.push(input, ¶ms, &strategy_lock)?; | ||
| state.push(input, ¶ms, &strategy_lock).await?; | ||
| Ok(state) | ||
| }, | ||
| Span::current(), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
partition_by_hash,agg) is spawned withtokio::task::JoinSet::spawn, which places tasks on tokio's cooperative async thread pool. Because these closures contain no.awaitpoints they run to completion without yielding, potentially blocking tokio workers and starving concurrent I/O tasks. For genuinely CPU-bound work,tokio::task::spawn_blocking(which dispatches to a dedicated blocking pool) is the idiomatic choice. The same pattern exists in theAggThenPartitionsharding above.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Matches the existing pattern in finalize() in the same file, which also dispatches CPU-bound concat+agg work via tokio::task::JoinSet::spawn rather than spawn_blocking. If this is worth changing it would be a sink-wide refactor across all spawn sites, not just this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. The
finalize()method already usestokio::task::JoinSet::spawnfor CPU-boundconcat+agg+eval_expression_listwork with no.awaitpoints — the same pattern this PR follows. Changing tospawn_blockingwould be a sink-wide refactor touchingfinalize()and all the sharding sites, not something scoped to this PR. Fair point, acknowledged.