Skip to content
53 changes: 43 additions & 10 deletions arrow-flight/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use crate::{FlightData, trailers::LazyTrailers, utils::flight_data_to_arrow_batch};
use crate::{FlightData, trailers::LazyTrailers};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_buffer::Buffer;
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_ipc::reader;
use arrow_schema::{Schema, SchemaRef};
use bytes::Bytes;
use futures::{Stream, StreamExt, ready, stream::BoxStream};
Expand Down Expand Up @@ -228,6 +229,8 @@ pub struct FlightDataDecoder {
state: Option<FlightStreamState>,
/// Seen the end of the inner stream?
done: bool,
/// Skip validation of decoded arrays (UTF-8, offset bounds, null counts).
skip_validation: bool,
}

impl Debug for FlightDataDecoder {
Expand All @@ -236,6 +239,7 @@ impl Debug for FlightDataDecoder {
.field("response", &"<stream>")
.field("state", &self.state)
.field("done", &self.done)
.field("skip_validation", &self.skip_validation)
.finish()
}
}
Expand All @@ -250,9 +254,17 @@ impl FlightDataDecoder {
state: None,
response: response.boxed(),
done: false,
skip_validation: false,
}
}

/// Only set for trusted senders, invalid data may cause undefined behavior.
/// Can improve performance by skipping validation
pub fn with_skip_validation(mut self, skip_validation: bool) -> Self {
self.skip_validation = skip_validation;
self
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than exposing this as a plan bool flag, I think we should be requiring an UnsafeFlag here.

By requiring an UnsafeFlag, we force consumers to explicitly have an unsafe block in their codebase, making sure they are aware that what they are doing is not safe, and that they are responsible for ensuring memory safety there.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me! pushed up a revision


/// Returns the current schema for this stream
pub fn schema(&self) -> Option<&SchemaRef> {
self.state.as_ref().map(|state| &state.schema)
Expand Down Expand Up @@ -319,14 +331,35 @@ impl FlightDataDecoder {
));
};

let batch = flight_data_to_arrow_batch(
&data,
Arc::clone(&state.schema),
&state.dictionaries_by_field,
)
.map_err(|e| {
FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}"))
})?;
let data_buffer = if data.data_body.as_ptr() as usize % 64 != 0 {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see context here

let mut buf = MutableBuffer::with_capacity(data.data_body.len());
buf.extend_from_slice(&data.data_body);
Buffer::from(buf)
} else {
Buffer::from(data.data_body.clone())
};

let batch = message
.header_as_record_batch()
.ok_or_else(|| {
FlightError::DecodeError(
"Unable to convert flight data header to a record batch".to_string(),
)
})
.and_then(|record_batch| {
reader::read_record_batch(
&data_buffer,
record_batch,
Arc::clone(&state.schema),
&state.dictionaries_by_field,
None,
&message.version(),
self.skip_validation,
)
.map_err(|e| {
FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}"))
})
})?;

Ok(Some(DecodedFlightData::new_record_batch(data, batch)))
}
Expand Down
1 change: 1 addition & 0 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ pub fn arrow_data_from_flight_data(
&dictionaries_by_field,
None,
&ipc_message.version(),
false,
)?;
Ok(ArrowFlightData::RecordBatch(record_batch))
}
Expand Down
6 changes: 4 additions & 2 deletions arrow-flight/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub fn flight_data_to_batches(flight_data: &[FlightData]) -> Result<Vec<RecordBa
let mut batches = vec![];
let dictionaries_by_id = HashMap::new();
for datum in flight_data[1..].iter() {
let batch = flight_data_to_arrow_batch(datum, schema.clone(), &dictionaries_by_id)?;
let batch = flight_data_to_arrow_batch(datum, schema.clone(), &dictionaries_by_id, false)?;
batches.push(batch);
}
Ok(batches)
Expand All @@ -56,6 +56,7 @@ pub fn flight_data_to_arrow_batch(
data: &FlightData,
schema: SchemaRef,
dictionaries_by_id: &HashMap<i64, ArrayRef>,
skip_validation: bool,
) -> Result<RecordBatch, ArrowError> {
// check that the data_header is a record batch message
let message = arrow_ipc::root_as_message(&data.data_header[..])
Expand All @@ -70,12 +71,13 @@ pub fn flight_data_to_arrow_batch(
})
.map(|batch| {
reader::read_record_batch(
&Buffer::from(data.data_body.as_ref()),
&Buffer::from(data.data_body.clone()),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I this is a sneaky one, but indeed this is avoiding a full clone

batch,
schema,
dictionaries_by_id,
None,
&message.version(),
skip_validation,
)
})?
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ async fn consume_flight_location(
assert_eq!(metadata, data.app_metadata);

let actual_batch =
flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id)
flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id, false)
.expect("Unable to convert flight data to Arrow batch");

assert_eq!(actual_schema, actual_batch.schema());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ async fn record_batch_from_message(
dictionaries_by_id,
None,
&message.version(),
false,
);

arrow_batch_result
Expand Down
21 changes: 14 additions & 7 deletions arrow-ipc/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ impl RecordBatchDecoder<'_> {
let null_buffer = self.next_buffer()?;

// read the arrays for each field
let mut struct_arrays = vec![];
let mut struct_arrays = Vec::with_capacity(struct_fields.len());
// TODO investigate whether just knowing the number of buffers could
// still work
for struct_field in struct_fields {
Expand Down Expand Up @@ -557,7 +557,7 @@ impl<'a> RecordBatchDecoder<'a> {

let schema = Arc::clone(&self.schema);
if let Some(projection) = self.projection {
let mut arrays = vec![];
let mut arrays = Vec::with_capacity(projection.len());
// project fields
for (idx, field) in schema.fields().iter().enumerate() {
// A projected field can appear more than once, so collect all matching positions.
Expand Down Expand Up @@ -597,7 +597,7 @@ impl<'a> RecordBatchDecoder<'a> {
RecordBatch::try_new_with_options(schema, columns, &options)
}
} else {
let mut children = vec![];
let mut children = Vec::with_capacity(schema.fields().len());
// keep track of index as lists require more than one node
for field in schema.fields() {
let child = self.create_array(field, &mut variadic_counts)?;
Expand Down Expand Up @@ -771,11 +771,18 @@ pub fn read_record_batch(
dictionaries_by_id: &HashMap<i64, ArrayRef>,
projection: Option<&[usize]>,
metadata: &MetadataVersion,
skip_validation: bool,
) -> Result<RecordBatch, ArrowError> {
RecordBatchDecoder::try_new(buf, batch, schema, dictionaries_by_id, metadata)?
.with_projection(projection)
.with_require_alignment(false)
.read_record_batch()
let mut decoder =
RecordBatchDecoder::try_new(buf, batch, schema, dictionaries_by_id, metadata)?
.with_projection(projection)
.with_require_alignment(false);
if skip_validation {
let mut flag = UnsafeFlag::new();
unsafe { flag.set(true) };
decoder = decoder.with_skip_validation(flag);
}
decoder.read_record_batch()
}

/// Read the dictionary from the buffer and provided metadata,
Expand Down
Loading