diff --git a/cpp/src/arrow/extension/parquet_variant.cc b/cpp/src/arrow/extension/parquet_variant.cc index 95aa5a0eb68e..4630c1d39321 100644 --- a/cpp/src/arrow/extension/parquet_variant.cc +++ b/cpp/src/arrow/extension/parquet_variant.cc @@ -17,28 +17,30 @@ #include "arrow/extension/parquet_variant.h" +#include #include #include "arrow/extension_type.h" #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" #include "arrow/util/logging_internal.h" namespace arrow::extension { VariantExtensionType::VariantExtensionType(const std::shared_ptr& storage_type) : ExtensionType(storage_type) { - // GH-45948: Shredded variants will need to handle an optional shredded_value as - // well as value_ becoming optional. - - // IsSupportedStorageType should have been called already, asserting that both - // metadata and value are present. - if (storage_type->field(0)->name() == "metadata") { - metadata_ = storage_type->field(0); - value_ = storage_type->field(1); - } else { - value_ = storage_type->field(0); - metadata_ = storage_type->field(1); + // IsSupportedStorageType should have been called already, asserting that + // metadata is present and at least one of value / typed_value is present. + for (const auto& field : storage_type->fields()) { + if (field->name() == "metadata") { + metadata_ = field; + } else if (field->name() == "value") { + value_ = field; + } else if (field->name() == "typed_value") { + typed_value_ = field; + } } } @@ -49,6 +51,9 @@ bool VariantExtensionType::ExtensionEquals(const ExtensionType& other) const { Result> VariantExtensionType::Deserialize( std::shared_ptr storage_type, const std::string& serialized) const { + if (!serialized.empty()) { + return Status::Invalid("Unexpected serialized metadata: '", serialized, "'"); + } return VariantExtensionType::Make(std::move(storage_type)); } @@ -57,50 +62,148 @@ std::string VariantExtensionType::Serialize() const { return ""; } std::shared_ptr VariantExtensionType::MakeArray( std::shared_ptr data) const { DCHECK_EQ(data->type->id(), Type::EXTENSION); - DCHECK_EQ("arrow.parquet.variant", + DCHECK_EQ(kVariantExtensionName, internal::checked_cast(*data->type).extension_name()); return std::make_shared(data); } namespace { -bool IsBinaryField(const std::shared_ptr field) { - return field->type()->storage_id() == Type::BINARY || - field->type()->storage_id() == Type::LARGE_BINARY; + +bool IsSupportedPrimitiveTypedValue(const std::shared_ptr& type) { + switch (type->id()) { + case Type::BOOL: + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::FLOAT: + case Type::DOUBLE: + case Type::DATE32: + case Type::BINARY: + case Type::LARGE_BINARY: + case Type::BINARY_VIEW: + case Type::STRING: + case Type::LARGE_STRING: + case Type::STRING_VIEW: + return true; + case Type::DECIMAL32: + case Type::DECIMAL64: + case Type::DECIMAL128: { + const auto& decimal = internal::checked_cast(*type); + return decimal.scale() >= 0 && decimal.scale() <= decimal.precision(); + } + case Type::TIME64: + return internal::checked_cast(*type).unit() == TimeUnit::MICRO; + case Type::TIMESTAMP: { + const auto unit = internal::checked_cast(*type).unit(); + return unit == TimeUnit::MICRO || unit == TimeUnit::NANO; + } + case Type::FIXED_SIZE_BINARY: + return internal::checked_cast(*type).byte_width() == 16; + case Type::EXTENSION: { + const auto& ext_type = internal::checked_cast(*type); + return ext_type.extension_name() == "arrow.uuid"; + } + default: + return false; + } } + +bool IsSupportedTypedValue(const std::shared_ptr& field); + +bool IsVariantFieldGroup(const std::shared_ptr& type) { + if (type->id() != Type::STRUCT) { + return false; + } + + std::shared_ptr value; + std::shared_ptr typed_value; + for (const auto& field : type->fields()) { + if (field->name() == "value") { + if (value != nullptr || !field->nullable() || + !is_binary_or_binary_view(field->type()->storage_id())) { + return false; + } + value = field; + } else if (field->name() == "typed_value") { + if (typed_value != nullptr || !IsSupportedTypedValue(field)) { + return false; + } + typed_value = field; + } else { + return false; + } + } + return value != nullptr || typed_value != nullptr; +} + +bool IsSupportedTypedValue(const std::shared_ptr& field) { + if (!field->nullable()) { + return false; + } + auto is_variant_field_group = [](const auto& field) { + return !field->nullable() && IsVariantFieldGroup(field->type()); + }; + + switch (field->type()->id()) { + case Type::STRUCT: + return field->type()->num_fields() > 0 && + std::ranges::all_of(field->type()->fields(), is_variant_field_group); + case Type::LIST: + case Type::LARGE_LIST: + case Type::LIST_VIEW: + case Type::LARGE_LIST_VIEW: + case Type::FIXED_SIZE_LIST: + return is_variant_field_group(field->type()->field(0)); + default: + return IsSupportedPrimitiveTypedValue(field->type()); + } +} + } // namespace bool VariantExtensionType::IsSupportedStorageType( const std::shared_ptr& storage_type) { - // For now we only supported unshredded variants. Unshredded variant storage - // type should be a struct with a binary metadata and binary value. - // - // GH-45948: In shredded variants, the binary value field can be replaced - // with one or more of the following: object, array, typed_value, and - // variant_value. - if (storage_type->id() == Type::STRUCT) { - if (storage_type->num_fields() == 2) { - // Ordering of metadata and value fields does not matter, as we will assign - // these to the VariantExtensionType's member shared_ptrs in the constructor. - // Here we just need to check that they are both present. - - const auto& field0 = storage_type->field(0); - const auto& field1 = storage_type->field(1); - - bool metadata_and_value_present = - (field0->name() == "metadata" && field1->name() == "value") || - (field1->name() == "metadata" && field0->name() == "value"); - - if (metadata_and_value_present) { - // Both metadata and value must be non-nullable binary types for unshredded - // variants. This will change in GH-46948, when we will require a Visitor - // to traverse the structure of the variant. - return IsBinaryField(field0) && IsBinaryField(field1) && !field0->nullable() && - !field1->nullable(); + if (storage_type->id() != Type::STRUCT) { + return false; + } + + std::shared_ptr metadata; + std::shared_ptr value; + std::shared_ptr typed_value; + + for (const auto& field : storage_type->fields()) { + if (field->name() == "metadata") { + if (metadata != nullptr || !is_binary_or_binary_view(field->type()->storage_id()) || + field->nullable()) { + return false; + } + metadata = field; + } else if (field->name() == "value") { + if (value != nullptr || !is_binary_or_binary_view(field->type()->storage_id())) { + return false; + } + value = field; + } else if (field->name() == "typed_value") { + if (typed_value != nullptr || !IsSupportedTypedValue(field)) { + return false; } + typed_value = field; + } else { + return false; } } - return false; + if (metadata == nullptr || (value == nullptr && typed_value == nullptr)) { + return false; + } + if (value == nullptr) { + return true; + } + if (typed_value == nullptr) { + return !value->nullable(); + } + return value->nullable(); } Result> VariantExtensionType::Make( @@ -113,9 +216,6 @@ Result> VariantExtensionType::Make( return std::make_shared(std::move(storage_type)); } -/// NOTE: this is still experimental. GH-45948 will add shredding support, at which point -/// we need to separate this into unshredded_variant and shredded_variant helper -/// functions. std::shared_ptr variant(std::shared_ptr storage_type) { return VariantExtensionType::Make(std::move(storage_type)).ValueOrDie(); } diff --git a/cpp/src/arrow/extension/parquet_variant.h b/cpp/src/arrow/extension/parquet_variant.h index be90923f14e6..c74a794f6a01 100644 --- a/cpp/src/arrow/extension/parquet_variant.h +++ b/cpp/src/arrow/extension/parquet_variant.h @@ -18,12 +18,16 @@ #pragma once #include +#include #include "arrow/extension_type.h" #include "arrow/util/visibility.h" namespace arrow::extension { +/// \brief The extension name for the Variant extension type. +inline constexpr std::string_view kVariantExtensionName = "arrow.parquet.variant"; + class ARROW_EXPORT VariantArray : public ExtensionArray { public: using ExtensionArray::ExtensionArray; @@ -43,13 +47,25 @@ class ARROW_EXPORT VariantArray : public ExtensionArray { /// To read more about variant encoding, see the variant encoding spec at /// https://github.com/apache/parquet-format/blob/master/VariantEncoding.md /// +/// Shredded variant representation: +/// optional group shredded_variant_name (VARIANT) { +/// required binary metadata; +/// optional binary value; +/// optional typed_value; +/// } +/// +/// The value and typed_value fields are optional in the schema, but at least one +/// must be present. +/// /// To read more about variant shredding, see the variant shredding spec at /// https://github.com/apache/parquet-format/blob/master/VariantShredding.md class ARROW_EXPORT VariantExtensionType : public ExtensionType { public: explicit VariantExtensionType(const std::shared_ptr& storage_type); - std::string extension_name() const override { return "arrow.parquet.variant"; } + std::string extension_name() const override { + return std::string(kVariantExtensionName); + } bool ExtensionEquals(const ExtensionType& other) const override; @@ -69,10 +85,12 @@ class ARROW_EXPORT VariantExtensionType : public ExtensionType { std::shared_ptr value() const { return value_; } + std::shared_ptr typed_value() const { return typed_value_; } + private: - // TODO GH-45948 added shredded_value std::shared_ptr metadata_; std::shared_ptr value_; + std::shared_ptr typed_value_; }; /// \brief Return a VariantExtensionType instance. diff --git a/cpp/src/parquet/CMakeLists.txt b/cpp/src/parquet/CMakeLists.txt index f8a42b5b96bf..11da3eccecab 100644 --- a/cpp/src/parquet/CMakeLists.txt +++ b/cpp/src/parquet/CMakeLists.txt @@ -180,7 +180,6 @@ set(PARQUET_SRCS level_comparison.cc level_conversion.cc metadata.cc - xxhasher.cc page_index.cc "${PARQUET_THRIFT_SOURCE_DIR}/parquet_types.cpp" platform.cc @@ -191,7 +190,11 @@ set(PARQUET_SRCS statistics.cc stream_reader.cc stream_writer.cc - types.cc) + types.cc + variant/builder.cc + variant/encoding.cc + variant/validate.cc + xxhasher.cc) if(ARROW_HAVE_RUNTIME_AVX2) # AVX2 is used as a proxy for BMI2. @@ -363,6 +366,7 @@ add_subdirectory(api) add_subdirectory(arrow) add_subdirectory(encryption) add_subdirectory(geospatial) +add_subdirectory(variant) arrow_install_all_headers("parquet") @@ -409,7 +413,8 @@ add_parquet_test(arrow-reader-writer-test SOURCES arrow/arrow_reader_writer_test.cc arrow/arrow_statistics_test.cc - arrow/variant_test.cc) + arrow/variant_test.cc + variant/test_util_internal.cc) add_parquet_test(arrow-index-test SOURCES arrow/index_test.cc) diff --git a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc index 8735aea731ce..917f2fe95a64 100644 --- a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc +++ b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc @@ -5999,11 +5999,8 @@ TEST(TestArrowReadWrite, AllNulls) { auto schema = ::arrow::schema({::arrow::field("all_nulls", ::arrow::int8())}); constexpr int64_t length = 3; - ASSERT_OK_AND_ASSIGN(auto null_bitmap, ::arrow::AllocateEmptyBitmap(length)); - auto array_data = ::arrow::ArrayData::Make( - ::arrow::int8(), length, {null_bitmap, /*values=*/nullptr}, /*null_count=*/length); - auto array = ::arrow::MakeArray(array_data); - auto record_batch = ::arrow::RecordBatch::Make(schema, length, {array}); + ASSERT_OK_AND_ASSIGN(auto array, MakeArrayOfNull(::arrow::int8(), length)); + auto record_batch = ::arrow::RecordBatch::Make(schema, length, {std::move(array)}); auto sink = CreateOutputStream(); ASSERT_OK_AND_ASSIGN(auto writer, parquet::arrow::FileWriter::Open( diff --git a/cpp/src/parquet/arrow/arrow_schema_test.cc b/cpp/src/parquet/arrow/arrow_schema_test.cc index 7a7b5a336939..f8196240648f 100644 --- a/cpp/src/parquet/arrow/arrow_schema_test.cc +++ b/cpp/src/parquet/arrow/arrow_schema_test.cc @@ -31,7 +31,7 @@ #include "parquet/test_util.h" #include "parquet/thrift_internal.h" -#include "arrow/array.h" +#include "arrow/array.h" // IWYU pragma: keep #include "arrow/extension/json.h" #include "arrow/extension/parquet_variant.h" #include "arrow/extension/uuid.h" @@ -86,6 +86,9 @@ static const std::vector kListCases = { [](std::shared_ptr<::arrow::Field> field) { return ::arrow::large_list(field); }}, }; +Status ArrowSchemaToParquetMetadata(std::shared_ptr<::arrow::Schema>& arrow_schema, + std::shared_ptr& metadata); + class TestConvertParquetSchema : public ::testing::Test { public: virtual void SetUp() {} @@ -111,6 +114,55 @@ class TestConvertParquetSchema : public ::testing::Test { return FromParquetSchema(&descr_, props, key_value_metadata, &result_schema_); } + void CheckParquetVariantSchema(const std::string& name, + std::vector parquet_children, + const std::shared_ptr<::arrow::DataType>& storage_type, + bool check_serialized_arrow_schema = false) { + SCOPED_TRACE(name); + auto variant = GroupNode::Make(name, Repetition::OPTIONAL, + std::move(parquet_children), LogicalType::Variant()); + std::vector parquet_fields = {variant}; + auto variant_extension = ::arrow::extension::variant(storage_type); + + { + auto arrow_schema = ::arrow::schema({::arrow::field(name, storage_type)}); + ASSERT_OK(ConvertSchema(parquet_fields)); + ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema, /*check_metadata=*/true)); + } + + for (bool register_extension : {true, false}) { + ::arrow::ExtensionTypeGuard guard(register_extension + ? ::arrow::DataTypeVector{variant_extension} + : ::arrow::DataTypeVector{}); + + ArrowReaderProperties props; + props.set_arrow_extensions_enabled(true); + auto arrow_schema = ::arrow::schema( + {::arrow::field(name, register_extension ? variant_extension : storage_type)}); + + ASSERT_OK(ConvertSchema(parquet_fields, /*metadata=*/nullptr, props)); + ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema, /*check_metadata=*/true)); + } + + if (check_serialized_arrow_schema) { + for (bool register_extension : {true, false}) { + ::arrow::ExtensionTypeGuard guard(register_extension + ? ::arrow::DataTypeVector{variant_extension} + : ::arrow::DataTypeVector{}); + + ArrowReaderProperties props; + props.set_arrow_extensions_enabled(false); + auto arrow_schema = ::arrow::schema({::arrow::field( + name, register_extension ? variant_extension : storage_type)}); + + std::shared_ptr metadata; + ASSERT_OK(ArrowSchemaToParquetMetadata(arrow_schema, metadata)); + ASSERT_OK(ConvertSchema(parquet_fields, metadata, props)); + ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema, /*check_metadata=*/true)); + } + } + } + protected: SchemaDescriptor descr_; std::shared_ptr<::arrow::Schema> result_schema_; @@ -992,73 +1044,43 @@ Status ArrowSchemaToParquetMetadata(std::shared_ptr<::arrow::Schema>& arrow_sche TEST_F(TestConvertParquetSchema, ParquetVariant) { // Unshredded variant - // optional group variant_col { + // optional group variant_unshredded { // required binary metadata; // required binary value; // } - // - // GH-45948: add shredded variants - std::vector parquet_fields; auto metadata = PrimitiveNode::Make("metadata", Repetition::REQUIRED, ParquetType::BYTE_ARRAY); auto value = PrimitiveNode::Make("value", Repetition::REQUIRED, ParquetType::BYTE_ARRAY); - auto variant = GroupNode::Make("variant_unshredded", Repetition::OPTIONAL, - {metadata, value}, LogicalType::Variant()); - parquet_fields.push_back(variant); - - // Arrow schema for unshredded variant struct. - auto arrow_metadata = ::arrow::field("metadata", ::arrow::binary(), /*nullable=*/false); - auto arrow_value = ::arrow::field("value", ::arrow::binary(), /*nullable=*/false); - auto arrow_variant = ::arrow::struct_({arrow_metadata, arrow_value}); - auto variant_extension = ::arrow::extension::variant(arrow_variant); - - { - // Parquet file does not contain Arrow schema. - // By default, field should be treated as a normal struct in Arrow. - auto arrow_schema = - ::arrow::schema({::arrow::field("variant_unshredded", arrow_variant)}); - ASSERT_OK(ConvertSchema(parquet_fields)); - ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema, /*check_metadata=*/true)); - } - - for (bool register_extension : {true, false}) { - ::arrow::ExtensionTypeGuard guard(register_extension - ? ::arrow::DataTypeVector{variant_extension} - : ::arrow::DataTypeVector{}); - - // Parquet file does not contain Arrow schema. - // If Arrow extensions are enabled, field should be interpreted as Parquet Variant - // extension type if registered. - ArrowReaderProperties props; - props.set_arrow_extensions_enabled(true); - - auto arrow_schema = ::arrow::schema({::arrow::field( - "variant_unshredded", register_extension ? variant_extension : arrow_variant)}); - - ASSERT_OK(ConvertSchema(parquet_fields, /*metadata=*/nullptr, props)); - ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema, /*check_metadata=*/true)); - } - - for (bool register_extension : {true, false}) { - ::arrow::ExtensionTypeGuard guard(register_extension - ? ::arrow::DataTypeVector{variant_extension} - : ::arrow::DataTypeVector{}); + auto storage_type = + ::arrow::struct_({::arrow::field("metadata", ::arrow::binary(), /*nullable=*/false), + ::arrow::field("value", ::arrow::binary(), /*nullable=*/false)}); - // Parquet file does contain Arrow schema. - // Field should be interpreted as Parquet Variant extension, if registered, - // even though extensions are not enabled. - ArrowReaderProperties props; - props.set_arrow_extensions_enabled(false); - - auto arrow_schema = ::arrow::schema({::arrow::field( - "variant_unshredded", register_extension ? variant_extension : arrow_variant)}); + ASSERT_NO_FATAL_FAILURE( + CheckParquetVariantSchema("variant_unshredded", {metadata, value}, storage_type, + /*check_serialized_arrow_schema=*/true)); +} - std::shared_ptr metadata; - ASSERT_OK(ArrowSchemaToParquetMetadata(arrow_schema, metadata)); - ASSERT_OK(ConvertSchema(parquet_fields, metadata, props)); - ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(arrow_schema, /*check_metadata=*/true)); - } +TEST_F(TestConvertParquetSchema, ParquetVariantShredded) { + // Shredded variant + // optional group variant_shredded { + // required binary metadata; + // optional binary value; + // optional int64 typed_value; + // } + auto metadata = + PrimitiveNode::Make("metadata", Repetition::REQUIRED, ParquetType::BYTE_ARRAY); + auto value = + PrimitiveNode::Make("value", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY); + auto typed_value = + PrimitiveNode::Make("typed_value", Repetition::OPTIONAL, ParquetType::INT64); + auto storage_type = + ::arrow::struct_({::arrow::field("metadata", ::arrow::binary(), false), + ::arrow::field("value", ::arrow::binary()), + ::arrow::field("typed_value", ::arrow::int64())}); + + ASSERT_NO_FATAL_FAILURE(CheckParquetVariantSchema( + "variant_shredded", {metadata, value, typed_value}, storage_type)); } TEST_F(TestConvertParquetSchema, ParquetSchemaArrowJsonExtension) { @@ -1534,6 +1556,222 @@ TEST_F(TestConvertArrowSchema, ParquetFlatPrimitivesAsDictionaries) { ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields)); } +TEST_F(TestConvertArrowSchema, ParquetVariantShredded) { + auto storage_type = + ::arrow::struct_({::arrow::field("typed_value", ::arrow::int64()), + ::arrow::field("value", ::arrow::binary()), + ::arrow::field("metadata", ::arrow::binary(), false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + + std::vector> arrow_fields = { + ::arrow::field("variant", variant_type)}; + + std::vector parquet_fields = {GroupNode::Make( + "variant", Repetition::OPTIONAL, + {PrimitiveNode::Make("metadata", Repetition::REQUIRED, ParquetType::BYTE_ARRAY), + PrimitiveNode::Make("value", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY), + PrimitiveNode::Make("typed_value", Repetition::OPTIONAL, ParquetType::INT64)}, + LogicalType::Variant())}; + + ASSERT_OK(ConvertSchema(arrow_fields)); + ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields)); +} + +TEST_F(TestConvertArrowSchema, ParquetVariantDictionary) { + auto storage_type = + ::arrow::struct_({::arrow::field("metadata", ::arrow::binary(), false), + ::arrow::field("value", ::arrow::binary(), false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + auto dictionary_variant = ::arrow::dictionary(::arrow::int32(), variant_type); + auto no_validation = + ArrowWriterProperties::Builder().set_variant_validation_enabled(false)->build(); + + ASSERT_RAISES( + NotImplemented, + ConvertSchema({::arrow::field("variant", dictionary_variant)}, no_validation)); +} + +TEST_F(TestConvertArrowSchema, ParquetVariantTypedSchema) { + auto ShreddedType = [](std::shared_ptr<::arrow::DataType> type) { + return ::arrow::struct_({::arrow::field("value", ::arrow::binary()), + ::arrow::field("typed_value", std::move(type))}); + }; + auto typed_value = ::arrow::struct_( + {::arrow::field("d4", ShreddedType(::arrow::decimal32(8, 2)), + /*nullable=*/false), + ::arrow::field("d8", ShreddedType(::arrow::decimal64(16, 4)), + /*nullable=*/false), + ::arrow::field("d64p8", ShreddedType(::arrow::decimal64(8, 2)), + /*nullable=*/false), + ::arrow::field("d128p8", ShreddedType(::arrow::decimal128(8, 2)), + /*nullable=*/false), + ::arrow::field("d128p16", ShreddedType(::arrow::decimal128(16, 4)), + /*nullable=*/false), + ::arrow::field("d128p32", ShreddedType(::arrow::decimal128(32, 8)), + /*nullable=*/false), + ::arrow::field("ts", ShreddedType(::arrow::timestamp(TimeUnit::NANO, "UTC")), + /*nullable=*/false), + ::arrow::field("time", ShreddedType(::arrow::time64(TimeUnit::MICRO)), + /*nullable=*/false), + ::arrow::field("id", ShreddedType(::arrow::fixed_size_binary(16)), + /*nullable=*/false), + ::arrow::field("uuid", ShreddedType(::arrow::extension::uuid()), + /*nullable=*/false)}); + auto storage_type = + ::arrow::struct_({::arrow::field("metadata", ::arrow::binary(), false), + ::arrow::field("value", ::arrow::binary()), + ::arrow::field("typed_value", typed_value)}); + auto variant_type = ::arrow::extension::variant(storage_type); + + std::vector> arrow_fields = { + ::arrow::field("variant", variant_type)}; + + auto ShreddedField = [](const std::string& name, NodePtr typed) { + return GroupNode::Make( + name, Repetition::REQUIRED, + {PrimitiveNode::Make("value", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY), + std::move(typed)}); + }; + std::vector parquet_fields = {GroupNode::Make( + "variant", Repetition::OPTIONAL, + {PrimitiveNode::Make("metadata", Repetition::REQUIRED, ParquetType::BYTE_ARRAY), + PrimitiveNode::Make("value", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY), + GroupNode::Make( + "typed_value", Repetition::OPTIONAL, + {ShreddedField("d4", PrimitiveNode::Make("typed_value", Repetition::OPTIONAL, + LogicalType::Decimal(8, 2), + ParquetType::INT32)), + ShreddedField("d8", PrimitiveNode::Make("typed_value", Repetition::OPTIONAL, + LogicalType::Decimal(16, 4), + ParquetType::INT64)), + ShreddedField("d64p8", PrimitiveNode::Make( + "typed_value", Repetition::OPTIONAL, + LogicalType::Decimal(8, 2), ParquetType::INT32)), + ShreddedField("d128p8", PrimitiveNode::Make( + "typed_value", Repetition::OPTIONAL, + LogicalType::Decimal(8, 2), ParquetType::INT32)), + ShreddedField( + "d128p16", + PrimitiveNode::Make("typed_value", Repetition::OPTIONAL, + LogicalType::Decimal(16, 4), ParquetType::INT64)), + ShreddedField("d128p32", + PrimitiveNode::Make("typed_value", Repetition::OPTIONAL, + LogicalType::Decimal(32, 8), + ParquetType::FIXED_LEN_BYTE_ARRAY, 14)), + ShreddedField("ts", + PrimitiveNode::Make( + "typed_value", Repetition::OPTIONAL, + LogicalType::Timestamp(true, LogicalType::TimeUnit::NANOS), + ParquetType::INT64)), + ShreddedField("time", + PrimitiveNode::Make( + "typed_value", Repetition::OPTIONAL, + LogicalType::Time(false, LogicalType::TimeUnit::MICROS), + ParquetType::INT64)), + ShreddedField("id", + PrimitiveNode::Make("typed_value", Repetition::OPTIONAL, + LogicalType::UUID(), + ParquetType::FIXED_LEN_BYTE_ARRAY, 16)), + ShreddedField("uuid", + PrimitiveNode::Make("typed_value", Repetition::OPTIONAL, + LogicalType::UUID(), + ParquetType::FIXED_LEN_BYTE_ARRAY, 16))})}, + LogicalType::Variant())}; + + ASSERT_OK(ConvertSchema(arrow_fields)); + ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields)); +} + +TEST_F(TestConvertArrowSchema, ParquetVariantShreddedObject) { + auto field_group = ::arrow::struct_({::arrow::field("value", ::arrow::binary()), + ::arrow::field("typed_value", ::arrow::int64())}); + auto typed_value = + ::arrow::struct_({::arrow::field("a", field_group, /*nullable=*/false), + ::arrow::field("b", field_group, /*nullable=*/false)}); + auto storage_type = + ::arrow::struct_({::arrow::field("metadata", ::arrow::binary(), false), + ::arrow::field("value", ::arrow::binary()), + ::arrow::field("typed_value", typed_value)}); + auto variant_type = ::arrow::extension::variant(storage_type); + + std::vector> arrow_fields = { + ::arrow::field("variant", variant_type)}; + + auto ShreddedField = [](const std::string& name) { + return GroupNode::Make( + name, Repetition::REQUIRED, + {PrimitiveNode::Make("value", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY), + PrimitiveNode::Make("typed_value", Repetition::OPTIONAL, ParquetType::INT64)}); + }; + std::vector parquet_fields = {GroupNode::Make( + "variant", Repetition::OPTIONAL, + {PrimitiveNode::Make("metadata", Repetition::REQUIRED, ParquetType::BYTE_ARRAY), + PrimitiveNode::Make("value", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY), + GroupNode::Make("typed_value", Repetition::OPTIONAL, + {ShreddedField("a"), ShreddedField("b")})}, + LogicalType::Variant())}; + + ASSERT_OK(ConvertSchema(arrow_fields)); + ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields)); +} + +TEST_F(TestConvertArrowSchema, ParquetVariantEmptyObject) { + auto storage_type = + ::arrow::struct_({::arrow::field("metadata", ::arrow::binary(), false), + ::arrow::field("value", ::arrow::binary()), + ::arrow::field("typed_value", ::arrow::struct_({}))}); + auto variant_type = + std::make_shared<::arrow::extension::VariantExtensionType>(storage_type); + + ASSERT_RAISES(Invalid, + ConvertSchema({::arrow::field("variant", std::move(variant_type))})); +} + +TEST_F(TestConvertArrowSchema, ParquetVariantShreddedList) { + auto field_group = ::arrow::struct_({::arrow::field("value", ::arrow::binary()), + ::arrow::field("typed_value", ::arrow::int64())}); + + auto element = GroupNode::Make( + "element", Repetition::REQUIRED, + {PrimitiveNode::Make("value", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY), + PrimitiveNode::Make("typed_value", Repetition::OPTIONAL, ParquetType::INT64)}); + auto list = GroupNode::Make("list", Repetition::REPEATED, {element}); + const auto expected = GroupNode::Make( + "variant", Repetition::OPTIONAL, + {PrimitiveNode::Make("metadata", Repetition::REQUIRED, ParquetType::BYTE_ARRAY), + PrimitiveNode::Make("value", Repetition::OPTIONAL, ParquetType::BYTE_ARRAY), + GroupNode::Make("typed_value", Repetition::OPTIONAL, {list}, ConvertedType::LIST)}, + LogicalType::Variant()); + + const std::vector< + std::function(std::shared_ptr<::arrow::Field>)>> + make_lists = { + [](std::shared_ptr<::arrow::Field> field) { return ::arrow::list(field); }, + [](std::shared_ptr<::arrow::Field> field) { + return ::arrow::large_list(field); + }, + [](std::shared_ptr<::arrow::Field> field) { return ::arrow::list_view(field); }, + [](std::shared_ptr<::arrow::Field> field) { + return ::arrow::large_list_view(field); + }}; + for (const auto& make_list : make_lists) { + auto typed_value = + make_list(::arrow::field("element", field_group, /*nullable=*/false)); + auto storage_type = + ::arrow::struct_({::arrow::field("metadata", ::arrow::binary(), false), + ::arrow::field("value", ::arrow::binary()), + ::arrow::field("typed_value", typed_value)}); + auto variant_type = ::arrow::extension::variant(storage_type); + + std::vector> arrow_fields = { + ::arrow::field("variant", variant_type)}; + std::vector parquet_fields = {expected}; + + ASSERT_OK(ConvertSchema(arrow_fields)); + ASSERT_NO_FATAL_FAILURE(CheckFlatSchema(parquet_fields)); + } +} + TEST_F(TestConvertArrowSchema, ParquetGeoArrowCrsLonLat) { // All the Arrow Schemas below should convert to the type defaults for GEOMETRY // and GEOGRAPHY when GeoArrow extension types are registered and the appropriate diff --git a/cpp/src/parquet/arrow/schema.cc b/cpp/src/parquet/arrow/schema.cc index bc4de6c39b5e..9d7d71d64aa6 100644 --- a/cpp/src/parquet/arrow/schema.cc +++ b/cpp/src/parquet/arrow/schema.cc @@ -50,6 +50,8 @@ using arrow::Field; using arrow::FieldVector; using arrow::KeyValueMetadata; using arrow::Status; +using arrow::extension::kVariantExtensionName; +using arrow::extension::VariantExtensionType; using arrow::internal::checked_cast; using arrow::internal::ToChars; @@ -112,6 +114,239 @@ Status ListToNode(const std::shared_ptr<::arrow::BaseListType>& type, return Status::OK(); } +static constexpr char FIELD_ID_KEY[] = "PARQUET:field_id"; + +int FieldIdFromMetadata( + const std::shared_ptr& metadata) { + if (!metadata) { + return -1; + } + int key = metadata->FindKey(FIELD_ID_KEY); + if (key < 0) { + return -1; + } + const std::string& field_id_str = metadata->value(key); + int field_id; + if (::arrow::internal::ParseValue<::arrow::Int32Type>( + field_id_str.c_str(), field_id_str.length(), &field_id)) { + if (field_id < 0) { + // Thrift should convert any negative value to null but normalize to -1 here in + // case we later check this in logic. + return -1; + } + return field_id; + } else { + return -1; + } +} + +Status VariantTypedValueToNode(const std::shared_ptr& field, + const WriterProperties& properties, + const ArrowWriterProperties& arrow_properties, + NodePtr* out); + +Status VariantFieldGroupToNode(const std::shared_ptr& field, + const WriterProperties& properties, + const ArrowWriterProperties& arrow_properties, + NodePtr* out) { + const int field_id = FieldIdFromMetadata(field->metadata()); + const auto& type = field->type(); + if (type->id() != ArrowTypeId::STRUCT) { + return Status::Invalid("Invalid Variant shredded field group: ", type->ToString()); + } + + std::vector children; + children.reserve(type->num_fields()); + for (const auto& child : type->fields()) { + NodePtr node; + if (child->name() == "value") { + RETURN_NOT_OK( + FieldToNode(child->name(), child, properties, arrow_properties, &node)); + } else if (child->name() == "typed_value") { + RETURN_NOT_OK(VariantTypedValueToNode(child, properties, arrow_properties, &node)); + } else { + return Status::Invalid("Invalid Variant shredded field group child: ", + child->name()); + } + children.emplace_back(std::move(node)); + } + + *out = GroupNode::Make(field->name(), RepetitionFromNullable(field->nullable()), + std::move(children), nullptr, field_id); + return Status::OK(); +} + +Status VariantObjectToNode(const std::shared_ptr& field, + const WriterProperties& properties, + const ArrowWriterProperties& arrow_properties, NodePtr* out) { + const int field_id = FieldIdFromMetadata(field->metadata()); + const auto& type = field->type(); + if (type->num_fields() == 0) { + return Status::Invalid("Invalid Variant object typed_value: expected at least one ", + "shredded field"); + } + + std::vector children(type->num_fields()); + for (int i = 0; i < type->num_fields(); ++i) { + RETURN_NOT_OK(VariantFieldGroupToNode(type->field(i), properties, arrow_properties, + &children[i])); + } + + *out = GroupNode::Make(field->name(), RepetitionFromNullable(field->nullable()), + std::move(children), nullptr, field_id); + return Status::OK(); +} + +Status VariantListToNode(const std::shared_ptr& field, + const WriterProperties& properties, + const ArrowWriterProperties& arrow_properties, NodePtr* out) { + const int field_id = FieldIdFromMetadata(field->metadata()); + const auto list_type = std::static_pointer_cast<::arrow::BaseListType>(field->type()); + + NodePtr element; + RETURN_NOT_OK(VariantFieldGroupToNode(list_type->value_field()->WithName("element"), + properties, arrow_properties, &element)); + + NodePtr list = GroupNode::Make("list", Repetition::REPEATED, {element}); + *out = GroupNode::Make(field->name(), RepetitionFromNullable(field->nullable()), {list}, + LogicalType::List(), field_id); + return Status::OK(); +} + +Status VariantPrimitiveToNode(const std::shared_ptr& field, NodePtr* out) { + std::shared_ptr logical_type = LogicalType::None(); + ParquetType::type type = ParquetType::UNDEFINED; + const int field_id = FieldIdFromMetadata(field->metadata()); + + int length = -1; + int precision = -1; + int scale = -1; + + switch (field->type()->id()) { + case ArrowTypeId::BOOL: + type = ParquetType::BOOLEAN; + break; + case ArrowTypeId::INT8: + type = ParquetType::INT32; + logical_type = LogicalType::Int(8, true); + break; + case ArrowTypeId::INT16: + type = ParquetType::INT32; + logical_type = LogicalType::Int(16, true); + break; + case ArrowTypeId::INT32: + type = ParquetType::INT32; + break; + case ArrowTypeId::INT64: + type = ParquetType::INT64; + break; + case ArrowTypeId::FLOAT: + type = ParquetType::FLOAT; + break; + case ArrowTypeId::DOUBLE: + type = ParquetType::DOUBLE; + break; + case ArrowTypeId::DECIMAL32: + case ArrowTypeId::DECIMAL64: + case ArrowTypeId::DECIMAL128: { + const auto& decimal = checked_cast(*field->type()); + precision = decimal.precision(); + scale = decimal.scale(); + if (precision <= 9) { + type = ParquetType::INT32; + } else if (precision <= 18) { + type = ParquetType::INT64; + } else { + type = ParquetType::FIXED_LEN_BYTE_ARRAY; + length = ::arrow::DecimalType::DecimalSize(precision); + } + PARQUET_CATCH_NOT_OK(logical_type = LogicalType::Decimal(precision, scale)); + } break; + case ArrowTypeId::DATE32: + type = ParquetType::INT32; + logical_type = LogicalType::Date(); + break; + case ArrowTypeId::TIME64: + type = ParquetType::INT64; + logical_type = + LogicalType::Time(/*is_adjusted_to_utc=*/false, LogicalType::TimeUnit::MICROS); + break; + case ArrowTypeId::TIMESTAMP: { + type = ParquetType::INT64; + const auto& timestamp = checked_cast(*field->type()); + const bool utc = !timestamp.timezone().empty(); + switch (timestamp.unit()) { + case ::arrow::TimeUnit::MICRO: + logical_type = LogicalType::Timestamp(utc, LogicalType::TimeUnit::MICROS, + /*is_from_converted_type=*/false, + /*force_set_converted_type=*/true); + break; + case ::arrow::TimeUnit::NANO: + logical_type = LogicalType::Timestamp(utc, LogicalType::TimeUnit::NANOS, + /*is_from_converted_type=*/false, + /*force_set_converted_type=*/false); + break; + default: + return Status::Invalid("Invalid Variant typed_value timestamp unit: ", + field->type()->ToString()); + } + } break; + case ArrowTypeId::LARGE_BINARY: + case ArrowTypeId::BINARY: + case ArrowTypeId::BINARY_VIEW: + type = ParquetType::BYTE_ARRAY; + break; + case ArrowTypeId::LARGE_STRING: + case ArrowTypeId::STRING: + case ArrowTypeId::STRING_VIEW: + type = ParquetType::BYTE_ARRAY; + logical_type = LogicalType::String(); + break; + case ArrowTypeId::FIXED_SIZE_BINARY: + type = ParquetType::FIXED_LEN_BYTE_ARRAY; + logical_type = LogicalType::UUID(); + length = 16; + break; + case ArrowTypeId::EXTENSION: { + const auto& ext_type = checked_cast(*field->type()); + if (ext_type.extension_name() == "arrow.uuid") { + type = ParquetType::FIXED_LEN_BYTE_ARRAY; + logical_type = LogicalType::UUID(); + length = 16; + break; + } + return Status::Invalid("Invalid Variant typed_value type: ", + field->type()->ToString()); + } + default: + return Status::Invalid("Invalid Variant typed_value type: ", + field->type()->ToString()); + } + + PARQUET_CATCH_NOT_OK(*out = PrimitiveNode::Make( + field->name(), RepetitionFromNullable(field->nullable()), + std::move(logical_type), type, length, field_id)); + return Status::OK(); +} + +Status VariantTypedValueToNode(const std::shared_ptr& field, + const WriterProperties& properties, + const ArrowWriterProperties& arrow_properties, + NodePtr* out) { + switch (field->type()->id()) { + case ArrowTypeId::STRUCT: + return VariantObjectToNode(field, properties, arrow_properties, out); + case ArrowTypeId::FIXED_SIZE_LIST: + case ArrowTypeId::LARGE_LIST: + case ArrowTypeId::LIST: + case ArrowTypeId::LIST_VIEW: + case ArrowTypeId::LARGE_LIST_VIEW: + return VariantListToNode(field, properties, arrow_properties, out); + default: + return VariantPrimitiveToNode(field, out); + } +} + Status MapToNode(const std::shared_ptr<::arrow::MapType>& type, const std::string& name, bool nullable, int field_id, const WriterProperties& properties, const ArrowWriterProperties& arrow_properties, NodePtr* out) { @@ -131,21 +366,33 @@ Status MapToNode(const std::shared_ptr<::arrow::MapType>& type, const std::strin return Status::OK(); } -Status VariantToNode( - const std::shared_ptr<::arrow::extension::VariantExtensionType>& type, - const std::string& name, bool nullable, int field_id, - const WriterProperties& properties, const ArrowWriterProperties& arrow_properties, - NodePtr* out) { - NodePtr metadata_node; - RETURN_NOT_OK(FieldToNode("metadata", type->metadata(), properties, arrow_properties, - &metadata_node)); +Status VariantToNode(const std::shared_ptr& type, + const std::string& name, bool nullable, int field_id, + const WriterProperties& properties, + const ArrowWriterProperties& arrow_properties, NodePtr* out) { + std::vector children; + children.reserve(type->storage_type()->num_fields()); - NodePtr value_node; - RETURN_NOT_OK( - FieldToNode("value", type->value(), properties, arrow_properties, &value_node)); + auto AddChild = [&](const std::shared_ptr<::arrow::Field>& field) { + if (field == nullptr) { + return Status::OK(); + } + NodePtr child; + if (field->name() == "typed_value") { + RETURN_NOT_OK(VariantTypedValueToNode(field, properties, arrow_properties, &child)); + } else { + RETURN_NOT_OK( + FieldToNode(field->name(), field, properties, arrow_properties, &child)); + } + children.emplace_back(std::move(child)); + return Status::OK(); + }; + + RETURN_NOT_OK(AddChild(type->metadata())); + RETURN_NOT_OK(AddChild(type->value())); + RETURN_NOT_OK(AddChild(type->typed_value())); - *out = GroupNode::Make(name, RepetitionFromNullable(nullable), - {std::move(metadata_node), std::move(value_node)}, + *out = GroupNode::Make(name, RepetitionFromNullable(nullable), std::move(children), LogicalType::Variant(), field_id); return Status::OK(); @@ -280,8 +527,6 @@ static Status GetTimestampMetadata(const ::arrow::TimestampType& type, return Status::OK(); } -static constexpr char FIELD_ID_KEY[] = "PARQUET:field_id"; - std::shared_ptr<::arrow::KeyValueMetadata> FieldIdMetadata(int field_id) { if (field_id >= 0) { return ::arrow::key_value_metadata({FIELD_ID_KEY}, {ToChars(field_id)}); @@ -290,30 +535,6 @@ std::shared_ptr<::arrow::KeyValueMetadata> FieldIdMetadata(int field_id) { } } -int FieldIdFromMetadata( - const std::shared_ptr& metadata) { - if (!metadata) { - return -1; - } - int key = metadata->FindKey(FIELD_ID_KEY); - if (key < 0) { - return -1; - } - std::string field_id_str = metadata->value(key); - int field_id; - if (::arrow::internal::ParseValue<::arrow::Int32Type>( - field_id_str.c_str(), field_id_str.length(), &field_id)) { - if (field_id < 0) { - // Thrift should convert any negative value to null but normalize to -1 here in - // case we later check this in logic. - return -1; - } - return field_id; - } else { - return -1; - } -} - Status FieldToNode(const std::string& name, const std::shared_ptr& field, const WriterProperties& properties, const ArrowWriterProperties& arrow_properties, NodePtr* out) { @@ -466,8 +687,15 @@ Status FieldToNode(const std::string& name, const std::shared_ptr& field, case ArrowTypeId::DICTIONARY: { // Parquet has no Dictionary type, dictionary-encoded is handled on // the encoding, not the schema level. - const ::arrow::DictionaryType& dict_type = - static_cast(*field->type()); + const auto& dict_type = static_cast(*field->type()); + if (dict_type.value_type()->id() == ArrowTypeId::EXTENSION) { + const auto& ext_type = + checked_cast(*dict_type.value_type()); + if (ext_type.extension_name() == kVariantExtensionName) { + return Status::NotImplemented( + "Dictionary-encoded Variant arrays are not supported"); + } + } std::shared_ptr<::arrow::Field> unpacked_field = ::arrow::field( name, dict_type.value_type(), field->nullable(), field->metadata()); return FieldToNode(name, unpacked_field, properties, arrow_properties, out); @@ -490,10 +718,8 @@ Status FieldToNode(const std::string& name, const std::shared_ptr& field, ARROW_ASSIGN_OR_RAISE(logical_type, LogicalTypeFromGeoArrowMetadata(ext_type->Serialize())); break; - } else if (ext_type->extension_name() == std::string("arrow.parquet.variant")) { - auto variant_type = - std::static_pointer_cast<::arrow::extension::VariantExtensionType>( - field->type()); + } else if (ext_type->extension_name() == kVariantExtensionName) { + auto variant_type = std::static_pointer_cast(field->type()); return VariantToNode(variant_type, name, field->nullable(), field_id, properties, arrow_properties, out); @@ -543,6 +769,17 @@ bool IsDictionaryReadSupported(const ArrowType& type) { return type.id() == ::arrow::Type::BINARY || type.id() == ::arrow::Type::STRING; } +bool IsInsideVariantLogicalType(const Node& node) { + const Node* current = node.parent(); + while (current != nullptr) { + if (current->logical_type()->is_variant()) { + return true; + } + current = current->parent(); + } + return false; +} + // ---------------------------------------------------------------------- // Schema logic @@ -552,7 +789,8 @@ ::arrow::Result> GetTypeForNode( ARROW_ASSIGN_OR_RAISE(std::shared_ptr storage_type, GetArrowType(primitive_node, ctx->properties, ctx->metadata)); if (ctx->properties.read_dictionary(column_index) && - IsDictionaryReadSupported(*storage_type)) { + IsDictionaryReadSupported(*storage_type) && + !IsInsideVariantLogicalType(primitive_node)) { return ::arrow::dictionary(::arrow::int32(), storage_type); } return storage_type; @@ -604,7 +842,7 @@ Status GroupToStruct(const GroupNode& node, LevelInfo current_levels, auto struct_type = ::arrow::struct_(arrow_fields); if (ctx->properties.get_arrow_extensions_enabled() && node.logical_type()->is_variant()) { - auto extension_type = ::arrow::GetExtensionType("arrow.parquet.variant"); + auto extension_type = ::arrow::GetExtensionType(std::string(kVariantExtensionName)); if (extension_type) { ARROW_ASSIGN_OR_RAISE( struct_type, @@ -1194,10 +1432,10 @@ Result ApplyOriginalMetadata(const Field& origin_field, SchemaField* infer extension_supports_inferred_storage = arrow_extension_inferred || ::arrow::extension::UuidType::IsSupportedStorageType(inferred_type); - } else if (origin_extension_name == "arrow.parquet.variant") { + } else if (origin_extension_name == kVariantExtensionName) { extension_supports_inferred_storage = arrow_extension_inferred || - ::arrow::extension::VariantExtensionType::IsSupportedStorageType(inferred_type); + VariantExtensionType::IsSupportedStorageType(inferred_type); } else { extension_supports_inferred_storage = origin_extension_type.storage_type()->Equals(*inferred_type); diff --git a/cpp/src/parquet/arrow/variant_test.cc b/cpp/src/parquet/arrow/variant_test.cc index 04f46d2e444d..62e68f3e3ad5 100644 --- a/cpp/src/parquet/arrow/variant_test.cc +++ b/cpp/src/parquet/arrow/variant_test.cc @@ -15,58 +15,348 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/array/validate.h" +#include "arrow/array.h" // IWYU pragma: keep #include "arrow/extension/parquet_variant.h" -#include "arrow/ipc/test_common.h" +#include "arrow/extension/uuid.h" +#include "arrow/io/memory.h" #include "arrow/record_batch.h" +#include "arrow/table.h" +#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" +#include "parquet/arrow/reader.h" +#include "parquet/arrow/writer.h" #include "parquet/exception.h" +#include "parquet/variant/builder.h" +#include "parquet/variant/test_util_internal.h" + +#include +#include +#include namespace parquet::arrow { using ::arrow::binary; +using ::arrow::field; using ::arrow::struct_; +using variant::internal::BinaryArrayFromValues; +using variant::internal::BinaryViewArrayFromValues; +using variant::internal::EmptyVariantMetadata; +using variant::internal::Int32ArrayFromValues; +using variant::internal::Int64ArrayFromValues; +using variant::internal::Int8Variant; +using variant::internal::StringArrayFromValues; +using variant::internal::UuidArrayFromValues; +using variant::internal::VariantTable; +using variant::internal::WriteVariantRecordBatch; +using variant::internal::WriteVariantTable; + +TEST(TestVariantExtensionType, WriterValidatesUnshreddedVariantBytes) { + ASSERT_OK_AND_ASSIGN(auto encoded, Int8Variant(42)); + + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + auto metadata_array = BinaryArrayFromValues({std::string_view{*encoded.metadata}}); + auto value_array = BinaryArrayFromValues({std::string_view{*encoded.value}}); + auto table = + VariantTable(variant_type, {metadata_array, value_array}, storage_type->fields()); + ASSERT_OK(WriteVariantTable(table)); + + auto invalid_value = BinaryArrayFromValues({std::string_view("\xff", 1)}); + auto invalid_table = + VariantTable(variant_type, {metadata_array, invalid_value}, storage_type->fields()); + ASSERT_RAISES(Invalid, WriteVariantTable(invalid_table)); + + auto no_validation = + ArrowWriterProperties::Builder().set_variant_validation_enabled(false)->build(); + ASSERT_OK(WriteVariantTable(invalid_table, default_writer_properties(), no_validation)); +} + +TEST(TestVariantExtensionType, WriteRecordBatchValidatesVariantBytes) { + ASSERT_OK_AND_ASSIGN(auto metadata, EmptyVariantMetadata()); + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + + auto metadata_array = BinaryArrayFromValues({std::string_view{*metadata}}); + auto invalid_value = BinaryArrayFromValues({std::string_view("\xff", 1)}); + auto table = + VariantTable(variant_type, {metadata_array, invalid_value}, storage_type->fields()); + + ASSERT_RAISES(Invalid, WriteVariantRecordBatch(table)); + + auto no_validation = + ArrowWriterProperties::Builder().set_variant_validation_enabled(false)->build(); + ASSERT_OK(WriteVariantRecordBatch(table, no_validation)); +} + +TEST(TestVariantExtensionType, WriteRecordBatchValidatesBatch) { + ASSERT_OK_AND_ASSIGN(auto encoded, Int8Variant(42)); + + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + auto metadata_array = StringArrayFromValues({"not binary"}); + auto value_array = BinaryArrayFromValues({std::string_view{*encoded.value}}); + ASSERT_OK_AND_ASSIGN( + auto storage, + ::arrow::StructArray::Make({metadata_array, value_array}, storage_type->fields())); + auto variant_array = ::arrow::ExtensionType::WrapArray(variant_type, storage); + auto schema = ::arrow::schema({field("variant", variant_type)}); + auto batch = ::arrow::RecordBatch::Make(schema, 1, {variant_array}); + + ASSERT_OK_AND_ASSIGN(auto sink, ::arrow::io::BufferOutputStream::Create( + 1024, ::arrow::default_memory_pool())); + ASSERT_OK_AND_ASSIGN( + auto writer, + FileWriter::Open(*schema, ::arrow::default_memory_pool(), sink, + default_writer_properties(), default_arrow_writer_properties())); + + const auto status = writer->WriteRecordBatch(*batch); + ASSERT_TRUE(status.IsInvalid()) << status; + ASSERT_NE(std::string::npos, status.ToStringWithoutContextLines().find( + "Struct child array #0 does not match type field")) + << status; +} + +TEST(TestVariantExtensionType, WriterValidatesBinaryViewVariantBytes) { + ASSERT_OK_AND_ASSIGN(auto encoded, Int8Variant(42)); + + auto storage_type = + struct_({field("metadata", ::arrow::binary_view(), /*nullable=*/false), + field("value", ::arrow::binary_view(), /*nullable=*/false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + auto metadata_array = BinaryViewArrayFromValues({std::string_view{*encoded.metadata}}); + auto value_array = BinaryViewArrayFromValues({std::string_view{*encoded.value}}); + auto table = + VariantTable(variant_type, {metadata_array, value_array}, storage_type->fields()); + ASSERT_OK(WriteVariantTable(table)); +} + +TEST(TestVariantExtensionType, WriterSkipsNullParents) { + ASSERT_OK_AND_ASSIGN(auto metadata, EmptyVariantMetadata()); + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + auto metadata_array = BinaryArrayFromValues({std::string_view{*metadata}}); + auto invalid_value = BinaryArrayFromValues({std::string_view("\xff", 1)}); + PARQUET_ASSIGN_OR_THROW(auto storage, + ::arrow::StructArray::Make({metadata_array, invalid_value}, + storage_type->fields())); + auto variant_array = ::arrow::ExtensionType::WrapArray(variant_type, storage); + + auto parent_type = struct_({field("child", variant_type)}); + PARQUET_ASSIGN_OR_THROW( + auto parent_array, + ::arrow::StructArray::Make({variant_array}, parent_type->fields(), + ::arrow::Buffer::FromString(std::string("\0", 1)))); + auto parent_table = ::arrow::Table::Make( + ::arrow::schema({field("parent", parent_type)}), {parent_array}); + ASSERT_OK(WriteVariantTable(parent_table)); + + auto map_type = ::arrow::map(::arrow::utf8(), field("item", variant_type)); + ASSERT_OK_AND_ASSIGN( + auto map_array, + ::arrow::MapArray::FromArrays(map_type, Int32ArrayFromValues({0, 1}), + StringArrayFromValues({"hidden"}), variant_array, + ::arrow::default_memory_pool(), + ::arrow::Buffer::FromString(std::string("\0", 1)))); + auto map_table = ::arrow::Table::Make(::arrow::schema({field("variant_map", map_type)}), + {map_array}); + ASSERT_OK(WriteVariantTable(map_table)); +} + +TEST(TestVariantExtensionType, WriterValidatesShreddedPrimitiveConflicts) { + ASSERT_OK_AND_ASSIGN(auto encoded, Int8Variant(42)); + + auto storage_type = + struct_({field("metadata", binary(), /*nullable=*/false), field("value", binary()), + field("typed_value", ::arrow::int64())}); + auto variant_type = ::arrow::extension::variant(storage_type); + + auto metadata_array = BinaryArrayFromValues( + {std::string_view{*encoded.metadata}, std::string_view{*encoded.metadata}}); + auto value_array = + BinaryArrayFromValues({std::nullopt, std::string_view{*encoded.value}}); + auto typed_array = Int64ArrayFromValues({34, 100}); + auto table = VariantTable(variant_type, {metadata_array, value_array, typed_array}, + storage_type->fields()); + ASSERT_RAISES(Invalid, WriteVariantTable(table)); + + auto empty_value = BinaryArrayFromValues({std::string_view{}, std::nullopt}); + auto empty_table = VariantTable( + variant_type, {metadata_array, empty_value, typed_array}, storage_type->fields()); + ASSERT_RAISES(Invalid, WriteVariantTable(empty_table)); + + auto valid_values = + BinaryArrayFromValues({std::nullopt, std::string_view{*encoded.value}}); + auto valid_typed = Int64ArrayFromValues({34, std::nullopt}); + auto valid_table = VariantTable( + variant_type, {metadata_array, valid_values, valid_typed}, storage_type->fields()); + ASSERT_OK(WriteVariantTable(valid_table)); +} + +TEST(TestVariantExtensionType, WriterValidatesShreddedWithoutValue) { + ASSERT_OK_AND_ASSIGN(auto metadata, EmptyVariantMetadata()); + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("typed_value", ::arrow::int64())}); + ASSERT_OK_AND_ASSIGN(auto variant_type, + ::arrow::extension::VariantExtensionType::Make(storage_type)); + + auto metadata_array = BinaryArrayFromValues({std::string_view{*metadata}}); + auto typed_array = Int64ArrayFromValues({34}); + auto table = + VariantTable(variant_type, {metadata_array, typed_array}, storage_type->fields()); + ASSERT_OK(WriteVariantTable(table)); +} + +TEST(TestVariantExtensionType, ReadsDictionaryEncodedMetadata) { + ASSERT_OK_AND_ASSIGN(auto encoded, Int8Variant(42)); + + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + auto metadata_array = BinaryArrayFromValues( + {std::string_view{*encoded.metadata}, std::string_view{*encoded.metadata}}); + auto value_array = BinaryArrayFromValues( + {std::string_view{*encoded.value}, std::string_view{*encoded.value}}); + auto table = + VariantTable(variant_type, {metadata_array, value_array}, storage_type->fields()); + + ASSERT_OK_AND_ASSIGN( + auto buffer, + WriteVariantTable(table, WriterProperties::Builder().enable_dictionary()->build())); + + auto buffer_reader = std::make_shared<::arrow::io::BufferReader>(buffer); + ArrowReaderProperties reader_properties; + reader_properties.set_arrow_extensions_enabled(true); + ::arrow::ExtensionTypeGuard guard(::arrow::extension::variant(storage_type)); + FileReaderBuilder builder; + ASSERT_OK(builder.Open(buffer_reader)); + builder.properties(reader_properties); + ASSERT_OK_AND_ASSIGN(auto reader, builder.Build()); + + ASSERT_TRUE(reader->parquet_reader() + ->metadata() + ->RowGroup(0) + ->ColumnChunk(0) + ->has_dictionary_page()); + + ASSERT_OK_AND_ASSIGN(auto read_table, reader->ReadTable()); + ASSERT_OK(read_table->ValidateFull()); + + auto field = read_table->schema()->GetFieldByName("variant"); + ASSERT_NE(nullptr, field); + auto read_variant_type = + std::dynamic_pointer_cast<::arrow::extension::VariantExtensionType>(field->type()); + ASSERT_NE(nullptr, read_variant_type); + ASSERT_EQ(::arrow::Type::BINARY, read_variant_type->metadata()->type()->id()); + + ASSERT_NE(nullptr, read_table->GetColumnByName(field->name())); +} + +TEST(TestVariantExtensionType, ReadsWithDictionaryOption) { + ASSERT_OK_AND_ASSIGN(auto encoded, Int8Variant(42)); + + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + auto metadata_array = BinaryArrayFromValues( + {std::string_view{*encoded.metadata}, std::string_view{*encoded.metadata}}); + auto value_array = BinaryArrayFromValues( + {std::string_view{*encoded.value}, std::string_view{*encoded.value}}); + auto table = + VariantTable(variant_type, {metadata_array, value_array}, storage_type->fields()); + + ASSERT_OK_AND_ASSIGN(auto buffer, WriteVariantTable(table)); + + auto buffer_reader = std::make_shared<::arrow::io::BufferReader>(buffer); + ArrowReaderProperties reader_properties; + reader_properties.set_arrow_extensions_enabled(true); + reader_properties.set_read_dictionary(0, true); + reader_properties.set_read_dictionary(1, true); + ::arrow::ExtensionTypeGuard guard(::arrow::extension::variant(storage_type)); + FileReaderBuilder builder; + ASSERT_OK(builder.Open(buffer_reader)); + builder.properties(reader_properties); + ASSERT_OK_AND_ASSIGN(auto reader, builder.Build()); + + ASSERT_OK_AND_ASSIGN(auto read_table, reader->ReadTable()); + ASSERT_OK(read_table->ValidateFull()); + + auto field = read_table->schema()->GetFieldByName("variant"); + ASSERT_NE(nullptr, field); + auto read_variant_type = + std::dynamic_pointer_cast<::arrow::extension::VariantExtensionType>(field->type()); + ASSERT_NE(nullptr, read_variant_type); + ASSERT_EQ(::arrow::Type::BINARY, read_variant_type->metadata()->type()->id()); + ASSERT_EQ(::arrow::Type::BINARY, read_variant_type->value()->type()->id()); + + ASSERT_NE(nullptr, read_table->GetColumnByName(field->name())); +} + +TEST(TestVariantExtensionType, WriterWritesUuid) { + ASSERT_OK_AND_ASSIGN(auto metadata, EmptyVariantMetadata()); + auto storage_type = + struct_({field("metadata", binary(), /*nullable=*/false), field("value", binary()), + field("typed_value", ::arrow::extension::uuid())}); + auto variant_type = ::arrow::extension::variant(storage_type); + + auto metadata_array = BinaryArrayFromValues({std::string_view{*metadata}}); + auto value_array = BinaryArrayFromValues({std::nullopt}); + auto typed_array = UuidArrayFromValues({std::string_view("0123456789abcdef", 16)}); + auto table = VariantTable(variant_type, {metadata_array, value_array, typed_array}, + storage_type->fields()); + ASSERT_OK(WriteVariantTable(table)); +} + +TEST(TestVariantExtensionType, WriterValidatesShreddedObjectConflicts) { + variant::VariantBuilder object_builder; + ASSERT_OK_AND_ASSIGN(auto object, object_builder.StartObject()); + ASSERT_OK(object.AppendShortString("event_type", "login")); + ASSERT_OK(object.Finish()); + ASSERT_OK_AND_ASSIGN(auto encoded, object_builder.Finish()); + + auto field_group_type = + struct_({field("value", binary()), field("typed_value", ::arrow::utf8())}); + auto typed_value_type = + struct_({field("event_type", field_group_type, /*nullable=*/false)}); + auto storage_type = + struct_({field("metadata", binary(), /*nullable=*/false), field("value", binary()), + field("typed_value", typed_value_type)}); + auto variant_type = ::arrow::extension::variant(storage_type); + + auto metadata_array = BinaryArrayFromValues({std::string_view{*encoded.metadata}}); + auto value_array = BinaryArrayFromValues({std::string_view{*encoded.value}}); + ASSERT_OK_AND_ASSIGN(auto event_type_group, + ::arrow::StructArray::Make({BinaryArrayFromValues({std::nullopt}), + StringArrayFromValues({"login"})}, + field_group_type->fields())); + ASSERT_OK_AND_ASSIGN( + auto typed_array, + ::arrow::StructArray::Make({event_type_group}, typed_value_type->fields())); + + auto table = VariantTable(variant_type, {metadata_array, value_array, typed_array}, + storage_type->fields()); + ASSERT_RAISES(Invalid, WriteVariantTable(table)); + + auto valid_value_array = BinaryArrayFromValues({std::nullopt}); + auto valid_table = + VariantTable(variant_type, {metadata_array, valid_value_array, typed_array}, + storage_type->fields()); + ASSERT_OK(WriteVariantTable(valid_table)); -TEST(TestVariantExtensionType, StorageTypeValidation) { - auto variant1 = ::arrow::extension::variant( - struct_({field("metadata", binary(), /*nullable=*/false), - field("value", binary(), /*nullable=*/false)})); - auto variant2 = ::arrow::extension::variant( - struct_({field("metadata", binary(), /*nullable=*/false), - field("value", binary(), /*nullable=*/false)})); - - ASSERT_TRUE(variant1->Equals(variant2)); - - // Metadata and value fields can be provided in either order - auto variantFieldsFlipped = - std::dynamic_pointer_cast<::arrow::extension::VariantExtensionType>( - ::arrow::extension::variant( - struct_({field("value", binary(), /*nullable=*/false), - field("metadata", binary(), /*nullable=*/false)}))); - - ASSERT_EQ("metadata", variantFieldsFlipped->metadata()->name()); - ASSERT_EQ("value", variantFieldsFlipped->value()->name()); - - auto missing_value = struct_({field("metadata", binary(), /*nullable=*/false)}); - auto missing_metadata = struct_({field("value", binary(), /*nullable=*/false)}); - auto bad_value_type = struct_({field("metadata", binary(), /*nullable=*/false), - field("value", ::arrow::int32(), /*nullable=*/false)}); - auto extra_field = struct_({field("metadata", binary(), /*nullable=*/false), - field("value", binary(), /*nullable=*/false), - field("extra", binary(), /*nullable=*/false)}); - auto nullable_metadata = struct_( - {field("metadata", binary()), field("value", binary(), /*nullable=*/false)}); - auto nullable_value = struct_( - {field("metadata", binary(), /*nullable=*/false), field("value", binary())}); - - for (const auto& storage_type : {missing_value, missing_metadata, bad_value_type, - extra_field, nullable_metadata, nullable_value}) { - ASSERT_RAISES_WITH_MESSAGE( - Invalid, - "Invalid: Invalid storage type for VariantExtensionType: " + - storage_type->ToString(), - ::arrow::extension::VariantExtensionType::Make(storage_type)); - } + ASSERT_OK_AND_ASSIGN(auto missing_event_type_group, + ::arrow::StructArray::Make({BinaryArrayFromValues({std::nullopt}), + StringArrayFromValues({std::nullopt})}, + field_group_type->fields())); + ASSERT_OK_AND_ASSIGN( + auto missing_typed_array, + ::arrow::StructArray::Make({missing_event_type_group}, typed_value_type->fields())); + auto missing_table = + VariantTable(variant_type, {metadata_array, valid_value_array, missing_typed_array}, + storage_type->fields()); + ASSERT_OK(WriteVariantTable(missing_table)); } } // namespace parquet::arrow diff --git a/cpp/src/parquet/arrow/writer.cc b/cpp/src/parquet/arrow/writer.cc index e0fbe308219c..024f5f6e7699 100644 --- a/cpp/src/parquet/arrow/writer.cc +++ b/cpp/src/parquet/arrow/writer.cc @@ -18,14 +18,12 @@ #include "parquet/arrow/writer.h" #include -#include #include #include -#include #include #include -#include "arrow/array.h" +#include "arrow/array.h" // IWYU pragma: keep #include "arrow/array/concatenate.h" #include "arrow/extension_type.h" #include "arrow/ipc/writer.h" @@ -46,6 +44,7 @@ #include "parquet/file_writer.h" #include "parquet/platform.h" #include "parquet/schema.h" +#include "parquet/variant/validate.h" using arrow::Array; using arrow::BinaryArray; @@ -62,7 +61,6 @@ using arrow::MemoryPool; using arrow::NumericArray; using arrow::PrimitiveArray; using arrow::RecordBatch; -using arrow::ResizableBuffer; using arrow::Result; using arrow::Status; using arrow::Table; @@ -323,6 +321,7 @@ class FileWriterImpl : public FileWriter { std::unique_ptr writer, std::shared_ptr arrow_properties) : schema_(std::move(schema)), + pool_(pool), writer_(std::move(writer)), row_group_writer_(nullptr), column_write_context_(pool, arrow_properties.get()), @@ -382,6 +381,9 @@ class FileWriterImpl : public FileWriter { Status WriteColumnChunk(const std::shared_ptr& data, int64_t offset, int64_t size) override { RETURN_NOT_OK(CheckClosed()); + if (arrow_properties_->variant_validation_enabled()) { + RETURN_NOT_OK(variant::ValidateVariants(*data->Slice(offset, size), pool_)); + } if (arrow_properties_->engine_version() == ArrowWriterProperties::V2 || arrow_properties_->engine_version() == ArrowWriterProperties::V1) { if (row_group_writer_->buffered()) { @@ -450,6 +452,9 @@ class FileWriterImpl : public FileWriter { Status WriteRecordBatch(const RecordBatch& batch) override { RETURN_NOT_OK(CheckClosed()); + if (arrow_properties_->variant_validation_enabled()) { + RETURN_NOT_OK(batch.Validate()); + } if (batch.num_rows() == 0) { return Status::OK(); } @@ -469,6 +474,10 @@ class FileWriterImpl : public FileWriter { for (int i = 0; i < batch.num_columns(); i++) { ChunkedArray chunked_array{batch.column(i)}; + if (arrow_properties_->variant_validation_enabled()) { + RETURN_NOT_OK( + variant::ValidateVariants(*chunked_array.Slice(offset, size), pool_)); + } ARROW_ASSIGN_OR_RAISE( std::unique_ptr writer, ArrowColumnWriterV2::Make(chunked_array, offset, size, schema_manifest_, @@ -532,6 +541,7 @@ class FileWriterImpl : public FileWriter { friend class FileWriter; std::shared_ptr<::arrow::Schema> schema_; + MemoryPool* pool_; SchemaManifest schema_manifest_; diff --git a/cpp/src/parquet/meson.build b/cpp/src/parquet/meson.build index 9069ccb5fd1a..d7737f7db0f9 100644 --- a/cpp/src/parquet/meson.build +++ b/cpp/src/parquet/meson.build @@ -55,6 +55,9 @@ parquet_srcs = files( 'stream_reader.cc', 'stream_writer.cc', 'types.cc', + 'variant/builder.cc', + 'variant/encoding.cc', + 'variant/validate.cc', 'xxhasher.cc', ) @@ -130,6 +133,7 @@ subdir('api') subdir('arrow') subdir('encryption') subdir('geospatial') +subdir('variant') install_headers( [ @@ -229,6 +233,7 @@ parquet_tests = { 'arrow/arrow_reader_writer_test.cc', 'arrow/arrow_statistics_test.cc', 'arrow/variant_test.cc', + 'variant/test_util_internal.cc', ), }, 'arrow-index-test': {'sources': files('arrow/index_test.cc')}, diff --git a/cpp/src/parquet/properties.h b/cpp/src/parquet/properties.h index e2244a1176e3..e00ec2fe7dd6 100644 --- a/cpp/src/parquet/properties.h +++ b/cpp/src/parquet/properties.h @@ -1259,10 +1259,9 @@ class PARQUET_EXPORT ArrowReaderProperties { /// Enable Parquet-supported Arrow extension types. /// /// When enabled, Parquet logical types will be mapped to their corresponding Arrow - /// extension types at read time, if such exist. Currently only arrow::extension::json() - /// extension type is supported. Columns whose LogicalType is JSON will be interpreted - /// as arrow::extension::json(), with storage type inferred from the serialized Arrow - /// schema if present, or `utf8` by default. + /// extension types at read time, if such exist. For example, columns whose LogicalType + /// is JSON will be interpreted as arrow::extension::json(), with storage type inferred + /// from the serialized Arrow schema if present, or `utf8` by default. void set_arrow_extensions_enabled(bool extensions_enabled) { arrow_extensions_enabled_ = extensions_enabled; } @@ -1332,7 +1331,8 @@ class PARQUET_EXPORT ArrowWriterProperties { engine_version_(V2), use_threads_(kArrowDefaultUseThreads), executor_(NULLPTR), - write_time_adjusted_to_utc_(false) {} + write_time_adjusted_to_utc_(false), + variant_validation_enabled_(true) {} /// \brief Disable writing legacy int96 timestamps (default disabled). Builder* disable_deprecated_int96_timestamps() { @@ -1436,12 +1436,23 @@ class PARQUET_EXPORT ArrowWriterProperties { return this; } + /// \brief Set whether to validate Parquet Variant binary values before writing. + /// + /// This is enabled by default. When enabled, Variant metadata/value bytes + /// are checked against the Parquet Variant encoding, and shredded value / + /// typed_value combinations are checked for conflicts. + Builder* set_variant_validation_enabled(bool enabled) { + variant_validation_enabled_ = enabled; + return this; + } + /// Create the final properties. std::shared_ptr build() { return std::shared_ptr(new ArrowWriterProperties( write_timestamps_as_int96_, coerce_timestamps_enabled_, coerce_timestamps_unit_, truncated_timestamps_allowed_, store_schema_, compliant_nested_types_, - engine_version_, use_threads_, executor_, write_time_adjusted_to_utc_)); + engine_version_, use_threads_, executor_, write_time_adjusted_to_utc_, + variant_validation_enabled_)); } private: @@ -1459,6 +1470,7 @@ class PARQUET_EXPORT ArrowWriterProperties { ::arrow::internal::Executor* executor_; bool write_time_adjusted_to_utc_; + bool variant_validation_enabled_; }; bool support_deprecated_int96_timestamps() const { return write_timestamps_as_int96_; } @@ -1497,15 +1509,16 @@ class PARQUET_EXPORT ArrowWriterProperties { /// Note this setting doesn't affect TIMESTAMP data. bool write_time_adjusted_to_utc() const { return write_time_adjusted_to_utc_; } + /// \brief Returns whether Parquet Variant binary values are validated before writing. + bool variant_validation_enabled() const { return variant_validation_enabled_; } + private: - explicit ArrowWriterProperties(bool write_nanos_as_int96, - bool coerce_timestamps_enabled, - ::arrow::TimeUnit::type coerce_timestamps_unit, - bool truncated_timestamps_allowed, bool store_schema, - bool compliant_nested_types, - EngineVersion engine_version, bool use_threads, - ::arrow::internal::Executor* executor, - bool write_time_adjusted_to_utc) + explicit ArrowWriterProperties( + bool write_nanos_as_int96, bool coerce_timestamps_enabled, + ::arrow::TimeUnit::type coerce_timestamps_unit, bool truncated_timestamps_allowed, + bool store_schema, bool compliant_nested_types, EngineVersion engine_version, + bool use_threads, ::arrow::internal::Executor* executor, + bool write_time_adjusted_to_utc, bool variant_validation_enabled) : write_timestamps_as_int96_(write_nanos_as_int96), coerce_timestamps_enabled_(coerce_timestamps_enabled), coerce_timestamps_unit_(coerce_timestamps_unit), @@ -1515,7 +1528,8 @@ class PARQUET_EXPORT ArrowWriterProperties { engine_version_(engine_version), use_threads_(use_threads), executor_(executor), - write_time_adjusted_to_utc_(write_time_adjusted_to_utc) {} + write_time_adjusted_to_utc_(write_time_adjusted_to_utc), + variant_validation_enabled_(variant_validation_enabled) {} const bool write_timestamps_as_int96_; const bool coerce_timestamps_enabled_; @@ -1527,6 +1541,7 @@ class PARQUET_EXPORT ArrowWriterProperties { const bool use_threads_; ::arrow::internal::Executor* executor_; const bool write_time_adjusted_to_utc_; + const bool variant_validation_enabled_; }; /// \brief State object used for writing Arrow data directly to a Parquet diff --git a/cpp/src/parquet/variant/CMakeLists.txt b/cpp/src/parquet/variant/CMakeLists.txt new file mode 100644 index 000000000000..8c9b99266331 --- /dev/null +++ b/cpp/src/parquet/variant/CMakeLists.txt @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_parquet_test(variant-test + SOURCES + builder_test.cc + encoding_test.cc + test_util_internal.cc + type_test.cc + validate_test.cc) + +arrow_install_all_headers("parquet/variant") diff --git a/cpp/src/parquet/variant/builder.cc b/cpp/src/parquet/variant/builder.cc new file mode 100644 index 000000000000..bd8540b368ee --- /dev/null +++ b/cpp/src/parquet/variant/builder.cc @@ -0,0 +1,1273 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "parquet/variant/builder.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/buffer_builder.h" +#include "arrow/builder.h" // IWYU pragma: keep +#include "arrow/util/checked_cast.h" +#include "arrow/util/endian.h" +#include "arrow/util/logging_internal.h" +#include "parquet/variant/encoding.h" +#include "parquet/variant/encoding_internal.h" + +namespace parquet::variant { + +using ::arrow::binary; +using ::arrow::BinaryBuilder; +using ::arrow::BooleanArray; +using ::arrow::BooleanBuilder; +using ::arrow::BufferBuilder; +using ::arrow::ExtensionType; +using ::arrow::field; +using ::arrow::struct_; +using ::arrow::StructType; +using ::arrow::Type; +using ::arrow::extension::VariantExtensionType; + +namespace bit_util = ::arrow::bit_util; + +namespace { + +Status AppendLittleEndian(BufferBuilder& out, uint32_t value, uint8_t width) { + DCHECK_LE(width, sizeof(uint32_t)); + const auto little_endian = bit_util::ToLittleEndian(value); + return out.Append(&little_endian, width); +} + +template + requires(std::is_arithmetic_v) +Status AppendFixedLittleEndian(BufferBuilder& out, T value) { + const auto little_endian = bit_util::ToLittleEndian(value); + return out.Append(&little_endian, sizeof(T)); +} + +void AppendLittleEndianToString(std::string& out, uint32_t value, uint8_t width) { + DCHECK_LE(width, sizeof(uint32_t)); + const auto little_endian = bit_util::ToLittleEndian(value); + out.append(reinterpret_cast(&little_endian), width); +} + +Status InsertBytes(BufferBuilder& out, int64_t offset, std::string_view bytes) { + const int64_t old_size = out.length(); + const auto insert_size = bytes.size(); + ARROW_RETURN_NOT_OK(out.Reserve(insert_size)); + uint8_t* data = out.mutable_data(); + std::memmove(data + offset + insert_size, data + offset, old_size - offset); + std::memcpy(data + offset, bytes.data(), insert_size); + out.UnsafeAdvance(insert_size); + return Status::OK(); +} + +uint8_t WidthForValue(uint64_t value) { + if (value <= std::numeric_limits::max()) { + return 1; + } + if (value <= std::numeric_limits::max()) { + return 2; + } + if (value <= 0xFFFFFFU) { + return 3; + } + return 4; +} + +class VariantMetadataBuilder { + public: + explicit VariantMetadataBuilder(MemoryPool* pool) : pool_(pool) {} + + Status Reserve(int64_t capacity) { + if (capacity < 0) { + return Status::Invalid("Variant metadata capacity must be non-negative"); + } + field_ids_.reserve(capacity); + return Status::OK(); + } + + Result Upsert(std::string_view name) { + auto it = field_ids_.find(name); + if (it != field_ids_.end()) { + return it->second; + } + if (field_names_.size() >= std::numeric_limits::max()) { + return Status::Invalid("Variant metadata dictionary is too large"); + } + ARROW_RETURN_NOT_OK(internal::ValidateUtf8(name, "metadata dictionary string")); + + const auto id = field_names_.size(); + if (field_names_.empty()) { + is_sorted_ = true; + } else if (is_sorted_ && !(field_names_.back() < name)) { + is_sorted_ = false; + } + field_names_.emplace_back(name); + field_ids_.emplace(field_names_.back(), id); + return static_cast(id); + } + + std::string_view FieldName(uint32_t id) const { + DCHECK_LT(id, field_names_.size()); + return field_names_[id]; + } + + size_t size() const { return field_names_.size(); } + + void Truncate(size_t size) { + DCHECK_LE(size, field_names_.size()); + field_names_.resize(size); + RebuildIndex(); + } + + Result> Finish() const { + uint64_t bytes_size = 0; + for (const auto& string : field_names_) { + bytes_size += string.size(); + } + if (field_names_.size() > std::numeric_limits::max() || + bytes_size > std::numeric_limits::max()) { + return Status::Invalid("Variant metadata dictionary is too large"); + } + + const uint8_t offset_size = + WidthForValue(std::max(field_names_.size(), bytes_size)); + BufferBuilder out(pool_); + ARROW_RETURN_NOT_OK(out.Reserve( + 1 + offset_size + (field_names_.size() + 1) * offset_size + bytes_size)); + + const bool sorted_strings = !field_names_.empty() && is_sorted_; + const auto header = + static_cast(internal::kVariantVersion | + (sorted_strings ? internal::kMetadataSortedStringsMask : 0) | + ((offset_size - 1) << 6)); + ARROW_RETURN_NOT_OK(out.Append(&header, sizeof(header))); + ARROW_RETURN_NOT_OK( + AppendLittleEndian(out, static_cast(field_names_.size()), offset_size)); + + uint32_t offset = 0; + ARROW_RETURN_NOT_OK(AppendLittleEndian(out, offset, offset_size)); + for (const auto& string : field_names_) { + offset += static_cast(string.size()); + ARROW_RETURN_NOT_OK(AppendLittleEndian(out, offset, offset_size)); + } + for (const auto& string : field_names_) { + ARROW_RETURN_NOT_OK(out.Append(string)); + } + return out.Finish(); + } + + private: + void RebuildIndex() { + field_ids_.clear(); + field_ids_.reserve(field_names_.size()); + is_sorted_ = false; + for (uint32_t i = 0; i < field_names_.size(); ++i) { + field_ids_.emplace(field_names_[i], i); + if (i == 0) { + is_sorted_ = true; + } else if (is_sorted_ && !(field_names_[i - 1] < field_names_[i])) { + is_sorted_ = false; + } + } + } + + MemoryPool* pool_; + std::deque field_names_; + std::unordered_map field_ids_; + bool is_sorted_ = false; +}; + +class VariantValueWriter { + public: + explicit VariantValueWriter(BufferBuilder& out) : out_(out) {} + + template + requires internal::HeaderOnlyVariantPrimitive + Status Append() { + return AppendPrimitiveHeader(); + } + + template + requires internal::FixedVariantPrimitive + Status Append(typename internal::VariantFixedPrimitiveTraits::CType value) { + using CType = typename internal::VariantFixedPrimitiveTraits::CType; + ARROW_RETURN_NOT_OK(AppendPrimitiveHeader()); + if constexpr (sizeof(CType) == 1) { + const auto byte = static_cast(value); + return out_.Append(&byte, sizeof(byte)); + } else { + return AppendFixedLittleEndian(out_, value); + } + } + + template + requires internal::DecimalVariantPrimitive + Status Append( + typename internal::VariantDecimalPrimitiveTraits::CType unscaled_value, + uint8_t scale) { + ARROW_RETURN_NOT_OK(internal::ValidateDecimalScale(scale)); + ARROW_RETURN_NOT_OK(AppendPrimitiveHeader()); + ARROW_RETURN_NOT_OK(out_.Append(&scale, sizeof(scale))); + return AppendFixedLittleEndian(out_, unscaled_value); + } + + template + requires internal::Decimal16VariantPrimitive + Status Append(std::string_view little_endian_unscaled_value, uint8_t scale) { + if (little_endian_unscaled_value.size() != 16) { + return Status::Invalid("Variant Decimal16 values must be 16 bytes"); + } + ARROW_RETURN_NOT_OK(internal::ValidateDecimalScale(scale)); + ARROW_RETURN_NOT_OK(AppendPrimitiveHeader()); + ARROW_RETURN_NOT_OK(out_.Append(&scale, sizeof(scale))); + return out_.Append(little_endian_unscaled_value); + } + + template + requires internal::LengthPrefixedVariantPrimitive + Status Append(std::string_view value) { + if (value.size() > std::numeric_limits::max()) { + return Status::Invalid("Variant ", + type == VariantPrimitiveType::kBinary ? "binary" : "string", + " value is too large"); + } + if constexpr (type == VariantPrimitiveType::kString) { + ARROW_RETURN_NOT_OK(internal::ValidateUtf8(value, "primitive string value")); + } + ARROW_RETURN_NOT_OK(AppendPrimitiveHeader()); + ARROW_RETURN_NOT_OK(AppendLittleEndian(out_, static_cast(value.size()), 4)); + return out_.Append(value); + } + + template + requires internal::UuidVariantPrimitive + Status Append(std::string_view big_endian_bytes) { + if (big_endian_bytes.size() != 16) { + return Status::Invalid("Variant UUID values must be 16 bytes"); + } + ARROW_RETURN_NOT_OK(AppendPrimitiveHeader()); + return out_.Append(big_endian_bytes); + } + + Status AppendShortString(std::string_view value) { + if (value.size() >= 64) { + return Status::Invalid("Variant short string value must be shorter than 64 bytes"); + } + ARROW_RETURN_NOT_OK(internal::ValidateUtf8(value, "short string value")); + const auto header = static_cast( + (value.size() << 2) | static_cast(VariantBasicType::kShortString)); + ARROW_RETURN_NOT_OK(out_.Append(&header, sizeof(header))); + return out_.Append(value); + } + + private: + template + Status AppendPrimitiveHeader() { + const auto header = + static_cast((static_cast(type) << 2) | + static_cast(VariantBasicType::kPrimitive)); + return out_.Append(&header, sizeof(header)); + } + + BufferBuilder& out_; +}; + +enum class VariantContainerKind { Object, List }; + +struct VariantFieldDescriptor { + uint32_t field_id = 0; + uint32_t offset = 0; +}; + +struct VariantBuildFrame { + VariantContainerKind kind; + int64_t value_start = 0; + size_t metadata_size = 0; + size_t parent_frame = 0; + size_t parent_entry_count = 0; + bool has_parent = false; + std::vector fields{}; + std::vector offsets{}; + std::unordered_set object_field_ids{}; + bool finished = false; +}; + +struct VariantBuildState { + explicit VariantBuildState(MemoryPool* pool) + : pool(pool), value(pool), metadata(pool) {} + + MemoryPool* pool; + BufferBuilder value; + VariantMetadataBuilder metadata; + std::vector frames; + bool root_has_value = false; +}; + +Status CheckRootWritable(const VariantBuildState& state) { + if (!state.frames.empty()) { + return Status::Invalid("VariantBuilder has an active container"); + } + if (state.root_has_value) { + return Status::Invalid("VariantBuilder already has a root value"); + } + return Status::OK(); +} + +Status CheckTopFrame(const VariantBuildState& state, size_t frame_index, + VariantContainerKind kind) { + if (state.frames.empty() || frame_index + 1 != state.frames.size()) { + return Status::Invalid("Variant nested builder is not the active container"); + } + const auto& frame = state.frames.back(); + if (frame.finished || frame.kind != kind) { + return Status::Invalid("Variant nested builder has invalid state"); + } + return Status::OK(); +} + +void TruncateFrameEntries(VariantBuildFrame& frame, size_t entry_count) { + if (frame.kind == VariantContainerKind::Object) { + frame.fields.resize(entry_count); + frame.object_field_ids.clear(); + for (const auto& field : frame.fields) { + frame.object_field_ids.insert(field.field_id); + } + } else { + frame.offsets.resize(entry_count); + } +} + +template +void RollbackIfActive(Impl* impl) { + if (impl != nullptr && impl->active) { + if (impl->state == nullptr || impl->frame_index >= impl->state->frames.size()) { + return; + } + const auto frame = impl->state->frames[impl->frame_index]; + impl->state->value.Rewind(frame.value_start); + impl->state->metadata.Truncate(frame.metadata_size); + if (frame.has_parent && frame.parent_frame < impl->state->frames.size()) { + TruncateFrameEntries(impl->state->frames[frame.parent_frame], + frame.parent_entry_count); + } + impl->state->frames.resize(impl->frame_index); + impl->active = false; + } +} + +Result BuildObjectHeader(const VariantBuildState& state, + const VariantBuildFrame& frame, + uint32_t values_size) { + if (frame.fields.size() > std::numeric_limits::max()) { + return Status::Invalid("Variant object has too many fields"); + } + + std::vector fields = frame.fields; + std::ranges::sort(fields, [&](const VariantFieldDescriptor& left, + const VariantFieldDescriptor& right) { + return state.metadata.FieldName(left.field_id) < + state.metadata.FieldName(right.field_id); + }); + + uint32_t max_field_id = 0; + for (const auto& field : fields) { + max_field_id = std::max(max_field_id, field.field_id); + } + const uint8_t id_size = WidthForValue(max_field_id); + const uint8_t offset_size = WidthForValue(values_size); + const bool is_large = fields.size() > std::numeric_limits::max(); + const auto header = static_cast(((is_large ? 1 : 0) << 4) | + ((id_size - 1) << 2) | (offset_size - 1)); + + std::string out; + out.push_back( + static_cast((header << 2) | static_cast(VariantBasicType::kObject))); + AppendLittleEndianToString(out, static_cast(fields.size()), is_large ? 4 : 1); + for (const auto& field : fields) { + AppendLittleEndianToString(out, field.field_id, id_size); + } + for (const auto& field : fields) { + AppendLittleEndianToString(out, field.offset, offset_size); + } + AppendLittleEndianToString(out, values_size, offset_size); + return out; +} + +Result BuildListHeader(const VariantBuildFrame& frame, + uint32_t values_size) { + if (frame.offsets.size() > std::numeric_limits::max()) { + return Status::Invalid("Variant array has too many elements"); + } + + const uint8_t offset_size = WidthForValue(values_size); + const bool is_large = frame.offsets.size() > std::numeric_limits::max(); + const auto header = static_cast(((is_large ? 1 : 0) << 2) | (offset_size - 1)); + + std::string out; + out.push_back( + static_cast((header << 2) | static_cast(VariantBasicType::kArray))); + AppendLittleEndianToString(out, static_cast(frame.offsets.size()), + is_large ? 4 : 1); + for (const auto offset : frame.offsets) { + AppendLittleEndianToString(out, offset, offset_size); + } + AppendLittleEndianToString(out, values_size, offset_size); + return out; +} + +using RootFinishCallback = std::function; + +Status FinishFrame(const std::shared_ptr& state, size_t frame_index, + VariantContainerKind kind, const RootFinishCallback& callback) { + ARROW_RETURN_NOT_OK(CheckTopFrame(*state, frame_index, kind)); + + auto& frame = state->frames.back(); + const auto values_size = state->value.length() - frame.value_start; + DCHECK_GE(values_size, 0); + if (values_size > std::numeric_limits::max()) { + return Status::Invalid("Variant container values are too large"); + } + + ARROW_ASSIGN_OR_RAISE( + auto header, + kind == VariantContainerKind::Object + ? BuildObjectHeader(*state, frame, static_cast(values_size)) + : BuildListHeader(frame, static_cast(values_size))); + ARROW_RETURN_NOT_OK(InsertBytes(state->value, frame.value_start, header)); + + const bool is_root = !frame.has_parent; + frame.finished = true; + state->frames.pop_back(); + if (!is_root) { + return Status::OK(); + } + + state->root_has_value = true; + if (!callback) { + return Status::OK(); + } + + ARROW_ASSIGN_OR_RAISE(auto metadata, state->metadata.Finish()); + ARROW_ASSIGN_OR_RAISE(auto value, state->value.Finish()); + return callback({.metadata = std::move(metadata), .value = std::move(value)}); +} + +template +Status AppendRootPrimitiveWith(const std::shared_ptr& state, + Write&& write) { + ARROW_RETURN_NOT_OK(CheckRootWritable(*state)); + const auto value_size = state->value.length(); + VariantValueWriter writer(state->value); + const Status status = std::invoke(std::forward(write), writer); + if (!status.ok()) { + state->value.Rewind(value_size); + return status; + } + state->root_has_value = true; + return Status::OK(); +} + +template +Status AppendRootPrimitive(const std::shared_ptr& state, + Args&&... args) { + return AppendRootPrimitiveWith(state, [&](VariantValueWriter& writer) { + return writer.template Append(std::forward(args)...); + }); +} + +template +Status AppendObjectPrimitiveWith(const std::shared_ptr& state, + size_t frame_index, std::string_view field_name, + Write&& write) { + ARROW_RETURN_NOT_OK(CheckTopFrame(*state, frame_index, VariantContainerKind::Object)); + auto& frame = state->frames.back(); + const auto metadata_size = state->metadata.size(); + const auto value_size = state->value.length(); + const auto field_count = frame.fields.size(); + + ARROW_ASSIGN_OR_RAISE(auto field_id, state->metadata.Upsert(field_name)); + if (!frame.object_field_ids.insert(field_id).second) { + state->metadata.Truncate(metadata_size); + return Status::Invalid("Duplicate Variant object field: ", field_name); + } + const auto offset = state->value.length() - frame.value_start; + DCHECK_GE(offset, 0); + if (offset > std::numeric_limits::max()) { + state->metadata.Truncate(metadata_size); + TruncateFrameEntries(frame, field_count); + return Status::Invalid("Variant object values are too large"); + } + frame.fields.push_back(VariantFieldDescriptor{.field_id = field_id, + .offset = static_cast(offset)}); + + VariantValueWriter writer(state->value); + const Status status = std::invoke(std::forward(write), writer); + if (!status.ok()) { + state->value.Rewind(value_size); + state->metadata.Truncate(metadata_size); + TruncateFrameEntries(frame, field_count); + return status; + } + return Status::OK(); +} + +template +Status AppendObjectPrimitive(const std::shared_ptr& state, + size_t frame_index, std::string_view field_name, + Args&&... args) { + return AppendObjectPrimitiveWith( + state, frame_index, field_name, [&](VariantValueWriter& writer) { + return writer.template Append(std::forward(args)...); + }); +} + +template +Status AppendListPrimitiveWith(const std::shared_ptr& state, + size_t frame_index, Write&& write) { + ARROW_RETURN_NOT_OK(CheckTopFrame(*state, frame_index, VariantContainerKind::List)); + auto& frame = state->frames.back(); + const auto value_size = state->value.length(); + const auto element_count = frame.offsets.size(); + const auto offset = state->value.length() - frame.value_start; + DCHECK_GE(offset, 0); + if (offset > std::numeric_limits::max()) { + return Status::Invalid("Variant array values are too large"); + } + frame.offsets.push_back(static_cast(offset)); + + VariantValueWriter writer(state->value); + const Status status = std::invoke(std::forward(write), writer); + if (!status.ok()) { + state->value.Rewind(value_size); + TruncateFrameEntries(frame, element_count); + return status; + } + return Status::OK(); +} + +template +Status AppendListPrimitive(const std::shared_ptr& state, + size_t frame_index, Args&&... args) { + return AppendListPrimitiveWith(state, frame_index, [&](VariantValueWriter& writer) { + return writer.template Append(std::forward(args)...); + }); +} + +} // namespace + +namespace internal { + +struct NestedVariantBuilderImpl { + NestedVariantBuilderImpl(std::shared_ptr state, size_t frame_index, + RootFinishCallback callback) + : state(std::move(state)), + frame_index(frame_index), + callback(std::move(callback)) {} + + std::shared_ptr state; + size_t frame_index = 0; + RootFinishCallback callback; + bool active = true; +}; + +} // namespace internal + +struct VariantBuilder::Impl { + explicit Impl(MemoryPool* pool) + : pool(pool), state(std::make_shared(pool)) {} + + MemoryPool* pool; + std::shared_ptr state; +}; + +VariantBuilder::VariantBuilder(MemoryPool* pool) : impl_(std::make_unique(pool)) {} +VariantBuilder::~VariantBuilder() = default; +VariantBuilder::VariantBuilder(VariantBuilder&&) noexcept = default; +VariantBuilder& VariantBuilder::operator=(VariantBuilder&&) noexcept = default; + +Status VariantBuilder::ReserveFieldNames(int64_t capacity) { + return impl_->state->metadata.Reserve(capacity); +} + +Result VariantBuilder::AddFieldName(std::string_view name) { + return impl_->state->metadata.Upsert(name); +} + +Status VariantBuilder::AppendVariantNull() { + return AppendRootPrimitive(impl_->state); +} + +Status VariantBuilder::AppendBoolean(bool value) { + if (value) { + return AppendRootPrimitive(impl_->state); + } + return AppendRootPrimitive(impl_->state); +} + +#define VARIANT_ROOT_APPEND_ONE_ARG(NAME, TYPE, C_TYPE) \ + Status VariantBuilder::Append##NAME(C_TYPE value) { \ + return AppendRootPrimitive(impl_->state, value); \ + } + +VARIANT_ROOT_APPEND_ONE_ARG(Int8, Int8, int8_t) +VARIANT_ROOT_APPEND_ONE_ARG(Int16, Int16, int16_t) +VARIANT_ROOT_APPEND_ONE_ARG(Int32, Int32, int32_t) +VARIANT_ROOT_APPEND_ONE_ARG(Int64, Int64, int64_t) +VARIANT_ROOT_APPEND_ONE_ARG(Float, Float, float) +VARIANT_ROOT_APPEND_ONE_ARG(Double, Double, double) +VARIANT_ROOT_APPEND_ONE_ARG(Binary, Binary, std::string_view) +VARIANT_ROOT_APPEND_ONE_ARG(String, String, std::string_view) +VARIANT_ROOT_APPEND_ONE_ARG(Date, Date, int32_t) +VARIANT_ROOT_APPEND_ONE_ARG(TimeNTZMicros, TimeNTZMicros, int64_t) +VARIANT_ROOT_APPEND_ONE_ARG(Uuid, Uuid, std::string_view) + +#undef VARIANT_ROOT_APPEND_ONE_ARG + +Status VariantBuilder::AppendDecimal4(int32_t unscaled_value, uint8_t scale) { + return AppendRootPrimitive(impl_->state, + unscaled_value, scale); +} + +Status VariantBuilder::AppendDecimal8(int64_t unscaled_value, uint8_t scale) { + return AppendRootPrimitive(impl_->state, + unscaled_value, scale); +} + +Status VariantBuilder::AppendDecimal16(std::string_view little_endian_unscaled_value, + uint8_t scale) { + return AppendRootPrimitive( + impl_->state, little_endian_unscaled_value, scale); +} + +Status VariantBuilder::AppendShortString(std::string_view value) { + return AppendRootPrimitiveWith(impl_->state, [&](VariantValueWriter& writer) { + return writer.AppendShortString(value); + }); +} + +Status VariantBuilder::AppendTimestampMicros(int64_t micros, bool adjusted_to_utc) { + if (adjusted_to_utc) { + return AppendRootPrimitive(impl_->state, + micros); + } + return AppendRootPrimitive(impl_->state, + micros); +} + +Status VariantBuilder::AppendTimestampNanos(int64_t nanos, bool adjusted_to_utc) { + if (adjusted_to_utc) { + return AppendRootPrimitive(impl_->state, + nanos); + } + return AppendRootPrimitive(impl_->state, + nanos); +} + +Result VariantBuilder::StartObject() { + ARROW_RETURN_NOT_OK(CheckRootWritable(*impl_->state)); + impl_->state->frames.push_back( + VariantBuildFrame{.kind = VariantContainerKind::Object, + .value_start = impl_->state->value.length(), + .metadata_size = impl_->state->metadata.size()}); + return VariantObjectBuilder(std::make_unique( + impl_->state, impl_->state->frames.size() - 1, RootFinishCallback{})); +} + +Result VariantBuilder::StartList() { + ARROW_RETURN_NOT_OK(CheckRootWritable(*impl_->state)); + impl_->state->frames.push_back( + VariantBuildFrame{.kind = VariantContainerKind::List, + .value_start = impl_->state->value.length(), + .metadata_size = impl_->state->metadata.size()}); + return VariantListBuilder(std::make_unique( + impl_->state, impl_->state->frames.size() - 1, RootFinishCallback{})); +} + +Result VariantBuilder::Finish() { + if (!impl_->state->frames.empty()) { + return Status::Invalid("Cannot finish VariantBuilder with active containers"); + } + if (!impl_->state->root_has_value) { + return Status::Invalid("Cannot finish empty VariantBuilder"); + } + + ARROW_ASSIGN_OR_RAISE(auto metadata, impl_->state->metadata.Finish()); + ARROW_ASSIGN_OR_RAISE(auto value, impl_->state->value.Finish()); + EncodedVariantValue out{.metadata = std::move(metadata), .value = std::move(value)}; + Reset(); + return out; +} + +void VariantBuilder::Reset() { + impl_->state = std::make_shared(impl_->pool); +} + +VariantObjectBuilder::VariantObjectBuilder( + std::unique_ptr impl) + : impl_(std::move(impl)) {} +VariantObjectBuilder::~VariantObjectBuilder() { RollbackIfActive(impl_.get()); } +VariantObjectBuilder::VariantObjectBuilder(VariantObjectBuilder&&) noexcept = default; +VariantObjectBuilder& VariantObjectBuilder::operator=(VariantObjectBuilder&&) noexcept = + default; + +Status VariantObjectBuilder::AppendVariantNull(std::string_view field_name) { + return AppendObjectPrimitive( + impl_->state, impl_->frame_index, field_name); +} + +Status VariantObjectBuilder::AppendBoolean(std::string_view field_name, bool value) { + if (value) { + return AppendObjectPrimitive( + impl_->state, impl_->frame_index, field_name); + } + return AppendObjectPrimitive( + impl_->state, impl_->frame_index, field_name); +} + +#define VARIANT_OBJECT_APPEND_ONE_ARG(NAME, TYPE, C_TYPE) \ + Status VariantObjectBuilder::Append##NAME(std::string_view field_name, C_TYPE value) { \ + return AppendObjectPrimitive( \ + impl_->state, impl_->frame_index, field_name, value); \ + } + +VARIANT_OBJECT_APPEND_ONE_ARG(Int8, Int8, int8_t) +VARIANT_OBJECT_APPEND_ONE_ARG(Int16, Int16, int16_t) +VARIANT_OBJECT_APPEND_ONE_ARG(Int32, Int32, int32_t) +VARIANT_OBJECT_APPEND_ONE_ARG(Int64, Int64, int64_t) +VARIANT_OBJECT_APPEND_ONE_ARG(Float, Float, float) +VARIANT_OBJECT_APPEND_ONE_ARG(Double, Double, double) +VARIANT_OBJECT_APPEND_ONE_ARG(Binary, Binary, std::string_view) +VARIANT_OBJECT_APPEND_ONE_ARG(String, String, std::string_view) +VARIANT_OBJECT_APPEND_ONE_ARG(Date, Date, int32_t) +VARIANT_OBJECT_APPEND_ONE_ARG(TimeNTZMicros, TimeNTZMicros, int64_t) +VARIANT_OBJECT_APPEND_ONE_ARG(Uuid, Uuid, std::string_view) + +#undef VARIANT_OBJECT_APPEND_ONE_ARG + +Status VariantObjectBuilder::AppendDecimal4(std::string_view field_name, + int32_t unscaled_value, uint8_t scale) { + return AppendObjectPrimitive( + impl_->state, impl_->frame_index, field_name, unscaled_value, scale); +} + +Status VariantObjectBuilder::AppendDecimal8(std::string_view field_name, + int64_t unscaled_value, uint8_t scale) { + return AppendObjectPrimitive( + impl_->state, impl_->frame_index, field_name, unscaled_value, scale); +} + +Status VariantObjectBuilder::AppendDecimal16( + std::string_view field_name, std::string_view little_endian_unscaled_value, + uint8_t scale) { + return AppendObjectPrimitive( + impl_->state, impl_->frame_index, field_name, little_endian_unscaled_value, scale); +} + +Status VariantObjectBuilder::AppendShortString(std::string_view field_name, + std::string_view value) { + return AppendObjectPrimitiveWith( + impl_->state, impl_->frame_index, field_name, + [&](VariantValueWriter& writer) { return writer.AppendShortString(value); }); +} + +Status VariantObjectBuilder::AppendTimestampMicros(std::string_view field_name, + int64_t micros, bool adjusted_to_utc) { + if (adjusted_to_utc) { + return AppendObjectPrimitive( + impl_->state, impl_->frame_index, field_name, micros); + } + return AppendObjectPrimitive( + impl_->state, impl_->frame_index, field_name, micros); +} + +Status VariantObjectBuilder::AppendTimestampNanos(std::string_view field_name, + int64_t nanos, bool adjusted_to_utc) { + if (adjusted_to_utc) { + return AppendObjectPrimitive( + impl_->state, impl_->frame_index, field_name, nanos); + } + return AppendObjectPrimitive( + impl_->state, impl_->frame_index, field_name, nanos); +} + +Result VariantObjectBuilder::StartObject( + std::string_view field_name) { + ARROW_RETURN_NOT_OK( + CheckTopFrame(*impl_->state, impl_->frame_index, VariantContainerKind::Object)); + auto& frame = impl_->state->frames.back(); + const auto metadata_size = impl_->state->metadata.size(); + const auto field_count = frame.fields.size(); + + ARROW_ASSIGN_OR_RAISE(auto field_id, impl_->state->metadata.Upsert(field_name)); + if (!frame.object_field_ids.insert(field_id).second) { + impl_->state->metadata.Truncate(metadata_size); + return Status::Invalid("Duplicate Variant object field: ", field_name); + } + const auto offset = impl_->state->value.length() - frame.value_start; + DCHECK_GE(offset, 0); + if (offset > std::numeric_limits::max()) { + impl_->state->metadata.Truncate(metadata_size); + TruncateFrameEntries(frame, field_count); + return Status::Invalid("Variant object values are too large"); + } + + frame.fields.push_back(VariantFieldDescriptor{.field_id = field_id, + .offset = static_cast(offset)}); + const auto value_start = impl_->state->value.length(); + impl_->state->frames.push_back(VariantBuildFrame{.kind = VariantContainerKind::Object, + .value_start = value_start, + .metadata_size = metadata_size, + .parent_frame = impl_->frame_index, + .parent_entry_count = field_count, + .has_parent = true}); + return VariantObjectBuilder(std::make_unique( + impl_->state, impl_->state->frames.size() - 1, impl_->callback)); +} + +Result VariantObjectBuilder::StartList(std::string_view field_name) { + ARROW_RETURN_NOT_OK( + CheckTopFrame(*impl_->state, impl_->frame_index, VariantContainerKind::Object)); + auto& frame = impl_->state->frames.back(); + const auto metadata_size = impl_->state->metadata.size(); + const auto field_count = frame.fields.size(); + + ARROW_ASSIGN_OR_RAISE(auto field_id, impl_->state->metadata.Upsert(field_name)); + if (!frame.object_field_ids.insert(field_id).second) { + impl_->state->metadata.Truncate(metadata_size); + return Status::Invalid("Duplicate Variant object field: ", field_name); + } + const auto offset = impl_->state->value.length() - frame.value_start; + DCHECK_GE(offset, 0); + if (offset > std::numeric_limits::max()) { + impl_->state->metadata.Truncate(metadata_size); + TruncateFrameEntries(frame, field_count); + return Status::Invalid("Variant object values are too large"); + } + + frame.fields.push_back(VariantFieldDescriptor{.field_id = field_id, + .offset = static_cast(offset)}); + const auto value_start = impl_->state->value.length(); + impl_->state->frames.push_back(VariantBuildFrame{.kind = VariantContainerKind::List, + .value_start = value_start, + .metadata_size = metadata_size, + .parent_frame = impl_->frame_index, + .parent_entry_count = field_count, + .has_parent = true}); + return VariantListBuilder(std::make_unique( + impl_->state, impl_->state->frames.size() - 1, impl_->callback)); +} + +Status VariantObjectBuilder::Finish() { + ARROW_RETURN_NOT_OK(FinishFrame(impl_->state, impl_->frame_index, + VariantContainerKind::Object, impl_->callback)); + impl_->active = false; + return Status::OK(); +} + +VariantListBuilder::VariantListBuilder( + std::unique_ptr impl) + : impl_(std::move(impl)) {} +VariantListBuilder::~VariantListBuilder() { RollbackIfActive(impl_.get()); } +VariantListBuilder::VariantListBuilder(VariantListBuilder&&) noexcept = default; +VariantListBuilder& VariantListBuilder::operator=(VariantListBuilder&&) noexcept = + default; + +Status VariantListBuilder::AppendVariantNull() { + return AppendListPrimitive(impl_->state, + impl_->frame_index); +} + +Status VariantListBuilder::AppendBoolean(bool value) { + if (value) { + return AppendListPrimitive(impl_->state, + impl_->frame_index); + } + return AppendListPrimitive(impl_->state, + impl_->frame_index); +} + +#define VARIANT_LIST_APPEND_ONE_ARG(NAME, TYPE, C_TYPE) \ + Status VariantListBuilder::Append##NAME(C_TYPE value) { \ + return AppendListPrimitive( \ + impl_->state, impl_->frame_index, value); \ + } + +VARIANT_LIST_APPEND_ONE_ARG(Int8, Int8, int8_t) +VARIANT_LIST_APPEND_ONE_ARG(Int16, Int16, int16_t) +VARIANT_LIST_APPEND_ONE_ARG(Int32, Int32, int32_t) +VARIANT_LIST_APPEND_ONE_ARG(Int64, Int64, int64_t) +VARIANT_LIST_APPEND_ONE_ARG(Float, Float, float) +VARIANT_LIST_APPEND_ONE_ARG(Double, Double, double) +VARIANT_LIST_APPEND_ONE_ARG(Binary, Binary, std::string_view) +VARIANT_LIST_APPEND_ONE_ARG(String, String, std::string_view) +VARIANT_LIST_APPEND_ONE_ARG(Date, Date, int32_t) +VARIANT_LIST_APPEND_ONE_ARG(TimeNTZMicros, TimeNTZMicros, int64_t) +VARIANT_LIST_APPEND_ONE_ARG(Uuid, Uuid, std::string_view) + +#undef VARIANT_LIST_APPEND_ONE_ARG + +Status VariantListBuilder::AppendDecimal4(int32_t unscaled_value, uint8_t scale) { + return AppendListPrimitive( + impl_->state, impl_->frame_index, unscaled_value, scale); +} + +Status VariantListBuilder::AppendDecimal8(int64_t unscaled_value, uint8_t scale) { + return AppendListPrimitive( + impl_->state, impl_->frame_index, unscaled_value, scale); +} + +Status VariantListBuilder::AppendDecimal16(std::string_view little_endian_unscaled_value, + uint8_t scale) { + return AppendListPrimitive( + impl_->state, impl_->frame_index, little_endian_unscaled_value, scale); +} + +Status VariantListBuilder::AppendShortString(std::string_view value) { + return AppendListPrimitiveWith( + impl_->state, impl_->frame_index, + [&](VariantValueWriter& writer) { return writer.AppendShortString(value); }); +} + +Status VariantListBuilder::AppendTimestampMicros(int64_t micros, bool adjusted_to_utc) { + if (adjusted_to_utc) { + return AppendListPrimitive( + impl_->state, impl_->frame_index, micros); + } + return AppendListPrimitive( + impl_->state, impl_->frame_index, micros); +} + +Status VariantListBuilder::AppendTimestampNanos(int64_t nanos, bool adjusted_to_utc) { + if (adjusted_to_utc) { + return AppendListPrimitive( + impl_->state, impl_->frame_index, nanos); + } + return AppendListPrimitive( + impl_->state, impl_->frame_index, nanos); +} + +Result VariantListBuilder::StartObject() { + ARROW_RETURN_NOT_OK( + CheckTopFrame(*impl_->state, impl_->frame_index, VariantContainerKind::List)); + auto& frame = impl_->state->frames.back(); + const auto element_count = frame.offsets.size(); + const auto offset = impl_->state->value.length() - frame.value_start; + DCHECK_GE(offset, 0); + if (offset > std::numeric_limits::max()) { + return Status::Invalid("Variant array values are too large"); + } + + frame.offsets.push_back(static_cast(offset)); + const auto value_start = impl_->state->value.length(); + impl_->state->frames.push_back( + VariantBuildFrame{.kind = VariantContainerKind::Object, + .value_start = value_start, + .metadata_size = impl_->state->metadata.size(), + .parent_frame = impl_->frame_index, + .parent_entry_count = element_count, + .has_parent = true}); + return VariantObjectBuilder(std::make_unique( + impl_->state, impl_->state->frames.size() - 1, impl_->callback)); +} + +Result VariantListBuilder::StartList() { + ARROW_RETURN_NOT_OK( + CheckTopFrame(*impl_->state, impl_->frame_index, VariantContainerKind::List)); + auto& frame = impl_->state->frames.back(); + const auto element_count = frame.offsets.size(); + const auto offset = impl_->state->value.length() - frame.value_start; + DCHECK_GE(offset, 0); + if (offset > std::numeric_limits::max()) { + return Status::Invalid("Variant array values are too large"); + } + + frame.offsets.push_back(static_cast(offset)); + const auto value_start = impl_->state->value.length(); + impl_->state->frames.push_back( + VariantBuildFrame{.kind = VariantContainerKind::List, + .value_start = value_start, + .metadata_size = impl_->state->metadata.size(), + .parent_frame = impl_->frame_index, + .parent_entry_count = element_count, + .has_parent = true}); + return VariantListBuilder(std::make_unique( + impl_->state, impl_->state->frames.size() - 1, impl_->callback)); +} + +Status VariantListBuilder::Finish() { + ARROW_RETURN_NOT_OK(FinishFrame(impl_->state, impl_->frame_index, + VariantContainerKind::List, impl_->callback)); + impl_->active = false; + return Status::OK(); +} + +struct VariantArrayBuilder::Impl { + explicit Impl(MemoryPool* pool) + : pool(pool), metadata_builder(pool), value_builder(pool), validity_builder(pool) {} + + template + Status AppendValue(Write&& write) { + auto state = std::make_shared(pool); + ARROW_RETURN_NOT_OK(AppendRootPrimitiveWith(state, std::forward(write))); + ARROW_ASSIGN_OR_RAISE(auto metadata, state->metadata.Finish()); + ARROW_ASSIGN_OR_RAISE(auto value, state->value.Finish()); + return AppendEncoded({.metadata = std::move(metadata), .value = std::move(value)}); + } + + template + Status AppendPrimitive(Args&&... args) { + return AppendValue([&](VariantValueWriter& writer) { + return writer.template Append(std::forward(args)...); + }); + } + + Status AppendShortString(std::string_view value) { + return AppendValue( + [&](VariantValueWriter& writer) { return writer.AppendShortString(value); }); + } + + Status AppendEncoded(const EncodedVariantValue& value) { + if (value.metadata == nullptr || value.value == nullptr) { + return Status::Invalid( + "Encoded Variant metadata and value buffers must be non-null"); + } + ARROW_ASSIGN_OR_RAISE(auto metadata, + VariantMetadataView::Make(std::string_view{*value.metadata})); + ARROW_RETURN_NOT_OK( + VariantValueView::Validate(std::string_view{*value.value}, metadata)); + ARROW_RETURN_NOT_OK(metadata_builder.Append(std::string_view{*value.metadata})); + ARROW_RETURN_NOT_OK(value_builder.Append(std::string_view{*value.value})); + return validity_builder.Append(true); + } + + MemoryPool* pool; + BinaryBuilder metadata_builder; + BinaryBuilder value_builder; + BooleanBuilder validity_builder; +}; + +VariantArrayBuilder::VariantArrayBuilder(MemoryPool* pool) + : impl_(std::make_unique(pool)) {} +VariantArrayBuilder::~VariantArrayBuilder() = default; +VariantArrayBuilder::VariantArrayBuilder(VariantArrayBuilder&&) noexcept = default; +VariantArrayBuilder& VariantArrayBuilder::operator=(VariantArrayBuilder&&) noexcept = + default; + +Status VariantArrayBuilder::AppendNull() { + ARROW_RETURN_NOT_OK(impl_->metadata_builder.Append("")); + ARROW_RETURN_NOT_OK(impl_->value_builder.Append("")); + return impl_->validity_builder.Append(false); +} + +Status VariantArrayBuilder::AppendVariantNull() { + return impl_->AppendPrimitive(); +} + +Status VariantArrayBuilder::AppendBoolean(bool value) { + if (value) { + return impl_->AppendPrimitive(); + } + return impl_->AppendPrimitive(); +} + +#define VARIANT_ARRAY_APPEND_ONE_ARG(NAME, TYPE, C_TYPE) \ + Status VariantArrayBuilder::Append##NAME(C_TYPE value) { \ + return impl_->AppendPrimitive(value); \ + } + +VARIANT_ARRAY_APPEND_ONE_ARG(Int8, Int8, int8_t) +VARIANT_ARRAY_APPEND_ONE_ARG(Int16, Int16, int16_t) +VARIANT_ARRAY_APPEND_ONE_ARG(Int32, Int32, int32_t) +VARIANT_ARRAY_APPEND_ONE_ARG(Int64, Int64, int64_t) +VARIANT_ARRAY_APPEND_ONE_ARG(Float, Float, float) +VARIANT_ARRAY_APPEND_ONE_ARG(Double, Double, double) +VARIANT_ARRAY_APPEND_ONE_ARG(Binary, Binary, std::string_view) +VARIANT_ARRAY_APPEND_ONE_ARG(String, String, std::string_view) +VARIANT_ARRAY_APPEND_ONE_ARG(Date, Date, int32_t) +VARIANT_ARRAY_APPEND_ONE_ARG(TimeNTZMicros, TimeNTZMicros, int64_t) +VARIANT_ARRAY_APPEND_ONE_ARG(Uuid, Uuid, std::string_view) + +#undef VARIANT_ARRAY_APPEND_ONE_ARG + +Status VariantArrayBuilder::AppendDecimal4(int32_t unscaled_value, uint8_t scale) { + return impl_->AppendPrimitive(unscaled_value, scale); +} + +Status VariantArrayBuilder::AppendDecimal8(int64_t unscaled_value, uint8_t scale) { + return impl_->AppendPrimitive(unscaled_value, scale); +} + +Status VariantArrayBuilder::AppendDecimal16(std::string_view little_endian_unscaled_value, + uint8_t scale) { + return impl_->AppendPrimitive( + little_endian_unscaled_value, scale); +} + +Status VariantArrayBuilder::AppendShortString(std::string_view value) { + return impl_->AppendShortString(value); +} + +Status VariantArrayBuilder::AppendTimestampMicros(int64_t micros, bool adjusted_to_utc) { + if (adjusted_to_utc) { + return impl_->AppendPrimitive(micros); + } + return impl_->AppendPrimitive(micros); +} + +Status VariantArrayBuilder::AppendTimestampNanos(int64_t nanos, bool adjusted_to_utc) { + if (adjusted_to_utc) { + return impl_->AppendPrimitive(nanos); + } + return impl_->AppendPrimitive(nanos); +} + +Status VariantArrayBuilder::AppendEncoded(const EncodedVariantValue& value) { + return impl_->AppendEncoded(value); +} + +Result VariantArrayBuilder::StartObject() { + auto state = std::make_shared(impl_->pool); + state->frames.push_back(VariantBuildFrame{.kind = VariantContainerKind::Object, + .value_start = state->value.length(), + .metadata_size = state->metadata.size()}); + auto callback = [this](EncodedVariantValue encoded) { + return impl_->AppendEncoded(encoded); + }; + return VariantObjectBuilder(std::make_unique( + state, state->frames.size() - 1, std::move(callback))); +} + +Result VariantArrayBuilder::StartList() { + auto state = std::make_shared(impl_->pool); + state->frames.push_back(VariantBuildFrame{.kind = VariantContainerKind::List, + .value_start = state->value.length(), + .metadata_size = state->metadata.size()}); + auto callback = [this](EncodedVariantValue encoded) { + return impl_->AppendEncoded(encoded); + }; + return VariantListBuilder(std::make_unique( + state, state->frames.size() - 1, std::move(callback))); +} + +Result> VariantArrayBuilder::Finish() { + std::shared_ptr metadata; + std::shared_ptr value; + std::shared_ptr validity; + ARROW_RETURN_NOT_OK(impl_->metadata_builder.Finish(&metadata)); + ARROW_RETURN_NOT_OK(impl_->value_builder.Finish(&value)); + ARROW_RETURN_NOT_OK(impl_->validity_builder.Finish(&validity)); + + auto null_bitmap = validity->data()->buffers[1]; + const int64_t null_count = validity->false_count(); + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + ARROW_ASSIGN_OR_RAISE(auto storage, + StructArray::Make({metadata, value}, storage_type->fields(), + null_bitmap, null_count)); + return MakeVariantArrayFromStorage(storage); +} + +void VariantArrayBuilder::Reset() { impl_ = std::make_unique(impl_->pool); } + +struct VariantValueArrayBuilder::Impl { + explicit Impl(MemoryPool* pool) : value_builder(pool) {} + + BinaryBuilder value_builder; +}; + +VariantValueArrayBuilder::VariantValueArrayBuilder(MemoryPool* pool) + : impl_(std::make_unique(pool)) {} +VariantValueArrayBuilder::VariantValueArrayBuilder(VariantValueArrayBuilder&&) noexcept = + default; +VariantValueArrayBuilder& VariantValueArrayBuilder::operator=( + VariantValueArrayBuilder&&) noexcept = default; +VariantValueArrayBuilder::~VariantValueArrayBuilder() = default; + +Status VariantValueArrayBuilder::AppendNull() { + return impl_->value_builder.AppendNull(); +} + +Status VariantValueArrayBuilder::AppendEncodedValue(std::string_view metadata, + std::string_view value) { + ARROW_ASSIGN_OR_RAISE(auto metadata_view, VariantMetadataView::Make(metadata)); + ARROW_RETURN_NOT_OK(VariantValueView::Validate(value, metadata_view)); + return impl_->value_builder.Append(value); +} + +Result> VariantValueArrayBuilder::Finish() { + std::shared_ptr out; + ARROW_RETURN_NOT_OK(impl_->value_builder.Finish(&out)); + return out; +} + +Result> MakeVariantArrayFromStorage( + std::shared_ptr storage) { + if (storage == nullptr) { + return Status::Invalid("Variant storage array must be non-null"); + } + ARROW_ASSIGN_OR_RAISE(auto type, VariantExtensionType::Make(storage->type())); + auto array = ExtensionType::WrapArray(type, std::move(storage)); + return std::static_pointer_cast(array); +} + +Result> MakeVariantArrayFromChildren( + std::shared_ptr storage_type, std::vector> children, + std::shared_ptr null_bitmap) { + if (storage_type->id() != Type::STRUCT) { + return Status::Invalid("Variant storage type must be struct, got ", + storage_type->ToString()); + } + + const auto& struct_type = + ::arrow::internal::checked_cast(*storage_type); + if (children.size() != static_cast(struct_type.num_fields())) { + return Status::Invalid("Variant storage expected ", struct_type.num_fields(), + " children, got ", children.size()); + } + + const int64_t length = children.empty() ? 0 : children[0]->length(); + for (int i = 0; i < struct_type.num_fields(); ++i) { + if (children[i] == nullptr) { + return Status::Invalid("Variant storage child ", i, " is null"); + } + if (!children[i]->type()->Equals(struct_type.field(i)->type())) { + return Status::Invalid("Variant storage child ", i, " has type ", + children[i]->type()->ToString(), ", expected ", + struct_type.field(i)->type()->ToString()); + } + if (children[i]->length() != length) { + return Status::Invalid("Variant storage child lengths must match"); + } + } + + ARROW_ASSIGN_OR_RAISE(auto storage, + StructArray::Make(std::move(children), struct_type.fields(), + std::move(null_bitmap))); + return MakeVariantArrayFromStorage(storage); +} + +} // namespace parquet::variant diff --git a/cpp/src/parquet/variant/builder.h b/cpp/src/parquet/variant/builder.h new file mode 100644 index 000000000000..cc0507bfa9f8 --- /dev/null +++ b/cpp/src/parquet/variant/builder.h @@ -0,0 +1,264 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/extension/parquet_variant.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "parquet/platform.h" + +namespace parquet::variant { + +using ::arrow::Array; +using ::arrow::BinaryArray; +using ::arrow::Buffer; +using ::arrow::DataType; +using ::arrow::MemoryPool; +using ::arrow::Result; +using ::arrow::Status; +using ::arrow::StructArray; +using ::arrow::extension::VariantArray; + +struct PARQUET_EXPORT EncodedVariantValue { + std::shared_ptr metadata; + std::shared_ptr value; +}; + +class VariantObjectBuilder; +class VariantListBuilder; + +class PARQUET_EXPORT VariantBuilder { + public: + explicit VariantBuilder(MemoryPool* pool = ::arrow::default_memory_pool()); + ~VariantBuilder(); + VariantBuilder(const VariantBuilder&) = delete; + VariantBuilder& operator=(const VariantBuilder&) = delete; + VariantBuilder(VariantBuilder&&) noexcept; + VariantBuilder& operator=(VariantBuilder&&) noexcept; + + Status ReserveFieldNames(int64_t capacity); + Result AddFieldName(std::string_view name); + + Status AppendVariantNull(); + Status AppendBoolean(bool value); + Status AppendInt8(int8_t value); + Status AppendInt16(int16_t value); + Status AppendInt32(int32_t value); + Status AppendInt64(int64_t value); + Status AppendFloat(float value); + Status AppendDouble(double value); + Status AppendBinary(std::string_view value); + Status AppendString(std::string_view value); + Status AppendDate(int32_t days); + Status AppendTimeNTZMicros(int64_t micros); + Status AppendUuid(std::string_view big_endian_bytes); + Status AppendDecimal4(int32_t unscaled_value, uint8_t scale); + Status AppendDecimal8(int64_t unscaled_value, uint8_t scale); + Status AppendDecimal16(std::string_view little_endian_unscaled_value, uint8_t scale); + Status AppendShortString(std::string_view value); + Status AppendTimestampMicros(int64_t micros, bool adjusted_to_utc); + Status AppendTimestampNanos(int64_t nanos, bool adjusted_to_utc); + + Result StartObject(); + Result StartList(); + + Result Finish(); + void Reset(); + + private: + struct Impl; + std::unique_ptr impl_; +}; + +namespace internal { +struct NestedVariantBuilderImpl; +} + +class PARQUET_EXPORT VariantObjectBuilder { + public: + ~VariantObjectBuilder(); + VariantObjectBuilder(const VariantObjectBuilder&) = delete; + VariantObjectBuilder& operator=(const VariantObjectBuilder&) = delete; + VariantObjectBuilder(VariantObjectBuilder&&) noexcept; + VariantObjectBuilder& operator=(VariantObjectBuilder&&) noexcept; + + Status AppendVariantNull(std::string_view field_name); + Status AppendBoolean(std::string_view field_name, bool value); + Status AppendInt8(std::string_view field_name, int8_t value); + Status AppendInt16(std::string_view field_name, int16_t value); + Status AppendInt32(std::string_view field_name, int32_t value); + Status AppendInt64(std::string_view field_name, int64_t value); + Status AppendFloat(std::string_view field_name, float value); + Status AppendDouble(std::string_view field_name, double value); + Status AppendBinary(std::string_view field_name, std::string_view value); + Status AppendString(std::string_view field_name, std::string_view value); + Status AppendDate(std::string_view field_name, int32_t days); + Status AppendTimeNTZMicros(std::string_view field_name, int64_t micros); + Status AppendUuid(std::string_view field_name, std::string_view big_endian_bytes); + Status AppendDecimal4(std::string_view field_name, int32_t unscaled_value, + uint8_t scale); + Status AppendDecimal8(std::string_view field_name, int64_t unscaled_value, + uint8_t scale); + Status AppendDecimal16(std::string_view field_name, + std::string_view little_endian_unscaled_value, uint8_t scale); + Status AppendShortString(std::string_view field_name, std::string_view value); + Status AppendTimestampMicros(std::string_view field_name, int64_t micros, + bool adjusted_to_utc); + Status AppendTimestampNanos(std::string_view field_name, int64_t nanos, + bool adjusted_to_utc); + + Result StartObject(std::string_view field_name); + Result StartList(std::string_view field_name); + + /// Commit this nested object into its parent builder. Destroying an unfinished + /// nested builder rolls back the object contents written through this builder. + Status Finish(); + + private: + friend class VariantBuilder; + friend class VariantListBuilder; + friend class VariantArrayBuilder; + + explicit VariantObjectBuilder(std::unique_ptr impl); + + std::unique_ptr impl_; +}; + +class PARQUET_EXPORT VariantListBuilder { + public: + ~VariantListBuilder(); + VariantListBuilder(const VariantListBuilder&) = delete; + VariantListBuilder& operator=(const VariantListBuilder&) = delete; + VariantListBuilder(VariantListBuilder&&) noexcept; + VariantListBuilder& operator=(VariantListBuilder&&) noexcept; + + Status AppendVariantNull(); + Status AppendBoolean(bool value); + Status AppendInt8(int8_t value); + Status AppendInt16(int16_t value); + Status AppendInt32(int32_t value); + Status AppendInt64(int64_t value); + Status AppendFloat(float value); + Status AppendDouble(double value); + Status AppendBinary(std::string_view value); + Status AppendString(std::string_view value); + Status AppendDate(int32_t days); + Status AppendTimeNTZMicros(int64_t micros); + Status AppendUuid(std::string_view big_endian_bytes); + Status AppendDecimal4(int32_t unscaled_value, uint8_t scale); + Status AppendDecimal8(int64_t unscaled_value, uint8_t scale); + Status AppendDecimal16(std::string_view little_endian_unscaled_value, uint8_t scale); + Status AppendShortString(std::string_view value); + Status AppendTimestampMicros(int64_t micros, bool adjusted_to_utc); + Status AppendTimestampNanos(int64_t nanos, bool adjusted_to_utc); + + Result StartObject(); + Result StartList(); + + /// Commit this nested list into its parent builder. Destroying an unfinished nested + /// builder rolls back the list contents written through this builder. + Status Finish(); + + private: + friend class VariantBuilder; + friend class VariantObjectBuilder; + friend class VariantArrayBuilder; + + explicit VariantListBuilder(std::unique_ptr impl); + + std::unique_ptr impl_; +}; + +class PARQUET_EXPORT VariantArrayBuilder { + public: + explicit VariantArrayBuilder(MemoryPool* pool = ::arrow::default_memory_pool()); + ~VariantArrayBuilder(); + VariantArrayBuilder(const VariantArrayBuilder&) = delete; + VariantArrayBuilder& operator=(const VariantArrayBuilder&) = delete; + VariantArrayBuilder(VariantArrayBuilder&&) noexcept; + VariantArrayBuilder& operator=(VariantArrayBuilder&&) noexcept; + + Status AppendNull(); + Status AppendVariantNull(); + Status AppendBoolean(bool value); + Status AppendInt8(int8_t value); + Status AppendInt16(int16_t value); + Status AppendInt32(int32_t value); + Status AppendInt64(int64_t value); + Status AppendFloat(float value); + Status AppendDouble(double value); + Status AppendBinary(std::string_view value); + Status AppendString(std::string_view value); + Status AppendDate(int32_t days); + Status AppendTimeNTZMicros(int64_t micros); + Status AppendUuid(std::string_view big_endian_bytes); + Status AppendDecimal4(int32_t unscaled_value, uint8_t scale); + Status AppendDecimal8(int64_t unscaled_value, uint8_t scale); + Status AppendDecimal16(std::string_view little_endian_unscaled_value, uint8_t scale); + Status AppendShortString(std::string_view value); + Status AppendTimestampMicros(int64_t micros, bool adjusted_to_utc); + Status AppendTimestampNanos(int64_t nanos, bool adjusted_to_utc); + Status AppendEncoded(const EncodedVariantValue& value); + + Result StartObject(); + Result StartList(); + + Result> Finish(); + void Reset(); + + private: + struct Impl; + std::unique_ptr impl_; +}; + +class PARQUET_EXPORT VariantValueArrayBuilder { + public: + explicit VariantValueArrayBuilder(MemoryPool* pool = ::arrow::default_memory_pool()); + ~VariantValueArrayBuilder(); + VariantValueArrayBuilder(const VariantValueArrayBuilder&) = delete; + VariantValueArrayBuilder& operator=(const VariantValueArrayBuilder&) = delete; + VariantValueArrayBuilder(VariantValueArrayBuilder&&) noexcept; + VariantValueArrayBuilder& operator=(VariantValueArrayBuilder&&) noexcept; + + Status AppendNull(); + Status AppendEncodedValue(std::string_view metadata, std::string_view value); + Result> Finish(); + + private: + struct Impl; + std::unique_ptr impl_; +}; + +PARQUET_EXPORT +Result> MakeVariantArrayFromStorage( + std::shared_ptr storage); + +PARQUET_EXPORT +Result> MakeVariantArrayFromChildren( + std::shared_ptr storage_type, std::vector> children, + std::shared_ptr null_bitmap = nullptr); + +} // namespace parquet::variant diff --git a/cpp/src/parquet/variant/builder_test.cc b/cpp/src/parquet/variant/builder_test.cc new file mode 100644 index 000000000000..e82addb27a81 --- /dev/null +++ b/cpp/src/parquet/variant/builder_test.cc @@ -0,0 +1,310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "parquet/variant/builder.h" +#include "parquet/variant/encoding.h" +#include "parquet/variant/test_util_internal.h" + +#include +#include +#include + +#include "arrow/array.h" // IWYU pragma: keep +#include "arrow/testing/gtest_util.h" +#include "arrow/util/checked_cast.h" + +namespace parquet::variant { + +namespace { + +using ::arrow::ArrayFromJSON; +using ::arrow::binary; +using ::arrow::default_memory_pool; +using ::arrow::field; +using ::arrow::int64; +using ::arrow::ProxyMemoryPool; +using ::arrow::struct_; +using ::arrow::StructArray; +using ::arrow::Type; +using ::arrow::extension::variant; +using internal::BinaryArrayFromValues; +using internal::MakeVariantValueView; + +void AssertPrimitiveType(std::string_view value, const VariantMetadataView& metadata, + VariantPrimitiveType expected) { + ASSERT_OK_AND_ASSIGN(auto view, VariantValueView::Make(value, metadata)); + ASSERT_EQ(VariantBasicType::kPrimitive, view.basic_type()); + ASSERT_EQ(expected, std::get(view.data()).type()); +} + +void AssertPrimitiveFieldType(const VariantObjectView& object, std::string_view name, + const VariantMetadataView& metadata, + VariantPrimitiveType expected) { + const auto* field = object.FindField(name); + ASSERT_NE(nullptr, field) << "Missing Variant object field: " << name; + AssertPrimitiveType(field->value, metadata, expected); +} + +} // namespace + +TEST(TestVariantBuilder, Object) { + VariantBuilder builder; + ASSERT_OK_AND_ASSIGN(auto object, builder.StartObject()); + ASSERT_OK(object.AppendVariantNull("c")); + ASSERT_OK(object.AppendVariantNull("b")); + ASSERT_OK(object.AppendVariantNull("a")); + ASSERT_OK(object.Finish()); + + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto view, MakeVariantValueView(encoded)); + ASSERT_EQ(VariantBasicType::kObject, view.basic_type()); + const auto& fields = std::get(view.data()).fields(); + ASSERT_EQ(3, fields.size()); + ASSERT_EQ("a", fields[0].name); + ASSERT_EQ("b", fields[1].name); + ASSERT_EQ("c", fields[2].name); +} + +TEST(TestVariantBuilder, List) { + VariantBuilder builder; + ASSERT_OK_AND_ASSIGN(auto list, builder.StartList()); + ASSERT_OK(list.AppendVariantNull()); + ASSERT_OK(list.AppendInt32(42)); + ASSERT_OK(list.AppendString("x")); + ASSERT_OK(list.Finish()); + + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto metadata, + VariantMetadataView::Make(std::string_view{*encoded.metadata})); + ASSERT_OK_AND_ASSIGN( + auto view, VariantValueView::Make(std::string_view{*encoded.value}, metadata)); + ASSERT_EQ(VariantBasicType::kArray, view.basic_type()); + const auto& elements = std::get(view.data()).elements(); + ASSERT_EQ(3, elements.size()); + + ASSERT_OK_AND_ASSIGN(auto element, VariantValueView::Make(elements[1], metadata)); + ASSERT_EQ(VariantPrimitiveType::kInt32, + std::get(element.data()).type()); +} + +TEST(TestVariantBuilder, Nested) { + VariantBuilder builder; + ASSERT_OK_AND_ASSIGN(auto object, builder.StartObject()); + ASSERT_OK_AND_ASSIGN(auto list, object.StartList("items")); + ASSERT_OK(list.AppendInt32(1)); + ASSERT_OK_AND_ASSIGN(auto child, list.StartObject()); + ASSERT_OK(child.AppendString("name", "x")); + ASSERT_OK(child.Finish()); + ASSERT_OK(list.Finish()); + ASSERT_OK(object.Finish()); + + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto metadata, + VariantMetadataView::Make(std::string_view{*encoded.metadata})); + ASSERT_OK_AND_ASSIGN( + auto root, VariantValueView::Make(std::string_view{*encoded.value}, metadata)); + const auto& root_fields = std::get(root.data()).fields(); + ASSERT_EQ(1, root_fields.size()); + ASSERT_EQ("items", root_fields[0].name); + + ASSERT_OK_AND_ASSIGN(auto items, + VariantValueView::Make(root_fields[0].value, metadata)); + const auto& item_values = std::get(items.data()).elements(); + ASSERT_EQ(2, item_values.size()); + ASSERT_OK_AND_ASSIGN(auto item, VariantValueView::Make(item_values[1], metadata)); + ASSERT_EQ("name", std::get(item.data()).fields()[0].name); +} + +TEST(TestVariantBuilder, ObjectAppends) { + const std::string uuid(16, '\1'); + VariantBuilder builder; + ASSERT_OK_AND_ASSIGN(auto object, builder.StartObject()); + ASSERT_OK(object.AppendInt8("int8", 1)); + ASSERT_OK(object.AppendInt64("int64", 2)); + ASSERT_OK(object.AppendDouble("double", 3)); + ASSERT_OK(object.AppendDecimal4("decimal", 1234, 2)); + ASSERT_OK(object.AppendBinary("binary", "abc")); + ASSERT_OK(object.AppendDate("date", 1)); + ASSERT_OK(object.AppendTimestampNanos("ts", 2, true)); + ASSERT_OK(object.AppendUuid("uuid", uuid)); + ASSERT_OK(object.Finish()); + + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto metadata, + VariantMetadataView::Make(std::string_view{*encoded.metadata})); + ASSERT_OK_AND_ASSIGN( + auto root, VariantValueView::Make(std::string_view{*encoded.value}, metadata)); + const auto& object_view = std::get(root.data()); + ASSERT_EQ(8, object_view.fields().size()); + AssertPrimitiveFieldType(object_view, "int8", metadata, VariantPrimitiveType::kInt8); + AssertPrimitiveFieldType(object_view, "int64", metadata, VariantPrimitiveType::kInt64); + AssertPrimitiveFieldType(object_view, "double", metadata, + VariantPrimitiveType::kDouble); + AssertPrimitiveFieldType(object_view, "decimal", metadata, + VariantPrimitiveType::kDecimal4); + AssertPrimitiveFieldType(object_view, "binary", metadata, + VariantPrimitiveType::kBinary); + AssertPrimitiveFieldType(object_view, "date", metadata, VariantPrimitiveType::kDate); + AssertPrimitiveFieldType(object_view, "ts", metadata, + VariantPrimitiveType::kTimestampNanos); + AssertPrimitiveFieldType(object_view, "uuid", metadata, VariantPrimitiveType::kUuid); +} + +TEST(TestVariantBuilder, ListAppends) { + const std::string uuid(16, '\2'); + VariantBuilder builder; + ASSERT_OK_AND_ASSIGN(auto list, builder.StartList()); + ASSERT_OK(list.AppendInt8(1)); + ASSERT_OK(list.AppendInt64(2)); + ASSERT_OK(list.AppendDouble(3)); + ASSERT_OK(list.AppendDecimal4(1234, 2)); + ASSERT_OK(list.AppendBinary("abc")); + ASSERT_OK(list.AppendDate(1)); + ASSERT_OK(list.AppendTimestampNanos(2, true)); + ASSERT_OK(list.AppendUuid(uuid)); + ASSERT_OK(list.Finish()); + + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto metadata, + VariantMetadataView::Make(std::string_view{*encoded.metadata})); + ASSERT_OK_AND_ASSIGN( + auto root, VariantValueView::Make(std::string_view{*encoded.value}, metadata)); + const auto& elements = std::get(root.data()).elements(); + ASSERT_EQ(8, elements.size()); + AssertPrimitiveType(elements[0], metadata, VariantPrimitiveType::kInt8); + AssertPrimitiveType(elements[1], metadata, VariantPrimitiveType::kInt64); + AssertPrimitiveType(elements[2], metadata, VariantPrimitiveType::kDouble); + AssertPrimitiveType(elements[3], metadata, VariantPrimitiveType::kDecimal4); + AssertPrimitiveType(elements[4], metadata, VariantPrimitiveType::kBinary); + AssertPrimitiveType(elements[5], metadata, VariantPrimitiveType::kDate); + AssertPrimitiveType(elements[6], metadata, VariantPrimitiveType::kTimestampNanos); + AssertPrimitiveType(elements[7], metadata, VariantPrimitiveType::kUuid); +} + +TEST(TestVariantBuilder, Rollback) { + VariantBuilder builder; + ASSERT_OK_AND_ASSIGN(auto object, builder.StartObject()); + { + ASSERT_OK_AND_ASSIGN(auto child, object.StartObject("drop")); + ASSERT_OK(child.AppendString("nested", "x")); + } + ASSERT_OK(object.AppendInt32("ok", 1)); + ASSERT_OK(object.Finish()); + + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto view, MakeVariantValueView(encoded)); + const auto& fields = std::get(view.data()).fields(); + ASSERT_EQ(1, fields.size()); + ASSERT_EQ("ok", fields[0].name); +} + +TEST(TestVariantBuilder, Duplicate) { + VariantBuilder builder; + ASSERT_OK_AND_ASSIGN(auto object, builder.StartObject()); + ASSERT_OK(object.AppendInt32("a", 1)); + ASSERT_RAISES(Invalid, object.AppendString("a", "x")); +} + +TEST(TestVariantBuilder, UsesMemoryPool) { + ProxyMemoryPool pool(default_memory_pool()); + VariantBuilder builder(&pool); + ASSERT_OK(builder.AppendString(std::string(128, 'x'))); + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_GT(pool.total_bytes_allocated(), 0); + + ASSERT_OK_AND_ASSIGN(auto metadata, + VariantMetadataView::Make(std::string_view{*encoded.metadata})); + ASSERT_OK(VariantValueView::Validate(std::string_view{*encoded.value}, metadata)); +} + +TEST(TestVariantBuilder, ArrayBuilder) { + VariantArrayBuilder builder; + ASSERT_OK(builder.AppendNull()); + ASSERT_OK(builder.AppendVariantNull()); + ASSERT_OK(builder.AppendInt32(42)); + ASSERT_OK(builder.AppendInt8(1)); + ASSERT_OK(builder.AppendDouble(2)); + ASSERT_OK(builder.AppendBinary("abc")); + ASSERT_OK(builder.AppendTimestampMicros(3, true)); + ASSERT_OK_AND_ASSIGN(auto object, builder.StartObject()); + ASSERT_OK(object.AppendString("a", "x")); + ASSERT_OK(object.Finish()); + + ASSERT_OK_AND_ASSIGN(auto array, builder.Finish()); + ASSERT_EQ(8, array->length()); + ASSERT_TRUE(array->IsNull(0)); + ASSERT_FALSE(array->IsNull(1)); + + const auto& storage = + ::arrow::internal::checked_cast(*array->storage()); + ASSERT_FALSE(storage.type()->field(0)->nullable()); + ASSERT_FALSE(storage.type()->field(1)->nullable()); + ASSERT_EQ(Type::BINARY, storage.field(0)->type_id()); + ASSERT_EQ(Type::BINARY, storage.field(1)->type_id()); +} + +TEST(TestVariantBuilder, FromStorage) { + VariantBuilder value; + ASSERT_OK(value.AppendInt32(1)); + ASSERT_OK_AND_ASSIGN(auto encoded, value.Finish()); + + auto metadata = BinaryArrayFromValues({std::string_view{*encoded.metadata}}); + auto values = BinaryArrayFromValues({std::string_view{*encoded.value}}); + ASSERT_OK_AND_ASSIGN( + auto storage, + StructArray::Make({metadata, values}, {field("metadata", binary(), false), + field("value", binary(), false)})); + + ASSERT_OK_AND_ASSIGN(auto array, MakeVariantArrayFromStorage(storage)); + ASSERT_EQ(1, array->length()); + ASSERT_TRUE(array->type()->Equals(variant(storage->type()))); +} + +TEST(TestVariantBuilder, FromShredded) { + VariantBuilder value; + ASSERT_OK(value.AppendVariantNull()); + ASSERT_OK_AND_ASSIGN(auto encoded, value.Finish()); + + auto metadata = BinaryArrayFromValues( + {std::string_view{*encoded.metadata}, std::string_view{*encoded.metadata}}); + auto values = BinaryArrayFromValues({std::nullopt, std::string_view{*encoded.value}}); + auto typed = ArrayFromJSON(int64(), "[1, null]"); + auto storage_type = struct_({field("metadata", binary(), false), + field("value", binary()), field("typed_value", int64())}); + + ASSERT_OK_AND_ASSIGN( + auto array, MakeVariantArrayFromChildren(storage_type, {metadata, values, typed})); + ASSERT_EQ(2, array->length()); + ASSERT_TRUE(array->type()->Equals(variant(storage_type))); +} + +TEST(TestVariantBuilder, FallbackValue) { + VariantBuilder value; + ASSERT_OK(value.AppendString("x")); + ASSERT_OK_AND_ASSIGN(auto encoded, value.Finish()); + + VariantValueArrayBuilder builder; + ASSERT_OK(builder.AppendNull()); + ASSERT_OK(builder.AppendEncodedValue(std::string_view{*encoded.metadata}, + std::string_view{*encoded.value})); + ASSERT_OK_AND_ASSIGN(auto array, builder.Finish()); + ASSERT_EQ(2, array->length()); + ASSERT_TRUE(array->IsNull(0)); + ASSERT_EQ(std::string_view{*encoded.value}, array->GetView(1)); +} + +} // namespace parquet::variant diff --git a/cpp/src/parquet/variant/encoding.cc b/cpp/src/parquet/variant/encoding.cc new file mode 100644 index 000000000000..95c61a235589 --- /dev/null +++ b/cpp/src/parquet/variant/encoding.cc @@ -0,0 +1,470 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "parquet/variant/encoding.h" + +#include +#include +#include + +#include "arrow/util/endian.h" +#include "arrow/util/logging_internal.h" +#include "parquet/variant/encoding_internal.h" + +namespace parquet::variant { + +namespace bit_util = ::arrow::bit_util; + +namespace { + +uint32_t ReadLittleEndian(std::string_view data, size_t offset, size_t width) { + DCHECK_LE(width, sizeof(uint32_t)); + uint32_t value = 0; + std::memcpy(&value, data.data() + offset, width); + return bit_util::FromLittleEndian(value); +} + +Status CheckAvailable(std::string_view data, size_t offset, size_t size, + std::string_view context) { + if (offset > data.size() || data.size() - offset < size) { + return Status::Invalid("Invalid Variant encoding: truncated ", context); + } + return Status::OK(); +} + +Status PrimitivePayloadSize(std::string_view value, size_t offset, + VariantPrimitiveType primitive, size_t* size) { + switch (primitive) { + case VariantPrimitiveType::kNull: + case VariantPrimitiveType::kBooleanTrue: + case VariantPrimitiveType::kBooleanFalse: + *size = 0; + return Status::OK(); + case VariantPrimitiveType::kInt8: + *size = 1; + return Status::OK(); + case VariantPrimitiveType::kInt16: + *size = 2; + return Status::OK(); + case VariantPrimitiveType::kInt32: + case VariantPrimitiveType::kDate: + case VariantPrimitiveType::kFloat: + *size = 4; + return Status::OK(); + case VariantPrimitiveType::kInt64: + case VariantPrimitiveType::kDouble: + case VariantPrimitiveType::kTimestampMicros: + case VariantPrimitiveType::kTimestampNTZMicros: + case VariantPrimitiveType::kTimeNTZMicros: + case VariantPrimitiveType::kTimestampNanos: + case VariantPrimitiveType::kTimestampNTZNanos: + *size = 8; + return Status::OK(); + case VariantPrimitiveType::kDecimal4: + *size = 5; + return Status::OK(); + case VariantPrimitiveType::kDecimal8: + *size = 9; + return Status::OK(); + case VariantPrimitiveType::kDecimal16: + *size = 17; + return Status::OK(); + case VariantPrimitiveType::kUuid: + *size = 16; + return Status::OK(); + case VariantPrimitiveType::kBinary: + case VariantPrimitiveType::kString: { + ARROW_RETURN_NOT_OK(CheckAvailable(value, offset, 4, "variable-length size")); + const uint32_t length = ReadLittleEndian(value, offset, 4); + *size = 4 + static_cast(length); + return Status::OK(); + } + } + return Status::Invalid("Invalid Variant encoding: unknown primitive type"); +} + +Status ParsePrimitive(std::string_view value, size_t offset, + VariantPrimitiveType primitive, size_t& consumed) { + if (!internal::IsKnownVariantPrimitive(primitive)) { + return Status::Invalid("Invalid Variant encoding: unknown primitive type ", + static_cast(primitive)); + } + + size_t payload_size = 0; + ARROW_RETURN_NOT_OK(PrimitivePayloadSize(value, offset, primitive, &payload_size)); + ARROW_RETURN_NOT_OK(CheckAvailable(value, offset, payload_size, "primitive value")); + + if (internal::IsDecimalVariantPrimitive(primitive)) { + const auto scale = static_cast(value[offset]); + ARROW_RETURN_NOT_OK(internal::ValidateDecimalScale(scale)); + } + + if (primitive == VariantPrimitiveType::kString) { + const uint32_t length = ReadLittleEndian(value, offset, 4); + ARROW_RETURN_NOT_OK(internal::ValidateUtf8(value.substr(offset + 4, length), + "primitive string value")); + } + + consumed = payload_size; + return Status::OK(); +} + +Status ParseValue(std::string_view value, const VariantMetadataView& metadata, + size_t& consumed, VariantValueView* out); + +Status ParseArray(std::string_view value, const VariantMetadataView& metadata, + uint8_t header, size_t& consumed, VariantValueView* out) { + const auto offset_size = static_cast((header & 0x03) + 1); + const bool is_large = (header & 0x04) != 0; + const size_t count_size = is_large ? 4 : 1; + + size_t offset = 1; + ARROW_RETURN_NOT_OK(CheckAvailable(value, offset, count_size, "array size")); + const uint32_t num_elements = ReadLittleEndian(value, offset, count_size); + offset += count_size; + + ARROW_RETURN_NOT_OK( + CheckAvailable(value, offset, (static_cast(num_elements) + 1) * offset_size, + "array offsets")); + + std::vector offsets(num_elements + 1); + for (uint32_t i = 0; i <= num_elements; ++i) { + offsets[i] = ReadLittleEndian(value, offset, offset_size); + offset += offset_size; + } + + if (offsets[0] != 0) { + return Status::Invalid("Invalid Variant encoding: first array offset must be 0"); + } + for (uint32_t i = 0; i < num_elements; ++i) { + if (offsets[i] > offsets[i + 1]) { + return Status::Invalid("Invalid Variant encoding: array offsets must be monotonic"); + } + } + + const size_t values_start = offset; + const size_t total_value_size = offsets[num_elements]; + ARROW_RETURN_NOT_OK( + CheckAvailable(value, values_start, total_value_size, "array values")); + + size_t current = 0; + std::vector array_elements; + if (out != nullptr) { + array_elements.reserve(num_elements); + } + for (uint32_t i = 0; i < num_elements; ++i) { + if (offsets[i] != current) { + return Status::Invalid( + "Invalid Variant encoding: array offset does not match value boundary"); + } + size_t child_consumed = 0; + ARROW_RETURN_NOT_OK(ParseValue(value.substr(values_start + current), metadata, + child_consumed, + /*out=*/nullptr)); + current += child_consumed; + if (current != offsets[i + 1]) { + return Status::Invalid( + "Invalid Variant encoding: array value does not end at next offset"); + } + if (out != nullptr) { + array_elements.push_back(value.substr(values_start + offsets[i], child_consumed)); + } + } + + if (current != total_value_size) { + return Status::Invalid("Invalid Variant encoding: array values have trailing data"); + } + + consumed = values_start + total_value_size; + if (out != nullptr) { + *out = VariantValueView(value.substr(0, consumed), VariantBasicType::kArray, + VariantArrayView(std::move(array_elements))); + } + return Status::OK(); +} + +Status ParseObject(std::string_view value, const VariantMetadataView& metadata, + uint8_t header, size_t& consumed, VariantValueView* out) { + const auto offset_size = static_cast((header & 0x03) + 1); + const auto id_size = static_cast(((header >> 2) & 0x03) + 1); + const bool is_large = (header & 0x10) != 0; + const size_t count_size = is_large ? 4 : 1; + + size_t offset = 1; + ARROW_RETURN_NOT_OK(CheckAvailable(value, offset, count_size, "object size")); + const uint32_t num_elements = ReadLittleEndian(value, offset, count_size); + offset += count_size; + + ARROW_RETURN_NOT_OK(CheckAvailable( + value, offset, static_cast(num_elements) * id_size, "object field ids")); + std::vector field_ids(num_elements); + for (uint32_t i = 0; i < num_elements; ++i) { + field_ids[i] = ReadLittleEndian(value, offset, id_size); + offset += id_size; + } + + ARROW_RETURN_NOT_OK( + CheckAvailable(value, offset, (static_cast(num_elements) + 1) * offset_size, + "object field offsets")); + std::vector field_offsets(num_elements + 1); + for (uint32_t i = 0; i <= num_elements; ++i) { + field_offsets[i] = ReadLittleEndian(value, offset, offset_size); + offset += offset_size; + } + + const size_t values_start = offset; + const size_t total_value_size = field_offsets[num_elements]; + if (num_elements == 0 && total_value_size != 0) { + return Status::Invalid( + "Invalid Variant encoding: empty object must have zero value size"); + } + ARROW_RETURN_NOT_OK( + CheckAvailable(value, values_start, total_value_size, "object values")); + + std::vector object_fields; + if (out != nullptr) { + object_fields.reserve(num_elements); + } + + std::string_view previous_name; + for (uint32_t i = 0; i < num_elements; ++i) { + if (field_ids[i] >= metadata.dictionary_size()) { + return Status::Invalid("Invalid Variant encoding: object field id ", field_ids[i], + " is outside metadata dictionary of size ", + metadata.dictionary_size()); + } + const auto name = metadata.string(field_ids[i]); + if (i > 0 && !(previous_name < name)) { + return Status::Invalid( + "Invalid Variant encoding: object field names must be sorted and unique"); + } + previous_name = name; + + const auto field_offset = field_offsets[i]; + if (field_offset >= total_value_size) { + return Status::Invalid( + "Invalid Variant encoding: object field offset is outside values"); + } + } + + std::vector value_offsets = field_offsets; + std::ranges::sort(value_offsets); + if (std::ranges::adjacent_find(value_offsets) != value_offsets.end()) { + return Status::Invalid( + "Invalid Variant encoding: object field offsets must be unique"); + } + if (value_offsets.front() != 0) { + return Status::Invalid("Invalid Variant encoding: object values have leading data"); + } + + for (uint32_t i = 0; i < num_elements; ++i) { + const uint32_t start = value_offsets[i]; + const uint32_t end = value_offsets[i + 1]; + size_t child_consumed = 0; + ARROW_RETURN_NOT_OK(ParseValue(value.substr(values_start + start), metadata, + child_consumed, + /*out=*/nullptr)); + if (child_consumed != end - start) { + return Status::Invalid( + "Invalid Variant encoding: object value does not end at next value boundary"); + } + } + + if (out != nullptr) { + for (uint32_t i = 0; i < num_elements; ++i) { + const auto field_offset = field_offsets[i]; + auto offset_it = std::ranges::lower_bound(value_offsets, field_offset); + DCHECK(offset_it != value_offsets.end()); + DCHECK(offset_it + 1 != value_offsets.end()); + DCHECK(*offset_it == field_offset); + const auto end = *(offset_it + 1); + object_fields.push_back(VariantObjectField{ + .name = metadata.string(field_ids[i]), + .field_id = field_ids[i], + .value = value.substr(values_start + field_offset, end - field_offset)}); + } + } + + consumed = values_start + total_value_size; + if (out != nullptr) { + *out = VariantValueView(value.substr(0, consumed), VariantBasicType::kObject, + VariantObjectView(std::move(object_fields))); + } + return Status::OK(); +} + +Status ParseValue(std::string_view value, const VariantMetadataView& metadata, + size_t& consumed, VariantValueView* out) { + ARROW_RETURN_NOT_OK(CheckAvailable(value, 0, 1, "value header")); + + const auto metadata_byte = static_cast(value[0]); + const auto basic_type = static_cast(metadata_byte & 0x03); + const auto header = static_cast(metadata_byte >> 2); + + switch (basic_type) { + case VariantBasicType::kPrimitive: { + const auto primitive = static_cast(header); + size_t payload_size = 0; + ARROW_RETURN_NOT_OK(ParsePrimitive(value, 1, primitive, payload_size)); + consumed = 1 + payload_size; + if (out != nullptr) { + *out = VariantValueView( + value.substr(0, consumed), VariantBasicType::kPrimitive, + VariantPrimitiveView(primitive, value.substr(1, payload_size))); + } + return Status::OK(); + } + case VariantBasicType::kShortString: { + ARROW_RETURN_NOT_OK(CheckAvailable(value, 1, header, "short string value")); + ARROW_RETURN_NOT_OK( + internal::ValidateUtf8(value.substr(1, header), "short string value")); + consumed = 1 + header; + if (out != nullptr) { + *out = VariantValueView(value.substr(0, consumed), VariantBasicType::kShortString, + VariantShortStringView(value.substr(1, header))); + } + return Status::OK(); + } + case VariantBasicType::kObject: + return ParseObject(value, metadata, header, consumed, out); + case VariantBasicType::kArray: + return ParseArray(value, metadata, header, consumed, out); + } + return Status::Invalid("Invalid Variant encoding: unknown basic type"); +} + +} // namespace + +Result VariantMetadataView::Make(std::string_view metadata) { + ARROW_RETURN_NOT_OK(CheckAvailable(metadata, 0, 1, "metadata header")); + const auto header = static_cast(metadata[0]); + const auto version = static_cast(header & internal::kMetadataVersionMask); + if (version != internal::kVariantVersion) { + return Status::Invalid("Invalid Variant metadata: expected version 1, got ", + static_cast(version)); + } + + VariantMetadataView view; + view.metadata_ = metadata; + view.sorted_strings_ = (header & internal::kMetadataSortedStringsMask) != 0; + view.offset_size_ = static_cast(((header >> 6) & 0x03) + 1); + + ARROW_RETURN_NOT_OK( + CheckAvailable(metadata, 1, view.offset_size_, "metadata dictionary size")); + const uint32_t dictionary_size = ReadLittleEndian(metadata, 1, view.offset_size_); + const size_t offsets_offset = 1 + view.offset_size_; + ARROW_RETURN_NOT_OK( + CheckAvailable(metadata, offsets_offset, + (static_cast(dictionary_size) + 1) * view.offset_size_, + "metadata dictionary offsets")); + + std::vector offsets(dictionary_size + 1); + for (uint32_t i = 0; i <= dictionary_size; ++i) { + offsets[i] = ReadLittleEndian(metadata, offsets_offset + i * view.offset_size_, + view.offset_size_); + } + + if (offsets[0] != 0) { + return Status::Invalid("Invalid Variant metadata: first dictionary offset must be 0"); + } + for (uint32_t i = 0; i < dictionary_size; ++i) { + if (offsets[i] > offsets[i + 1]) { + return Status::Invalid( + "Invalid Variant metadata: dictionary offsets must be monotonic"); + } + } + + const size_t bytes_offset = + offsets_offset + (static_cast(dictionary_size) + 1) * view.offset_size_; + const size_t bytes_size = offsets[dictionary_size]; + ARROW_RETURN_NOT_OK( + CheckAvailable(metadata, bytes_offset, bytes_size, "metadata dictionary bytes")); + if (metadata.size() != bytes_offset + bytes_size) { + return Status::Invalid("Invalid Variant metadata: trailing bytes after dictionary"); + } + + view.strings_.reserve(dictionary_size); + for (uint32_t i = 0; i < dictionary_size; ++i) { + auto string = metadata.substr(bytes_offset + offsets[i], offsets[i + 1] - offsets[i]); + ARROW_RETURN_NOT_OK(internal::ValidateUtf8(string, "metadata dictionary string")); + if (view.sorted_strings_ && i > 0 && !(view.strings_.back() < string)) { + return Status::Invalid( + "Invalid Variant metadata: sorted dictionary strings must be unique and " + "lexicographically sorted"); + } + view.strings_.push_back(string); + } + + return view; +} + +std::string_view VariantMetadataView::string(uint32_t id) const { + DCHECK_LT(id, strings_.size()); + return strings_[id]; +} + +std::optional VariantMetadataView::FindString(std::string_view value) const { + if (sorted_strings_) { + const auto it = std::ranges::lower_bound(strings_, value); + if (it != strings_.end() && *it == value) { + return static_cast(it - strings_.begin()); + } + return std::nullopt; + } + + for (uint32_t i = 0; i < strings_.size(); ++i) { + if (strings_[i] == value) { + return i; + } + } + return std::nullopt; +} + +const VariantObjectField* VariantObjectView::FindField(std::string_view name) const { + const auto it = std::ranges::lower_bound(fields_, name, {}, &VariantObjectField::name); + return (it == fields_.end() || it->name != name) ? nullptr : &*it; +} + +bool VariantObjectView::ContainsField(std::string_view name) const { + // The Parquet Variant encoding requires object fields to be sorted by name, so field + // lookup can use binary search. + return std::ranges::binary_search(fields_, name, {}, &VariantObjectField::name); +} + +Result VariantValueView::Make(std::string_view value, + const VariantMetadataView& metadata) { + size_t consumed = 0; + VariantValueView view({}, VariantBasicType::kPrimitive, + VariantPrimitiveView(VariantPrimitiveType::kNull, {})); + ARROW_RETURN_NOT_OK(ParseValue(value, metadata, consumed, &view)); + if (consumed != value.size()) { + return Status::Invalid("Invalid Variant encoding: trailing bytes after value"); + } + return view; +} + +Status VariantValueView::Validate(std::string_view value, + const VariantMetadataView& metadata) { + size_t consumed = 0; + ARROW_RETURN_NOT_OK(ParseValue(value, metadata, consumed, /*out=*/nullptr)); + if (consumed != value.size()) { + return Status::Invalid("Invalid Variant encoding: trailing bytes after value"); + } + return Status::OK(); +} + +} // namespace parquet::variant diff --git a/cpp/src/parquet/variant/encoding.h b/cpp/src/parquet/variant/encoding.h new file mode 100644 index 000000000000..1dbf3d9a104d --- /dev/null +++ b/cpp/src/parquet/variant/encoding.h @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/status.h" +#include "parquet/platform.h" + +namespace parquet::variant { + +using ::arrow::Result; +using ::arrow::Status; + +enum class VariantBasicType : uint8_t { + kPrimitive = 0, + kShortString = 1, + kObject = 2, + kArray = 3, +}; + +enum class VariantPrimitiveType : uint8_t { + kNull = 0, + kBooleanTrue = 1, + kBooleanFalse = 2, + kInt8 = 3, + kInt16 = 4, + kInt32 = 5, + kInt64 = 6, + kDouble = 7, + kDecimal4 = 8, + kDecimal8 = 9, + kDecimal16 = 10, + kDate = 11, + kTimestampMicros = 12, + kTimestampNTZMicros = 13, + kFloat = 14, + kBinary = 15, + kString = 16, + kTimeNTZMicros = 17, + kTimestampNanos = 18, + kTimestampNTZNanos = 19, + kUuid = 20, +}; + +struct PARQUET_EXPORT VariantObjectField { + std::string_view name; + uint32_t field_id = 0; + std::string_view value; +}; + +class PARQUET_EXPORT VariantMetadataView { + public: + /// Parse Variant metadata bytes without copying them. + /// + /// The returned view and all string views borrowed from it are valid only while the + /// input metadata bytes remain alive and unchanged. + static Result Make(std::string_view metadata); + + std::string_view metadata() const { return metadata_; } + bool sorted_strings() const { return sorted_strings_; } + uint8_t offset_size() const { return offset_size_; } + uint32_t dictionary_size() const { return static_cast(strings_.size()); } + + std::string_view string(uint32_t id) const; + std::optional FindString(std::string_view value) const; + + private: + std::string_view metadata_; + bool sorted_strings_ = false; + uint8_t offset_size_ = 0; + std::vector strings_; +}; + +class PARQUET_EXPORT VariantPrimitiveView { + public: + VariantPrimitiveView(VariantPrimitiveType type, std::string_view payload) + : type_(type), payload_(payload) {} + + VariantPrimitiveType type() const { return type_; } + std::string_view payload() const { return payload_; } + + private: + VariantPrimitiveType type_ = VariantPrimitiveType::kNull; + std::string_view payload_; +}; + +class PARQUET_EXPORT VariantShortStringView { + public: + explicit VariantShortStringView(std::string_view string) : string_(string) {} + + std::string_view string() const { return string_; } + + private: + std::string_view string_; +}; + +class PARQUET_EXPORT VariantObjectView { + public: + explicit VariantObjectView(std::vector fields) + : fields_(std::move(fields)) {} + + const std::vector& fields() const { return fields_; } + const VariantObjectField* FindField(std::string_view name) const; + bool ContainsField(std::string_view name) const; + + private: + std::vector fields_; +}; + +class PARQUET_EXPORT VariantArrayView { + public: + explicit VariantArrayView(std::vector elements) + : elements_(std::move(elements)) {} + + const std::vector& elements() const { return elements_; } + + private: + std::vector elements_; +}; + +class PARQUET_EXPORT VariantValueView { + public: + using Data = std::variant; + + VariantValueView(std::string_view value, VariantBasicType basic_type, Data data) + : value_(value), basic_type_(basic_type), data_(std::move(data)) {} + + /// Parse Variant value bytes without copying them. + /// + /// The returned view and any nested object/list/string views are valid only while the + /// input value bytes remain alive and unchanged. The metadata view must also remain + /// valid while object field names are accessed. + static Result Make(std::string_view value, + const VariantMetadataView& metadata); + static Status Validate(std::string_view value, const VariantMetadataView& metadata); + + std::string_view value() const { return value_; } + VariantBasicType basic_type() const { return basic_type_; } + const Data& data() const { return data_; } + + private: + std::string_view value_; + VariantBasicType basic_type_; + Data data_; +}; + +} // namespace parquet::variant diff --git a/cpp/src/parquet/variant/encoding_internal.h b/cpp/src/parquet/variant/encoding_internal.h new file mode 100644 index 000000000000..f34e088a2ef3 --- /dev/null +++ b/cpp/src/parquet/variant/encoding_internal.h @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/status.h" +#include "arrow/util/utf8.h" +#include "parquet/variant/encoding.h" + +namespace parquet::variant::internal { + +using ::arrow::Status; + +namespace util = ::arrow::util; + +inline constexpr uint8_t kVariantVersion = 1; +inline constexpr uint8_t kMetadataVersionMask = 0x0F; +inline constexpr uint8_t kMetadataSortedStringsMask = 0x10; +inline constexpr uint8_t kMetadataReservedMask = 0x20; + +template +struct VariantBasicTypeTraits {}; + +template <> +struct VariantBasicTypeTraits { + using ViewType = VariantPrimitiveView; +}; + +template <> +struct VariantBasicTypeTraits { + using ViewType = VariantShortStringView; +}; + +template <> +struct VariantBasicTypeTraits { + using ViewType = VariantObjectView; +}; + +template <> +struct VariantBasicTypeTraits { + using ViewType = VariantArrayView; +}; + +inline bool IsKnownVariantPrimitive(VariantPrimitiveType type) { + return type <= VariantPrimitiveType::kUuid; +} + +inline bool IsDecimalVariantPrimitive(VariantPrimitiveType type) { + return type == VariantPrimitiveType::kDecimal4 || + type == VariantPrimitiveType::kDecimal8 || + type == VariantPrimitiveType::kDecimal16; +} + +inline Status ValidateDecimalScale(uint8_t scale) { + if (scale > 38) { + return Status::Invalid("Invalid Variant decimal scale ", scale, " exceeds 38"); + } + return Status::OK(); +} + +inline Status ValidateUtf8(std::string_view value, std::string_view context) { + util::InitializeUTF8(); + if (!util::ValidateUTF8(reinterpret_cast(value.data()), value.size())) { + return Status::Invalid("Invalid Variant encoding: ", context, " is not valid UTF-8"); + } + return Status::OK(); +} + +template +concept HeaderOnlyVariantPrimitive = + type == VariantPrimitiveType::kNull || type == VariantPrimitiveType::kBooleanTrue || + type == VariantPrimitiveType::kBooleanFalse; + +template +struct VariantFixedPrimitiveTraits {}; + +#define VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(TYPE, C_TYPE) \ + template <> \ + struct VariantFixedPrimitiveTraits { \ + using CType = C_TYPE; \ + }; + +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(Int8, int8_t) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(Int16, int16_t) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(Int32, int32_t) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(Int64, int64_t) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(Float, float) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(Double, double) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(Date, int32_t) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(TimeNTZMicros, int64_t) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(TimestampMicros, int64_t) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(TimestampNTZMicros, int64_t) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(TimestampNanos, int64_t) +VARIANT_FIXED_PRIMITIVE_TRAITS_DEF(TimestampNTZNanos, int64_t) + +#undef VARIANT_FIXED_PRIMITIVE_TRAITS_DEF + +template +concept FixedVariantPrimitive = + requires { typename VariantFixedPrimitiveTraits::CType; }; + +template +struct VariantDecimalPrimitiveTraits {}; + +#define VARIANT_DECIMAL_PRIMITIVE_TRAITS_DEF(TYPE, C_TYPE) \ + template <> \ + struct VariantDecimalPrimitiveTraits { \ + using CType = C_TYPE; \ + }; + +VARIANT_DECIMAL_PRIMITIVE_TRAITS_DEF(Decimal4, int32_t) +VARIANT_DECIMAL_PRIMITIVE_TRAITS_DEF(Decimal8, int64_t) + +#undef VARIANT_DECIMAL_PRIMITIVE_TRAITS_DEF + +template +concept DecimalVariantPrimitive = + requires { typename VariantDecimalPrimitiveTraits::CType; }; + +template +concept LengthPrefixedVariantPrimitive = + type == VariantPrimitiveType::kBinary || type == VariantPrimitiveType::kString; + +template +concept Decimal16VariantPrimitive = type == VariantPrimitiveType::kDecimal16; + +template +concept UuidVariantPrimitive = type == VariantPrimitiveType::kUuid; + +} // namespace parquet::variant::internal diff --git a/cpp/src/parquet/variant/encoding_test.cc b/cpp/src/parquet/variant/encoding_test.cc new file mode 100644 index 000000000000..e68b15692394 --- /dev/null +++ b/cpp/src/parquet/variant/encoding_test.cc @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "parquet/variant/encoding.h" +#include "parquet/variant/builder.h" +#include "parquet/variant/test_util_internal.h" + +#include +#include +#include + +#include "arrow/testing/gtest_util.h" + +namespace parquet::variant { + +TEST(TestVariantEncoding, EmptyMetadata) { + VariantBuilder builder; + ASSERT_OK(builder.AppendVariantNull()); + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_EQ(std::string_view("\x01\x00\x00", 3), std::string_view{*encoded.metadata}); + + ASSERT_OK_AND_ASSIGN(auto metadata, + VariantMetadataView::Make(std::string_view{*encoded.metadata})); + ASSERT_FALSE(metadata.sorted_strings()); + ASSERT_EQ(0, metadata.dictionary_size()); +} + +TEST(TestVariantEncoding, Primitive) { + std::array decimal16 = {0}; + std::array uuid = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + auto check = [](auto append, VariantPrimitiveType expected) { + VariantBuilder builder; + ASSERT_OK(append(builder)); + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto view, internal::MakeVariantValueView(encoded)); + ASSERT_EQ(VariantBasicType::kPrimitive, view.basic_type()); + ASSERT_EQ(expected, std::get(view.data()).type()); + }; + + check([](VariantBuilder& b) { return b.AppendVariantNull(); }, + VariantPrimitiveType::kNull); + check([](VariantBuilder& b) { return b.AppendBoolean(true); }, + VariantPrimitiveType::kBooleanTrue); + check([](VariantBuilder& b) { return b.AppendInt8(1); }, VariantPrimitiveType::kInt8); + check([](VariantBuilder& b) { return b.AppendInt16(2); }, VariantPrimitiveType::kInt16); + check([](VariantBuilder& b) { return b.AppendInt32(3); }, VariantPrimitiveType::kInt32); + check([](VariantBuilder& b) { return b.AppendInt64(4); }, VariantPrimitiveType::kInt64); + check([](VariantBuilder& b) { return b.AppendFloat(1.5F); }, + VariantPrimitiveType::kFloat); + check([](VariantBuilder& b) { return b.AppendDouble(2.5); }, + VariantPrimitiveType::kDouble); + check([](VariantBuilder& b) { return b.AppendDecimal4(123, 2); }, + VariantPrimitiveType::kDecimal4); + check([](VariantBuilder& b) { return b.AppendDecimal8(123, 2); }, + VariantPrimitiveType::kDecimal8); + check( + [&](VariantBuilder& b) { + return b.AppendDecimal16(std::string_view(decimal16.data(), decimal16.size()), 2); + }, + VariantPrimitiveType::kDecimal16); + check([](VariantBuilder& b) { return b.AppendBinary("abc"); }, + VariantPrimitiveType::kBinary); + check([](VariantBuilder& b) { return b.AppendString("abc"); }, + VariantPrimitiveType::kString); + check([](VariantBuilder& b) { return b.AppendDate(1); }, VariantPrimitiveType::kDate); + check([](VariantBuilder& b) { return b.AppendTimeNTZMicros(1); }, + VariantPrimitiveType::kTimeNTZMicros); + check([](VariantBuilder& b) { return b.AppendTimestampMicros(1, true); }, + VariantPrimitiveType::kTimestampMicros); + check([](VariantBuilder& b) { return b.AppendTimestampMicros(1, false); }, + VariantPrimitiveType::kTimestampNTZMicros); + check([](VariantBuilder& b) { return b.AppendTimestampNanos(1, true); }, + VariantPrimitiveType::kTimestampNanos); + check([](VariantBuilder& b) { return b.AppendTimestampNanos(1, false); }, + VariantPrimitiveType::kTimestampNTZNanos); + check( + [&](VariantBuilder& b) { + return b.AppendUuid(std::string_view(uuid.data(), uuid.size())); + }, + VariantPrimitiveType::kUuid); +} + +TEST(TestVariantEncoding, ShortString) { + VariantBuilder builder; + ASSERT_OK(builder.AppendShortString("abc")); + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto view, internal::MakeVariantValueView(encoded)); + ASSERT_EQ(VariantBasicType::kShortString, view.basic_type()); + ASSERT_EQ("abc", std::get(view.data()).string()); +} + +TEST(TestVariantEncoding, InvalidMetadata) { + // Metadata version is 2 instead of 1. + ASSERT_RAISES(Invalid, VariantMetadataView::Make(std::string("\x02\x00\x00", 3))); + // Dictionary offsets are truncated. + ASSERT_RAISES(Invalid, VariantMetadataView::Make(std::string("\x01\x01\x00", 3))); + // Dictionary string payload is not valid UTF-8. + ASSERT_RAISES(Invalid, + VariantMetadataView::Make(std::string("\x01\x01\x00\x01\xff", 5))); +} + +TEST(TestVariantEncoding, ReservedBits) { + // Metadata bytes: + // - 0x21: version 1 with reserved bit 0x20 set + // - 0x00: zero dictionary entries, encoded with one-byte width + // - 0x00: first and final dictionary byte offset + ASSERT_OK_AND_ASSIGN(auto metadata, + VariantMetadataView::Make(std::string("\x21\x00\x00", 3))); + ASSERT_EQ(0, metadata.dictionary_size()); + + // Object value bytes: + // - 0x82: basic type Object with reserved header bit 0x20 set + // - 0x00: zero fields + // - 0x00: first and final value offset + ASSERT_OK(VariantValueView::Validate(std::string("\x82\x00\x00", 3), metadata)); + // Array value bytes: + // - 0xe3: basic type Array with reserved header bits 0x38 set + // - 0x00: zero elements + // - 0x00: first and final value offset + ASSERT_OK(VariantValueView::Validate(std::string("\xe3\x00\x00", 3), metadata)); +} + +TEST(TestVariantEncoding, ValidateValue) { + VariantBuilder builder; + ASSERT_OK(builder.AppendVariantNull()); + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto metadata, + VariantMetadataView::Make(std::string_view{*encoded.metadata})); + + ASSERT_OK(VariantValueView::Validate(std::string_view{*encoded.value}, metadata)); + std::string invalid_value(std::string_view{*encoded.value}); + invalid_value.push_back('\0'); + ASSERT_RAISES(Invalid, VariantValueView::Validate(invalid_value, metadata)); +} + +TEST(TestVariantEncoding, InvalidValue) { + VariantBuilder builder; + ASSERT_OK(builder.AppendVariantNull()); + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto metadata, + VariantMetadataView::Make(std::string_view{*encoded.metadata})); + + // Header byte 0x54 encodes Primitive with primitive tag 21, which is unknown. + ASSERT_RAISES( + Invalid, VariantValueView::Make(std::string(1, static_cast(0x54)), metadata)); + // Short string payload is not valid UTF-8. + ASSERT_RAISES(Invalid, VariantValueView::Make(std::string("\x05\xff", 2), metadata)); + // Null primitive has trailing bytes after the encoded value. + ASSERT_RAISES(Invalid, VariantValueView::Make(std::string("\x00\x00", 2), metadata)); + // Object references field id 0, but the metadata dictionary is empty. + ASSERT_RAISES(Invalid, VariantValueView::Make( + std::string("\x02\x01\x00\x00\x01\x00", 6), metadata)); +} + +TEST(TestVariantEncoding, InvalidObject) { + VariantBuilder builder; + ASSERT_OK_AND_ASSIGN(auto object, builder.StartObject()); + ASSERT_OK(object.AppendVariantNull("a")); + ASSERT_OK(object.AppendVariantNull("b")); + ASSERT_OK(object.Finish()); + ASSERT_OK_AND_ASSIGN(auto encoded, builder.Finish()); + ASSERT_OK_AND_ASSIGN(auto metadata, + VariantMetadataView::Make(std::string_view{*encoded.metadata})); + + // Object value bytes: + // - 0x02: basic type Object, one-byte field ids, one-byte offsets, one-byte count + // - 0x02: two fields + // - 0x00 0x01: field ids for "a" and "b" + // - 0x00 0x00 0x01: duplicate field start offsets 0/0, final value size 1 + // - 0x00: encoded Variant null primitive value + const std::string duplicate_offsets("\x02\x02\x00\x01\x00\x00\x01\x00", 8); + ASSERT_RAISES(Invalid, VariantValueView::Validate(duplicate_offsets, metadata)); +} + +} // namespace parquet::variant diff --git a/cpp/src/parquet/variant/meson.build b/cpp/src/parquet/variant/meson.build new file mode 100644 index 000000000000..7a3f74014ad5 --- /dev/null +++ b/cpp/src/parquet/variant/meson.build @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +exc = executable( + 'parquet-variant-test', + sources: [ + 'builder_test.cc', + 'encoding_test.cc', + 'test_util_internal.cc', + 'type_test.cc', + 'validate_test.cc', + ], + dependencies: parquet_test_dep, +) +test('parquet-variant-test', exc) + +install_headers( + ['builder.h', 'encoding.h', 'validate.h'], + subdir: 'parquet/variant', +) diff --git a/cpp/src/parquet/variant/test_util_internal.cc b/cpp/src/parquet/variant/test_util_internal.cc new file mode 100644 index 000000000000..048b20271202 --- /dev/null +++ b/cpp/src/parquet/variant/test_util_internal.cc @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "parquet/variant/test_util_internal.h" + +#include +#include + +#include "arrow/array.h" // IWYU pragma: keep +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/extension/uuid.h" +#include "arrow/extension_type.h" +#include "arrow/io/memory.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "parquet/arrow/reader.h" +#include "parquet/arrow/writer.h" +#include "parquet/exception.h" + +namespace parquet::variant::internal { + +Result MakeVariantValueView(const EncodedVariantValue& encoded) { + ARROW_ASSIGN_OR_RAISE(auto metadata, + VariantMetadataView::Make(std::string_view{*encoded.metadata})); + return VariantValueView::Make(std::string_view{*encoded.value}, metadata); +} + +std::shared_ptr<::arrow::Table> VariantTable( + const std::shared_ptr& variant_type, + const std::vector>& storage_children, + const FieldVector& storage_fields) { + PARQUET_ASSIGN_OR_THROW(auto storage, + ::arrow::StructArray::Make(storage_children, storage_fields)); + auto array = ::arrow::ExtensionType::WrapArray(variant_type, storage); + return ::arrow::Table::Make(::arrow::schema({::arrow::field("variant", variant_type)}), + {array}); +} + +Result> WriteVariantTable( + const std::shared_ptr<::arrow::Table>& table, + std::shared_ptr writer_properties, + std::shared_ptr arrow_properties) { + ARROW_ASSIGN_OR_RAISE(auto sink, ::arrow::io::BufferOutputStream::Create( + 1024, ::arrow::default_memory_pool())); + RETURN_NOT_OK(parquet::arrow::WriteTable(*table, ::arrow::default_memory_pool(), sink, + /*chunk_size=*/table->num_rows(), + std::move(writer_properties), + std::move(arrow_properties))); + return sink->Finish(); +} + +::arrow::Status WriteVariantRecordBatch( + const std::shared_ptr<::arrow::Table>& table, + std::shared_ptr arrow_properties) { + ARROW_ASSIGN_OR_RAISE(auto sink, ::arrow::io::BufferOutputStream::Create( + 1024, ::arrow::default_memory_pool())); + ARROW_ASSIGN_OR_RAISE(auto writer, + parquet::arrow::FileWriter::Open( + *table->schema(), ::arrow::default_memory_pool(), sink, + default_writer_properties(), std::move(arrow_properties))); + ARROW_ASSIGN_OR_RAISE(auto batch, table->CombineChunksToBatch()); + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + return writer->Close(); +} + +std::optional ShreddedVariantTestingDir() { + const char* data = std::getenv("PARQUET_TEST_DATA"); + if (data == nullptr) { + return std::nullopt; + } + std::string path(data); + const auto pos = path.find_last_of("/\\"); + if (pos == std::string::npos) { + return std::nullopt; + } + return path.substr(0, pos) + "/shredded_variant"; +} + +Result> ReadVariantTestingTable(const std::string& path) { + ArrowReaderProperties reader_properties; + reader_properties.set_arrow_extensions_enabled(true); + + parquet::arrow::FileReaderBuilder builder; + RETURN_NOT_OK(builder.OpenFile(path, /*memory_map=*/false)); + builder.properties(reader_properties); + ARROW_ASSIGN_OR_RAISE(auto reader, builder.Build()); + return reader->ReadTable(); +} + +Result> EmptyVariantMetadata() { + VariantBuilder builder; + ARROW_RETURN_NOT_OK(builder.AppendVariantNull()); + ARROW_ASSIGN_OR_RAISE(auto encoded, builder.Finish()); + return encoded.metadata; +} + +Result Int8Variant(int8_t value) { + VariantBuilder builder; + ARROW_RETURN_NOT_OK(builder.AppendInt8(value)); + return builder.Finish(); +} + +std::shared_ptr BinaryArrayFromValues( + const std::vector>& values) { + ::arrow::BinaryBuilder builder; + for (const auto& value : values) { + if (value.has_value()) { + ARROW_EXPECT_OK(builder.Append(*value)); + } else { + ARROW_EXPECT_OK(builder.AppendNull()); + } + } + std::shared_ptr out; + ARROW_EXPECT_OK(builder.Finish(&out)); + return out; +} + +std::shared_ptr BinaryViewArrayFromValues( + const std::vector>& values) { + ::arrow::BinaryViewBuilder builder; + for (const auto& value : values) { + if (value.has_value()) { + ARROW_EXPECT_OK(builder.Append(*value)); + } else { + ARROW_EXPECT_OK(builder.AppendNull()); + } + } + std::shared_ptr out; + ARROW_EXPECT_OK(builder.Finish(&out)); + return out; +} + +std::shared_ptr Int64ArrayFromValues( + const std::vector>& values) { + ::arrow::Int64Builder builder; + for (const auto& value : values) { + if (value.has_value()) { + ARROW_EXPECT_OK(builder.Append(*value)); + } else { + ARROW_EXPECT_OK(builder.AppendNull()); + } + } + std::shared_ptr out; + ARROW_EXPECT_OK(builder.Finish(&out)); + return out; +} + +std::shared_ptr Int32ArrayFromValues(const std::vector& values) { + ::arrow::Int32Builder builder; + ARROW_EXPECT_OK(builder.AppendValues(values)); + std::shared_ptr out; + ARROW_EXPECT_OK(builder.Finish(&out)); + return out; +} + +std::shared_ptr StringArrayFromValues( + const std::vector>& values) { + ::arrow::StringBuilder builder; + for (const auto& value : values) { + if (value.has_value()) { + ARROW_EXPECT_OK(builder.Append(*value)); + } else { + ARROW_EXPECT_OK(builder.AppendNull()); + } + } + std::shared_ptr out; + ARROW_EXPECT_OK(builder.Finish(&out)); + return out; +} + +std::shared_ptr UuidArrayFromValues(const std::vector& values) { + ::arrow::FixedSizeBinaryBuilder builder(::arrow::fixed_size_binary(16)); + for (const auto& value : values) { + ARROW_EXPECT_OK(builder.Append(value)); + } + std::shared_ptr storage; + ARROW_EXPECT_OK(builder.Finish(&storage)); + return ::arrow::ExtensionType::WrapArray(::arrow::extension::uuid(), storage); +} + +} // namespace parquet::variant::internal diff --git a/cpp/src/parquet/variant/test_util_internal.h b/cpp/src/parquet/variant/test_util_internal.h new file mode 100644 index 000000000000..f36e92d01a4b --- /dev/null +++ b/cpp/src/parquet/variant/test_util_internal.h @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "parquet/properties.h" +#include "parquet/variant/builder.h" +#include "parquet/variant/encoding.h" + +namespace parquet::variant::internal { + +using ::arrow::Array; +using ::arrow::Buffer; +using ::arrow::DataType; +using ::arrow::FieldVector; +using ::arrow::Result; + +Result MakeVariantValueView(const EncodedVariantValue& encoded); + +std::shared_ptr<::arrow::Table> VariantTable( + const std::shared_ptr& variant_type, + const std::vector>& storage_children, + const FieldVector& storage_fields); + +Result> WriteVariantTable( + const std::shared_ptr<::arrow::Table>& table, + std::shared_ptr writer_properties = default_writer_properties(), + std::shared_ptr arrow_properties = + default_arrow_writer_properties()); + +::arrow::Status WriteVariantRecordBatch( + const std::shared_ptr<::arrow::Table>& table, + std::shared_ptr arrow_properties = + default_arrow_writer_properties()); + +std::optional ShreddedVariantTestingDir(); + +Result> ReadVariantTestingTable(const std::string& path); + +Result> EmptyVariantMetadata(); + +Result Int8Variant(int8_t value); + +std::shared_ptr BinaryArrayFromValues( + const std::vector>& values); + +std::shared_ptr BinaryViewArrayFromValues( + const std::vector>& values); + +std::shared_ptr Int64ArrayFromValues( + const std::vector>& values); + +std::shared_ptr Int32ArrayFromValues(const std::vector& values); + +std::shared_ptr StringArrayFromValues( + const std::vector>& values); + +std::shared_ptr UuidArrayFromValues(const std::vector& values); + +} // namespace parquet::variant::internal diff --git a/cpp/src/parquet/variant/type_test.cc b/cpp/src/parquet/variant/type_test.cc new file mode 100644 index 000000000000..c6e93263ab83 --- /dev/null +++ b/cpp/src/parquet/variant/type_test.cc @@ -0,0 +1,167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/parquet_variant.h" +#include "arrow/extension/uuid.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/checked_cast.h" + +#include +#include + +namespace parquet::variant { + +using ::arrow::binary; +using ::arrow::binary_view; +using ::arrow::decimal128; +using ::arrow::decimal256; +using ::arrow::decimal32; +using ::arrow::decimal64; +using ::arrow::dictionary; +using ::arrow::duration; +using ::arrow::ExtensionType; +using ::arrow::field; +using ::arrow::fixed_size_binary; +using ::arrow::fixed_size_list; +using ::arrow::int32; +using ::arrow::int64; +using ::arrow::large_list_view; +using ::arrow::list; +using ::arrow::list_view; +using ::arrow::struct_; +using ::arrow::timestamp; +using ::arrow::TimeUnit; +using ::arrow::uint64; +using ::arrow::utf8; +using ::arrow::utf8_view; +using ::arrow::extension::kVariantExtensionName; +using ::arrow::extension::VariantExtensionType; + +TEST(TestVariantType, Storage) { + auto unshredded = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + ASSERT_OK_AND_ASSIGN(auto type, VariantExtensionType::Make(unshredded)); + ASSERT_EQ( + kVariantExtensionName, + ::arrow::internal::checked_cast(*type).extension_name()); + + auto shredded = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary()), field("typed_value", int64())}); + ASSERT_OK_AND_ASSIGN(auto shredded_type, VariantExtensionType::Make(shredded)); + auto variant_type = + ::arrow::internal::checked_pointer_cast(shredded_type); + ASSERT_EQ("metadata", variant_type->metadata()->name()); + ASSERT_EQ("value", variant_type->value()->name()); + ASSERT_EQ("typed_value", variant_type->typed_value()->name()); + + auto typed_only = struct_( + {field("metadata", binary(), /*nullable=*/false), field("typed_value", int64())}); + ASSERT_OK(VariantExtensionType::Make(typed_only)); + + auto flipped = + std::dynamic_pointer_cast(::arrow::extension::variant( + struct_({field("value", binary(), /*nullable=*/false), + field("metadata", binary(), /*nullable=*/false)}))); + ASSERT_EQ("metadata", flipped->metadata()->name()); + ASSERT_EQ("value", flipped->value()->name()); + + ASSERT_OK(VariantExtensionType::Make( + struct_({field("metadata", binary_view(), /*nullable=*/false), + field("value", binary_view(), /*nullable=*/false)}))); + + auto shredded_field_group = + struct_({field("value", binary()), field("typed_value", int64())}); + auto shredded_object = + struct_({field("metadata", binary(), /*nullable=*/false), field("value", binary()), + field("typed_value", + struct_({field("a", shredded_field_group, /*nullable=*/false)}))}); + auto shredded_list = + struct_({field("metadata", binary(), /*nullable=*/false), field("value", binary()), + field("typed_value", + list(field("element", shredded_field_group, /*nullable=*/false)))}); + ASSERT_OK(VariantExtensionType::Make(shredded_object)); + ASSERT_OK(VariantExtensionType::Make(shredded_list)); + + for (const auto& typed_value_type : + {binary_view(), utf8_view(), ::arrow::extension::uuid(), + list_view(field("element", shredded_field_group, /*nullable=*/false)), + large_list_view(field("element", shredded_field_group, /*nullable=*/false)), + fixed_size_list(field("element", shredded_field_group, /*nullable=*/false), + /*list_size=*/2)}) { + auto valid_shredded_type = + struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary()), field("typed_value", typed_value_type)}); + ASSERT_OK(VariantExtensionType::Make(valid_shredded_type)); + } +} + +TEST(TestVariantType, InvalidStorage) { + auto missing_value = struct_({field("metadata", binary(), /*nullable=*/false)}); + auto missing_metadata = struct_({field("value", binary(), /*nullable=*/false)}); + auto nullable_metadata = + struct_({field("metadata", binary()), field("value", binary(), false)}); + auto nullable_unshredded = + struct_({field("metadata", binary(), false), field("value", binary())}); + auto bad_value_type = + struct_({field("metadata", binary(), false), field("value", int32(), false)}); + auto extra = struct_({field("metadata", binary(), false), + field("value", binary(), false), field("extra", binary())}); + auto dictionary_typed_value = + struct_({field("metadata", binary(), false), field("value", binary()), + field("typed_value", dictionary(int32(), utf8()))}); + auto dictionary_value = struct_({field("metadata", binary(), false), + field("value", dictionary(int32(), binary()), false)}); + auto dictionary_metadata = + struct_({field("metadata", dictionary(int32(), binary()), false), + field("value", binary(), false)}); + auto required_nested_value = + struct_({field("metadata", binary(), false), field("value", binary()), + field("typed_value", + struct_({field("a", struct_({field("value", binary(), false)}), + /*nullable=*/false)}))}); + + ASSERT_RAISES(Invalid, VariantExtensionType::Make(missing_value)); + ASSERT_RAISES(Invalid, VariantExtensionType::Make(missing_metadata)); + ASSERT_RAISES(Invalid, VariantExtensionType::Make(nullable_metadata)); + ASSERT_RAISES(Invalid, VariantExtensionType::Make(nullable_unshredded)); + ASSERT_RAISES(Invalid, VariantExtensionType::Make(bad_value_type)); + ASSERT_RAISES(Invalid, VariantExtensionType::Make(extra)); + ASSERT_RAISES(Invalid, VariantExtensionType::Make(dictionary_typed_value)); + ASSERT_RAISES(Invalid, VariantExtensionType::Make(dictionary_value)); + ASSERT_RAISES(Invalid, VariantExtensionType::Make(dictionary_metadata)); + ASSERT_RAISES(Invalid, VariantExtensionType::Make(required_nested_value)); + + std::array invalid_typed_value_types{ + uint64(), + duration(TimeUnit::MICRO), + timestamp(TimeUnit::MILLI), + struct_({}), + fixed_size_binary(/*byte_width=*/8), + decimal32(/*precision=*/8, /*scale=*/9), + decimal64(/*precision=*/16, /*scale=*/17), + decimal128(/*precision=*/32, /*scale=*/-1), + decimal256(/*precision=*/39, /*scale=*/0), + }; + for (const auto& typed_value_type : invalid_typed_value_types) { + auto invalid_shredded_type = + struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary()), field("typed_value", typed_value_type)}); + ASSERT_RAISES(Invalid, VariantExtensionType::Make(std::move(invalid_shredded_type))); + } +} + +} // namespace parquet::variant diff --git a/cpp/src/parquet/variant/validate.cc b/cpp/src/parquet/variant/validate.cc new file mode 100644 index 000000000000..8ff53d4e58be --- /dev/null +++ b/cpp/src/parquet/variant/validate.cc @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "parquet/variant/validate.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array.h" // IWYU pragma: keep +#include "arrow/buffer.h" +#include "arrow/chunked_array.h" +#include "arrow/extension/parquet_variant.h" +#include "arrow/extension_type.h" +#include "arrow/result.h" +#include "arrow/type.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging_internal.h" +#include "parquet/variant/encoding.h" + +namespace parquet::variant { + +namespace { + +using ::arrow::Array; +using ::arrow::Buffer; +using ::arrow::DataType; +using ::arrow::ExtensionArray; +using ::arrow::ExtensionType; +using ::arrow::MemoryPool; +using ::arrow::Result; +using ::arrow::Status; +using ::arrow::extension::kVariantExtensionName; +using ::arrow::internal::checked_cast; + +enum class VariantShreddedTypeKind { + None, + Primitive, + Object, + Array, +}; + +struct VariantValidationPlan { + std::shared_ptr array; + std::vector children; +}; + +Result> ValuesArray(const Array& array) { + switch (array.type_id()) { + case ::arrow::Type::LIST_VIEW: + return checked_cast(array).values(); + case ::arrow::Type::LARGE_LIST_VIEW: + return checked_cast(array).values(); + case ::arrow::Type::LIST: + return checked_cast(array).values(); + case ::arrow::Type::LARGE_LIST: + return checked_cast(array).values(); + case ::arrow::Type::FIXED_SIZE_LIST: + return checked_cast(array).values(); + case ::arrow::Type::MAP: + return checked_cast(array).values(); + default: + return Status::Invalid("Expected list or map storage, got ", + array.type()->ToString()); + } +} + +Result> BuildVariantValidationPlan( + std::shared_ptr array) { + switch (array->type_id()) { + case ::arrow::Type::EXTENSION: { + const auto& ext_array = checked_cast(*array); + const auto& ext_type = checked_cast(*array->type()); + if (ext_type.extension_name() == kVariantExtensionName) { + return VariantValidationPlan{.array = std::move(array), .children = {}}; + } + return BuildVariantValidationPlan(ext_array.storage()); + } + case ::arrow::Type::STRUCT: { + const auto& struct_array = checked_cast(*array); + std::vector children; + for (auto field : struct_array.fields()) { + ARROW_ASSIGN_OR_RAISE(auto child_plan, + BuildVariantValidationPlan(std::move(field))); + if (child_plan.has_value()) { + children.push_back(std::move(*child_plan)); + } + } + if (children.empty()) { + return std::nullopt; + } + return VariantValidationPlan{.array = std::move(array), + .children = std::move(children)}; + } + case ::arrow::Type::LIST_VIEW: + case ::arrow::Type::LARGE_LIST_VIEW: + case ::arrow::Type::LIST: + case ::arrow::Type::LARGE_LIST: + case ::arrow::Type::FIXED_SIZE_LIST: + case ::arrow::Type::MAP: { + ARROW_ASSIGN_OR_RAISE(auto values, ValuesArray(*array)); + ARROW_ASSIGN_OR_RAISE(auto child_plan, + BuildVariantValidationPlan(std::move(values))); + if (!child_plan.has_value()) { + return std::nullopt; + } + std::vector children; + children.push_back(std::move(*child_plan)); + return VariantValidationPlan{.array = std::move(array), + .children = std::move(children)}; + } + default: { + std::vector children; + for (const auto& child_data : array->data()->child_data) { + if (child_data != nullptr) { + ARROW_ASSIGN_OR_RAISE(auto child_plan, BuildVariantValidationPlan( + ::arrow::MakeArray(child_data))); + if (child_plan.has_value()) { + children.push_back(std::move(*child_plan)); + } + } + } + if (children.empty()) { + return std::nullopt; + } + return VariantValidationPlan{.array = std::move(array), + .children = std::move(children)}; + } + } +} + +Result BinaryFieldView(const Array& array, int64_t row) { + switch (array.type_id()) { + case ::arrow::Type::BINARY: + return checked_cast(array).GetView(row); + case ::arrow::Type::LARGE_BINARY: + return checked_cast(array).GetView(row); + case ::arrow::Type::BINARY_VIEW: + return checked_cast(array).GetView(row); + default: + return Status::Invalid("Expected binary Variant field, got ", + array.type()->ToString()); + } +} + +VariantShreddedTypeKind ShreddedTypeKind(const DataType& type) { + switch (type.id()) { + case ::arrow::Type::STRUCT: + return VariantShreddedTypeKind::Object; + case ::arrow::Type::LIST: + case ::arrow::Type::LARGE_LIST: + case ::arrow::Type::LIST_VIEW: + case ::arrow::Type::LARGE_LIST_VIEW: + case ::arrow::Type::FIXED_SIZE_LIST: + return VariantShreddedTypeKind::Array; + default: + return VariantShreddedTypeKind::Primitive; + } +} + +Status ValidateVariantShredding( + const VariantMetadataView& metadata, std::string_view value, bool typed_value_present, + VariantShreddedTypeKind typed_value_kind, + std::span shredded_field_names = {}) { + if (typed_value_kind == VariantShreddedTypeKind::None) { + if (value.empty()) { + return Status::OK(); + } + return VariantValueView::Validate(value, metadata); + } + + if (typed_value_kind == VariantShreddedTypeKind::Object) { + for (const auto& name : shredded_field_names) { + if (!metadata.FindString(name).has_value()) { + return Status::Invalid("Invalid shredded Variant: shredded field '", name, + "' is not in metadata dictionary"); + } + } + } + + if (value.empty()) { + return Status::OK(); + } + ARROW_ASSIGN_OR_RAISE(auto value_view, VariantValueView::Make(value, metadata)); + + switch (typed_value_kind) { + case VariantShreddedTypeKind::Primitive: + if (typed_value_present) { + return Status::Invalid( + "Invalid shredded Variant: value and primitive typed_value are both " + "non-null"); + } + break; + case VariantShreddedTypeKind::Array: + if (typed_value_present) { + return Status::Invalid( + "Invalid shredded Variant: value and array typed_value are both " + "non-null"); + } + if (value_view.basic_type() == VariantBasicType::kArray) { + return Status::Invalid( + "Invalid shredded Variant: array value must be stored in typed_value"); + } + break; + case VariantShreddedTypeKind::Object: + default: + if (value_view.basic_type() == VariantBasicType::kObject) { + if (!typed_value_present) { + return Status::Invalid( + "Invalid shredded Variant: object value requires object typed_value"); + } + for (const auto& name : shredded_field_names) { + if (std::get(value_view.data()).ContainsField(name)) { + return Status::Invalid( + "Invalid shredded Variant: value object contains shredded field: ", name); + } + } + } else if (typed_value_present) { + return Status::Invalid( + "Invalid shredded Variant: partially shredded value must be an object"); + } + break; + } + return Status::OK(); +} + +template +std::pair ValueOffsetAndLength(const Array& array, int64_t row) { + const auto& typed_array = checked_cast(array); + return {typed_array.value_offset(row), typed_array.value_length(row)}; +} + +Result> ValuesRangeAt(const Array& array, int64_t row) { + switch (array.type_id()) { + case ::arrow::Type::LIST: + return ValueOffsetAndLength<::arrow::ListArray>(array, row); + case ::arrow::Type::LARGE_LIST: + return ValueOffsetAndLength<::arrow::LargeListArray>(array, row); + case ::arrow::Type::LIST_VIEW: + return ValueOffsetAndLength<::arrow::ListViewArray>(array, row); + case ::arrow::Type::LARGE_LIST_VIEW: + return ValueOffsetAndLength<::arrow::LargeListViewArray>(array, row); + case ::arrow::Type::MAP: + return ValueOffsetAndLength<::arrow::MapArray>(array, row); + case ::arrow::Type::FIXED_SIZE_LIST: + return ValueOffsetAndLength<::arrow::FixedSizeListArray>(array, row); + default: + return Status::Invalid("Expected list or map storage, got ", + array.type()->ToString()); + } +} + +Status ValidateShreddedVariantSlot(const VariantMetadataView& metadata, + const std::shared_ptr& value_array, + const std::shared_ptr& typed_array, int64_t row, + bool allow_missing, std::string_view path); + +Status ValidateShreddedVariantField(const VariantMetadataView& metadata, + const Array& field_array, int64_t row, + bool allow_missing, std::string_view path) { + if (field_array.type_id() != ::arrow::Type::STRUCT) { + return Status::Invalid("Invalid shredded Variant field at ", path, + ": expected struct storage, got ", + field_array.type()->ToString()); + } + if (field_array.IsNull(row)) { + return Status::Invalid("Invalid shredded Variant field at ", path, + ": field group must be required"); + } + const auto& field_struct = checked_cast(field_array); + auto value_array = field_struct.GetFieldByName("value"); + auto typed_array = field_struct.GetFieldByName("typed_value"); + return ValidateShreddedVariantSlot(metadata, value_array, typed_array, row, + allow_missing, path); +} + +Status ValidateShreddedVariantSlot(const VariantMetadataView& metadata, + const std::shared_ptr& value_array, + const std::shared_ptr& typed_array, int64_t row, + bool allow_missing, std::string_view path) { + std::string_view value; + if (value_array != nullptr && !value_array->IsNull(row)) { + ARROW_ASSIGN_OR_RAISE(value, BinaryFieldView(*value_array, row)); + if (value.empty()) { + ARROW_RETURN_NOT_OK(VariantValueView::Validate(value, metadata)); + } + } + + const bool typed_value_present = typed_array != nullptr && !typed_array->IsNull(row); + if (value.empty() && !typed_value_present) { + if (allow_missing) { + return Status::OK(); + } + return Status::Invalid("Invalid shredded Variant at ", path, + ": value and typed_value are both null"); + } + + const auto typed_kind = typed_array == nullptr ? VariantShreddedTypeKind::None + : ShreddedTypeKind(*typed_array->type()); + std::vector field_names; + if (typed_kind == VariantShreddedTypeKind::Object) { + field_names.reserve(typed_array->type()->num_fields()); + for (const auto& field : typed_array->type()->fields()) { + field_names.emplace_back(field->name()); + } + } + ARROW_RETURN_NOT_OK(ValidateVariantShredding(metadata, value, typed_value_present, + typed_kind, field_names)); + + if (!typed_value_present) { + return Status::OK(); + } + + if (typed_kind == VariantShreddedTypeKind::Object) { + const auto& typed_struct = checked_cast(*typed_array); + for (int i = 0; i < typed_struct.struct_type()->num_fields(); ++i) { + ARROW_RETURN_NOT_OK(ValidateShreddedVariantField( + metadata, *typed_struct.field(i), row, /*allow_missing=*/true, + typed_struct.struct_type()->field(i)->name())); + } + } else if (typed_kind == VariantShreddedTypeKind::Array) { + ARROW_ASSIGN_OR_RAISE(auto values, ValuesArray(*typed_array)); + ARROW_ASSIGN_OR_RAISE(auto range, ValuesRangeAt(*typed_array, row)); + const auto [offset, length] = range; + auto elements = values->Slice(offset, length); + for (int64_t i = 0; i < elements->length(); ++i) { + ARROW_RETURN_NOT_OK(ValidateShreddedVariantField(metadata, *elements, i, + /*allow_missing=*/false, path)); + } + } + return Status::OK(); +} + +template + requires requires(VisitVisible& visit_visible, int64_t row) { + ::arrow::ToStatus(visit_visible(row)); + } +Status VisitVisibleRows(const std::shared_ptr& valid_rows, const Array& array, + VisitVisible&& visit_visible) { + if (valid_rows == nullptr && !array.data()->MayHaveNulls()) { + for (int64_t row = 0; row < array.length(); ++row) { + ARROW_RETURN_NOT_OK(visit_visible(row)); + } + return Status::OK(); + } + return ::arrow::internal::VisitTwoBitBlocks( + valid_rows != nullptr ? valid_rows->data() : nullptr, /*left_offset=*/0, + array.data()->MayHaveNulls() ? array.null_bitmap_data() : nullptr, array.offset(), + array.length(), std::forward(visit_visible), + [] { return Status::OK(); }); +} + +Status ValidateVariantExtensionArray(const ExtensionArray& array, + const std::shared_ptr& valid_rows) { + const auto& storage = checked_cast(*array.storage()); + auto metadata_array = storage.GetFieldByName("metadata"); + auto value_array = storage.GetFieldByName("value"); + auto typed_array = storage.GetFieldByName("typed_value"); + + std::string_view last_metadata_bytes; + std::optional last_metadata; + return VisitVisibleRows(valid_rows, array, [&](int64_t row) { + if (metadata_array->IsNull(row)) { + return Status::Invalid("Invalid Variant extension storage: metadata is null"); + } + + ARROW_ASSIGN_OR_RAISE(auto metadata_value, BinaryFieldView(*metadata_array, row)); + if (!last_metadata.has_value() || last_metadata_bytes != metadata_value) { + ARROW_ASSIGN_OR_RAISE(last_metadata, VariantMetadataView::Make(metadata_value)); + last_metadata_bytes = metadata_value; + } + + if (typed_array == nullptr) { + if (value_array->IsNull(row)) { + return Status::Invalid( + "Invalid Variant extension storage: unshredded value is null"); + } + ARROW_ASSIGN_OR_RAISE(auto value, BinaryFieldView(*value_array, row)); + ARROW_RETURN_NOT_OK(VariantValueView::Validate(value, *last_metadata)); + } else { + ARROW_RETURN_NOT_OK( + ValidateShreddedVariantSlot(*last_metadata, value_array, typed_array, row, + /*allow_missing=*/false, "variant")); + } + return Status::OK(); + }); +} + +Status ValidateVariantPlan(const VariantValidationPlan& plan, MemoryPool* pool, + const std::shared_ptr& valid_rows) { + switch (plan.array->type_id()) { + case ::arrow::Type::EXTENSION: { + return ValidateVariantExtensionArray( + checked_cast(*plan.array), valid_rows); + } + case ::arrow::Type::STRUCT: { + const auto& struct_array = checked_cast(*plan.array); + std::shared_ptr child_valid_rows = valid_rows; + if (struct_array.data()->MayHaveNulls()) { + ARROW_ASSIGN_OR_RAISE( + child_valid_rows, + ::arrow::internal::OptionalBitmapAnd( + pool, valid_rows, /*left_offset=*/0, struct_array.null_bitmap(), + struct_array.offset(), struct_array.length(), /*out_offset=*/0)); + } + for (const auto& child : plan.children) { + ARROW_RETURN_NOT_OK(ValidateVariantPlan(child, pool, child_valid_rows)); + } + return Status::OK(); + } + case ::arrow::Type::LIST_VIEW: + case ::arrow::Type::LARGE_LIST_VIEW: + case ::arrow::Type::LIST: + case ::arrow::Type::LARGE_LIST: + case ::arrow::Type::MAP: + case ::arrow::Type::FIXED_SIZE_LIST: { + DCHECK_EQ(plan.children.size(), 1); + ARROW_ASSIGN_OR_RAISE(auto values, ValuesArray(*plan.array)); + ARROW_ASSIGN_OR_RAISE(auto values_valid_rows, + ::arrow::AllocateEmptyBitmap(values->length(), pool)); + ARROW_RETURN_NOT_OK( + VisitVisibleRows(valid_rows, *plan.array, [&](int64_t row) -> Status { + ARROW_ASSIGN_OR_RAISE(auto range, ValuesRangeAt(*plan.array, row)); + const auto [offset, length] = range; + ::arrow::bit_util::SetBitsTo(values_valid_rows->mutable_data(), offset, + length, true); + return Status::OK(); + })); + return ValidateVariantPlan(plan.children[0], pool, values_valid_rows); + } + default: { + for (const auto& child : plan.children) { + ARROW_RETURN_NOT_OK(ValidateVariantPlan(child, pool, valid_rows)); + } + return Status::OK(); + } + } +} + +} // namespace + +Status ValidateVariants(const ::arrow::ChunkedArray& data, MemoryPool* pool) { + for (const auto& chunk : data.chunks()) { + ARROW_ASSIGN_OR_RAISE(auto plan, BuildVariantValidationPlan(chunk)); + if (plan.has_value()) { + ARROW_RETURN_NOT_OK(ValidateVariantPlan(*plan, pool, nullptr)); + } + } + return Status::OK(); +} + +} // namespace parquet::variant diff --git a/cpp/src/parquet/variant/validate.h b/cpp/src/parquet/variant/validate.h new file mode 100644 index 000000000000..2afc808b9182 --- /dev/null +++ b/cpp/src/parquet/variant/validate.h @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/status.h" +#include "parquet/platform.h" + +namespace arrow { + +class ChunkedArray; +class MemoryPool; + +} // namespace arrow + +namespace parquet::variant { + +PARQUET_EXPORT +::arrow::Status ValidateVariants(const ::arrow::ChunkedArray& data, + ::arrow::MemoryPool* pool); + +} // namespace parquet::variant diff --git a/cpp/src/parquet/variant/validate_test.cc b/cpp/src/parquet/variant/validate_test.cc new file mode 100644 index 000000000000..425cf79df137 --- /dev/null +++ b/cpp/src/parquet/variant/validate_test.cc @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "parquet/variant/validate.h" + +#include +#include +#include +#include +#include + +#include "arrow/array.h" // IWYU pragma: keep +#include "arrow/chunked_array.h" +#include "arrow/extension/parquet_variant.h" +#include "arrow/extension_type.h" +#include "arrow/io/memory.h" +#include "arrow/table.h" +#include "arrow/testing/extension_type.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "parquet/arrow/reader.h" +#include "parquet/variant/test_util_internal.h" + +namespace parquet::variant { + +using ::arrow::binary; +using ::arrow::field; +using ::arrow::struct_; +using internal::BinaryArrayFromValues; +using internal::Int32ArrayFromValues; +using internal::Int8Variant; +using internal::ReadVariantTestingTable; +using internal::ShreddedVariantTestingDir; +using internal::VariantTable; +using internal::WriteVariantTable; + +TEST(TestVariantValidate, ListView) { + ASSERT_OK_AND_ASSIGN(auto encoded, Int8Variant(42)); + + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + auto metadata_array = BinaryArrayFromValues( + {std::string_view{*encoded.metadata}, std::string_view{*encoded.metadata}}); + auto value_array = BinaryArrayFromValues( + {std::string_view{*encoded.value}, std::string_view("\xff", 1)}); + ASSERT_OK_AND_ASSIGN( + auto storage, + ::arrow::StructArray::Make({metadata_array, value_array}, storage_type->fields())); + auto variant_array = ::arrow::ExtensionType::WrapArray(variant_type, storage); + + ASSERT_OK_AND_ASSIGN(auto valid_list, + ::arrow::ListViewArray::FromArrays( + *Int32ArrayFromValues({0}), *Int32ArrayFromValues({1}), + *variant_array, ::arrow::default_memory_pool())); + ::arrow::ChunkedArray valid_data{valid_list}; + ASSERT_OK(ValidateVariants(valid_data, ::arrow::default_memory_pool())); + + ASSERT_OK_AND_ASSIGN(auto invalid_list, + ::arrow::ListViewArray::FromArrays( + *Int32ArrayFromValues({1}), *Int32ArrayFromValues({1}), + *variant_array, ::arrow::default_memory_pool())); + ::arrow::ChunkedArray invalid_data{invalid_list}; + ASSERT_RAISES(Invalid, ValidateVariants(invalid_data, ::arrow::default_memory_pool())); +} + +TEST(TestVariantValidate, ParquetTestingShredded) { + auto maybe_dir = ShreddedVariantTestingDir(); + if (!maybe_dir.has_value()) { + GTEST_SKIP() << "PARQUET_TEST_DATA not set"; + } + if (!std::filesystem::exists(*maybe_dir)) { + GTEST_SKIP() << *maybe_dir << " does not exist"; + } + + auto registered_storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + ::arrow::ExtensionTypeGuard guard(::arrow::extension::variant(registered_storage_type)); + + struct ShreddedCase { + std::string file_name; + bool has_top_level_value; + }; + const std::vector cases = { + {.file_name = "case-041.parquet", .has_top_level_value = false}, + {.file_name = "case-088.parquet", .has_top_level_value = true}, + {.file_name = "case-131.parquet", .has_top_level_value = false}, + {.file_name = "case-132.parquet", .has_top_level_value = true}, + {.file_name = "case-138.parquet", .has_top_level_value = false}, + }; + + for (const auto& test_case : cases) { + SCOPED_TRACE(test_case.file_name); + ASSERT_OK_AND_ASSIGN(auto table, + ReadVariantTestingTable(*maybe_dir + "/" + test_case.file_name)); + ASSERT_OK(table->ValidateFull()); + + auto field = table->schema()->GetFieldByName("var"); + ASSERT_NE(nullptr, field); + auto variant_type = + std::dynamic_pointer_cast<::arrow::extension::VariantExtensionType>( + field->type()); + ASSERT_NE(nullptr, variant_type); + ASSERT_EQ(test_case.has_top_level_value, variant_type->value() != nullptr); + + auto column = table->GetColumnByName("var"); + ASSERT_NE(nullptr, column); + ASSERT_OK(ValidateVariants(*column, ::arrow::default_memory_pool())); + } +} + +TEST(TestVariantValidate, DictionaryMetadata) { + ASSERT_OK_AND_ASSIGN(auto encoded, Int8Variant(42)); + + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + auto metadata_array = BinaryArrayFromValues( + {std::string_view{*encoded.metadata}, std::string_view{*encoded.metadata}}); + auto value_array = BinaryArrayFromValues( + {std::string_view{*encoded.value}, std::string_view{*encoded.value}}); + auto table = + VariantTable(variant_type, {metadata_array, value_array}, storage_type->fields()); + + ASSERT_OK_AND_ASSIGN( + auto buffer, + WriteVariantTable(table, WriterProperties::Builder().enable_dictionary()->build())); + + auto buffer_reader = std::make_shared<::arrow::io::BufferReader>(buffer); + ArrowReaderProperties reader_properties; + reader_properties.set_arrow_extensions_enabled(true); + ::arrow::ExtensionTypeGuard guard(::arrow::extension::variant(storage_type)); + parquet::arrow::FileReaderBuilder builder; + ASSERT_OK(builder.Open(buffer_reader)); + builder.properties(reader_properties); + ASSERT_OK_AND_ASSIGN(auto reader, builder.Build()); + + ASSERT_OK_AND_ASSIGN(auto read_table, reader->ReadTable()); + auto column = read_table->GetColumnByName("variant"); + ASSERT_NE(nullptr, column); + ASSERT_OK(ValidateVariants(*column, ::arrow::default_memory_pool())); +} + +TEST(TestVariantValidate, ReadDictionaryOption) { + ASSERT_OK_AND_ASSIGN(auto encoded, Int8Variant(42)); + + auto storage_type = struct_({field("metadata", binary(), /*nullable=*/false), + field("value", binary(), /*nullable=*/false)}); + auto variant_type = ::arrow::extension::variant(storage_type); + auto metadata_array = BinaryArrayFromValues( + {std::string_view{*encoded.metadata}, std::string_view{*encoded.metadata}}); + auto value_array = BinaryArrayFromValues( + {std::string_view{*encoded.value}, std::string_view{*encoded.value}}); + auto table = + VariantTable(variant_type, {metadata_array, value_array}, storage_type->fields()); + + ASSERT_OK_AND_ASSIGN(auto buffer, WriteVariantTable(table)); + + auto buffer_reader = std::make_shared<::arrow::io::BufferReader>(buffer); + ArrowReaderProperties reader_properties; + reader_properties.set_arrow_extensions_enabled(true); + reader_properties.set_read_dictionary(0, true); + reader_properties.set_read_dictionary(1, true); + ::arrow::ExtensionTypeGuard guard(::arrow::extension::variant(storage_type)); + parquet::arrow::FileReaderBuilder builder; + ASSERT_OK(builder.Open(buffer_reader)); + builder.properties(reader_properties); + ASSERT_OK_AND_ASSIGN(auto reader, builder.Build()); + + ASSERT_OK_AND_ASSIGN(auto read_table, reader->ReadTable()); + auto column = read_table->GetColumnByName("variant"); + ASSERT_NE(nullptr, column); + ASSERT_OK(ValidateVariants(*column, ::arrow::default_memory_pool())); +} + +} // namespace parquet::variant