Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions parquet-variant-compute/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ parquet-variant-json = { workspace = true }
chrono = { workspace = true }
uuid = { version = "1.18.0", features = ["v4"] }
serde_json = "1.0"
num-traits = { version = "0.2", default-features = false }

# uuid requires the `js` feature to run on wasm
[target.'cfg(target_arch = "wasm32")'.dependencies]
Expand Down
222 changes: 218 additions & 4 deletions parquet-variant-compute/src/shred_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub(crate) fn shred_variant_with_options(
cast_options,
array.len(),
NullValue::TopLevelVariant,
true,
)?;
for i in 0..array.len() {
if array.is_null(i) {
Expand Down Expand Up @@ -145,6 +146,7 @@ pub(crate) fn make_variant_to_shredded_variant_arrow_row_builder<'a>(
cast_options: &'a CastOptions,
capacity: usize,
null_value: NullValue,
shred: bool,
) -> Result<VariantToShreddedVariantRowBuilder<'a>> {
let builder = match data_type {
DataType::Struct(fields) => {
Expand All @@ -153,6 +155,7 @@ pub(crate) fn make_variant_to_shredded_variant_arrow_row_builder<'a>(
cast_options,
capacity,
null_value,
shred,
)?;
VariantToShreddedVariantRowBuilder::Object(typed_value_builder)
}
Expand Down Expand Up @@ -193,7 +196,7 @@ pub(crate) fn make_variant_to_shredded_variant_arrow_row_builder<'a>(
| DataType::FixedSizeBinary(16) // UUID
=> {
let builder =
make_primitive_variant_to_arrow_row_builder(data_type, cast_options, capacity)?;
make_primitive_variant_to_arrow_row_builder(data_type, cast_options, capacity, shred)?;
let typed_value_builder =
VariantToShreddedPrimitiveVariantRowBuilder::new(builder, capacity, null_value);
VariantToShreddedVariantRowBuilder::Primitive(typed_value_builder)
Expand Down Expand Up @@ -369,13 +372,15 @@ impl<'a> VariantToShreddedObjectVariantRowBuilder<'a> {
cast_options: &'a CastOptions,
capacity: usize,
null_value: NullValue,
shred: bool,
) -> Result<Self> {
let typed_value_builders = fields.iter().map(|field| {
let builder = make_variant_to_shredded_variant_arrow_row_builder(
field.data_type(),
cast_options,
capacity,
NullValue::ObjectField,
shred,
)?;
Ok((field.name().as_str(), builder))
});
Expand Down Expand Up @@ -710,9 +715,12 @@ mod tests {
use arrow::datatypes::{
ArrowPrimitiveType, DataType, Field, Fields, Int64Type, TimeUnit, UnionFields, UnionMode,
};
use arrow_schema::IntervalUnit;
use chrono::{DateTime, NaiveDate, NaiveTime};
use parquet_variant::{
BuilderSpecificState, EMPTY_VARIANT_METADATA_BYTES, ObjectBuilder, ReadOnlyMetadataBuilder,
Variant, VariantBuilder, VariantPath, VariantPathElement,
ShortString, Variant, VariantBuilder, VariantDecimal4, VariantDecimal8, VariantDecimal16,
VariantPath, VariantPathElement,
};
use std::sync::Arc;
use uuid::Uuid;
Expand Down Expand Up @@ -1046,6 +1054,7 @@ mod tests {
&cast_options,
1,
mode,
true,
)
.unwrap();
primitive_builder.append_null().unwrap();
Expand Down Expand Up @@ -1076,6 +1085,7 @@ mod tests {
&cast_options,
1,
mode,
true,
)
.unwrap();
array_builder.append_null().unwrap();
Expand Down Expand Up @@ -1104,6 +1114,7 @@ mod tests {
&cast_options,
1,
mode,
true,
)
.unwrap();
object_builder.append_null().unwrap();
Expand Down Expand Up @@ -1310,7 +1321,7 @@ mod tests {
.downcast_ref::<arrow::array::Int32Array>()
.unwrap();
assert_eq!(typed_value_int32.value(0), 42);
assert_eq!(typed_value_int32.value(1), 3);
assert!(typed_value_int32.is_null(1)); // float doesn't shred to int32
assert!(typed_value_int32.is_null(2)); // string doesn't convert to int32

// Test Float64 target
Expand All @@ -1321,7 +1332,7 @@ mod tests {
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert_eq!(typed_value_float64.value(0), 42.0); // int converts to float
assert!(typed_value_float64.is_null(0)); // int doesn't shred to float
assert_eq!(typed_value_float64.value(1), 3.15);
assert!(typed_value_float64.is_null(2)); // string doesn't convert
}
Expand Down Expand Up @@ -2807,4 +2818,207 @@ mod tests {
let shredding_type = ShreddedSchemaBuilder::default().build();
assert_eq!(shredding_type, DataType::Null);
}

// This test wants to cover that the variant can/can't be shredded to the given data type.
#[test]
fn test_variant_type_shredded_correctly() {
// array contains all variant types
let mut array_builder = VariantArrayBuilder::new(30);
array_builder.append_value(Variant::Null);
array_builder.append_value(Variant::Int8(1));
array_builder.append_value(Variant::Int16(2));
array_builder.append_value(Variant::Int32(3));
array_builder.append_value(Variant::Int64(4));
array_builder.append_value(Variant::Date(NaiveDate::from_epoch_days(12345).unwrap()));
array_builder.append_value(Variant::TimestampMicros(
DateTime::from_timestamp_micros(123456789).unwrap(),
));
array_builder.append_value(Variant::TimestampNtzMicros(
DateTime::from_timestamp_micros(123456789)
.unwrap()
.naive_utc(),
));
array_builder.append_value(Variant::TimestampNanos(DateTime::from_timestamp_nanos(
1234567890123,
)));
array_builder.append_value(Variant::TimestampNtzNanos(
DateTime::from_timestamp_nanos(1234567890123).naive_utc(),
));
array_builder.append_value(VariantDecimal4::try_new(123, 2).unwrap());
array_builder.append_value(VariantDecimal8::try_new(123, 3).unwrap());
array_builder.append_value(VariantDecimal16::try_new(123, 4).unwrap());
array_builder.append_value(Variant::Float(5.2));
array_builder.append_value(Variant::Double(6.4));
array_builder.append_value(Variant::BooleanTrue);
array_builder.append_value(Variant::BooleanFalse);
array_builder.append_value(Variant::Binary("helow".as_bytes()));
array_builder.append_value(Variant::String("hello"));
array_builder.append_value(Variant::ShortString(
ShortString::try_from("world").unwrap(),
));
array_builder.append_value(Variant::Time(
NaiveTime::from_num_seconds_from_midnight_opt(12345, 123).unwrap(),
));

let array = array_builder.build();

fn can_shred_to(v: &Variant, dt: &DataType) -> bool {
matches!(
(v, dt),
(Variant::Int8(_), DataType::Int8)
| (Variant::Int8(_), DataType::Int16)
| (Variant::Int8(_), DataType::Int32)
| (Variant::Int8(_), DataType::Int64)
| (Variant::Int16(_), DataType::Int8)
| (Variant::Int16(_), DataType::Int16)
| (Variant::Int16(_), DataType::Int32)
| (Variant::Int16(_), DataType::Int64)
| (Variant::Int32(_), DataType::Int8)
| (Variant::Int32(_), DataType::Int16)
| (Variant::Int32(_), DataType::Int32)
| (Variant::Int32(_), DataType::Int64)
| (Variant::Int64(_), DataType::Int8)
| (Variant::Int64(_), DataType::Int16)
| (Variant::Int64(_), DataType::Int32)
| (Variant::Int64(_), DataType::Int64)
| (Variant::Date(_), DataType::Date32)
| (
Variant::TimestampMicros(_),
DataType::Timestamp(TimeUnit::Microsecond, Some(_)),
)
| (
Variant::TimestampMicros(_),
DataType::Timestamp(TimeUnit::Nanosecond, Some(_))
)
| (
Variant::TimestampNtzMicros(_),
DataType::Timestamp(TimeUnit::Microsecond, None),
)
| (
Variant::TimestampNtzMicros(_),
DataType::Timestamp(TimeUnit::Nanosecond, None)
)
| (
Variant::TimestampNanos(_),
DataType::Timestamp(TimeUnit::Nanosecond, Some(_)),
)
| (
Variant::TimestampNtzNanos(_),
DataType::Timestamp(TimeUnit::Nanosecond, None),
)
| (Variant::Decimal4(_), DataType::Decimal32(_, _))
| (Variant::Decimal4(_), DataType::Decimal64(_, _))
| (Variant::Decimal4(_), DataType::Decimal128(_, _))
| (Variant::Decimal8(_), DataType::Decimal32(_, _))
| (Variant::Decimal8(_), DataType::Decimal64(_, _))
| (Variant::Decimal8(_), DataType::Decimal128(_, _))
| (Variant::Decimal16(_), DataType::Decimal32(_, _))
| (Variant::Decimal16(_), DataType::Decimal64(_, _))
| (Variant::Decimal16(_), DataType::Decimal128(_, _))
| (Variant::Float(_), DataType::Float32)
| (Variant::Float(_), DataType::Float64)
| (Variant::Double(_), DataType::Float32)
| (Variant::Double(_), DataType::Float64)
| (Variant::BooleanFalse, DataType::Boolean)
| (Variant::BooleanTrue, DataType::Boolean)
| (Variant::Binary(_), DataType::Binary)
| (Variant::Binary(_), DataType::BinaryView)
| (Variant::Binary(_), DataType::LargeBinary)
| (Variant::ShortString(_), DataType::Utf8)
| (Variant::ShortString(_), DataType::Utf8View)
| (Variant::ShortString(_), DataType::LargeUtf8)
| (Variant::String(_), DataType::Utf8)
| (Variant::String(_), DataType::Utf8View)
| (Variant::String(_), DataType::LargeUtf8)
| (Variant::Time(_), DataType::Time64(_))
)
}

macro_rules! assert_shred_type {
($shred_type:expr, $expected_value_valid_bits:expr) => {
let shredded_array_result = shred_variant(&array, &$shred_type);
match shredded_array_result {
Ok(shredded_array) => {
let value_column = shredded_array.inner().column_by_name("value").unwrap();
for (idx, valid) in $expected_value_valid_bits.iter().enumerate() {
match valid {
true => assert!(
value_column.is_null(idx),
"{:?} should be shredded to {}",
array.value(idx),
$shred_type
),
false => assert!(
value_column.is_valid(idx),
"{:?} should not be shredded to {}",
array.value(idx),
$shred_type
),
}
}
}
Err(e) => {
let error_msg = format!("is not a valid variant shredding type");
assert!(
e.to_string().contains(error_msg.as_str()),
"{} => {}",
$shred_type,
e.to_string()
);
}
}
};
}

let types = [
DataType::Null,
DataType::Boolean,
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float32,
DataType::Float64,
DataType::Timestamp(TimeUnit::Second, Some("+00:00".into())),
DataType::Timestamp(TimeUnit::Second, None),
DataType::Timestamp(TimeUnit::Millisecond, Some("-00:00".into())),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Microsecond, Some("-00:00".into())),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())),
DataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::Date32,
DataType::Date64,
DataType::Time32(TimeUnit::Second),
DataType::Time32(TimeUnit::Millisecond),
DataType::Time64(TimeUnit::Microsecond),
DataType::Time64(TimeUnit::Nanosecond),
DataType::Duration(TimeUnit::Nanosecond),
DataType::Interval(IntervalUnit::DayTime),
DataType::Binary,
DataType::FixedSizeBinary(16), // uuid
DataType::FixedSizeBinary(32),
DataType::LargeBinary,
DataType::BinaryView,
DataType::Utf8,
DataType::LargeUtf8,
DataType::Utf8View,
DataType::Decimal32(4, 2),
DataType::Decimal64(10, 4),
DataType::Decimal128(20, 10),
DataType::Decimal256(30, 10),
];

for data_type in types {
let expected_bits = array
.iter()
.map(|v| can_shred_to(&v.unwrap(), &data_type))
.collect::<Vec<bool>>();
assert_shred_type!(data_type, expected_bits);
}
}
}
Loading
Loading