Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 59 additions & 73 deletions arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::{
ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
bit_util,
ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, RunEndBuffer,
ScalarBuffer, bit_util,
};
use arrow_data::{ArrayDataBuilder, transform::MutableArrayData};
use arrow_data::transform::MutableArrayData;
use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};

use num_traits::Zero;
Expand Down Expand Up @@ -256,11 +256,18 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
*length as u32,
)?))
}
DataType::Map(_, _) => {
DataType::Map(field, ordered) => {
let list_arr = ListArray::from(values.as_map().clone());
let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
let (_, offsets, entries, nulls) = list_data.into_parts();
let entries = entries.as_struct().clone();
Ok(Arc::new(MapArray::try_new(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i suppose the checks inside try_new are cheap enough to not be too much of an impact 👍

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we switch back to MapArray::new_unchecked

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(can do as a follow on PR)

field.clone(),
offsets,
entries,
nulls,
*ordered,
)?))
}
DataType::Struct(fields) => {
let array: &StructArray = values.as_struct();
Expand Down Expand Up @@ -710,18 +717,15 @@ where
"New offsets was filled under/over the expected capacity"
);

let child_data = array_data.freeze();
let value_offsets = Buffer::from_vec(new_offsets);

let list_data = ArrayDataBuilder::new(values.data_type().clone())
.len(indices.len())
.nulls(nulls)
.offset(0)
.add_child_data(child_data)
.add_buffer(value_offsets);
let field = match values.data_type() {
DataType::List(field) | DataType::LargeList(field) => field.clone(),
d => unreachable!("take_list called with non-list data type {d}"),
};
// SAFETY: `new_offsets` is constructed to be monotonically increasing above
let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(new_offsets)) };
let child = make_array(array_data.freeze());

let list_data = unsafe { list_data.build_unchecked() };
Ok(GenericListArray::<OffsetType::Native>::from(list_data))
GenericListArray::<OffsetType::Native>::try_new(field, offsets, child, nulls)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, as above I think we should use new_unchecked

}

fn take_list_view<IndexType, OffsetType>(
Expand All @@ -737,18 +741,22 @@ where
let taken_sizes = take_native(values.sizes(), indices);
let nulls = take_nulls(values.nulls(), indices);

let list_view_data = ArrayDataBuilder::new(values.data_type().clone())
.len(indices.len())
.nulls(nulls)
.buffers(vec![taken_offsets.into(), taken_sizes.into()])
.child_data(vec![values.values().to_data()]);

// SAFETY: all buffers and child nodes for ListView added in constructor
let list_view_data = unsafe { list_view_data.build_unchecked() };
let field = match values.data_type() {
DataType::ListView(field) | DataType::LargeListView(field) => field.clone(),
d => unreachable!("take_list_view called with non-list-view data type {d}"),
};

Ok(GenericListViewArray::<OffsetType::Native>::from(
list_view_data,
))
// SAFETY: the taken offsets/sizes are a permutation of the (valid) input
// offsets/sizes, so they remain within the bounds of the child array.
Ok(unsafe {
GenericListViewArray::<OffsetType::Native>::new_unchecked(
field,
taken_offsets,
taken_sizes,
Arc::clone(values.values()),
nulls,
)
})
}

/// `take` implementation for `FixedSizeListArray`
Expand Down Expand Up @@ -779,15 +787,13 @@ fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
}
}

let list_data = ArrayDataBuilder::new(values.data_type().clone())
.len(indices.len())
.null_bit_buffer(Some(null_buf.into()))
.offset(0)
.add_child_data(taken.into_data());

let list_data = unsafe { list_data.build_unchecked() };
let field = match values.data_type() {
DataType::FixedSizeList(field, _) => field.clone(),
d => unreachable!("take_fixed_size_list called with non-fixed-size-list data type {d}"),
};
let nulls = NullBuffer::from_unsliced_buffer(null_buf, indices.len());

Ok(FixedSizeListArray::from(list_data))
FixedSizeListArray::try_new(field, length as i32, taken, nulls)
}

/// The take kernel implementation for `FixedSizeBinaryArray`.
Expand Down Expand Up @@ -815,14 +821,8 @@ fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(

let value_nulls = take_nulls(values.nulls(), indices);
let final_nulls = NullBuffer::union(value_nulls.as_ref(), indices.nulls());
let array_data = ArrayDataBuilder::new(DataType::FixedSizeBinary(size))
.len(indices.len())
.nulls(final_nulls)
.offset(0)
.add_buffer(result_buffer)
.build()?;

return Ok(FixedSizeBinaryArray::from(array_data));
return FixedSizeBinaryArray::try_new(size, result_buffer, final_nulls);

/// Implementation of the take kernel for fixed size binary arrays.
#[inline(never)]
Expand Down Expand Up @@ -959,50 +959,36 @@ fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
// `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
let mut new_physical_len = 1;
for ix in 1..physical_indices.len() {
if physical_indices[ix] != physical_indices[ix - 1] {
take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
new_physical_len += 1;
}
}
take_value_indices
.append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
let new_run_ends = unsafe {
// Safety:
// The function builds a valid run_ends array and hence need not be validated.
ArrayDataBuilder::new(T::DATA_TYPE)
.len(new_physical_len)
.null_count(0)
.add_buffer(new_run_ends_builder.finish())
.build_unchecked()
};

let take_value_indices: PrimitiveArray<I> = unsafe {
// Safety:
// The function builds a valid take_value_indices array and hence need not be validated.
ArrayDataBuilder::new(I::DATA_TYPE)
.len(new_physical_len)
.null_count(0)
.add_buffer(take_value_indices.finish())
.build_unchecked()
.into()
// SAFETY: run-ends are strictly increasing with last value == logical length.
let run_ends = unsafe {
RunEndBuffer::new_unchecked(
ScalarBuffer::from(new_run_ends_builder.finish()),
0,
physical_indices.len(),
)
};

let take_value_indices =
PrimitiveArray::<I>::new(ScalarBuffer::from(take_value_indices.finish()), None);

let new_values = take(run_array.values(), &take_value_indices, None)?;

let builder = ArrayDataBuilder::new(run_array.data_type().clone())
.len(physical_indices.len())
.add_child_data(new_run_ends)
.add_child_data(new_values.into_data());
let array_data = unsafe {
// Safety:
// This function builds a valid run array and hence can skip validation.
builder.build_unchecked()
};
Ok(array_data.into())
// SAFETY: `new_values` has one entry per run.
Ok(
unsafe {
RunArray::<T>::new_unchecked(run_array.data_type().clone(), run_ends, new_values)
},
)
}

/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
Expand Down
Loading