diff --git a/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala b/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala index d1361ab7e..6001ef0b6 100644 --- a/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala +++ b/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.auron.plan.NativeAggBase -import org.apache.spark.sql.functions.{collect_list, monotonically_increasing_id, rand, randn, spark_partition_id, sum} +import org.apache.spark.sql.functions.{collect_list, last, monotonically_increasing_id, rand, randn, spark_partition_id, sum} import org.apache.spark.sql.internal.SQLConf class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQueryTestsBase { @@ -75,4 +75,31 @@ class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQue rand(Random.nextLong()), randn(Random.nextLong())).foreach(assertNoExceptions) } + + testAuron("native last / last(ignoreNulls) aggregate") { + // The grouped aggregate is reliably offloaded to NativeAggBase, and the data + // is deterministic by construction (no intra-group ordering dependence): + // k=1 -> all values 10 => last=10, last(ignoreNulls)=10 + // k=2 -> all values null => last=null, last(ignoreNulls)=null + // k=3 -> single row 30 => last=30, last(ignoreNulls)=30 + val df = Seq[(Int, Option[Int])]( + (1, Some(10)), + (1, Some(10)), + (2, None), + (2, None), + (3, Some(30))) + .toDF("k", "v") + + val aggDF = df + .groupBy("k") + .agg(last($"v").as("last_v"), last($"v", ignoreNulls = true).as("last_v_ign")) + + checkAnswer(aggDF, Seq(Row(1, 10, 10), Row(2, null, null), Row(3, 30, 30))) + + // the aggregate must be offloaded to the native engine + assert(getExecutedPlan(aggDF).exists { + case _: NativeAggBase => true + case _ => false + }) + } } diff --git a/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala b/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala index d1361ab7e..6001ef0b6 100644 --- a/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala +++ b/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.auron.plan.NativeAggBase -import org.apache.spark.sql.functions.{collect_list, monotonically_increasing_id, rand, randn, spark_partition_id, sum} +import org.apache.spark.sql.functions.{collect_list, last, monotonically_increasing_id, rand, randn, spark_partition_id, sum} import org.apache.spark.sql.internal.SQLConf class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQueryTestsBase { @@ -75,4 +75,31 @@ class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQue rand(Random.nextLong()), randn(Random.nextLong())).foreach(assertNoExceptions) } + + testAuron("native last / last(ignoreNulls) aggregate") { + // The grouped aggregate is reliably offloaded to NativeAggBase, and the data + // is deterministic by construction (no intra-group ordering dependence): + // k=1 -> all values 10 => last=10, last(ignoreNulls)=10 + // k=2 -> all values null => last=null, last(ignoreNulls)=null + // k=3 -> single row 30 => last=30, last(ignoreNulls)=30 + val df = Seq[(Int, Option[Int])]( + (1, Some(10)), + (1, Some(10)), + (2, None), + (2, None), + (3, Some(30))) + .toDF("k", "v") + + val aggDF = df + .groupBy("k") + .agg(last($"v").as("last_v"), last($"v", ignoreNulls = true).as("last_v_ign")) + + checkAnswer(aggDF, Seq(Row(1, 10, 10), Row(2, null, null), Row(3, 30, 30))) + + // the aggregate must be offloaded to the native engine + assert(getExecutedPlan(aggDF).exists { + case _: NativeAggBase => true + case _ => false + }) + } } diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index 13b9f48bc..fd8390673 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -148,6 +148,8 @@ enum AggFunction { FIRST = 7; FIRST_IGNORES_NULL = 8; BLOOM_FILTER = 9; + LAST = 10; + LAST_IGNORES_NULL = 11; BRICKHOUSE_COLLECT = 1000; BRICKHOUSE_COMBINE_UNIQUE = 1001; UDAF = 1002; diff --git a/native-engine/auron-planner/src/lib.rs b/native-engine/auron-planner/src/lib.rs index a0f7b83d2..fb862bffd 100644 --- a/native-engine/auron-planner/src/lib.rs +++ b/native-engine/auron-planner/src/lib.rs @@ -135,6 +135,8 @@ impl From for AggFunction { protobuf::AggFunction::CollectSet => AggFunction::CollectSet, protobuf::AggFunction::First => AggFunction::First, protobuf::AggFunction::FirstIgnoresNull => AggFunction::FirstIgnoresNull, + protobuf::AggFunction::Last => AggFunction::Last, + protobuf::AggFunction::LastIgnoresNull => AggFunction::LastIgnoresNull, protobuf::AggFunction::BloomFilter => AggFunction::BloomFilter, protobuf::AggFunction::BrickhouseCollect => AggFunction::BrickhouseCollect, protobuf::AggFunction::BrickhouseCombineUnique => AggFunction::BrickhouseCombineUnique, diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index 418cc951d..30e4f5757 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -680,6 +680,12 @@ impl PhysicalPlanner { protobuf::AggFunction::FirstIgnoresNull => { WindowFunction::Agg(AggFunction::FirstIgnoresNull) } + protobuf::AggFunction::Last => { + WindowFunction::Agg(AggFunction::Last) + } + protobuf::AggFunction::LastIgnoresNull => { + WindowFunction::Agg(AggFunction::LastIgnoresNull) + } protobuf::AggFunction::BloomFilter => { WindowFunction::Agg(AggFunction::BloomFilter) } diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index 5eb4c3dad..99adc470d 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -33,6 +33,8 @@ use crate::agg::{ count::AggCount, first::AggFirst, first_ignores_null::AggFirstIgnoresNull, + last::AggLast, + last_ignores_null::AggLastIgnoresNull, maxmin::{AggMax, AggMin}, spark_udaf_wrapper::SparkUDAFWrapper, sum::AggSum, @@ -212,6 +214,14 @@ pub fn create_agg( let dt = children[0].data_type(input_schema)?; Arc::new(AggFirstIgnoresNull::try_new(children[0].clone(), dt)?) } + AggFunction::Last => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggLast::try_new(children[0].clone(), dt)?) + } + AggFunction::LastIgnoresNull => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggLastIgnoresNull::try_new(children[0].clone(), dt)?) + } AggFunction::BloomFilter => { let dt = children[0].data_type(input_schema)?; let empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty())); diff --git a/native-engine/datafusion-ext-plans/src/agg/last.rs b/native-engine/datafusion-ext-plans/src/agg/last.rs new file mode 100644 index 000000000..8cbe776a2 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/agg/last.rs @@ -0,0 +1,330 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use auron_memmgr::spill::{SpillCompressedReader, SpillCompressedWriter}; +use datafusion::{ + common::{Result, ScalarValue}, + physical_expr::PhysicalExprRef, +}; +use datafusion_ext_commons::{downcast_any, scalar_value::compacted_scalar_value_from_array}; + +use crate::{ + agg::{ + Agg, + acc::{ + AccBooleanColumn, AccBytes, AccBytesColumn, AccColumn, AccColumnRef, AccPrimColumn, + AccScalarValueColumn, create_acc_generic_column, + }, + agg::IdxSelection, + }, + idx_for_zipped, +}; + +pub struct AggLast { + child: PhysicalExprRef, + data_type: DataType, + acc_array_data_types: Vec, +} + +impl AggLast { + pub fn try_new(child: PhysicalExprRef, data_type: DataType) -> Result { + let acc_array_data_types = vec![data_type.clone(), DataType::Boolean]; + Ok(Self { + child, + data_type, + acc_array_data_types, + }) + } +} + +impl Debug for AggLast { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Last({:?})", self.child) + } +} + +impl Agg for AggLast { + fn as_any(&self) -> &dyn Any { + self + } + + fn exprs(&self) -> Vec { + vec![self.child.clone()] + } + + fn with_new_exprs(&self, exprs: Vec) -> Result> { + Ok(Arc::new(Self::try_new( + exprs[0].clone(), + self.data_type.clone(), + )?)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nullable(&self) -> bool { + true + } + + fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { + Box::new(AccLastColumn { + values: create_acc_generic_column(self.data_type.clone(), num_rows), + flags: AccBooleanColumn::new(num_rows), + }) + } + + fn acc_array_data_types(&self) -> &[DataType] { + &self.acc_array_data_types + } + + fn partial_update( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + partial_args: &[ArrayRef], + partial_arg_idx: IdxSelection<'_>, + ) -> Result<()> { + let partial_arg = &partial_args[0]; + let accs = downcast_any!(accs, mut AccLastColumn)?; + accs.ensure_size(acc_idx); + + let (value_accs, flag_accs) = accs.inner_mut(); + + // last() keeps the latest visited row (including null), so every row + // unconditionally overwrites the accumulator. + macro_rules! handle_bytes { + ($array:expr) => {{ + let value_accs = downcast_any!(value_accs, mut AccBytesColumn)?; + let partial_arg = $array; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + value_accs.set_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref()))); + } else { + value_accs.set_value(acc_idx, None); + } + flag_accs.set_value(acc_idx, Some(true)); + } + } + }} + } + + downcast_primitive_array! { + partial_arg => { + let value_accs = downcast_any!(value_accs, mut AccPrimColumn<_>)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + value_accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } else { + value_accs.set_value(acc_idx, None); + } + flag_accs.set_value(acc_idx, Some(true)); + } + } + } + DataType::Boolean => { + let value_accs = downcast_any!(value_accs, mut AccBooleanColumn)?; + let partial_arg = downcast_any!(partial_arg, BooleanArray)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + value_accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } else { + value_accs.set_value(acc_idx, None); + } + flag_accs.set_value(acc_idx, Some(true)); + } + } + } + DataType::Utf8 => handle_bytes!(downcast_any!(partial_arg, StringArray)?), + DataType::Binary => handle_bytes!(downcast_any!(partial_arg, BinaryArray)?), + _other => { + let value_accs = downcast_any!(value_accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + value_accs.set_value(acc_idx, compacted_scalar_value_from_array(partial_arg, partial_arg_idx)?); + } else { + value_accs.set_value(acc_idx, ScalarValue::Null); + } + flag_accs.set_value(acc_idx, Some(true)); + } + } + } + } + Ok(()) + } + + fn partial_merge( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + merging_accs: &mut AccColumnRef, + merging_acc_idx: IdxSelection<'_>, + ) -> Result<()> { + let accs = downcast_any!(accs, mut AccLastColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccLastColumn)?; + accs.ensure_size(acc_idx); + + let (value_accs, flag_accs) = accs.inner_mut(); + let (merging_value_accs, merging_flag_accs) = merging_accs.inner_mut(); + + // the merging accumulator carries later-visited rows, so it overwrites + // whenever it has visited any row. + macro_rules! handle_primitive { + ($ty:ty) => {{ + type TNative = <$ty as ArrowPrimitiveType>::Native; + let value_accs = downcast_any!(value_accs, mut AccPrimColumn)?; + let merging_value_accs = downcast_any!(merging_value_accs, mut AccPrimColumn<_>)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_flag_accs.value(merging_acc_idx).is_some() { + value_accs.set_value(acc_idx, merging_value_accs.value(merging_acc_idx)); + flag_accs.set_value(acc_idx, Some(true)); + } + } + } + }} + } + + macro_rules! handle_boolean { + () => {{ + let value_accs = downcast_any!(value_accs, mut AccBooleanColumn)?; + let merging_value_accs = downcast_any!(merging_value_accs, mut AccBooleanColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_flag_accs.value(merging_acc_idx).is_some() { + value_accs.set_value(acc_idx, merging_value_accs.value(merging_acc_idx)); + flag_accs.set_value(acc_idx, Some(true)); + } + } + } + }} + } + + macro_rules! handle_bytes { + () => {{ + let value_accs = downcast_any!(value_accs, mut AccBytesColumn)?; + let merging_value_accs = downcast_any!(merging_value_accs, mut AccBytesColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_flag_accs.value(merging_acc_idx).is_some() { + value_accs.set_value(acc_idx, merging_value_accs.take_value(merging_acc_idx)); + flag_accs.set_value(acc_idx, Some(true)); + } + } + } + }} + } + + downcast_primitive! { + (&self.data_type) => (handle_primitive), + DataType::Boolean => handle_boolean!(), + DataType::Utf8 | DataType::Binary => handle_bytes!(), + DataType::Null => {} + _ => { + let value_accs = downcast_any!(value_accs, mut AccScalarValueColumn)?; + let merging_value_accs = downcast_any!(merging_value_accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_flag_accs.value(merging_acc_idx).is_some() { + value_accs.set_value(acc_idx, merging_value_accs.take_value(merging_acc_idx)); + flag_accs.set_value(acc_idx, Some(true)); + } + } + } + } + } + Ok(()) + } + + fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + Ok(accs.freeze_to_arrays(acc_idx)?[0].clone()) + } +} + +struct AccLastColumn { + values: AccColumnRef, + flags: AccBooleanColumn, +} + +impl AccLastColumn { + fn inner_mut(&mut self) -> (&mut AccColumnRef, &mut AccBooleanColumn) { + let values = &mut self.values as *mut AccColumnRef; + let flags = &mut self.flags as *mut AccBooleanColumn; + unsafe { (&mut *values, &mut *flags) } // safety: bypass borrow checker + } +} + +impl AccColumn for AccLastColumn { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn resize(&mut self, len: usize) { + self.values.resize(len); + self.flags.resize(len); + } + + fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.flags.shrink_to_fit(); + } + + fn num_records(&self) -> usize { + self.values.num_records() + } + + fn mem_used(&self) -> usize { + self.values.mem_used() + self.flags.mem_used() + } + + fn freeze_to_arrays(&mut self, idx: IdxSelection<'_>) -> Result> { + let value_array = self.values.freeze_to_arrays(idx)?[0].clone(); + let flags_array = self.flags.freeze_to_arrays(idx)?[0].clone(); + Ok(vec![value_array, flags_array]) + } + + fn unfreeze_from_arrays(&mut self, arrays: &[ArrayRef]) -> Result<()> { + self.values.unfreeze_from_arrays(&arrays[0..1])?; + self.flags.unfreeze_from_arrays(&arrays[1..2])?; + Ok(()) + } + + fn spill(&mut self, idx: IdxSelection<'_>, w: &mut SpillCompressedWriter) -> Result<()> { + self.values.spill(idx, w)?; + self.flags.spill(idx, w)?; + Ok(()) + } + + fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> { + self.values.unspill(num_rows, r)?; + self.flags.unspill(num_rows, r)?; + Ok(()) + } +} diff --git a/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs b/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs new file mode 100644 index 000000000..1032a674e --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{common::Result, physical_expr::PhysicalExprRef}; +use datafusion_ext_commons::{downcast_any, scalar_value::compacted_scalar_value_from_array}; + +use crate::{ + agg::{ + Agg, + acc::{ + AccBooleanColumn, AccBytes, AccBytesColumn, AccColumnRef, AccPrimColumn, + AccScalarValueColumn, create_acc_generic_column, + }, + agg::IdxSelection, + }, + idx_for_zipped, +}; + +pub struct AggLastIgnoresNull { + child: PhysicalExprRef, + data_type: DataType, + acc_array_data_types: Vec, +} + +impl AggLastIgnoresNull { + pub fn try_new(child: PhysicalExprRef, data_type: DataType) -> Result { + let acc_array_data_types = vec![data_type.clone()]; + Ok(Self { + child, + data_type, + acc_array_data_types, + }) + } +} + +impl Debug for AggLastIgnoresNull { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "LastIgnoresNull({:?})", self.child) + } +} + +impl Agg for AggLastIgnoresNull { + fn as_any(&self) -> &dyn Any { + self + } + + fn exprs(&self) -> Vec { + vec![self.child.clone()] + } + + fn with_new_exprs(&self, exprs: Vec) -> Result> { + Ok(Arc::new(Self::try_new( + exprs[0].clone(), + self.data_type.clone(), + )?)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nullable(&self) -> bool { + true + } + + fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { + create_acc_generic_column(self.data_type.clone(), num_rows) + } + + fn acc_array_data_types(&self) -> &[DataType] { + &self.acc_array_data_types + } + + fn partial_update( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + partial_args: &[ArrayRef], + partial_arg_idx: IdxSelection<'_>, + ) -> Result<()> { + let partial_arg = &partial_args[0]; + accs.ensure_size(acc_idx); + + // last(ignoreNulls=true) keeps the latest visited non-null row, so every + // non-null row unconditionally overwrites the accumulator. + macro_rules! handle_bytes { + ($array:expr) => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let partial_arg = $array; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref()))); + } + } + } + }} + } + + downcast_primitive_array! { + partial_arg => { + if let Ok(accs) = downcast_any!(accs, mut AccPrimColumn<_>) { + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } + } + } + } + } + DataType::Boolean => { + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let partial_arg = downcast_any!(partial_arg, BooleanArray)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } + } + } + } + DataType::Utf8 => handle_bytes!(downcast_any!(partial_arg, StringArray)?), + DataType::Binary => handle_bytes!(downcast_any!(partial_arg, BinaryArray)?), + _other => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, compacted_scalar_value_from_array(partial_arg, partial_arg_idx)?); + } + } + } + } + } + Ok(()) + } + + fn partial_merge( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + merging_accs: &mut AccColumnRef, + merging_acc_idx: IdxSelection<'_>, + ) -> Result<()> { + accs.ensure_size(acc_idx); + + // the merging accumulator carries later-visited rows, so its non-null + // value overwrites whatever was visited earlier. + macro_rules! handle_primitive { + ($ty:ty) => {{ + type TNative = <$ty as ArrowPrimitiveType>::Native; + let accs = downcast_any!(accs, mut AccPrimColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccPrimColumn<_>)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + } + }} + } + + macro_rules! handle_boolean { + () => {{ + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBooleanColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + } + }}; + } + + macro_rules! handle_bytes { + () => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBytesColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + }}; + } + + downcast_primitive! { + (&self.data_type) => (handle_primitive), + DataType::Boolean => handle_boolean!(), + DataType::Utf8 | DataType::Binary => handle_bytes!(), + DataType::Null => {} + _ => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if !merging_accs.value(merging_acc_idx).is_null() { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + } + } + Ok(()) + } + + fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + Ok(accs.freeze_to_arrays(acc_idx)?[0].clone()) + } +} diff --git a/native-engine/datafusion-ext-plans/src/agg/mod.rs b/native-engine/datafusion-ext-plans/src/agg/mod.rs index 565e15b16..9867524ca 100644 --- a/native-engine/datafusion-ext-plans/src/agg/mod.rs +++ b/native-engine/datafusion-ext-plans/src/agg/mod.rs @@ -25,6 +25,8 @@ pub mod collect; pub mod count; pub mod first; pub mod first_ignores_null; +pub mod last; +pub mod last_ignores_null; pub mod maxmin; pub mod spark_udaf_wrapper; pub mod sum; @@ -69,6 +71,8 @@ pub enum AggFunction { Min, First, FirstIgnoresNull, + Last, + LastIgnoresNull, CollectList, CollectSet, BloomFilter, diff --git a/native-engine/datafusion-ext-plans/src/agg_exec.rs b/native-engine/datafusion-ext-plans/src/agg_exec.rs index d75d304f0..bbf166453 100644 --- a/native-engine/datafusion-ext-plans/src/agg_exec.rs +++ b/native-engine/datafusion-ext-plans/src/agg_exec.rs @@ -694,6 +694,112 @@ mod test { Ok(()) } + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_agg_last() -> Result<()> { + MemManager::init(10000); + + // group key "k" and a nullable value column "v". rows are visited in + // order, so within each group the last visited row wins. + // k=1: v = [10, null, 40] -> last = 40, last(ignoreNulls) = 40 + // k=2: v = [20, null] -> last = null, last(ignoreNulls) = 20 + let schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Int32, false), + Field::new("v", DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 1, 2, 1])), + Arc::new(Int32Array::from(vec![ + Some(10), + Some(20), + None, + None, + Some(40), + ])), + ], + )?; + let input: Arc = + Arc::new(TestMemoryExec::try_new(&[vec![batch]], schema, None)?); + + let agg_expr_last = create_agg( + AggFunction::Last, + &[phys_expr::col("v", &input.schema())?], + &input.schema(), + DataType::Int32, + )?; + + let agg_expr_last_ign = create_agg( + AggFunction::LastIgnoresNull, + &[phys_expr::col("v", &input.schema())?], + &input.schema(), + DataType::Int32, + )?; + + let aggs_agg_expr = vec![ + AggExpr { + field_name: "agg_expr_last".to_string(), + mode: Partial, + filter: None, + agg: agg_expr_last, + }, + AggExpr { + field_name: "agg_expr_last_ign".to_string(), + mode: Partial, + filter: None, + agg: agg_expr_last_ign, + }, + ]; + + let agg_exec_partial = AggExec::try_new( + HashAgg, + vec![GroupingExpr { + field_name: "k".to_string(), + expr: Arc::new(Column::new("k", 0)), + }], + aggs_agg_expr.clone(), + false, + input, + )?; + + let agg_exec_final = AggExec::try_new( + HashAgg, + vec![GroupingExpr { + field_name: "k".to_string(), + expr: Arc::new(Column::new("k", 0)), + }], + aggs_agg_expr + .into_iter() + .map(|mut agg| { + agg.agg = agg + .agg + .with_new_exprs(vec![Arc::new(phys_expr::Literal::new( + ScalarValue::Null, + ))])?; + agg.mode = Final; + Ok(agg) + }) + .collect::>()?, + false, + Arc::new(agg_exec_partial), + )?; + + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let output_final = agg_exec_final.execute(0, task_ctx)?; + let batches = datafusion::physical_plan::common::collect(output_final).await?; + let expected = vec![ + "+---+---------------+-------------------+", + "| k | agg_expr_last | agg_expr_last_ign |", + "+---+---------------+-------------------+", + "| 1 | 40 | 40 |", + "| 2 | | 20 |", + "+---+---------------+-------------------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn test_agg_with_filter() -> Result<()> { MemManager::init(1000); diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 52e813eba..f87311853 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -37,7 +37,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.auron.util.Using import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, Max, Min, Sum, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, Last, Max, Min, Sum, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero @@ -1269,6 +1269,18 @@ object NativeConverters extends Logging { }) aggBuilder.addChildren(convertExpr(child)) + case Last(child, ignoresNullExpr) => + val ignoresNull = ignoresNullExpr.asInstanceOf[Any] match { + case Literal(v: Boolean, BooleanType) => v + case v: Boolean => v + } + aggBuilder.setAggFunction(if (ignoresNull) { + pb.AggFunction.LAST_IGNORES_NULL + } else { + pb.AggFunction.LAST + }) + aggBuilder.addChildren(convertExpr(child)) + case CollectList(child, _, _) => aggBuilder.setAggFunction(pb.AggFunction.COLLECT_LIST) aggBuilder.addChildren(convertExpr(child)) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala index 755fb6466..949e5291c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala @@ -308,6 +308,8 @@ object NativeAggBase extends Logging { case f: Average => Seq(f.dataType, LongType) case f @ First(_, true) => Seq(f.dataType) case f @ First(_, false) => Seq(f.dataType, BooleanType) + case f @ Last(_, true) => Seq(f.dataType) + case f @ Last(_, false) => Seq(f.dataType, BooleanType) case _ => Seq(BinaryType) } }