diff --git a/Cargo.lock b/Cargo.lock index 6a6bb511037b..14f251e2d02a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -388,6 +388,7 @@ version = "59.0.0" dependencies = [ "arrow", "arrow-buffer", + "arrow-data", "arrow-flight", "arrow-integration-test", "clap", diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 46fcd0810315..c4541fdb6012 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -32,7 +32,7 @@ arrow-array = { workspace = true } arrow-buffer = { workspace = true } # Cast is needed to work around https://github.com/apache/arrow-rs/issues/3389 arrow-cast = { workspace = true } -arrow-data = { workspace = true, optional = true } +arrow-data = { workspace = true } arrow-ipc = { workspace = true } arrow-ord = { workspace = true, optional = true } arrow-row = { workspace = true, optional = true } @@ -62,7 +62,7 @@ all-features = true [features] default = [] -flight-sql = ["dep:arrow-arith", "dep:arrow-data", "dep:arrow-ord", "dep:arrow-row", "dep:arrow-select", "dep:arrow-string", "dep:once_cell", "dep:paste"] +flight-sql = ["dep:arrow-arith", "dep:arrow-ord", "dep:arrow-row", "dep:arrow-select", "dep:arrow-string", "dep:once_cell", "dep:paste"] # TODO: Remove in the next release flight-sql-experimental = ["flight-sql"] tls-aws-lc= ["tonic/tls-aws-lc"] diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs index 8c518ac9d454..6d5ebb04c1d1 100644 --- a/arrow-flight/src/decode.rs +++ b/arrow-flight/src/decode.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::{FlightData, trailers::LazyTrailers, utils::flight_data_to_arrow_batch}; +use crate::{FlightData, trailers::LazyTrailers}; use arrow_array::{ArrayRef, RecordBatch}; use arrow_buffer::Buffer; +use arrow_data::UnsafeFlag; use arrow_schema::{Schema, SchemaRef}; use bytes::Bytes; use futures::{Stream, StreamExt, ready, stream::BoxStream}; @@ -228,6 +229,8 @@ pub struct FlightDataDecoder { state: Option, /// Seen the end of the inner stream? done: bool, + /// Skip validation of decoded arrays (UTF-8, offset bounds, null counts). + skip_validation: UnsafeFlag, } impl Debug for FlightDataDecoder { @@ -236,6 +239,7 @@ impl Debug for FlightDataDecoder { .field("response", &"") .field("state", &self.state) .field("done", &self.done) + .field("skip_validation", &self.skip_validation) .finish() } } @@ -250,9 +254,17 @@ impl FlightDataDecoder { state: None, response: response.boxed(), done: false, + skip_validation: UnsafeFlag::new(), } } + /// # Safety + /// Invalid data may cause undefined behavior. Only use for trusted senders. + pub unsafe fn with_skip_validation(mut self) -> Self { + unsafe { self.skip_validation.set(true) }; + self + } + /// Returns the current schema for this stream pub fn schema(&self) -> Option<&SchemaRef> { self.state.as_ref().map(|state| &state.schema) @@ -319,11 +331,27 @@ impl FlightDataDecoder { )); }; - let batch = flight_data_to_arrow_batch( - &data, + let record_batch = message.header_as_record_batch().ok_or_else(|| { + FlightError::DecodeError( + "Unable to convert flight data header to a record batch".to_string(), + ) + })?; + let buf = if data.data_body.as_ptr() as usize % 64 == 0 { + Buffer::from(data.data_body.clone()) + } else { + Buffer::from(data.data_body.as_ref()) + }; + let batch = arrow_ipc::reader::RecordBatchDecoder::try_new( + &buf, + record_batch, Arc::clone(&state.schema), &state.dictionaries_by_field, + &message.version(), ) + .and_then(|d| { + d.with_skip_validation(self.skip_validation.clone()) + .read_record_batch() + }) .map_err(|e| { FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}")) })?; diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 6effb5f86aaf..0e38e7ed77aa 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -69,8 +69,13 @@ pub fn flight_data_to_arrow_batch( ) }) .map(|batch| { + let buf = if data.data_body.as_ptr() as usize % 64 == 0 { + Buffer::from(data.data_body.clone()) + } else { + Buffer::from(data.data_body.as_ref()) + }; reader::read_record_batch( - &Buffer::from(data.data_body.as_ref()), + &buf, batch, schema, dictionaries_by_id, diff --git a/arrow-integration-testing/Cargo.toml b/arrow-integration-testing/Cargo.toml index ae13d32b57a9..cb488f5ff791 100644 --- a/arrow-integration-testing/Cargo.toml +++ b/arrow-integration-testing/Cargo.toml @@ -35,6 +35,7 @@ logging = ["tracing-subscriber"] [dependencies] arrow = { path = "../arrow", default-features = false, features = ["test_utils", "ipc", "ipc_compression", "json", "ffi"] } +arrow-data = { workspace = true } arrow-flight = { path = "../arrow-flight", default-features = false } arrow-integration-test = { path = "../arrow-integration-test", default-features = false } clap = { version = "4", default-features = false, features = ["std", "derive", "help", "error-context", "usage"] } diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 6d1e799d43c9..dd7557784e53 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -146,7 +146,7 @@ impl RecordBatchDecoder<'_> { let null_buffer = self.next_buffer()?; // read the arrays for each field - let mut struct_arrays = vec![]; + let mut struct_arrays = Vec::with_capacity(struct_fields.len()); // TODO investigate whether just knowing the number of buffers could // still work for struct_field in struct_fields { @@ -474,7 +474,7 @@ pub struct RecordBatchDecoder<'a> { impl<'a> RecordBatchDecoder<'a> { /// Create a reader for decoding arrays from an encoded [`RecordBatch`] - fn try_new( + pub fn try_new( buf: &'a Buffer, batch: crate::RecordBatch<'a>, schema: SchemaRef, @@ -530,6 +530,11 @@ impl<'a> RecordBatchDecoder<'a> { /// Specifies if validation should be skipped when reading data (defaults to `false`) /// + /// When enabled, the following checks are bypassed: + /// - Offset bounds (e.g. list/string offsets pointing past the end of their value buffer) + /// - UTF-8 validity of string columns (`Utf8` / `LargeUtf8`) + /// - Null count consistency and buffer length checks + /// /// Note this API is somewhat "funky" as it allows the caller to skip validation /// without having to use `unsafe` code. If this is ever made public /// it should be made clearer that this is a potentially unsafe by @@ -538,14 +543,15 @@ impl<'a> RecordBatchDecoder<'a> { /// # Safety /// /// Relies on the caller only passing a flag with `true` value if they are - /// certain that the data is valid - pub(crate) fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self { + /// certain that the data is valid. Invalid data that bypasses these checks + /// may cause undefined behavior when the arrays are later accessed. + pub fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self { self.skip_validation = skip_validation; self } /// Read the record batch, consuming the reader - fn read_record_batch(mut self) -> Result { + pub fn read_record_batch(mut self) -> Result { let mut variadic_counts: VecDeque = self .batch .variadicBufferCounts() @@ -557,7 +563,7 @@ impl<'a> RecordBatchDecoder<'a> { let schema = Arc::clone(&self.schema); if let Some(projection) = self.projection { - let mut arrays = vec![]; + let mut arrays = Vec::with_capacity(projection.len()); // project fields for (idx, field) in schema.fields().iter().enumerate() { // A projected field can appear more than once, so collect all matching positions. @@ -597,7 +603,7 @@ impl<'a> RecordBatchDecoder<'a> { RecordBatch::try_new_with_options(schema, columns, &options) } } else { - let mut children = vec![]; + let mut children = Vec::with_capacity(schema.fields().len()); // keep track of index as lists require more than one node for field in schema.fields() { let child = self.create_array(field, &mut variadic_counts)?;