feat(grouped-agg): shard PartitionThenAgg execution per morsel#7079
feat(grouped-agg): shard PartitionThenAgg execution per morsel#7079BABTUNA wants to merge 7 commits into
Conversation
Greptile SummaryThis PR fans large morsels (≥ 32 768 rows) across 4 concurrent
Confidence Score: 4/5The change is functionally correct — partial-agg combinators are commutative so merge order doesn't affect results — but CPU-bound work on tokio's async pool and non-deterministic shard ordering from Both strategies produce correct aggregation results for all input sizes. The main concerns are that CPU-bound src/daft-local-execution/src/sinks/grouped_aggregate.rs — specifically the Important Files Changed
Reviews (1): Last reviewed commit: "feat(grouped-agg): shard PartitionThenAg..." | Re-trigger Greptile |
| /// Minimum input rows before the `AggThenPartition` strategy fans a single morsel | ||
| /// out across multiple shard tasks. Smaller inputs run the existing single-threaded | ||
| /// path so the per-task overhead doesn't dominate. | ||
| const SHARD_THRESHOLD: usize = 32_768; | ||
|
|
||
| /// Number of shard tasks spawned per morsel when the input crosses | ||
| /// `SHARD_THRESHOLD`. Fixed rather than tied to `max_concurrency` because the | ||
| /// framework already runs `max_concurrency` morsels concurrently; fanning out | ||
| /// further per morsel would oversubscribe. | ||
| const NUM_SHARDS_PER_MORSEL: usize = 4; |
There was a problem hiding this comment.
Stale doc comment on
SHARD_THRESHOLD
The constant's doc comment says "before the AggThenPartition strategy fans…" but after this PR the same constant gates sharding in PartitionThenAgg as well. A reader scanning just the constant will think only AggThenPartition is affected.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| tasks.spawn( | ||
| async move { shard.partition_by_hash(&group_by, num_slots) } | ||
| .instrument(Span::current()), | ||
| ); |
There was a problem hiding this comment.
CPU-bound work (
partition_by_hash, agg) is spawned with tokio::task::JoinSet::spawn, which places tasks on tokio's cooperative async thread pool. Because these closures contain no .await points they run to completion without yielding, potentially blocking tokio workers and starving concurrent I/O tasks. For genuinely CPU-bound work, tokio::task::spawn_blocking (which dispatches to a dedicated blocking pool) is the idiomatic choice. The same pattern exists in the AggThenPartition sharding above.
| tasks.spawn( | |
| async move { shard.partition_by_hash(&group_by, num_slots) } | |
| .instrument(Span::current()), | |
| ); | |
| let span = Span::current(); | |
| tasks.spawn(tokio::task::spawn_blocking(move || { | |
| let _enter = span.enter(); | |
| shard.partition_by_hash(&group_by, num_slots) | |
| }).map(|r| r.expect("shard task panicked"))); |
There was a problem hiding this comment.
Matches the existing pattern in finalize() in the same file, which also dispatches CPU-bound concat+agg work via tokio::task::JoinSet::spawn rather than spawn_blocking. If this is worth changing it would be a sink-wide refactor across all spawn sites, not just this PR.
There was a problem hiding this comment.
You're right. The finalize() method already uses tokio::task::JoinSet::spawn for CPU-bound concat + agg + eval_expression_list work with no .await points — the same pattern this PR follows. Changing to spawn_blocking would be a sink-wide refactor touching finalize() and all the sharding sites, not something scoped to this PR. Fair point, acknowledged.
| let shard_results: Vec<Vec<MicroPartition>> = tasks | ||
| .join_all() | ||
| .await | ||
| .into_iter() | ||
| .collect::<DaftResult<Vec<_>>>()?; |
There was a problem hiding this comment.
JoinSet::join_all returns results in completion order, not spawn order
JoinSet::join_all() is implemented via repeated join_next(), which yields tasks in the order they finish, not the order they were spawned. For PartitionThenAgg this means the flush-on-threshold logic in the merge loop runs over shards in a non-deterministic order. The PR author correctly argues correctness is preserved because partial-agg combinators are commutative, but the non-deterministic flush ordering makes memory-usage patterns harder to reason about and can make test assertions around intermediate state order-sensitive. Consider using an ordered collection (e.g. collecting into a Vec indexed by shard_idx) so processing order is deterministic.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Keeping join_all. Correctness is preserved (commutative combinators as noted) and forcing spawn-order via indexed collection adds code without changing observable output. finalize() in the same file uses the same join_all pattern.
Summary
Shard the
PartitionThenAggexecution path insideGroupedAggregateSink::sinkso that a single large input MicroPartition is row-range fanned out across 4 concurrentpartition_by_hashtasks before the existing flush-on-threshold logic runs. Small inputs continue to run the existing path. Single-file diff tosrc/daft-local-execution/src/sinks/grouped_aggregate.rs. Stacked on #7060.Why
Continues #6585 item 5. PR #7060 sharded the
AggThenPartitionstrategy (selected for low-cardinality groupbys). This PR extends the same row-range sharding toPartitionThenAgg(selected for high-cardinality groupbys). Both share the same single-threadedpartition_by_hashwork that benefits from K-way parallelism on large morsels.Changes Made
execute_partition_then_aggbecomes async; dispatch inexecute_strategyawaits itSHARD_THRESHOLD = 32_768andNUM_SHARDS_PER_MORSEL = 4constants from feat(grouped-agg): shard AggThenPartition execution per morsel #7060:< SHARD_THRESHOLD: existing sync logic (unchanged behavior)>= SHARD_THRESHOLD: row-range slice into 4 shards, spawn 4 tokio tasks each runningshard.partition_by_hash(group_by, num_slots), instrumented with the parent span. Then merge each shard's slot-i output into slot i of inner_states using the same flush-on-threshold logic as the sync pathPartitionOnlyleft unchanged for a follow-up PRBehavior
Functionally equivalent for all input sizes:
partition_by_hashresults are merged in shard order into the existing per-slot state. The flush-on-threshold check runs sequentially in the merge loop becausestate.unaggregated_sizeis per-slot shared state across shardsTest Plan
cargo build -p daft-local-execution --libcleancargo fmt -p daft-local-executioncleantests/dataframe/test_groupby*.pyand broader query testsRelated Issues
Part of #6585 item 5. Builds on #7060.