From 162d5032762b0851b120cf6d5b14218807f8f128 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Thu, 25 Jun 2026 22:24:48 -0700 Subject: [PATCH 1/2] GH-45946: [C++][Parquet] Variant decoding --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/extension/CMakeLists.txt | 3 +- cpp/src/arrow/extension/meson.build | 3 +- cpp/src/arrow/extension/variant.cc | 1314 +++++++++ cpp/src/arrow/extension/variant.h | 583 ++++ .../extension/variant_internal_test_util.h | 137 + .../arrow/extension/variant_internal_util.h | 71 + cpp/src/arrow/extension/variant_test.cc | 2412 +++++++++++++++++ cpp/src/arrow/meson.build | 1 + 9 files changed, 4523 insertions(+), 2 deletions(-) create mode 100644 cpp/src/arrow/extension/variant.cc create mode 100644 cpp/src/arrow/extension/variant.h create mode 100644 cpp/src/arrow/extension/variant_internal_test_util.h create mode 100644 cpp/src/arrow/extension/variant_internal_util.h create mode 100644 cpp/src/arrow/extension/variant_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 45cd7e838121..149ec9c6ff19 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -391,6 +391,7 @@ set(ARROW_SRCS extension/bool8.cc extension/json.cc extension/parquet_variant.cc + extension/variant.cc extension/uuid.cc pretty_print.cc record_batch.cc diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index ae52bc32a998..283a328a9098 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc) +set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc + variant_test.cc) if(ARROW_JSON) list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc) diff --git a/cpp/src/arrow/extension/meson.build b/cpp/src/arrow/extension/meson.build index 84dafe4bbe32..6d2222698c12 100644 --- a/cpp/src/arrow/extension/meson.build +++ b/cpp/src/arrow/extension/meson.build @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc'] +canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc', 'variant_test.cc'] if needs_json canonical_extension_tests += [ @@ -40,5 +40,6 @@ install_headers( 'parquet_variant.h', 'uuid.h', 'variable_shape_tensor.h', + 'variant.h', ], ) diff --git a/cpp/src/arrow/extension/variant.cc b/cpp/src/arrow/extension/variant.cc new file mode 100644 index 000000000000..3deff3d0610b --- /dev/null +++ b/cpp/src/arrow/extension/variant.cc @@ -0,0 +1,1314 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant.h" + +#include + +#include "arrow/extension/variant_internal_util.h" +#include "arrow/util/endian.h" +#include "arrow/util/logging_internal.h" + +namespace arrow::extension::variant { + +// Ensure view classes remain lightweight (stack-allocated, cache-friendly). +static_assert(sizeof(VariantView) <= 32, "VariantView should fit in 32 bytes"); +static_assert(sizeof(VariantObjectView) <= 80, + "VariantObjectView should fit in 80 bytes"); +static_assert(sizeof(VariantArrayView) <= 64, "VariantArrayView should fit in 64 bytes"); + +namespace { + +// --------------------------------------------------------------------------- +// Little-endian helpers (delegate to shared internal utility) +// --------------------------------------------------------------------------- + +using internal::ReadUnsignedLE; + +/// \brief Validate that offsets are monotonically non-decreasing and in bounds. +Status ValidateOffsets(const std::vector& offsets, int64_t data_length) { + for (size_t i = 1; i < offsets.size(); ++i) { + if (offsets[i] < offsets[i - 1]) { + return Status::Invalid( + "Variant metadata: string offsets are not monotonically " + "non-decreasing at index ", + i); + } + } + if (!offsets.empty() && offsets.back() > static_cast(data_length)) { + return Status::Invalid("Variant metadata: last string offset ", offsets.back(), + " exceeds data length ", data_length); + } + return Status::OK(); +} + +// --------------------------------------------------------------------------- +// Recursive visitor traversal (internal) +// --------------------------------------------------------------------------- + +Status VisitValueAt(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, VariantVisitor* visitor, int64_t* bytes_consumed, + int32_t depth); + +Status VisitPrimitive(const uint8_t* data, int64_t length, int64_t offset, uint8_t header, + VariantVisitor* visitor, int64_t* bytes_consumed) { + auto primitive_type = GetPrimitiveType(header); + int64_t pos = offset + 1; + + auto check_remaining = [&](int64_t needed) -> Status { + if (pos + needed > length) { + return Status::Invalid("Variant value: truncated primitive at offset ", offset, + ", need ", needed, " bytes but only ", length - pos, + " remaining"); + } + return Status::OK(); + }; + + switch (primitive_type) { + case PrimitiveType::kNull: + ARROW_RETURN_NOT_OK(visitor->Null()); + *bytes_consumed = 1; + return Status::OK(); + case PrimitiveType::kTrue: + ARROW_RETURN_NOT_OK(visitor->Bool(true)); + *bytes_consumed = 1; + return Status::OK(); + case PrimitiveType::kFalse: + ARROW_RETURN_NOT_OK(visitor->Bool(false)); + *bytes_consumed = 1; + return Status::OK(); + case PrimitiveType::kInt8: { + ARROW_RETURN_NOT_OK(check_remaining(1)); + ARROW_RETURN_NOT_OK(visitor->Int8(static_cast(data[pos]))); + *bytes_consumed = 2; + return Status::OK(); + } + case PrimitiveType::kInt16: { + ARROW_RETURN_NOT_OK(check_remaining(2)); + int16_t value; + std::memcpy(&value, data + pos, 2); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Int16(value)); + *bytes_consumed = 3; + return Status::OK(); + } + case PrimitiveType::kInt32: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + int32_t value; + std::memcpy(&value, data + pos, 4); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Int32(value)); + *bytes_consumed = 5; + return Status::OK(); + } + case PrimitiveType::kInt64: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Int64(value)); + *bytes_consumed = 9; + return Status::OK(); + } + case PrimitiveType::kFloat: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + float value; + std::memcpy(&value, data + pos, 4); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Float(value)); + *bytes_consumed = 5; + return Status::OK(); + } + case PrimitiveType::kDouble: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + double value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Double(value)); + *bytes_consumed = 9; + return Status::OK(); + } + case PrimitiveType::kDecimal4: { + ARROW_RETURN_NOT_OK(check_remaining(5)); + auto scale = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Decimal4(data + pos + 1, scale)); + *bytes_consumed = 6; + return Status::OK(); + } + case PrimitiveType::kDecimal8: { + ARROW_RETURN_NOT_OK(check_remaining(9)); + auto scale = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Decimal8(data + pos + 1, scale)); + *bytes_consumed = 10; + return Status::OK(); + } + case PrimitiveType::kDecimal16: { + ARROW_RETURN_NOT_OK(check_remaining(17)); + auto scale = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Decimal16(data + pos + 1, scale)); + *bytes_consumed = 18; + return Status::OK(); + } + case PrimitiveType::kDate: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + int32_t value; + std::memcpy(&value, data + pos, 4); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Date(value)); + *bytes_consumed = 5; + return Status::OK(); + } + case PrimitiveType::kTimestampMicros: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampMicros(value)); + *bytes_consumed = 9; + return Status::OK(); + } + case PrimitiveType::kTimestampMicrosNTZ: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampMicrosNTZ(value)); + *bytes_consumed = 9; + return Status::OK(); + } + case PrimitiveType::kBinary: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + uint32_t bin_length; + std::memcpy(&bin_length, data + pos, 4); + bin_length = bit_util::FromLittleEndian(bin_length); + ARROW_RETURN_NOT_OK(check_remaining(4 + static_cast(bin_length))); + auto view = + std::string_view(reinterpret_cast(data + pos + 4), bin_length); + ARROW_RETURN_NOT_OK(visitor->Binary(view)); + *bytes_consumed = 1 + 4 + static_cast(bin_length); + return Status::OK(); + } + case PrimitiveType::kString: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + uint32_t str_length; + std::memcpy(&str_length, data + pos, 4); + str_length = bit_util::FromLittleEndian(str_length); + ARROW_RETURN_NOT_OK(check_remaining(4 + static_cast(str_length))); + auto view = + std::string_view(reinterpret_cast(data + pos + 4), str_length); + ARROW_RETURN_NOT_OK(visitor->String(view)); + *bytes_consumed = 1 + 4 + static_cast(str_length); + return Status::OK(); + } + case PrimitiveType::kTimeNTZ: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimeNTZ(value)); + *bytes_consumed = 9; + return Status::OK(); + } + case PrimitiveType::kTimestampNanos: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampNanos(value)); + *bytes_consumed = 9; + return Status::OK(); + } + case PrimitiveType::kTimestampNanosNTZ: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampNanosNTZ(value)); + *bytes_consumed = 9; + return Status::OK(); + } + case PrimitiveType::kUUID: { + ARROW_RETURN_NOT_OK(check_remaining(kUUIDByteLength)); + ARROW_RETURN_NOT_OK(visitor->UUID(data + pos)); + *bytes_consumed = kUUIDByteLength + 1; + return Status::OK(); + } + default: + return Status::Invalid("Variant value: unknown primitive type ", + static_cast(primitive_type)); + } +} + +Status VisitShortString(const uint8_t* data, int64_t length, int64_t offset, + uint8_t header, VariantVisitor* visitor, + int64_t* bytes_consumed) { + int32_t str_len = (header >> 2) & 0x3F; + int64_t pos = offset + 1; + if (pos + str_len > length) { + return Status::Invalid("Variant value: truncated short string at offset ", offset); + } + auto view = std::string_view(reinterpret_cast(data + pos), str_len); + ARROW_RETURN_NOT_OK(visitor->String(view)); + *bytes_consumed = 1 + str_len; + return Status::OK(); +} + +Status VisitObject(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, uint8_t header, VariantVisitor* visitor, + int64_t* bytes_consumed, int32_t depth) { + uint8_t type_info = (header >> 2) & 0x3F; + int32_t field_offset_size = (type_info & 0x03) + 1; + int32_t field_id_size = ((type_info >> 2) & 0x03) + 1; + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t num_fields_size = is_large ? 4 : 1; + + int64_t pos = offset + 1; + if (pos + num_fields_size > length) { + return Status::Invalid("Variant value: truncated object num_fields at offset ", + offset); + } + auto num_fields = static_cast(ReadUnsignedLE(data + pos, num_fields_size)); + pos += num_fields_size; + + int64_t field_ids_size = static_cast(num_fields) * field_id_size; + if (pos + field_ids_size > length) { + return Status::Invalid("Variant value: truncated object field_ids at offset ", + offset); + } + std::vector field_ids(num_fields); + for (int32_t i = 0; i < num_fields; ++i) { + field_ids[i] = ReadUnsignedLE(data + pos, field_id_size); + pos += field_id_size; + } + + int64_t offsets_size = (static_cast(num_fields) + 1) * field_offset_size; + if (pos + offsets_size > length) { + return Status::Invalid("Variant value: truncated object offsets at offset ", offset); + } + std::vector value_offsets(num_fields + 1); + for (int32_t i = 0; i <= num_fields; ++i) { + value_offsets[i] = ReadUnsignedLE(data + pos, field_offset_size); + pos += field_offset_size; + } + + int64_t data_start = pos; + int64_t total_data_size = static_cast(value_offsets[num_fields]); + if (data_start + total_data_size > length) { + return Status::Invalid("Variant value: object data exceeds buffer at offset ", + offset); + } + + for (int32_t i = 0; i < num_fields; ++i) { + if (value_offsets[i] > static_cast(total_data_size)) { + return Status::Invalid("Variant value: object field offset ", value_offsets[i], + " at index ", i, " exceeds data size ", total_data_size); + } + } + + ARROW_RETURN_NOT_OK(visitor->StartObject(num_fields)); + + for (int32_t i = 0; i < num_fields; ++i) { + auto field_id = field_ids[i]; + if (field_id >= metadata.strings.size()) { + return Status::Invalid("Variant value: field_id ", field_id, + " exceeds metadata dictionary size ", + metadata.strings.size()); + } + ARROW_RETURN_NOT_OK(visitor->FieldName(metadata.strings[field_id])); + + int64_t field_offset = data_start + value_offsets[i]; + int64_t consumed = 0; + ARROW_RETURN_NOT_OK(VisitValueAt(metadata, data, data_start + total_data_size, + field_offset, visitor, &consumed, depth)); + } + + ARROW_RETURN_NOT_OK(visitor->EndObject()); + *bytes_consumed = (data_start - offset) + total_data_size; + return Status::OK(); +} + +Status VisitArray(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, uint8_t header, VariantVisitor* visitor, + int64_t* bytes_consumed, int32_t depth) { + uint8_t type_info = (header >> 2) & 0x3F; + int32_t field_offset_size = (type_info & 0x03) + 1; + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t num_elements_size = is_large ? 4 : 1; + + int64_t pos = offset + 1; + if (pos + num_elements_size > length) { + return Status::Invalid("Variant value: truncated array num_elements at offset ", + offset); + } + auto num_elements = static_cast(ReadUnsignedLE(data + pos, num_elements_size)); + pos += num_elements_size; + + int64_t offsets_size = (static_cast(num_elements) + 1) * field_offset_size; + if (pos + offsets_size > length) { + return Status::Invalid("Variant value: truncated array offsets at offset ", offset); + } + std::vector value_offsets(num_elements + 1); + for (int32_t i = 0; i <= num_elements; ++i) { + value_offsets[i] = ReadUnsignedLE(data + pos, field_offset_size); + pos += field_offset_size; + } + + for (int32_t i = 1; i <= num_elements; ++i) { + if (value_offsets[i] < value_offsets[i - 1]) { + return Status::Invalid( + "Variant value: array value offsets are not monotonically " + "non-decreasing at index ", + i); + } + } + + int64_t data_start = pos; + int64_t total_data_size = static_cast(value_offsets[num_elements]); + if (data_start + total_data_size > length) { + return Status::Invalid("Variant value: array data exceeds buffer at offset ", offset); + } + + ARROW_RETURN_NOT_OK(visitor->StartArray(num_elements)); + + for (int32_t i = 0; i < num_elements; ++i) { + int64_t elem_offset = data_start + value_offsets[i]; + int64_t consumed = 0; + ARROW_RETURN_NOT_OK(VisitValueAt(metadata, data, data_start + total_data_size, + elem_offset, visitor, &consumed, depth)); + } + + ARROW_RETURN_NOT_OK(visitor->EndArray()); + *bytes_consumed = (data_start - offset) + total_data_size; + return Status::OK(); +} + +Status VisitValueAt(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, VariantVisitor* visitor, int64_t* bytes_consumed, + int32_t depth) { + if (offset >= length) { + return Status::Invalid("Variant value: offset ", offset, + " is at or beyond buffer length ", length); + } + if (depth > kMaxNestingDepth) { + return Status::Invalid("Variant value: nesting depth exceeds maximum of ", + kMaxNestingDepth); + } + + uint8_t header = data[offset]; + auto basic_type = GetBasicType(header); + + switch (basic_type) { + case BasicType::kPrimitive: + return VisitPrimitive(data, length, offset, header, visitor, bytes_consumed); + case BasicType::kShortString: + return VisitShortString(data, length, offset, header, visitor, bytes_consumed); + case BasicType::kObject: + return VisitObject(metadata, data, length, offset, header, visitor, bytes_consumed, + depth + 1); + case BasicType::kArray: + return VisitArray(metadata, data, length, offset, header, visitor, bytes_consumed, + depth + 1); + default: + return Status::Invalid("Variant value: unknown basic type ", + static_cast(basic_type)); + } +} + +} // namespace + +// =========================================================================== +// Public API: Metadata +// =========================================================================== + +Result DecodeMetadata(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant metadata: buffer is null or empty"); + } + + uint8_t header = data[0]; + uint8_t version = header & 0x0F; + if (version != kVariantVersion) { + return Status::Invalid("Variant metadata: unsupported version ", + static_cast(version), ", expected ", + static_cast(kVariantVersion)); + } + + if ((header >> 5) & 0x01) { + return Status::Invalid("Variant metadata: reserved bit 5 is set in header"); + } + + bool is_sorted = ((header >> 4) & 0x01) != 0; + int32_t offset_size = ((header >> 6) & 0x03) + 1; + + int64_t pos = 1; + if (pos + offset_size > length) { + return Status::Invalid("Variant metadata: truncated dictionary size at byte ", pos); + } + auto dict_size = static_cast(ReadUnsignedLE(data + pos, offset_size)); + pos += offset_size; + + int64_t offsets_bytes = static_cast(dict_size + 1) * offset_size; + if (pos + offsets_bytes > length) { + return Status::Invalid("Variant metadata: truncated string offsets, need ", + offsets_bytes, " bytes at position ", pos, + " but buffer length is ", length); + } + + std::vector offsets(dict_size + 1); + for (int32_t i = 0; i <= dict_size; ++i) { + offsets[i] = ReadUnsignedLE(data + pos, offset_size); + pos += offset_size; + } + + int64_t string_data_length = length - pos; + ARROW_RETURN_NOT_OK(ValidateOffsets(offsets, string_data_length)); + + std::vector strings(dict_size); + for (int32_t i = 0; i < dict_size; ++i) { + auto start = static_cast(offsets[i]); + auto end = static_cast(offsets[i + 1]); + strings[i] = + std::string_view(reinterpret_cast(data + pos + start), end - start); + } + + VariantMetadata result; + result.version = version; + result.is_sorted = is_sorted; + result.offset_size = offset_size; + result.strings = std::move(strings); + return result; +} + +int32_t FindMetadataKey(const VariantMetadata& metadata, std::string_view key) { + if (metadata.is_sorted) { + int32_t lo = 0; + int32_t hi = static_cast(metadata.strings.size()) - 1; + while (lo <= hi) { + int32_t mid = lo + (hi - lo) / 2; + int cmp = metadata.strings[mid].compare(key); + if (cmp == 0) return mid; + if (cmp < 0) + lo = mid + 1; + else + hi = mid - 1; + } + return -1; + } + for (int32_t i = 0; i < static_cast(metadata.strings.size()); ++i) { + if (metadata.strings[i] == key) return i; + } + return -1; +} + +int32_t PrimitiveValueSize(PrimitiveType primitive_type) { + switch (primitive_type) { + case PrimitiveType::kNull: + case PrimitiveType::kTrue: + case PrimitiveType::kFalse: + return 0; + case PrimitiveType::kInt8: + return 1; + case PrimitiveType::kInt16: + return 2; + case PrimitiveType::kInt32: + case PrimitiveType::kFloat: + case PrimitiveType::kDate: + return 4; + case PrimitiveType::kInt64: + case PrimitiveType::kDouble: + case PrimitiveType::kTimestampMicros: + case PrimitiveType::kTimestampMicrosNTZ: + case PrimitiveType::kTimeNTZ: + case PrimitiveType::kTimestampNanos: + case PrimitiveType::kTimestampNanosNTZ: + return 8; + case PrimitiveType::kDecimal4: + return 5; + case PrimitiveType::kDecimal8: + return 9; + case PrimitiveType::kDecimal16: + return 17; + case PrimitiveType::kUUID: + return kUUIDByteLength; + case PrimitiveType::kBinary: + case PrimitiveType::kString: + return -1; + default: + return -1; + } +} + +Result ValueSize(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("ValueSize: buffer is null or empty"); + } + + uint8_t header = data[0]; + auto basic_type = GetBasicType(header); + uint8_t type_info = (header >> 2) & 0x3F; + + switch (basic_type) { + case BasicType::kShortString: + return 1 + static_cast(type_info); + + case BasicType::kObject: { + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t sz_bytes = is_large ? 4 : 1; + if (1 + sz_bytes > length) { + return Status::Invalid("ValueSize: truncated object header"); + } + auto num_elements = static_cast(ReadUnsignedLE(data + 1, sz_bytes)); + int32_t id_size = ((type_info >> 2) & 0x03) + 1; + int32_t offset_size = (type_info & 0x03) + 1; + int64_t id_start = 1 + sz_bytes; + int64_t offset_start = id_start + num_elements * id_size; + int64_t data_start = offset_start + (num_elements + 1) * offset_size; + int64_t last_offset_pos = offset_start + num_elements * offset_size; + if (last_offset_pos + offset_size > length) { + return Status::Invalid("ValueSize: truncated object offsets"); + } + auto total_data = + static_cast(ReadUnsignedLE(data + last_offset_pos, offset_size)); + return data_start + total_data; + } + + case BasicType::kArray: { + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t sz_bytes = is_large ? 4 : 1; + if (1 + sz_bytes > length) { + return Status::Invalid("ValueSize: truncated array header"); + } + auto num_elements = static_cast(ReadUnsignedLE(data + 1, sz_bytes)); + int32_t offset_size = (type_info & 0x03) + 1; + int64_t offset_start = 1 + sz_bytes; + int64_t data_start = offset_start + (num_elements + 1) * offset_size; + int64_t last_offset_pos = offset_start + num_elements * offset_size; + if (last_offset_pos + offset_size > length) { + return Status::Invalid("ValueSize: truncated array offsets"); + } + auto total_data = + static_cast(ReadUnsignedLE(data + last_offset_pos, offset_size)); + return data_start + total_data; + } + + case BasicType::kPrimitive: { + auto ptype = static_cast(type_info); + int32_t payload_size = PrimitiveValueSize(ptype); + if (payload_size >= 0) { + return 1 + static_cast(payload_size); + } + if (1 + 4 > length) { + return Status::Invalid("ValueSize: truncated variable-length header"); + } + uint32_t var_len; + std::memcpy(&var_len, data + 1, 4); + var_len = bit_util::FromLittleEndian(var_len); + return 1 + 4 + static_cast(var_len); + } + + default: + return Status::Invalid("ValueSize: unknown basic type"); + } +} + +// =========================================================================== +// Public API: VariantView +// =========================================================================== + +VariantView::VariantView(const VariantMetadata* metadata, const uint8_t* data, + int64_t size, BasicType type) + : metadata_(metadata), data_(data), size_(size), type_(type) {} + +Result VariantView::Make(const VariantMetadata& metadata, + const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("VariantView: buffer is null or empty"); + } + ARROW_ASSIGN_OR_RAISE(auto size, ValueSize(data, length)); + if (size > length) { + return Status::Invalid("VariantView: value size ", size, " exceeds buffer length ", + length); + } + auto type = GetBasicType(data[0]); + return VariantView(&metadata, data, size, type); +} + +bool VariantView::is_null() const { + return type_ == BasicType::kPrimitive && + GetPrimitiveType(data_[0]) == PrimitiveType::kNull; +} + +Status VariantView::Visit(VariantVisitor* visitor) const { + DCHECK_NE(visitor, nullptr); + int64_t bytes_consumed = 0; + return VisitValueAt(*metadata_, data_, size_, 0, visitor, &bytes_consumed, 0); +} + +// --- Primitive accessors --- + +Result VariantView::as_bool() const { + if (type_ != BasicType::kPrimitive) { + return Status::Invalid("VariantView::as_bool: not a primitive"); + } + auto pt = GetPrimitiveType(data_[0]); + if (pt == PrimitiveType::kTrue) return true; + if (pt == PrimitiveType::kFalse) return false; + return Status::Invalid("VariantView::as_bool: not a boolean"); +} + +Result VariantView::as_int8() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kInt8) { + return Status::Invalid("VariantView::as_int8: type mismatch"); + } + return static_cast(data_[1]); +} + +Result VariantView::as_int16() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kInt16) { + return Status::Invalid("VariantView::as_int16: type mismatch"); + } + int16_t value; + std::memcpy(&value, data_ + 1, 2); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_int32() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kInt32) { + return Status::Invalid("VariantView::as_int32: type mismatch"); + } + int32_t value; + std::memcpy(&value, data_ + 1, 4); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_int64() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kInt64) { + return Status::Invalid("VariantView::as_int64: type mismatch"); + } + int64_t value; + std::memcpy(&value, data_ + 1, 8); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_float() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kFloat) { + return Status::Invalid("VariantView::as_float: type mismatch"); + } + float value; + std::memcpy(&value, data_ + 1, 4); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_double() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kDouble) { + return Status::Invalid("VariantView::as_double: type mismatch"); + } + double value; + std::memcpy(&value, data_ + 1, 8); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_string() const { + if (type_ == BasicType::kShortString) { + int32_t len = (data_[0] >> 2) & 0x3F; + return std::string_view(reinterpret_cast(data_ + 1), len); + } + if (type_ == BasicType::kPrimitive && + GetPrimitiveType(data_[0]) == PrimitiveType::kString) { + uint32_t len; + std::memcpy(&len, data_ + 1, 4); + len = bit_util::FromLittleEndian(len); + return std::string_view(reinterpret_cast(data_ + 5), len); + } + return Status::Invalid("VariantView::as_string: not a string"); +} + +Result VariantView::as_binary() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kBinary) { + return Status::Invalid("VariantView::as_binary: not binary"); + } + uint32_t len; + std::memcpy(&len, data_ + 1, 4); + len = bit_util::FromLittleEndian(len); + return std::string_view(reinterpret_cast(data_ + 5), len); +} + +Result VariantView::as_date() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kDate) { + return Status::Invalid("VariantView::as_date: type mismatch"); + } + int32_t value; + std::memcpy(&value, data_ + 1, 4); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_timestamp_micros() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kTimestampMicros) { + return Status::Invalid("VariantView::as_timestamp_micros: type mismatch"); + } + int64_t value; + std::memcpy(&value, data_ + 1, 8); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_timestamp_micros_ntz() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kTimestampMicrosNTZ) { + return Status::Invalid("VariantView::as_timestamp_micros_ntz: type mismatch"); + } + int64_t value; + std::memcpy(&value, data_ + 1, 8); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_timestamp_nanos() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kTimestampNanos) { + return Status::Invalid("VariantView::as_timestamp_nanos: type mismatch"); + } + int64_t value; + std::memcpy(&value, data_ + 1, 8); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_timestamp_nanos_ntz() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kTimestampNanosNTZ) { + return Status::Invalid("VariantView::as_timestamp_nanos_ntz: type mismatch"); + } + int64_t value; + std::memcpy(&value, data_ + 1, 8); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_time_ntz() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kTimeNTZ) { + return Status::Invalid("VariantView::as_time_ntz: type mismatch"); + } + int64_t value; + std::memcpy(&value, data_ + 1, 8); + return bit_util::FromLittleEndian(value); +} + +Result VariantView::as_uuid() const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kUUID) { + return Status::Invalid("VariantView::as_uuid: type mismatch"); + } + return data_ + 1; +} + +Result VariantView::as_decimal4(int32_t* scale) const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kDecimal4) { + return Status::Invalid("VariantView::as_decimal4: type mismatch"); + } + *scale = static_cast(data_[1]); + return data_ + 2; +} + +Result VariantView::as_decimal8(int32_t* scale) const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kDecimal8) { + return Status::Invalid("VariantView::as_decimal8: type mismatch"); + } + *scale = static_cast(data_[1]); + return data_ + 2; +} + +Result VariantView::as_decimal16(int32_t* scale) const { + if (type_ != BasicType::kPrimitive || + GetPrimitiveType(data_[0]) != PrimitiveType::kDecimal16) { + return Status::Invalid("VariantView::as_decimal16: type mismatch"); + } + *scale = static_cast(data_[1]); + return data_ + 2; +} + +Result VariantView::as_object() const { + if (type_ != BasicType::kObject) { + return Status::Invalid("VariantView::as_object: not an object"); + } + return VariantObjectView::Make(*metadata_, data_, size_); +} + +Result VariantView::as_array() const { + if (type_ != BasicType::kArray) { + return Status::Invalid("VariantView::as_array: not an array"); + } + return VariantArrayView::Make(*metadata_, data_, size_); +} + +// =========================================================================== +// Public API: VariantObjectView +// =========================================================================== + +VariantObjectView::VariantObjectView(const VariantMetadata* metadata, const uint8_t* data, + int64_t length, int32_t num_fields, + int8_t field_id_size, int8_t field_offset_size, + int64_t id_start, int64_t offset_start, + int64_t data_start) + : metadata_(metadata), + data_(data), + length_(length), + num_fields_(num_fields), + field_id_size_(field_id_size), + field_offset_size_(field_offset_size), + id_start_(id_start), + offset_start_(offset_start), + data_start_(data_start) {} + +Result VariantObjectView::Make(const VariantMetadata& metadata, + const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("VariantObjectView: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kObject) { + return Status::Invalid("VariantObjectView: not an object"); + } + + uint8_t type_info = (header >> 2) & 0x3F; + int8_t field_offset_size = static_cast((type_info & 0x03) + 1); + int8_t field_id_size = static_cast(((type_info >> 2) & 0x03) + 1); + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t num_fields_size = is_large ? 4 : 1; + + if (1 + num_fields_size > length) { + return Status::Invalid("VariantObjectView: truncated num_fields"); + } + auto num_fields = static_cast(ReadUnsignedLE(data + 1, num_fields_size)); + + int64_t id_start = 1 + num_fields_size; + int64_t offset_start = id_start + static_cast(num_fields) * field_id_size; + int64_t data_start = + offset_start + (static_cast(num_fields) + 1) * field_offset_size; + + if (data_start > length) { + return Status::Invalid("VariantObjectView: truncated object structure"); + } + + // Validate last offset is within buffer + int64_t last_offset_pos = + offset_start + static_cast(num_fields) * field_offset_size; + auto total_data = + static_cast(ReadUnsignedLE(data + last_offset_pos, field_offset_size)); + if (data_start + total_data > length) { + return Status::Invalid("VariantObjectView: object data exceeds buffer"); + } + + return VariantObjectView(&metadata, data, length, num_fields, field_id_size, + field_offset_size, id_start, offset_start, data_start); +} + +uint32_t VariantObjectView::field_id_at(int32_t i) const { + return ReadUnsignedLE(data_ + id_start_ + i * field_id_size_, field_id_size_); +} + +int64_t VariantObjectView::value_offset_at(int32_t i) const { + return data_start_ + + static_cast(ReadUnsignedLE( + data_ + offset_start_ + i * field_offset_size_, field_offset_size_)); +} + +Result VariantObjectView::field_name(int32_t index) const { + if (index < 0 || index >= num_fields_) { + return Status::Invalid("VariantObjectView::field_name: index out of range"); + } + auto id = field_id_at(index); + if (id >= metadata_->strings.size()) { + return Status::Invalid("VariantObjectView::field_name: field_id exceeds dictionary"); + } + return metadata_->strings[id]; +} + +Result VariantObjectView::field_value(int32_t index) const { + if (index < 0 || index >= num_fields_) { + return Status::Invalid("VariantObjectView::field_value: index out of range"); + } + int64_t offset = value_offset_at(index); + return VariantView::Make(*metadata_, data_ + offset, length_ - offset); +} + +std::optional VariantObjectView::get(std::string_view name) const { + // Binary search — field IDs are sorted by lexicographic key order per spec. + int32_t lo = 0, hi = num_fields_ - 1; + while (lo <= hi) { + int32_t mid = lo + (hi - lo) / 2; + auto id = field_id_at(mid); + if (id >= metadata_->strings.size()) { + return std::nullopt; // Malformed data — graceful degradation + } + auto key = metadata_->strings[id]; + if (key == name) { + int64_t offset = value_offset_at(mid); + auto view = VariantView::Make(*metadata_, data_ + offset, length_ - offset); + if (view.ok()) return *view; + return std::nullopt; + } + if (key < name) + lo = mid + 1; + else + hi = mid - 1; + } + return std::nullopt; +} + +bool VariantObjectView::contains(std::string_view name) const { + return get(name).has_value(); +} + +std::optional VariantObjectView::locate( + std::string_view name) const { + // Binary search for the field + int32_t lo = 0, hi = num_fields_ - 1; + while (lo <= hi) { + int32_t mid = lo + (hi - lo) / 2; + auto id = field_id_at(mid); + if (id >= metadata_->strings.size()) { + return std::nullopt; + } + auto key = metadata_->strings[id]; + if (key == name) { + int64_t offset = value_offset_at(mid); + auto size_result = ValueSize(data_ + offset, length_ - offset); + if (!size_result.ok()) return std::nullopt; + return FieldLocation{offset, *size_result}; + } + if (key < name) + lo = mid + 1; + else + hi = mid - 1; + } + return std::nullopt; +} + +// Object iterator +VariantObjectView::Iterator::Iterator(const VariantObjectView* obj, int32_t index) + : obj_(obj), index_(index) {} + +VariantObjectView::Iterator::value_type VariantObjectView::Iterator::operator*() const { + auto name = obj_->field_name(index_).ValueOrDie(); + auto value = obj_->field_value(index_).ValueOrDie(); + return {name, value}; +} + +VariantObjectView::Iterator& VariantObjectView::Iterator::operator++() { + ++index_; + return *this; +} + +bool VariantObjectView::Iterator::operator!=(const Iterator& other) const { + return index_ != other.index_; +} + +VariantObjectView::Iterator VariantObjectView::begin() const { return Iterator(this, 0); } + +VariantObjectView::Iterator VariantObjectView::end() const { + return Iterator(this, num_fields_); +} + +// =========================================================================== +// Public API: VariantArrayView +// =========================================================================== + +VariantArrayView::VariantArrayView(const VariantMetadata* metadata, const uint8_t* data, + int64_t length, int32_t num_elements, + int8_t offset_size, int64_t offset_start, + int64_t data_start) + : metadata_(metadata), + data_(data), + length_(length), + num_elements_(num_elements), + offset_size_(offset_size), + offset_start_(offset_start), + data_start_(data_start) {} + +Result VariantArrayView::Make(const VariantMetadata& metadata, + const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("VariantArrayView: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kArray) { + return Status::Invalid("VariantArrayView: not an array"); + } + + uint8_t type_info = (header >> 2) & 0x3F; + int8_t offset_size = static_cast((type_info & 0x03) + 1); + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t num_elements_size = is_large ? 4 : 1; + + if (1 + num_elements_size > length) { + return Status::Invalid("VariantArrayView: truncated num_elements"); + } + auto num_elements = static_cast(ReadUnsignedLE(data + 1, num_elements_size)); + + int64_t offset_start = 1 + num_elements_size; + int64_t data_start = + offset_start + (static_cast(num_elements) + 1) * offset_size; + + if (data_start > length) { + return Status::Invalid("VariantArrayView: truncated array structure"); + } + + // Validate monotonicity and last offset in bounds + uint32_t prev = 0; + for (int32_t i = 0; i <= num_elements; ++i) { + auto off = ReadUnsignedLE(data + offset_start + i * offset_size, offset_size); + if (i > 0 && off < prev) { + return Status::Invalid( + "VariantArrayView: offsets not monotonically non-decreasing at index ", i); + } + prev = off; + } + + auto total_data = static_cast(prev); // last offset = total data size + if (data_start + total_data > length) { + return Status::Invalid("VariantArrayView: array data exceeds buffer"); + } + + return VariantArrayView(&metadata, data, length, num_elements, offset_size, + offset_start, data_start); +} + +int64_t VariantArrayView::element_offset_at(int32_t i) const { + return data_start_ + static_cast(ReadUnsignedLE( + data_ + offset_start_ + i * offset_size_, offset_size_)); +} + +Result VariantArrayView::get(int32_t index) const { + if (index < 0 || index >= num_elements_) { + return Status::Invalid("VariantArrayView::get: index ", index, " out of range [0, ", + num_elements_, ")"); + } + int64_t offset = element_offset_at(index); + return VariantView::Make(*metadata_, data_ + offset, length_ - offset); +} + +// Array iterator +VariantArrayView::Iterator::Iterator(const VariantArrayView* arr, int32_t index) + : arr_(arr), index_(index) {} + +VariantArrayView::Iterator::value_type VariantArrayView::Iterator::operator*() const { + return arr_->get(index_).ValueOrDie(); +} + +VariantArrayView::Iterator& VariantArrayView::Iterator::operator++() { + ++index_; + return *this; +} + +bool VariantArrayView::Iterator::operator!=(const Iterator& other) const { + return index_ != other.index_; +} + +VariantArrayView::Iterator VariantArrayView::begin() const { return Iterator(this, 0); } + +VariantArrayView::Iterator VariantArrayView::end() const { + return Iterator(this, num_elements_); +} + +// =========================================================================== +// Widening numeric accessors +// =========================================================================== + +Result VariantView::as_int64_coerced() const { + if (type_ != BasicType::kPrimitive) { + return Status::Invalid("VariantView::as_int64_coerced: not a primitive"); + } + auto pt = GetPrimitiveType(data_[0]); + switch (pt) { + case PrimitiveType::kInt8: + return static_cast(static_cast(data_[1])); + case PrimitiveType::kInt16: { + int16_t value; + std::memcpy(&value, data_ + 1, 2); + return static_cast(bit_util::FromLittleEndian(value)); + } + case PrimitiveType::kInt32: { + int32_t value; + std::memcpy(&value, data_ + 1, 4); + return static_cast(bit_util::FromLittleEndian(value)); + } + case PrimitiveType::kInt64: { + int64_t value; + std::memcpy(&value, data_ + 1, 8); + return bit_util::FromLittleEndian(value); + } + default: + return Status::Invalid("VariantView::as_int64_coerced: not an integer type"); + } +} + +Result VariantView::as_int32_coerced() const { + if (type_ != BasicType::kPrimitive) { + return Status::Invalid("VariantView::as_int32_coerced: not a primitive"); + } + auto pt = GetPrimitiveType(data_[0]); + switch (pt) { + case PrimitiveType::kInt8: + return static_cast(static_cast(data_[1])); + case PrimitiveType::kInt16: { + int16_t value; + std::memcpy(&value, data_ + 1, 2); + return static_cast(bit_util::FromLittleEndian(value)); + } + case PrimitiveType::kInt32: { + int32_t value; + std::memcpy(&value, data_ + 1, 4); + return bit_util::FromLittleEndian(value); + } + default: + return Status::Invalid( + "VariantView::as_int32_coerced: not a 32-bit-or-narrower " + "integer type"); + } +} + +Result VariantView::as_double_coerced() const { + if (type_ != BasicType::kPrimitive) { + return Status::Invalid("VariantView::as_double_coerced: not a primitive"); + } + auto pt = GetPrimitiveType(data_[0]); + switch (pt) { + case PrimitiveType::kInt8: + return static_cast(static_cast(data_[1])); + case PrimitiveType::kInt16: { + int16_t value; + std::memcpy(&value, data_ + 1, 2); + return static_cast(bit_util::FromLittleEndian(value)); + } + case PrimitiveType::kInt32: { + int32_t value; + std::memcpy(&value, data_ + 1, 4); + return static_cast(bit_util::FromLittleEndian(value)); + } + case PrimitiveType::kInt64: { + int64_t value; + std::memcpy(&value, data_ + 1, 8); + return static_cast(bit_util::FromLittleEndian(value)); + } + case PrimitiveType::kFloat: { + float value; + std::memcpy(&value, data_ + 1, 4); + return static_cast(bit_util::FromLittleEndian(value)); + } + case PrimitiveType::kDouble: { + double value; + std::memcpy(&value, data_ + 1, 8); + return bit_util::FromLittleEndian(value); + } + default: + return Status::Invalid("VariantView::as_double_coerced: not a numeric type"); + } +} + +// =========================================================================== +// ValidateVariant — full recursive validation +// =========================================================================== + +namespace { + +Status ValidateVariantRecursive(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, int64_t offset, int32_t depth) { + if (offset >= length) { + return Status::Invalid("ValidateVariant: offset ", offset, + " at or beyond buffer length ", length); + } + if (depth > kMaxNestingDepth) { + return Status::Invalid("ValidateVariant: nesting depth exceeds maximum of ", + kMaxNestingDepth); + } + + uint8_t header = data[offset]; + auto basic_type = GetBasicType(header); + + switch (basic_type) { + case BasicType::kPrimitive: { + // Validate the value size is computable and within bounds + ARROW_ASSIGN_OR_RAISE(auto size, ValueSize(data + offset, length - offset)); + if (offset + size > length) { + return Status::Invalid("ValidateVariant: primitive value at offset ", offset, + " requires ", size, " bytes but only ", length - offset, + " available"); + } + return Status::OK(); + } + case BasicType::kShortString: { + int32_t str_len = (header >> 2) & 0x3F; + if (offset + 1 + str_len > length) { + return Status::Invalid("ValidateVariant: truncated short string at offset ", + offset); + } + return Status::OK(); + } + case BasicType::kObject: { + // Construct a view (validates structure) then recursively validate children + ARROW_ASSIGN_OR_RAISE( + auto obj, VariantObjectView::Make(metadata, data + offset, length - offset)); + for (int32_t i = 0; i < obj.num_fields(); ++i) { + // Validate field name (checks field_id within dictionary bounds) + ARROW_RETURN_NOT_OK(obj.field_name(i).status()); + // Validate field value recursively + ARROW_ASSIGN_OR_RAISE(auto field_view, obj.field_value(i)); + auto field_offset = field_view.data() - data; + ARROW_RETURN_NOT_OK( + ValidateVariantRecursive(metadata, data, length, field_offset, depth + 1)); + } + return Status::OK(); + } + case BasicType::kArray: { + // Construct a view (validates structure) then recursively validate elements + ARROW_ASSIGN_OR_RAISE( + auto arr, VariantArrayView::Make(metadata, data + offset, length - offset)); + for (int32_t i = 0; i < arr.num_elements(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto elem_view, arr.get(i)); + auto elem_offset = elem_view.data() - data; + ARROW_RETURN_NOT_OK( + ValidateVariantRecursive(metadata, data, length, elem_offset, depth + 1)); + } + return Status::OK(); + } + default: + return Status::Invalid("ValidateVariant: unknown basic type at offset ", offset); + } +} + +} // namespace + +Status ValidateVariant(const VariantMetadata& metadata, const uint8_t* data, + int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("ValidateVariant: buffer is null or empty"); + } + return ValidateVariantRecursive(metadata, data, length, 0, 0); +} + +// =========================================================================== + +} // namespace arrow::extension::variant diff --git a/cpp/src/arrow/extension/variant.h b/cpp/src/arrow/extension/variant.h new file mode 100644 index 000000000000..ec9bfab2000b --- /dev/null +++ b/cpp/src/arrow/extension/variant.h @@ -0,0 +1,583 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +/// \file variant.h +/// \brief Public C++ API for Variant binary encoding/decoding. +/// +/// Provides zero-copy view classes for reading variant values (VariantView, +/// VariantObjectView, VariantArrayView) and a visitor interface for full +/// tree traversal. Implements the Variant Encoding Spec: +/// https://github.com/apache/parquet-format/blob/master/VariantEncoding.md +/// +/// Design principles: +/// - Parse once, query many (view classes pre-parse headers at construction) +/// - Zero-copy (string_view into source buffers, no heap allocation for reads) +/// - Type safety at boundaries (views validate at construction, not on access) +/// - O(log n) field lookup always (no threshold heuristics) + +#include +#include +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow::extension::variant { + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +/// Variant encoding spec version 1. +constexpr uint8_t kVariantVersion = 1; + +/// Maximum nesting depth for recursive value decoding. +/// Prevents stack overflow on deeply nested (possibly malicious) input. +constexpr int32_t kMaxNestingDepth = 128; + +/// UUID values are always 16 bytes (128-bit, big-endian per RFC 4122). +constexpr int32_t kUUIDByteLength = 16; + +// --------------------------------------------------------------------------- +// Enumerations +// --------------------------------------------------------------------------- + +/// \brief Basic type codes from bits 0-1 of the value header byte. +/// +/// See: +/// https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types +enum class BasicType : uint8_t { + kPrimitive = 0, + kShortString = 1, + kObject = 2, + kArray = 3, +}; + +/// \brief Primitive type codes from bits 2-7 when basic_type == kPrimitive. +/// +/// See: +/// https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types +enum class PrimitiveType : uint8_t { + kNull = 0, + kTrue = 1, + kFalse = 2, + kInt8 = 3, + kInt16 = 4, + kInt32 = 5, + kInt64 = 6, + kDouble = 7, + kDecimal4 = 8, + kDecimal8 = 9, + kDecimal16 = 10, + kDate = 11, + kTimestampMicros = 12, + kTimestampMicrosNTZ = 13, + kFloat = 14, + kBinary = 15, + kString = 16, + kTimeNTZ = 17, + kTimestampNanos = 18, + kTimestampNanosNTZ = 19, + kUUID = 20, +}; + +// --------------------------------------------------------------------------- +// Metadata +// --------------------------------------------------------------------------- + +/// \brief Parsed variant metadata (string dictionary). +/// +/// The metadata buffer contains a header byte followed by a dictionary of +/// interned strings used as object field names. String views reference the +/// raw buffer and are valid only as long as the underlying buffer is alive. +/// +/// \note This is NOT a schema — it contains key names only, not value types. +struct ARROW_EXPORT VariantMetadata { + /// Spec version (must be kVariantVersion). + uint8_t version = 0; + + /// Whether the dictionary strings are sorted lexicographically. + bool is_sorted = false; + + /// Number of bytes used for each offset (1, 2, 3, or 4). + int32_t offset_size = 0; + + /// Dictionary of interned strings. Views into the raw metadata buffer. + std::vector strings; +}; + +/// \brief Decode a variant metadata buffer. +/// +/// Parses the header byte and string dictionary from the raw metadata +/// buffer. The returned VariantMetadata contains string_views that +/// reference the input buffer directly (zero-copy). +/// +/// \param[in] data Pointer to the metadata buffer (must not be null) +/// \param[in] length Length of the metadata buffer in bytes +/// \return Parsed VariantMetadata on success, Status::Invalid on +/// malformed input +/// +/// \note The input buffer must outlive the returned VariantMetadata. +ARROW_EXPORT Result DecodeMetadata(const uint8_t* data, int64_t length); + +/// \brief Find the dictionary ID for a given key name. +/// +/// Uses binary search if the metadata is sorted, otherwise linear scan. +/// +/// \param[in] metadata Parsed metadata +/// \param[in] key The key to search for +/// \return The dictionary ID if found, or -1 if not present +ARROW_EXPORT int32_t FindMetadataKey(const VariantMetadata& metadata, + std::string_view key); + +// --------------------------------------------------------------------------- +// Header utilities +// --------------------------------------------------------------------------- + +/// \brief Extract the basic type from a value header byte. +inline BasicType GetBasicType(uint8_t header) { + return static_cast(header & 0x03); +} + +/// \brief Extract the primitive type from a value header byte. +/// Only valid when GetBasicType(header) == BasicType::kPrimitive. +inline PrimitiveType GetPrimitiveType(uint8_t header) { + return static_cast((header >> 2) & 0x3F); +} + +/// \brief Get the byte size of a primitive value (excluding header). +/// Returns -1 for variable-length types (Binary, String). +ARROW_EXPORT int32_t PrimitiveValueSize(PrimitiveType primitive_type); + +/// \brief Compute the total byte size of a variant value (header + data). +/// +/// Determines how many bytes a variant value occupies without full decoding. +/// \param[in] data Pointer to the start of a variant value +/// \param[in] length Maximum bytes available +/// \return Total byte count, or Status::Invalid if truncated +ARROW_EXPORT Result ValueSize(const uint8_t* data, int64_t length); + +/// \brief Recursively validate a variant value and all nested children. +/// +/// Performs deep structural validation of the entire value tree: +/// - All headers are well-formed +/// - All field IDs reference valid dictionary entries +/// - All offsets are within bounds and monotonically non-decreasing +/// - All nested values are recursively valid +/// - Nesting depth does not exceed kMaxNestingDepth +/// +/// Use this for untrusted input where you want a single pass/fail before +/// operating on the data. For trusted data (e.g., builder output), the +/// per-access validation in view classes is sufficient. +/// +/// \param[in] metadata Parsed metadata (for validating field ID references) +/// \param[in] data Pointer to the variant value buffer +/// \param[in] length Length of the variant value buffer in bytes +/// \return Status::OK if valid, Status::Invalid with description of first error +ARROW_EXPORT Status ValidateVariant(const VariantMetadata& metadata, const uint8_t* data, + int64_t length); + +// --------------------------------------------------------------------------- +// Forward declarations +// --------------------------------------------------------------------------- + +class VariantObjectView; +class VariantArrayView; + +// --------------------------------------------------------------------------- +// VariantView — non-owning view over a single variant value +// --------------------------------------------------------------------------- + +/// \brief A non-owning, zero-copy view over a single variant value. +/// +/// Construction validates the header byte is readable. Subsequent typed +/// accessors validate type compatibility and return errors on mismatch. +/// +/// Stack-allocated, ~32 bytes. No heap allocation. +/// +/// \note The metadata and data buffers must outlive this view. +class ARROW_EXPORT VariantView { + public: + /// \brief Construct a view over a variant value. + /// + /// Validates that the buffer has at least one byte and computes the + /// value's total size. Returns Invalid on empty/null buffers or if + /// the value is truncated. + /// + /// \param[in] metadata Parsed metadata (for resolving object field names) + /// \param[in] data Pointer to the value buffer + /// \param[in] length Length of the value buffer in bytes + static Result Make(const VariantMetadata& metadata, const uint8_t* data, + int64_t length); + + /// \brief The basic type of this value. + BasicType type() const { return type_; } + + /// \brief Whether this value is null (PrimitiveType::kNull). + bool is_null() const; + + /// \brief Total byte size of this value (header + payload). + int64_t size_bytes() const { return size_; } + + /// \brief Raw pointer to the value bytes. + const uint8_t* data() const { return data_; } + + /// @name Primitive accessors + /// Each returns the value if the type matches, or Status::Invalid. + /// @{ + Result as_bool() const; + Result as_int8() const; + Result as_int16() const; + Result as_int32() const; + Result as_int64() const; + Result as_float() const; + Result as_double() const; + Result as_string() const; + Result as_binary() const; + Result as_date() const; + Result as_timestamp_micros() const; + Result as_timestamp_micros_ntz() const; + Result as_timestamp_nanos() const; + Result as_timestamp_nanos_ntz() const; + Result as_time_ntz() const; + /// @} + + /// @name Widening numeric accessors (Rust parity) + /// These coerce narrower integer types to a wider target type. + /// e.g., as_int64_coerced() succeeds on Int8, Int16, Int32, or Int64 values. + /// @{ + + /// \brief Read any integer variant as int64, widening if necessary. + /// Succeeds for Int8, Int16, Int32, and Int64 encoded values. + Result as_int64_coerced() const; + + /// \brief Read any integer variant as int32, widening if necessary. + /// Succeeds for Int8 and Int16 and Int32 encoded values. + /// Returns Invalid for Int64 (would narrow). + Result as_int32_coerced() const; + + /// \brief Read any numeric variant as double, coercing if necessary. + /// Succeeds for Int8, Int16, Int32, Int64, Float, and Double. + /// Note: Int64 -> double may lose precision for large values. + Result as_double_coerced() const; + /// @} + + /// \brief Access UUID value (16 bytes, big-endian). + /// Returns pointer to the 16 UUID bytes within the source buffer. + Result as_uuid() const; + + /// \brief Access Decimal4 (scale + 4 bytes unscaled value). + /// \param[out] scale Set to the decimal scale + /// \return Pointer to the 4 unscaled value bytes + Result as_decimal4(int32_t* scale) const; + + /// \brief Access Decimal8 (scale + 8 bytes unscaled value). + Result as_decimal8(int32_t* scale) const; + + /// \brief Access Decimal16 (scale + 16 bytes unscaled value). + Result as_decimal16(int32_t* scale) const; + /// @} + + /// @name Container accessors + /// @{ + + /// \brief Interpret this value as an object. + /// Returns Invalid if the basic type is not kObject. + Result as_object() const; + + /// \brief Interpret this value as an array. + /// Returns Invalid if the basic type is not kArray. + Result as_array() const; + /// @} + + /// @name Visitor traversal + /// @{ + + /// \brief Full recursive traversal via visitor pattern. + /// + /// Visits every node in the value tree, calling appropriate visitor + /// methods. For bulk processing of entire variant values (e.g., JSON + /// serialization, schema inference). + /// + /// \param[in] visitor Callback interface for decoded values + /// \return Status::OK on success, or first error from visitor/decode + Status Visit(class VariantVisitor* visitor) const; + /// @} + + private: + VariantView(const VariantMetadata* metadata, const uint8_t* data, int64_t size, + BasicType type); + + const VariantMetadata* metadata_; + const uint8_t* data_; + int64_t size_; + BasicType type_; +}; + +// --------------------------------------------------------------------------- +// VariantObjectView — pre-parsed object for O(log n) field lookup +// --------------------------------------------------------------------------- + +/// \brief A non-owning view over a variant object value with pre-parsed header. +/// +/// Construction validates the object header structure (field counts, array +/// bounds). After construction, field lookup is O(log n) via binary search +/// with no redundant parsing. +/// +/// Stack-allocated, ~72 bytes. No heap allocation. +/// +/// \note The metadata and data buffers must outlive this view. +class ARROW_EXPORT VariantObjectView { + public: + /// \brief Construct an object view, pre-parsing the header. + /// + /// Validates: + /// 1. Basic type is kObject + /// 2. num_fields is readable + /// 3. Field ID array fits in buffer + /// 4. Offset array fits in buffer + /// 5. Last offset (total data size) is within buffer + /// + /// After Make() succeeds, all field accessors are bounds-safe. + static Result Make(const VariantMetadata& metadata, + const uint8_t* data, int64_t length); + + /// \brief Number of fields in this object. + int32_t num_fields() const { return num_fields_; } + + /// \brief Get the name of the i-th field (0-indexed). + /// \return The field name, or Invalid if index is out of range or + /// field ID references an invalid dictionary entry. + Result field_name(int32_t index) const; + + /// \brief Get the value of the i-th field (0-indexed). + /// \return A VariantView over the field's value. + Result field_value(int32_t index) const; + + /// \brief Lookup a field by name using binary search. + /// + /// Per spec, field IDs are sorted by lexicographic order of their + /// corresponding key names. This enables O(log n) lookup for all + /// object sizes. + /// + /// \param[in] name The field name to search for + /// \return The field's value if found, or std::nullopt if not present. + /// + /// \note Returns std::nullopt for both "field not found" AND "malformed data" + /// (e.g., field ID exceeds dictionary, or field value bytes are truncated). + /// For untrusted data where error reporting is needed, use field_name(i) + /// and field_value(i) directly which return Result. + std::optional get(std::string_view name) const; + + /// \brief Check if a field exists by name. + bool contains(std::string_view name) const; + + /// \brief Location of a field's value bytes within the object buffer. + struct FieldLocation { + int64_t offset; ///< Byte offset from object start to field value + int64_t size; ///< Byte size of the field value + }; + + /// \brief Locate a field's raw bytes by name without constructing a view. + /// + /// Useful for zero-copy extraction when you need the offset+size for + /// raw byte operations (e.g., shredding's UnsafeAppendEncoded). + /// + /// \param[in] name The field name to search for + /// \return The field location, or std::nullopt if not present. + std::optional locate(std::string_view name) const; + + /// @name Iteration support + /// @{ + + /// \brief Iterator over (name, value) pairs. + /// + /// \warning Iteration uses ValueOrDie() internally. If the object contains + /// malformed field data (corrupt value bytes), dereferencing the + /// iterator will abort. Use field_name(i)/field_value(i) for + /// error-safe access on untrusted data. + class Iterator { + public: + using value_type = std::pair; + + Iterator(const VariantObjectView* obj, int32_t index); + + value_type operator*() const; + Iterator& operator++(); + bool operator!=(const Iterator& other) const; + + private: + const VariantObjectView* obj_; + int32_t index_; + }; + + Iterator begin() const; + Iterator end() const; + /// @} + + private: + VariantObjectView(const VariantMetadata* metadata, const uint8_t* data, int64_t length, + int32_t num_fields, int8_t field_id_size, int8_t field_offset_size, + int64_t id_start, int64_t offset_start, int64_t data_start); + + /// \brief Read the field ID at position i. + uint32_t field_id_at(int32_t i) const; + + /// \brief Read the value offset at position i. + int64_t value_offset_at(int32_t i) const; + + const VariantMetadata* metadata_; + const uint8_t* data_; + int64_t length_; + int32_t num_fields_; + int8_t field_id_size_; + int8_t field_offset_size_; + int64_t id_start_; + int64_t offset_start_; + int64_t data_start_; +}; + +// --------------------------------------------------------------------------- +// VariantArrayView — pre-parsed array for O(1) element access +// --------------------------------------------------------------------------- + +/// \brief A non-owning view over a variant array value with pre-parsed header. +/// +/// Construction validates the array header structure. After construction, +/// element access is O(1) via the pre-computed offset table. +/// +/// Stack-allocated, ~56 bytes. No heap allocation. +class ARROW_EXPORT VariantArrayView { + public: + /// \brief Construct an array view, pre-parsing the header. + /// + /// Validates: + /// 1. Basic type is kArray + /// 2. num_elements is readable + /// 3. Offset array fits in buffer + /// 4. Offsets are monotonically non-decreasing + /// 5. Last offset is within buffer + static Result Make(const VariantMetadata& metadata, + const uint8_t* data, int64_t length); + + /// \brief Number of elements in this array. + int32_t num_elements() const { return num_elements_; } + + /// \brief Get the i-th element (0-indexed, O(1) access). + /// \return A VariantView over the element, or Invalid if index is out of range. + Result get(int32_t index) const; + + /// @name Iteration support + /// @{ + + /// \warning Iteration uses ValueOrDie() internally. If the array contains + /// malformed element data, dereferencing the iterator will abort. + /// Use get(i) for error-safe access on untrusted data. + class Iterator { + public: + using value_type = VariantView; + + Iterator(const VariantArrayView* arr, int32_t index); + + value_type operator*() const; + Iterator& operator++(); + bool operator!=(const Iterator& other) const; + + private: + const VariantArrayView* arr_; + int32_t index_; + }; + + Iterator begin() const; + Iterator end() const; + /// @} + + private: + VariantArrayView(const VariantMetadata* metadata, const uint8_t* data, int64_t length, + int32_t num_elements, int8_t offset_size, int64_t offset_start, + int64_t data_start); + + /// \brief Read the element offset at position i. + int64_t element_offset_at(int32_t i) const; + + const VariantMetadata* metadata_; + const uint8_t* data_; + int64_t length_; + int32_t num_elements_; + int8_t offset_size_; + int64_t offset_start_; + int64_t data_start_; +}; + +// --------------------------------------------------------------------------- +// Visitor interface (for full tree traversal) +// --------------------------------------------------------------------------- + +/// \brief Visitor interface for variant value traversal (SAX-style). +/// +/// Implement this interface to receive callbacks during recursive variant +/// value traversal. Use VariantView::Visit() to drive the traversal. +/// +/// For point-queries (reading a specific field), use the view classes +/// directly instead. The visitor is for bulk operations that need to +/// process every node in the tree. +/// +/// \note String values are raw bytes without UTF-8 validation. +class ARROW_EXPORT VariantVisitor { + public: + virtual ~VariantVisitor() = default; + + /// @name Primitive value callbacks + /// @{ + virtual Status Null() = 0; + virtual Status Bool(bool value) = 0; + virtual Status Int8(int8_t value) = 0; + virtual Status Int16(int16_t value) = 0; + virtual Status Int32(int32_t value) = 0; + virtual Status Int64(int64_t value) = 0; + virtual Status Float(float value) = 0; + virtual Status Double(double value) = 0; + virtual Status Decimal4(const uint8_t* bytes, int32_t scale) = 0; + virtual Status Decimal8(const uint8_t* bytes, int32_t scale) = 0; + virtual Status Decimal16(const uint8_t* bytes, int32_t scale) = 0; + virtual Status Date(int32_t days_since_epoch) = 0; + virtual Status TimestampMicros(int64_t micros_since_epoch) = 0; + virtual Status TimestampMicrosNTZ(int64_t micros_since_epoch) = 0; + virtual Status String(std::string_view value) = 0; + virtual Status Binary(std::string_view value) = 0; + virtual Status TimeNTZ(int64_t micros_since_midnight) = 0; + virtual Status TimestampNanos(int64_t nanos_since_epoch) = 0; + virtual Status TimestampNanosNTZ(int64_t nanos_since_epoch) = 0; + virtual Status UUID(const uint8_t* bytes) = 0; + /// @} + + /// @name Container callbacks + /// @{ + virtual Status StartObject(int32_t num_fields) = 0; + virtual Status FieldName(std::string_view name) = 0; + virtual Status EndObject() = 0; + virtual Status StartArray(int32_t num_elements) = 0; + virtual Status EndArray() = 0; + /// @} +}; + +} // namespace arrow::extension::variant diff --git a/cpp/src/arrow/extension/variant_internal_test_util.h b/cpp/src/arrow/extension/variant_internal_test_util.h new file mode 100644 index 000000000000..c93826324a57 --- /dev/null +++ b/cpp/src/arrow/extension/variant_internal_test_util.h @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +// This file is for tests only and is not installed as a public header. + +#include +#include + +#include "arrow/extension/variant.h" + +namespace arrow::extension::variant { + +/// \brief A visitor that records all callbacks as a vector of strings +/// for easy assertion in tests. +class RecordingVisitor : public VariantVisitor { + public: + std::vector events; + + Status Null() override { + events.push_back("Null"); + return Status::OK(); + } + Status Bool(bool value) override { + events.push_back(std::string("Bool(") + (value ? "true" : "false") + ")"); + return Status::OK(); + } + Status Int8(int8_t value) override { + events.push_back("Int8(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Int16(int16_t value) override { + events.push_back("Int16(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Int32(int32_t value) override { + events.push_back("Int32(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Int64(int64_t value) override { + events.push_back("Int64(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Float(float value) override { + events.push_back("Float(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Double(double value) override { + events.push_back("Double(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Decimal4(const uint8_t* /*bytes*/, int32_t scale) override { + events.push_back("Decimal4(scale=" + std::to_string(scale) + ")"); + return Status::OK(); + } + Status Decimal8(const uint8_t* /*bytes*/, int32_t scale) override { + events.push_back("Decimal8(scale=" + std::to_string(scale) + ")"); + return Status::OK(); + } + Status Decimal16(const uint8_t* /*bytes*/, int32_t scale) override { + events.push_back("Decimal16(scale=" + std::to_string(scale) + ")"); + return Status::OK(); + } + Status Date(int32_t days) override { + events.push_back("Date(" + std::to_string(days) + ")"); + return Status::OK(); + } + Status TimestampMicros(int64_t micros) override { + events.push_back("TimestampMicros(" + std::to_string(micros) + ")"); + return Status::OK(); + } + Status TimestampMicrosNTZ(int64_t micros) override { + events.push_back("TimestampMicrosNTZ(" + std::to_string(micros) + ")"); + return Status::OK(); + } + Status String(std::string_view value) override { + events.push_back("String(\"" + std::string(value) + "\")"); + return Status::OK(); + } + Status Binary(std::string_view value) override { + events.push_back("Binary(len=" + std::to_string(value.size()) + ")"); + return Status::OK(); + } + Status TimeNTZ(int64_t micros) override { + events.push_back("TimeNTZ(" + std::to_string(micros) + ")"); + return Status::OK(); + } + Status TimestampNanos(int64_t nanos) override { + events.push_back("TimestampNanos(" + std::to_string(nanos) + ")"); + return Status::OK(); + } + Status TimestampNanosNTZ(int64_t nanos) override { + events.push_back("TimestampNanosNTZ(" + std::to_string(nanos) + ")"); + return Status::OK(); + } + Status UUID(const uint8_t* /*bytes*/) override { + events.push_back("UUID"); + return Status::OK(); + } + Status StartObject(int32_t num_fields) override { + events.push_back("StartObject(" + std::to_string(num_fields) + ")"); + return Status::OK(); + } + Status FieldName(std::string_view name) override { + events.push_back("FieldName(\"" + std::string(name) + "\")"); + return Status::OK(); + } + Status EndObject() override { + events.push_back("EndObject"); + return Status::OK(); + } + Status StartArray(int32_t num_elements) override { + events.push_back("StartArray(" + std::to_string(num_elements) + ")"); + return Status::OK(); + } + Status EndArray() override { + events.push_back("EndArray"); + return Status::OK(); + } +}; + +} // namespace arrow::extension::variant diff --git a/cpp/src/arrow/extension/variant_internal_util.h b/cpp/src/arrow/extension/variant_internal_util.h new file mode 100644 index 000000000000..2e517361fcd5 --- /dev/null +++ b/cpp/src/arrow/extension/variant_internal_util.h @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +/// \file variant_internal_util.h +/// \brief Internal utilities shared by variant implementation files. +/// +/// NOT part of the public API — not installed. Used only by variant.cc +/// and variant_shredding.cc to avoid duplicating low-level helpers. + +#include +#include + +#include "arrow/util/endian.h" + +namespace arrow::extension::variant::internal { + +/// \brief Read an unsigned integer of 1-4 bytes in little-endian order. +/// +/// Reads exactly \p num_bytes from \p data, interprets them as an unsigned +/// little-endian integer, and returns the result as uint32_t. +/// +/// \pre num_bytes must be in [1, 4] +/// \pre data must point to at least num_bytes readable bytes +inline uint32_t ReadUnsignedLE(const uint8_t* data, int32_t num_bytes) { + uint32_t result = 0; + std::memcpy(&result, data, num_bytes); + result = ::arrow::bit_util::FromLittleEndian(result); + if (num_bytes < 4) { + result &= (static_cast(1) << (num_bytes * 8)) - 1; + } + return result; +} + +/// \brief Read an unsigned integer of 1-8 bytes in little-endian order. +/// +/// Extended version that supports reading up to 8-byte values (for int64 +/// extraction in shredding paths). Returns int64_t for compatibility with +/// Arrow's signed-offset conventions. +/// +/// \pre num_bytes must be in [1, 8] +/// \pre data must point to at least num_bytes readable bytes +inline int64_t ReadUnsignedLE64(const uint8_t* data, int32_t num_bytes) { + if (num_bytes <= 4) { + return static_cast(ReadUnsignedLE(data, num_bytes)); + } + uint64_t result = 0; + std::memcpy(&result, data, num_bytes); + result = ::arrow::bit_util::FromLittleEndian(result); + if (num_bytes < 8) { + result &= (static_cast(1) << (num_bytes * 8)) - 1; + } + return static_cast(result); +} + +} // namespace arrow::extension::variant::internal diff --git a/cpp/src/arrow/extension/variant_test.cc b/cpp/src/arrow/extension/variant_test.cc new file mode 100644 index 000000000000..4a3c1187c45d --- /dev/null +++ b/cpp/src/arrow/extension/variant_test.cc @@ -0,0 +1,2412 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant.h" +#include "arrow/extension/variant_internal_test_util.h" + +#include +#include +#include +#include +#include + +#include "arrow/testing/gtest_util.h" + +namespace arrow::extension::variant { + +// =========================================================================== +// Test helpers +// =========================================================================== + +namespace { + +/// \brief Decode and visit a variant value (convenience wrapper for tests). +/// Replaces the old DecodeVariantValue free function. +Status DecodeAndVisit(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, VariantVisitor* visitor) { + ARROW_ASSIGN_OR_RAISE(auto view, VariantView::Make(metadata, data, length)); + return view.Visit(visitor); +} + +/// \brief Get the basic type of a value (convenience for tests). +Result GetValueBasicType(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("buffer is null or empty"); + } + return GetBasicType(data[0]); +} + +/// \brief Get object field count (convenience for tests). +Result GetObjectFieldCount(const uint8_t* data, int64_t length) { + VariantMetadata empty_meta; + ARROW_ASSIGN_OR_RAISE(auto obj, VariantObjectView::Make(empty_meta, data, length)); + return obj.num_fields(); +} + +/// \brief Get array element count (convenience for tests). +Result GetArrayElementCount(const uint8_t* data, int64_t length) { + VariantMetadata empty_meta; + ARROW_ASSIGN_OR_RAISE(auto arr, VariantArrayView::Make(empty_meta, data, length)); + return arr.num_elements(); +} + +/// \brief Find object field by name (convenience for tests). +Status FindObjectField(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, std::string_view field_name, int64_t* field_offset, + int64_t* field_size) { + *field_offset = -1; + *field_size = 0; + ARROW_ASSIGN_OR_RAISE(auto obj, VariantObjectView::Make(metadata, data, length)); + auto result = obj.get(field_name); + if (result.has_value()) { + *field_offset = result->data() - data; + *field_size = result->size_bytes(); + } + return Status::OK(); +} + +/// \brief Get array element by index (convenience for tests). +Status GetArrayElement(const uint8_t* data, int64_t length, int32_t index, + int64_t* element_offset, int64_t* element_size) { + VariantMetadata empty_meta; + ARROW_ASSIGN_OR_RAISE(auto arr, VariantArrayView::Make(empty_meta, data, length)); + ARROW_ASSIGN_OR_RAISE(auto elem, arr.get(index)); + *element_offset = elem.data() - data; + *element_size = elem.size_bytes(); + return Status::OK(); +} + +/// \brief Get object field at index (convenience for tests). +Status GetObjectFieldAt(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, int32_t index, std::string_view* field_name, + int64_t* field_offset, int64_t* field_size) { + ARROW_ASSIGN_OR_RAISE(auto obj, VariantObjectView::Make(metadata, data, length)); + ARROW_ASSIGN_OR_RAISE(*field_name, obj.field_name(index)); + ARROW_ASSIGN_OR_RAISE(auto value, obj.field_value(index)); + *field_offset = value.data() - data; + *field_size = value.size_bytes(); + return Status::OK(); +} + +/// \brief Build a metadata buffer from a list of strings. +/// +/// Uses offset_size=1, version=1, sorted flag as specified. +std::vector BuildMetadataBuffer(const std::vector& strings, + bool sorted = false, int32_t offset_size = 1) { + std::vector buffer; + + // Header byte: version=1, sorted flag, offset_size + uint8_t header = kVariantVersion; + if (sorted) { + header |= (1 << 4); + } + header |= static_cast((offset_size - 1) << 6); + buffer.push_back(header); + + // Dictionary size + auto dict_size = static_cast(strings.size()); + for (int32_t b = 0; b < offset_size; ++b) { + buffer.push_back(static_cast((dict_size >> (b * 8)) & 0xFF)); + } + + // Compute string offsets + std::vector offsets(dict_size + 1); + offsets[0] = 0; + for (uint32_t i = 0; i < dict_size; ++i) { + offsets[i + 1] = offsets[i] + static_cast(strings[i].size()); + } + + // Write offsets + for (uint32_t i = 0; i <= dict_size; ++i) { + for (int32_t b = 0; b < offset_size; ++b) { + buffer.push_back(static_cast((offsets[i] >> (b * 8)) & 0xFF)); + } + } + + // Write string data + for (const auto& s : strings) { + buffer.insert(buffer.end(), s.begin(), s.end()); + } + + return buffer; +} + +/// \brief Build a primitive value header byte. +uint8_t PrimitiveHeader(PrimitiveType type) { + return static_cast(BasicType::kPrimitive) | (static_cast(type) << 2); +} + +/// \brief Build a short string value buffer. +std::vector BuildShortString(const std::string& s) { + std::vector buffer; + auto len = static_cast(s.size()); + uint8_t header = static_cast(BasicType::kShortString) | (len << 2); + buffer.push_back(header); + buffer.insert(buffer.end(), s.begin(), s.end()); + return buffer; +} + +/// \brief Build an object value buffer. +/// +/// \param field_ids Dictionary indices for each field name +/// \param field_values Serialized variant values for each field +/// \param field_id_size Bytes per field ID (1-4) +/// \param field_offset_size Bytes per offset (1-4) +std::vector BuildObject(const std::vector& field_ids, + const std::vector>& field_values, + int32_t field_id_size = 1, + int32_t field_offset_size = 1) { + auto num_fields = static_cast(field_ids.size()); + bool is_large = (num_fields > 255); + + std::vector buffer; + + // Header per spec: basic_type=2 in bits 0-1, + // bits 2-3: field_offset_size-1 + // bits 4-5: field_id_size-1 + // bit 6: is_large + uint8_t header = static_cast(BasicType::kObject); + header |= static_cast((field_offset_size - 1) << 2); + header |= static_cast((field_id_size - 1) << 4); + if (is_large) { + header |= (1 << 6); + } + buffer.push_back(header); + + // num_fields: 1 byte or 4 bytes depending on is_large + int32_t num_fields_size = is_large ? 4 : 1; + for (int32_t b = 0; b < num_fields_size; ++b) { + buffer.push_back(static_cast((num_fields >> (b * 8)) & 0xFF)); + } + + // field_ids + for (auto fid : field_ids) { + for (int32_t b = 0; b < field_id_size; ++b) { + buffer.push_back(static_cast((fid >> (b * 8)) & 0xFF)); + } + } + + // Compute offsets + std::vector offsets(num_fields + 1); + offsets[0] = 0; + for (uint32_t i = 0; i < num_fields; ++i) { + offsets[i + 1] = offsets[i] + static_cast(field_values[i].size()); + } + + // Write offsets + for (uint32_t i = 0; i <= num_fields; ++i) { + for (int32_t b = 0; b < field_offset_size; ++b) { + buffer.push_back(static_cast((offsets[i] >> (b * 8)) & 0xFF)); + } + } + + // Write field value data + for (const auto& fv : field_values) { + buffer.insert(buffer.end(), fv.begin(), fv.end()); + } + + return buffer; +} + +/// \brief Build an array value buffer. +/// +/// \param elements Serialized variant values for each element +/// \param field_offset_size Bytes per offset (1-4) +std::vector BuildArray(const std::vector>& elements, + int32_t field_offset_size = 1) { + auto num_elements = static_cast(elements.size()); + bool is_large = (num_elements > 255); + + std::vector buffer; + + // Header per spec: basic_type=3 in bits 0-1, + // bits 2-3: field_offset_size-1 + // bit 4: is_large + uint8_t header = static_cast(BasicType::kArray); + header |= static_cast((field_offset_size - 1) << 2); + if (is_large) { + header |= (1 << 4); + } + buffer.push_back(header); + + // num_elements: 1 byte or 4 bytes depending on is_large + int32_t num_elements_size = is_large ? 4 : 1; + for (int32_t b = 0; b < num_elements_size; ++b) { + buffer.push_back(static_cast((num_elements >> (b * 8)) & 0xFF)); + } + + // Compute offsets + std::vector offsets(num_elements + 1); + offsets[0] = 0; + for (uint32_t i = 0; i < num_elements; ++i) { + offsets[i + 1] = offsets[i] + static_cast(elements[i].size()); + } + + // Write offsets + for (uint32_t i = 0; i <= num_elements; ++i) { + for (int32_t b = 0; b < field_offset_size; ++b) { + buffer.push_back(static_cast((offsets[i] >> (b * 8)) & 0xFF)); + } + } + + // Write element data + for (const auto& elem : elements) { + buffer.insert(buffer.end(), elem.begin(), elem.end()); + } + + return buffer; +} + +} // namespace + +// =========================================================================== +// Metadata decoding tests +// =========================================================================== + +class VariantMetadataTest : public ::testing::Test {}; + +TEST_F(VariantMetadataTest, EmptyDictionary) { + auto buf = BuildMetadataBuffer({}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.version, 1); + ASSERT_FALSE(metadata.is_sorted); + ASSERT_EQ(metadata.offset_size, 1); + ASSERT_EQ(metadata.strings.size(), 0); +} + +TEST_F(VariantMetadataTest, SingleString) { + auto buf = BuildMetadataBuffer({"hello"}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.strings.size(), 1); + ASSERT_EQ(metadata.strings[0], "hello"); +} + +TEST_F(VariantMetadataTest, MultipleStrings) { + auto buf = BuildMetadataBuffer({"name", "age", "scores"}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.strings.size(), 3); + ASSERT_EQ(metadata.strings[0], "name"); + ASSERT_EQ(metadata.strings[1], "age"); + ASSERT_EQ(metadata.strings[2], "scores"); +} + +TEST_F(VariantMetadataTest, SortedFlag) { + auto buf = BuildMetadataBuffer({"age", "name", "score"}, true); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_TRUE(metadata.is_sorted); +} + +TEST_F(VariantMetadataTest, OffsetSize2) { + auto buf = BuildMetadataBuffer({"key1", "key2"}, false, 2); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.offset_size, 2); + ASSERT_EQ(metadata.strings.size(), 2); + ASSERT_EQ(metadata.strings[0], "key1"); + ASSERT_EQ(metadata.strings[1], "key2"); +} + +TEST_F(VariantMetadataTest, OffsetSize4) { + auto buf = BuildMetadataBuffer({"a", "bb", "ccc"}, false, 4); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.offset_size, 4); + ASSERT_EQ(metadata.strings.size(), 3); + ASSERT_EQ(metadata.strings[0], "a"); + ASSERT_EQ(metadata.strings[1], "bb"); + ASSERT_EQ(metadata.strings[2], "ccc"); +} + +TEST_F(VariantMetadataTest, EmptyStrings) { + auto buf = BuildMetadataBuffer({"", "nonempty", ""}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.strings.size(), 3); + ASSERT_EQ(metadata.strings[0], ""); + ASSERT_EQ(metadata.strings[1], "nonempty"); + ASSERT_EQ(metadata.strings[2], ""); +} + +// Error cases + +TEST_F(VariantMetadataTest, NullBuffer) { + ASSERT_RAISES(Invalid, DecodeMetadata(nullptr, 0)); +} + +TEST_F(VariantMetadataTest, EmptyBuffer) { + uint8_t data = 0; + ASSERT_RAISES(Invalid, DecodeMetadata(&data, 0)); +} + +TEST_F(VariantMetadataTest, UnsupportedVersion) { + // Version 2 (unsupported) + uint8_t data[] = {0x02, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, TruncatedDictionarySize) { + // Header says offset_size=2 (bits 6-7 = 01), but only 1 byte follows + uint8_t data[] = {0x41, 0x00}; // version=1, offset_size=2 + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, TruncatedStringOffsets) { + // Claims dict_size=5 but buffer is too short for offsets + uint8_t data[] = {0x01, 0x05, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, OffsetSize3) { + auto buf = BuildMetadataBuffer({"foo", "bar"}, false, 3); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.offset_size, 3); + ASSERT_EQ(metadata.strings.size(), 2); + ASSERT_EQ(metadata.strings[0], "foo"); + ASSERT_EQ(metadata.strings[1], "bar"); +} + +TEST_F(VariantMetadataTest, ReservedBit5Set) { + // Header with bit 5 set: 0x21 = version=1, bit5=1 + uint8_t data[] = {0x21, 0x00, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, NonMonotonicStringOffsets) { + // Manually construct metadata where string offsets are NOT monotonically + // non-decreasing. ValidateOffsets should reject this. + // Header: version=1, offset_size=1 + // dict_size=2, offsets=[0, 5, 3] — 3 < 5, non-monotonic + // String data: "helloabc" (8 bytes, but offsets claim 3 as last) + uint8_t data[] = {0x01, // header: version=1, offset_size=1 + 0x02, // dict_size = 2 + 0x00, 0x05, 0x03, // offsets: [0, 5, 3] — non-monotonic + 'h', 'e', 'l', 'l', 'o', 'a', 'b', 'c'}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +// =========================================================================== +// Primitive value decoding tests +// =========================================================================== + +class VariantPrimitiveTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantPrimitiveTest, DecodeNull) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Null"); +} + +TEST_F(VariantPrimitiveTest, DecodeTrue) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kTrue)}; + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Bool(true)"); +} + +TEST_F(VariantPrimitiveTest, DecodeFalse) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kFalse)}; + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Bool(false)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt8) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0x2A}; + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(42)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt8Negative) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0xD6}; + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(-42)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt16) { + // 300 = 0x012C in little-endian: 0x2C, 0x01 + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt16), 0x2C, 0x01}; + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int16(300)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt32) { + // 100000 = 0x000186A0 in LE: A0 86 01 00 + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0xA0, 0x86, 0x01, 0x00}; + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(100000)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt32Max) { + int32_t val = std::numeric_limits::max(); + uint8_t data[5]; + data[0] = PrimitiveHeader(PrimitiveType::kInt32); + std::memcpy(data + 1, &val, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(" + std::to_string(val) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt64) { + int64_t val = 1234567890123LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kInt64); + std::memcpy(data + 1, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int64(" + std::to_string(val) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeFloat) { + float val = 3.14f; + uint8_t data[5]; + data[0] = PrimitiveHeader(PrimitiveType::kFloat); + std::memcpy(data + 1, &val, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + // Float string representation may vary; just check it starts with Float( + ASSERT_TRUE(visitor.events[0].find("Float(") == 0); +} + +TEST_F(VariantPrimitiveTest, DecodeDouble) { + double val = 2.718281828459045; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kDouble); + std::memcpy(data + 1, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_TRUE(visitor.events[0].find("Double(") == 0); +} + +TEST_F(VariantPrimitiveTest, DecodeDate) { + // Days since epoch: 19000 (approximately 2022-01-01) + int32_t days = 19000; + uint8_t data[5]; + data[0] = PrimitiveHeader(PrimitiveType::kDate); + std::memcpy(data + 1, &days, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Date(19000)"); +} + +TEST_F(VariantPrimitiveTest, DecodeTimestampMicros) { + int64_t micros = 1654041600000000LL; // some timestamp + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampMicros); + std::memcpy(data + 1, µs, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampMicros(" + std::to_string(micros) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeTimestampMicrosNTZ) { + int64_t micros = 1654041600000000LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampMicrosNTZ); + std::memcpy(data + 1, µs, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampMicrosNTZ(" + std::to_string(micros) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal4) { + // Spec layout: 1 byte scale, then 4 bytes LE unscaled value + uint8_t data[6]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal4); + data[1] = 2; // scale = 2 + int32_t val = 12345; + std::memcpy(data + 2, &val, 4); // unscaled value + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal4(scale=2)"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal4MaxScale) { + // Scale at maximum per spec: 38 + uint8_t data[6]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal4); + data[1] = 38; // scale = 38 (maximum per spec) + int32_t val = 12345; + std::memcpy(data + 2, &val, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal4(scale=38)"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal8) { + // Spec layout: 1 byte scale, then 8 bytes LE unscaled value + uint8_t data[10]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal8); + data[1] = 5; // scale = 5 + int64_t val = 123456789012345LL; + std::memcpy(data + 2, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal8(scale=5)"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal16) { + // Spec layout: 1 byte scale, then 16 bytes LE unscaled value + uint8_t data[18]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal16); + data[1] = 10; // scale = 10 + std::memset(data + 2, 0, 16); + data[2] = 0x01; // low byte = 1 + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal16(scale=10)"); +} + +TEST_F(VariantPrimitiveTest, DecodeLongString) { + // Long string: primitive type kString with 4-byte length prefix + std::string test_str = "hello world, this is a long string"; + auto str_len = static_cast(test_str.size()); + + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kString)); + // 4-byte little-endian length + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((str_len >> (b * 8)) & 0xFF)); + } + data.insert(data.end(), test_str.begin(), test_str.end()); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"hello world, this is a long string\")"); +} + +TEST_F(VariantPrimitiveTest, DecodeBinary) { + std::vector bin_bytes = {0x00, 0x01, 0x02, 0x03}; + auto bin_len = static_cast(bin_bytes.size()); + + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kBinary)); + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((bin_len >> (b * 8)) & 0xFF)); + } + data.insert(data.end(), bin_bytes.begin(), bin_bytes.end()); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "Binary(len=4)"); +} + +// Truncation errors + +TEST_F(VariantPrimitiveTest, TruncatedInt32) { + // Only 2 bytes after header, but Int32 needs 4 + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0x00, 0x00}; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); +} + +TEST_F(VariantPrimitiveTest, EmptyValueBuffer) { + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeAndVisit(empty_metadata_, nullptr, 0, &visitor)); +} + +// =========================================================================== +// Short string tests +// =========================================================================== + +class VariantShortStringTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantShortStringTest, EmptyShortString) { + auto data = BuildShortString(""); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"\")"); +} + +TEST_F(VariantShortStringTest, SimpleShortString) { + auto data = BuildShortString("hi"); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"hi\")"); +} + +TEST_F(VariantShortStringTest, MaxLengthShortString) { + // Maximum short string is 63 bytes + std::string max_str(63, 'x'); + auto data = BuildShortString(max_str); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"" + max_str + "\")"); +} + +TEST_F(VariantShortStringTest, TruncatedShortString) { + // Header says length=10 but buffer only has 3 bytes total + uint8_t data[] = {static_cast(BasicType::kShortString) | (10 << 2), 'a', 'b'}; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); +} + +// =========================================================================== +// Object decoding tests +// =========================================================================== + +class VariantObjectTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"name", "age", "scores"}; + } +}; + +TEST_F(VariantObjectTest, EmptyObject) { + auto data = BuildObject({}, {}); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 2); + ASSERT_EQ(visitor.events[0], "StartObject(0)"); + ASSERT_EQ(visitor.events[1], "EndObject"); +} + +TEST_F(VariantObjectTest, SingleField) { + // Object with one field: name -> "Alice" (short string) + auto value = BuildShortString("Alice"); + auto data = BuildObject({0}, {value}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 4); + ASSERT_EQ(visitor.events[0], "StartObject(1)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Alice\")"); + ASSERT_EQ(visitor.events[3], "EndObject"); +} + +TEST_F(VariantObjectTest, MultipleFields) { + // Object: {name: "Bob", age: 30} + auto name_val = BuildShortString("Bob"); + // age: Int32(30) + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + + auto data = BuildObject({0, 1}, {name_val, age_val}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 6); + ASSERT_EQ(visitor.events[0], "StartObject(2)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Bob\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"age\")"); + ASSERT_EQ(visitor.events[4], "Int32(30)"); + ASSERT_EQ(visitor.events[5], "EndObject"); +} + +TEST_F(VariantObjectTest, InvalidFieldId) { + // field_id=99 exceeds dictionary size of 3 + auto value = BuildShortString("oops"); + auto data = BuildObject({99}, {value}); + + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeAndVisit(metadata_, data.data(), + static_cast(data.size()), &visitor)); +} + +TEST_F(VariantObjectTest, ThreeByteOffsetSize) { + // Exercises value decoding with 3-byte field_offset_size and field_id_size. + // Object with 2 fields: {name: "test", age: 42} + auto name_val = BuildShortString("test"); + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + auto data = BuildObject({0, 1}, {name_val, age_val}, + /*field_id_size=*/3, /*field_offset_size=*/3); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 6); + ASSERT_EQ(visitor.events[0], "StartObject(2)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"test\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"age\")"); + ASSERT_EQ(visitor.events[4], "Int32(42)"); + ASSERT_EQ(visitor.events[5], "EndObject"); +} + +// =========================================================================== +// Array decoding tests +// =========================================================================== + +class VariantArrayTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantArrayTest, EmptyArray) { + auto data = BuildArray({}); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events.size(), 2); + ASSERT_EQ(visitor.events[0], "StartArray(0)"); + ASSERT_EQ(visitor.events[1], "EndArray"); +} + +TEST_F(VariantArrayTest, SingleElement) { + std::vector elem = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + auto data = BuildArray({elem}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events.size(), 3); + ASSERT_EQ(visitor.events[0], "StartArray(1)"); + ASSERT_EQ(visitor.events[1], "Int32(42)"); + ASSERT_EQ(visitor.events[2], "EndArray"); +} + +TEST_F(VariantArrayTest, HeterogeneousElements) { + // Array with mixed types: [42, "hello", true] + std::vector int_elem = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + auto str_elem = BuildShortString("hello"); + std::vector bool_elem = {PrimitiveHeader(PrimitiveType::kTrue)}; + + auto data = BuildArray({int_elem, str_elem, bool_elem}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events.size(), 5); + ASSERT_EQ(visitor.events[0], "StartArray(3)"); + ASSERT_EQ(visitor.events[1], "Int32(42)"); + ASSERT_EQ(visitor.events[2], "String(\"hello\")"); + ASSERT_EQ(visitor.events[3], "Bool(true)"); + ASSERT_EQ(visitor.events[4], "EndArray"); +} + +TEST_F(VariantArrayTest, LargeArrayIsLargeFlag) { + // Build an array with 256 elements to exercise is_large=true (4-byte + // num_elements). Each element is a Null primitive (1 byte each). + // Use field_offset_size=2 since total data (256 bytes) exceeds 1-byte max. + std::vector> elements; + elements.reserve(256); + for (int i = 0; i < 256; ++i) { + elements.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildArray(elements, /*field_offset_size=*/2); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + // StartArray(256) + 256 Nulls + EndArray = 258 events + ASSERT_EQ(visitor.events.size(), 258); + ASSERT_EQ(visitor.events[0], "StartArray(256)"); + ASSERT_EQ(visitor.events[1], "Null"); + ASSERT_EQ(visitor.events[256], "Null"); + ASSERT_EQ(visitor.events[257], "EndArray"); +} + +// =========================================================================== +// Nested structure tests +// =========================================================================== + +class VariantNestedTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"name", "scores", "inner"}; + } +}; + +TEST_F(VariantNestedTest, ObjectWithNestedArray) { + // {name: "Alice", scores: [95, 87]} + auto name_val = BuildShortString("Alice"); + + // scores array: [Int32(95), Int32(87)] + std::vector score1 = {PrimitiveHeader(PrimitiveType::kInt32), 95, 0, 0, 0}; + std::vector score2 = {PrimitiveHeader(PrimitiveType::kInt32), 87, 0, 0, 0}; + auto scores_val = BuildArray({score1, score2}); + + auto data = BuildObject({0, 1}, {name_val, scores_val}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, data.data(), static_cast(data.size()), + &visitor)); + + // Expected events: + // StartObject(2), FieldName("name"), String("Alice"), + // FieldName("scores"), StartArray(2), Int32(95), Int32(87), EndArray, + // EndObject + ASSERT_EQ(visitor.events.size(), 9); + ASSERT_EQ(visitor.events[0], "StartObject(2)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Alice\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"scores\")"); + ASSERT_EQ(visitor.events[4], "StartArray(2)"); + ASSERT_EQ(visitor.events[5], "Int32(95)"); + ASSERT_EQ(visitor.events[6], "Int32(87)"); + ASSERT_EQ(visitor.events[7], "EndArray"); + ASSERT_EQ(visitor.events[8], "EndObject"); +} + +TEST_F(VariantNestedTest, NestedObjects) { + // {inner: {name: "deep"}} + auto deep_name = BuildShortString("deep"); + auto inner_obj = BuildObject({0}, {deep_name}); + auto data = BuildObject({2}, {inner_obj}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, data.data(), static_cast(data.size()), + &visitor)); + + ASSERT_EQ(visitor.events.size(), 7); + ASSERT_EQ(visitor.events[0], "StartObject(1)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"inner\")"); + ASSERT_EQ(visitor.events[2], "StartObject(1)"); + ASSERT_EQ(visitor.events[3], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[4], "String(\"deep\")"); + ASSERT_EQ(visitor.events[5], "EndObject"); + ASSERT_EQ(visitor.events[6], "EndObject"); +} + +TEST_F(VariantNestedTest, ArrayOfObjects) { + // [{name: "a"}, {name: "b"}] + auto val_a = BuildShortString("a"); + auto obj_a = BuildObject({0}, {val_a}); + + auto val_b = BuildShortString("b"); + auto obj_b = BuildObject({0}, {val_b}); + + auto data = BuildArray({obj_a, obj_b}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, data.data(), static_cast(data.size()), + &visitor)); + + ASSERT_EQ(visitor.events.size(), 10); + ASSERT_EQ(visitor.events[0], "StartArray(2)"); + ASSERT_EQ(visitor.events[1], "StartObject(1)"); + ASSERT_EQ(visitor.events[2], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[3], "String(\"a\")"); + ASSERT_EQ(visitor.events[4], "EndObject"); + ASSERT_EQ(visitor.events[5], "StartObject(1)"); + ASSERT_EQ(visitor.events[6], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[7], "String(\"b\")"); + ASSERT_EQ(visitor.events[8], "EndObject"); + ASSERT_EQ(visitor.events[9], "EndArray"); +} + +// =========================================================================== +// Recursion depth limit test +// =========================================================================== + +class VariantDepthTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"x"}; + } +}; + +TEST_F(VariantDepthTest, ExceedsMaxNestingDepth) { + // Build a deeply nested array: [[[[...]]]] + // Each level wraps the inner in a 1-element array with offset_size=2 + // to allow buffers larger than 255 bytes. + std::vector inner = {PrimitiveHeader(PrimitiveType::kNull)}; + + // Wrap 130 times (exceeds kMaxNestingDepth=128) + for (int i = 0; i < 130; ++i) { + inner = BuildArray({inner}, /*field_offset_size=*/2); + } + + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeAndVisit(metadata_, inner.data(), + static_cast(inner.size()), &visitor)); +} + +TEST_F(VariantDepthTest, AtMaxNestingDepthSucceeds) { + // Build 50 levels of nesting — well within kMaxNestingDepth=128 + // and within offset_size=1 limits (each level adds ~4 bytes). + std::vector inner = {PrimitiveHeader(PrimitiveType::kNull)}; + + for (int i = 0; i < 50; ++i) { + inner = BuildArray({inner}); + } + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, inner.data(), static_cast(inner.size()), + &visitor)); +} + +// =========================================================================== +// Utility function tests +// =========================================================================== + +class VariantUtilTest : public ::testing::Test {}; + +TEST_F(VariantUtilTest, GetValueBasicTypePrimitive) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0, 0, 0, 0}; + ASSERT_OK_AND_ASSIGN(auto bt, GetValueBasicType(data, sizeof(data))); + ASSERT_EQ(bt, BasicType::kPrimitive); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeShortString) { + auto data = BuildShortString("test"); + ASSERT_OK_AND_ASSIGN(auto bt, + GetValueBasicType(data.data(), static_cast(data.size()))); + ASSERT_EQ(bt, BasicType::kShortString); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeObject) { + VariantMetadata meta; + meta.version = 1; + meta.strings = {"key"}; + auto val = BuildShortString("val"); + auto data = BuildObject({0}, {val}); + ASSERT_OK_AND_ASSIGN(auto bt, + GetValueBasicType(data.data(), static_cast(data.size()))); + ASSERT_EQ(bt, BasicType::kObject); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeArray) { + auto data = BuildArray({}); + ASSERT_OK_AND_ASSIGN(auto bt, + GetValueBasicType(data.data(), static_cast(data.size()))); + ASSERT_EQ(bt, BasicType::kArray); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeEmptyBuffer) { + ASSERT_RAISES(Invalid, GetValueBasicType(nullptr, 0)); +} + +TEST_F(VariantUtilTest, GetObjectFieldCount) { + VariantMetadata meta; + meta.version = 1; + meta.strings = {"a", "b", "c"}; + auto v1 = BuildShortString("x"); + auto v2 = BuildShortString("y"); + auto data = BuildObject({0, 1}, {v1, v2}); + ASSERT_OK_AND_ASSIGN( + auto count, GetObjectFieldCount(data.data(), static_cast(data.size()))); + ASSERT_EQ(count, 2); +} + +TEST_F(VariantUtilTest, GetArrayElementCount) { + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kTrue)}; + std::vector e3 = {PrimitiveHeader(PrimitiveType::kFalse)}; + auto data = BuildArray({e1, e2, e3}); + ASSERT_OK_AND_ASSIGN( + auto count, GetArrayElementCount(data.data(), static_cast(data.size()))); + ASSERT_EQ(count, 3); +} + +TEST_F(VariantUtilTest, PrimitiveValueSizes) { + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kNull), 0); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTrue), 0); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kFalse), 0); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt8), 1); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt16), 2); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt32), 4); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt64), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kFloat), 4); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDouble), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDate), 4); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampMicros), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampMicrosNTZ), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimeNTZ), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampNanos), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampNanosNTZ), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kUUID), 16); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDecimal4), 5); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDecimal8), 9); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDecimal16), 17); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kBinary), -1); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kString), -1); +} + +// =========================================================================== +// Integration: Metadata + Value decoding together +// =========================================================================== + +class VariantIntegrationTest : public ::testing::Test {}; + +TEST_F(VariantIntegrationTest, FullRoundTrip) { + // Build a complete variant: {name: "Alice", age: 30, scores: [95, 87]} + auto meta_buf = BuildMetadataBuffer({"name", "age", "scores"}); + + auto name_val = BuildShortString("Alice"); + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + std::vector s1 = {PrimitiveHeader(PrimitiveType::kInt32), 95, 0, 0, 0}; + std::vector s2 = {PrimitiveHeader(PrimitiveType::kInt32), 87, 0, 0, 0}; + auto scores_val = BuildArray({s1, s2}); + + auto value_buf = BuildObject({0, 1, 2}, {name_val, age_val, scores_val}); + + // Decode metadata + ASSERT_OK_AND_ASSIGN( + auto metadata, + DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + ASSERT_EQ(metadata.strings.size(), 3); + + // Decode value + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata, value_buf.data(), + static_cast(value_buf.size()), &visitor)); + + // Verify full event sequence + ASSERT_EQ(visitor.events.size(), 11); + ASSERT_EQ(visitor.events[0], "StartObject(3)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Alice\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"age\")"); + ASSERT_EQ(visitor.events[4], "Int32(30)"); + ASSERT_EQ(visitor.events[5], "FieldName(\"scores\")"); + ASSERT_EQ(visitor.events[6], "StartArray(2)"); + ASSERT_EQ(visitor.events[7], "Int32(95)"); + ASSERT_EQ(visitor.events[8], "Int32(87)"); + ASSERT_EQ(visitor.events[9], "EndArray"); + ASSERT_EQ(visitor.events[10], "EndObject"); +} + +// =========================================================================== +// Visitor early abort test +// =========================================================================== + +/// \brief A visitor that aborts after receiving a specific number of events. +class AbortingVisitor : public VariantVisitor { + public: + int32_t abort_after; + int32_t count = 0; + + explicit AbortingVisitor(int32_t abort_after) : abort_after(abort_after) {} + + Status MaybeAbort() { + ++count; + if (count >= abort_after) { + return Status::Cancelled("Visitor aborted after ", count, " events"); + } + return Status::OK(); + } + + Status Null() override { return MaybeAbort(); } + Status Bool(bool /*value*/) override { return MaybeAbort(); } + Status Int8(int8_t /*value*/) override { return MaybeAbort(); } + Status Int16(int16_t /*value*/) override { return MaybeAbort(); } + Status Int32(int32_t /*value*/) override { return MaybeAbort(); } + Status Int64(int64_t /*value*/) override { return MaybeAbort(); } + Status Float(float /*value*/) override { return MaybeAbort(); } + Status Double(double /*value*/) override { return MaybeAbort(); } + Status Decimal4(const uint8_t* /*bytes*/, int32_t /*s*/) override { + return MaybeAbort(); + } + Status Decimal8(const uint8_t* /*bytes*/, int32_t /*s*/) override { + return MaybeAbort(); + } + Status Decimal16(const uint8_t* /*bytes*/, int32_t /*s*/) override { + return MaybeAbort(); + } + Status Date(int32_t /*days*/) override { return MaybeAbort(); } + Status TimestampMicros(int64_t /*micros*/) override { return MaybeAbort(); } + Status TimestampMicrosNTZ(int64_t /*micros*/) override { return MaybeAbort(); } + Status String(std::string_view /*value*/) override { return MaybeAbort(); } + Status Binary(std::string_view /*value*/) override { return MaybeAbort(); } + Status TimeNTZ(int64_t /*micros*/) override { return MaybeAbort(); } + Status TimestampNanos(int64_t /*nanos*/) override { return MaybeAbort(); } + Status TimestampNanosNTZ(int64_t /*nanos*/) override { return MaybeAbort(); } + Status UUID(const uint8_t* /*bytes*/) override { return MaybeAbort(); } + Status StartObject(int32_t /*num_fields*/) override { return MaybeAbort(); } + Status FieldName(std::string_view /*name*/) override { return MaybeAbort(); } + Status EndObject() override { return MaybeAbort(); } + Status StartArray(int32_t /*num_elements*/) override { return MaybeAbort(); } + Status EndArray() override { return MaybeAbort(); } +}; + +class VariantAbortTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"name", "age"}; + } +}; + +TEST_F(VariantAbortTest, VisitorAbortsEarly) { + // Object: {name: "Alice", age: 30} + auto name_val = BuildShortString("Alice"); + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto data = BuildObject({0, 1}, {name_val, age_val}); + + // Abort after 3 events (StartObject, FieldName, String) + // Should NOT reach the second field + AbortingVisitor visitor(3); + auto status = + DecodeAndVisit(metadata_, data.data(), static_cast(data.size()), &visitor); + ASSERT_TRUE(status.IsCancelled()); + ASSERT_EQ(visitor.count, 3); +} + +TEST_F(VariantAbortTest, VisitorAbortsOnFirstEvent) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + AbortingVisitor visitor(1); + auto status = DecodeAndVisit(metadata_, data, sizeof(data), &visitor); + ASSERT_TRUE(status.IsCancelled()); +} + +// =========================================================================== +// Spec-conformance test with hardcoded byte sequences +// =========================================================================== + +class VariantSpecTest : public ::testing::Test {}; + +TEST_F(VariantSpecTest, HandcraftedNullValue) { + // Variant Encoding Spec: Null is basic_type=0, primitive_type=0 + // Header byte: 0x00 (bits 0-1=00 for primitive, bits 2-7=000000 for null) + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; // v1, 0 strings, offset[0]=0 + uint8_t value_bytes[] = {0x00}; // null + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Null"); +} + +TEST_F(VariantSpecTest, HandcraftedInt32Value) { + // Int32(42): basic_type=0, primitive_type=5 + // Header: (5 << 2) | 0 = 0x14 + // Value: 42 as LE int32 = 2A 00 00 00 + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + uint8_t value_bytes[] = {0x14, 0x2A, 0x00, 0x00, 0x00}; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(42)"); +} + +TEST_F(VariantSpecTest, HandcraftedShortString) { + // Short string "hello": basic_type=1, length=5 + // Header: (5 << 2) | 1 = 0x15 + // Followed by 5 bytes of UTF-8 "hello" + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + uint8_t value_bytes[] = {0x15, 'h', 'e', 'l', 'l', 'o'}; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"hello\")"); +} + +TEST_F(VariantSpecTest, HandcraftedSimpleObject) { + // Object {"a": 1} with metadata dictionary ["a"] + // + // Metadata: version=1, sorted=false, offset_size=1 + // header=0x01, dict_size=0x01, offsets=[0x00, 0x01], data="a" + uint8_t metadata_bytes[] = {0x01, 0x01, 0x00, 0x01, 'a'}; + // + // Value: object with 1 field + // header: basic_type=2, field_id_size=1(bits2-3=00), + // offset_size=1(bits4-5=00), num_fields_size=1(bits6-7=00) + // = 0x02 + // num_fields: 0x01 + // field_ids: [0x00] (index into metadata for "a") + // offsets: [0x00, 0x05] (field 0 at offset 0, total size 5) + // field value: Int32(1) = header 0x14 + LE bytes 01 00 00 00 + uint8_t value_bytes[] = { + 0x02, // object header + 0x01, // num_fields = 1 + 0x00, // field_id[0] = 0 + 0x00, 0x05, // offsets: [0, 5] + 0x14, 0x01, 0x00, 0x00, 0x00 // Int32(1) + }; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events.size(), 4); + ASSERT_EQ(visitor.events[0], "StartObject(1)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"a\")"); + ASSERT_EQ(visitor.events[2], "Int32(1)"); + ASSERT_EQ(visitor.events[3], "EndObject"); +} + +TEST_F(VariantSpecTest, HandcraftedTrueAndFalse) { + // True: basic_type=0, primitive_type=1 → header = (1<<2)|0 = 0x04 + // False: basic_type=0, primitive_type=2 → header = (2<<2)|0 = 0x08 + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + + uint8_t true_bytes[] = {0x04}; + uint8_t false_bytes[] = {0x08}; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + + RecordingVisitor v1; + ASSERT_OK(DecodeAndVisit(metadata, true_bytes, sizeof(true_bytes), &v1)); + ASSERT_EQ(v1.events[0], "Bool(true)"); + + RecordingVisitor v2; + ASSERT_OK(DecodeAndVisit(metadata, false_bytes, sizeof(false_bytes), &v2)); + ASSERT_EQ(v2.events[0], "Bool(false)"); +} + +TEST_F(VariantSpecTest, HandcraftedDouble) { + // Double: basic_type=0, primitive_type=7 → header = (7<<2)|0 = 0x1C + // Value: 3.14 as IEEE 754 double LE + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + uint8_t value_bytes[9]; + value_bytes[0] = 0x1C; + double val = 3.14; + std::memcpy(value_bytes + 1, &val, 8); + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_TRUE(visitor.events[0].find("Double(") == 0); +} + +// =========================================================================== +// ValueSize tests +// =========================================================================== + +class VariantValueSizeTest : public ::testing::Test {}; + +TEST_F(VariantValueSizeTest, NullSize) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(data, sizeof(data))); + ASSERT_EQ(size, 1); +} + +TEST_F(VariantValueSizeTest, Int32Size) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0, 0, 0, 0}; + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(data, sizeof(data))); + ASSERT_EQ(size, 5); +} + +TEST_F(VariantValueSizeTest, ShortStringSize) { + auto data = BuildShortString("hello"); + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, 6); // 1 header + 5 chars +} + +TEST_F(VariantValueSizeTest, ObjectSize) { + VariantMetadata meta; + meta.version = 1; + meta.strings = {"key"}; + auto val = BuildShortString("val"); + auto data = BuildObject({0}, {val}); + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeTest, ArraySize) { + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kTrue)}; + auto data = BuildArray({e1, e2}); + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeTest, UUIDSize) { + uint8_t data[17]; + data[0] = PrimitiveHeader(PrimitiveType::kUUID); + std::memset(data + 1, 0, 16); + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(data, sizeof(data))); + ASSERT_EQ(size, 17); +} + +// =========================================================================== +// Random access tests +// =========================================================================== + +class VariantRandomAccessTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = true; + metadata_.offset_size = 1; + // Sorted lexicographically for binary search + metadata_.strings = {"age", "name", "score"}; + } +}; + +TEST_F(VariantRandomAccessTest, FindObjectFieldExists) { + // Object: {age: 30, name: "Alice", score: 95} + // field_ids must be in lex order of keys: age=0, name=1, score=2 + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto name_val = BuildShortString("Alice"); + std::vector score_val = {PrimitiveHeader(PrimitiveType::kInt32), 95, 0, 0, 0}; + auto data = BuildObject({0, 1, 2}, {age_val, name_val, score_val}); + + int64_t offset = -1, size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "name", &offset, &size)); + ASSERT_GT(offset, 0); + ASSERT_EQ(size, 6); // short string "Alice" = 1 + 5 + + // Verify we can decode just that field + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, data.data() + offset, size, &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"Alice\")"); +} + +TEST_F(VariantRandomAccessTest, FindObjectFieldNotFound) { + auto val = BuildShortString("x"); + auto data = BuildObject({0}, {val}); + + int64_t offset = -1, size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "nonexistent", &offset, &size)); + ASSERT_EQ(offset, -1); + ASSERT_EQ(size, 0); +} + +TEST_F(VariantRandomAccessTest, GetArrayElementFirst) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0, e1}); + + int64_t offset = 0, size = 0; + ASSERT_OK( + GetArrayElement(data.data(), static_cast(data.size()), 0, &offset, &size)); + ASSERT_EQ(size, 5); // Int32 = 5 bytes + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, data.data() + offset, size, &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(42)"); +} + +TEST_F(VariantRandomAccessTest, GetArrayElementLast) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0, e1}); + + int64_t offset = 0, size = 0; + ASSERT_OK( + GetArrayElement(data.data(), static_cast(data.size()), 1, &offset, &size)); + ASSERT_EQ(size, 1); // Null = 1 byte +} + +TEST_F(VariantRandomAccessTest, GetArrayElementOutOfRange) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0}); + + int64_t offset = 0, size = 0; + ASSERT_RAISES(Invalid, GetArrayElement(data.data(), static_cast(data.size()), + 5, &offset, &size)); +} + +TEST_F(VariantRandomAccessTest, GetObjectFieldAtByIndex) { + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto name_val = BuildShortString("Bob"); + auto data = BuildObject({0, 1}, {age_val, name_val}); + + std::string_view name; + int64_t offset = 0, size = 0; + ASSERT_OK(GetObjectFieldAt(metadata_, data.data(), static_cast(data.size()), 1, + &name, &offset, &size)); + ASSERT_EQ(name, "name"); + ASSERT_EQ(size, 4); // short string "Bob" = 1 + 3 +} + +TEST_F(VariantRandomAccessTest, GetObjectFieldAtOutOfRange) { + auto val = BuildShortString("x"); + auto data = BuildObject({0}, {val}); + + std::string_view name; + int64_t offset = 0, size = 0; + ASSERT_RAISES( + Invalid, GetObjectFieldAt(metadata_, data.data(), static_cast(data.size()), + 99, &name, &offset, &size)); +} + +// =========================================================================== +// FindMetadataKey tests +// =========================================================================== + +class VariantFindMetadataKeyTest : public ::testing::Test {}; + +TEST_F(VariantFindMetadataKeyTest, SortedFound) { + VariantMetadata meta; + meta.is_sorted = true; + meta.strings = {"age", "name", "score"}; + ASSERT_EQ(FindMetadataKey(meta, "name"), 1); + ASSERT_EQ(FindMetadataKey(meta, "age"), 0); + ASSERT_EQ(FindMetadataKey(meta, "score"), 2); +} + +TEST_F(VariantFindMetadataKeyTest, SortedNotFound) { + VariantMetadata meta; + meta.is_sorted = true; + meta.strings = {"age", "name", "score"}; + ASSERT_EQ(FindMetadataKey(meta, "missing"), -1); +} + +TEST_F(VariantFindMetadataKeyTest, UnsortedFound) { + VariantMetadata meta; + meta.is_sorted = false; + meta.strings = {"name", "age", "score"}; + ASSERT_EQ(FindMetadataKey(meta, "age"), 1); +} + +TEST_F(VariantFindMetadataKeyTest, UnsortedNotFound) { + VariantMetadata meta; + meta.is_sorted = false; + meta.strings = {"name", "age"}; + ASSERT_EQ(FindMetadataKey(meta, "missing"), -1); +} + +// =========================================================================== +// ValueSize regression tests (Go bug: array is_large bit position) +// =========================================================================== + +class VariantValueSizeRegressionTest : public ::testing::Test {}; + +TEST_F(VariantValueSizeRegressionTest, LargeArrayIsLargeBit) { + // Build a large array with 300 elements (>255) to trigger is_large=true. + // This verifies the is_large bit is read at bit 2 of type_info (bit 4 of + // full byte), NOT bit 4 of type_info (bit 6 of full byte) which was the + // Go bug (apache/arrow-go#839). + std::vector> elements; + elements.reserve(300); + for (int i = 0; i < 300; ++i) { + elements.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildArray(elements, /*field_offset_size=*/2); + + // Verify the header byte is correctly structured + uint8_t header = data[0]; + ASSERT_EQ(GetBasicType(header), BasicType::kArray); + // is_large should be set at bit 4 of the full header byte + ASSERT_TRUE(((header >> 4) & 0x01) != 0); + + // ValueSize must return the total size of the buffer + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeRegressionTest, SmallArrayIsLargeFalse) { + // Array with 3 elements — is_large=false + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kTrue)}; + std::vector e3 = {PrimitiveHeader(PrimitiveType::kFalse)}; + auto data = BuildArray({e1, e2, e3}); + + // Verify is_large is NOT set + uint8_t header = data[0]; + ASSERT_FALSE(((header >> 4) & 0x01) != 0); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeRegressionTest, LargeObjectIsLargeBit) { + // Object with 300 fields to trigger is_large=true (bit 6 of full byte) + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 300; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = + BuildObject(field_ids, values, /*field_id_size=*/2, /*field_offset_size=*/2); + + // Verify is_large is set at bit 6 of the full header byte + uint8_t header = data[0]; + ASSERT_EQ(GetBasicType(header), BasicType::kObject); + ASSERT_TRUE(((header >> 6) & 0x01) != 0); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +// =========================================================================== +// Additional primitive decoding tests +// =========================================================================== + +class VariantPrimitiveExtraTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantPrimitiveExtraTest, DecodeTimeNTZ) { + int64_t micros = 43200000000LL; // 12:00:00 in microseconds + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimeNTZ); + std::memcpy(data + 1, µs, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimeNTZ(" + std::to_string(micros) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeTimestampNanos) { + int64_t nanos = 1654041600000000000LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampNanos); + std::memcpy(data + 1, &nanos, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampNanos(" + std::to_string(nanos) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeTimestampNanosNTZ) { + int64_t nanos = 1654041600000000000LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampNanosNTZ); + std::memcpy(data + 1, &nanos, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampNanosNTZ(" + std::to_string(nanos) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeUUID) { + uint8_t data[17]; + data[0] = PrimitiveHeader(PrimitiveType::kUUID); + // Fill UUID with recognizable pattern (big-endian per spec) + for (int i = 0; i < 16; ++i) { + data[1 + i] = static_cast(i + 1); + } + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "UUID"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeInt8Boundaries) { + // INT8_MIN = -128 + { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0x80}; + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(-128)"); + } + // INT8_MAX = 127 + { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0x7F}; + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(127)"); + } +} + +TEST_F(VariantPrimitiveExtraTest, DecodeInt16Boundaries) { + // INT16_MIN = -32768 + { + int16_t val = std::numeric_limits::min(); + uint8_t data[3]; + data[0] = PrimitiveHeader(PrimitiveType::kInt16); + std::memcpy(data + 1, &val, 2); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int16(-32768)"); + } + // INT16_MAX = 32767 + { + int16_t val = std::numeric_limits::max(); + uint8_t data[3]; + data[0] = PrimitiveHeader(PrimitiveType::kInt16); + std::memcpy(data + 1, &val, 2); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int16(32767)"); + } +} + +TEST_F(VariantPrimitiveExtraTest, DecodeInt64Min) { + int64_t val = std::numeric_limits::min(); + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kInt64); + std::memcpy(data + 1, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int64(" + std::to_string(val) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeEmptyBinary) { + // Binary with zero length + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kBinary)); + uint32_t len = 0; + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "Binary(len=0)"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeEmptyLongString) { + // Long string with zero length + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kString)); + uint32_t len = 0; + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"\")"); +} + +// =========================================================================== +// Object with non-monotonic offsets (spec-compliant) +// =========================================================================== + +class VariantObjectNonMonotonicTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = true; + metadata_.offset_size = 1; + // Sorted lexicographically + metadata_.strings = {"a", "b", "c"}; + } +}; + +TEST_F(VariantObjectNonMonotonicTest, NonMonotonicObjectOffsets) { + // Per spec: "field IDs and offsets must be listed in the order of the + // corresponding field names, sorted lexicographically" but "the actual + // value entries do not need to be in any particular order" and "the + // field_offset values may not be monotonically increasing." + // + // Construct: {a: 1, b: 2, c: 3} where values are stored as [3, 1, 2] + // in the data area but offsets point to them in key-sorted order. + std::vector val_a = {PrimitiveHeader(PrimitiveType::kInt8), 1}; + std::vector val_b = {PrimitiveHeader(PrimitiveType::kInt8), 2}; + std::vector val_c = {PrimitiveHeader(PrimitiveType::kInt8), 3}; + + // Data area stores: val_c (2 bytes) | val_a (2 bytes) | val_b (2 bytes) + // Offsets: a->2, b->4, c->0, end->6 + uint8_t header = static_cast(BasicType::kObject); // offset_size=1, id_size=1 + std::vector data; + data.push_back(header); + data.push_back(3); // num_fields = 3 + data.push_back(0); // field_id[0] = 0 ("a") + data.push_back(1); // field_id[1] = 1 ("b") + data.push_back(2); // field_id[2] = 2 ("c") + data.push_back(2); // offset[0] = 2 (val_a starts at byte 2) + data.push_back(4); // offset[1] = 4 (val_b starts at byte 4) + data.push_back(0); // offset[2] = 0 (val_c starts at byte 0) + data.push_back(6); // offset[3] = 6 (total data size) + // Data area: val_c, val_a, val_b + data.insert(data.end(), val_c.begin(), val_c.end()); + data.insert(data.end(), val_a.begin(), val_a.end()); + data.insert(data.end(), val_b.begin(), val_b.end()); + + RecordingVisitor visitor; + ASSERT_OK(DecodeAndVisit(metadata_, data.data(), static_cast(data.size()), + &visitor)); + // Field iteration order follows field_ids (sorted by key): a, b, c + ASSERT_EQ(visitor.events.size(), 8); + ASSERT_EQ(visitor.events[0], "StartObject(3)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"a\")"); + ASSERT_EQ(visitor.events[2], "Int8(1)"); + ASSERT_EQ(visitor.events[3], "FieldName(\"b\")"); + ASSERT_EQ(visitor.events[4], "Int8(2)"); + ASSERT_EQ(visitor.events[5], "FieldName(\"c\")"); + ASSERT_EQ(visitor.events[6], "Int8(3)"); + ASSERT_EQ(visitor.events[7], "EndObject"); +} + +TEST_F(VariantObjectNonMonotonicTest, FindFieldWithNonMonotonicOffsets) { + // Same layout as above: values stored out-of-order + uint8_t header = static_cast(BasicType::kObject); + std::vector data; + data.push_back(header); + data.push_back(3); + data.push_back(0); + data.push_back(1); + data.push_back(2); + data.push_back(2); // a -> offset 2 + data.push_back(4); // b -> offset 4 + data.push_back(0); // c -> offset 0 + data.push_back(6); // end = 6 + // Data: [Int8(3), Int8(1), Int8(2)] + data.push_back(PrimitiveHeader(PrimitiveType::kInt8)); + data.push_back(3); + data.push_back(PrimitiveHeader(PrimitiveType::kInt8)); + data.push_back(1); + data.push_back(PrimitiveHeader(PrimitiveType::kInt8)); + data.push_back(2); + + // FindObjectField should find "c" at offset 0 of data area + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "c", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); + ASSERT_EQ(field_size, 2); // Int8 = 2 bytes + + // Decode the value at that offset and verify it's 3 (val_c) + RecordingVisitor v; + ASSERT_OK(DecodeAndVisit(metadata_, data.data() + field_offset, field_size, &v)); + ASSERT_EQ(v.events[0], "Int8(3)"); +} + +// =========================================================================== +// ValueSize for variable-length primitives +// =========================================================================== + +class VariantValueSizeVarLenTest : public ::testing::Test {}; + +TEST_F(VariantValueSizeVarLenTest, LongStringSize) { + // Long string "hello" (5 chars): header(1) + length(4) + data(5) = 10 + std::string s = "hello"; + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kString)); + auto len = static_cast(s.size()); + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + data.insert(data.end(), s.begin(), s.end()); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, 10); +} + +TEST_F(VariantValueSizeVarLenTest, BinarySize) { + // Binary with 4 bytes: header(1) + length(4) + data(4) = 9 + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kBinary)); + uint32_t len = 4; + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + data.push_back(0x00); + data.push_back(0x01); + data.push_back(0x02); + data.push_back(0x03); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, 9); +} + +TEST_F(VariantValueSizeVarLenTest, TruncatedLongString) { + // Only header byte, no length field + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kString)}; + ASSERT_RAISES(Invalid, ValueSize(data, sizeof(data))); +} + +// =========================================================================== +// Unknown/invalid type tests +// =========================================================================== + +class VariantUnknownTypeTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantUnknownTypeTest, UnknownPrimitiveType) { + // Primitive type ID 25 (beyond kUUID=20) should produce an error. + // Header: (25 << 2) | 0 = 0x64 + uint8_t data[] = {0x64}; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); +} + +TEST_F(VariantUnknownTypeTest, UnknownPrimitiveTypeValueSize) { + // ValueSize on an unknown primitive type should still return a value + // (PrimitiveValueSize returns -1, triggering variable-length path). + // With only 1 byte, variable-length path requires 5 bytes → truncated. + uint8_t data[] = {0x64}; + ASSERT_RAISES(Invalid, ValueSize(data, sizeof(data))); +} + +// =========================================================================== +// Array non-monotonic offset rejection test +// =========================================================================== + +class VariantArrayNonMonotonicTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantArrayNonMonotonicTest, RejectsNonMonotonicOffsets) { + // Manually craft an array with 2 elements where offsets go [0, 3, 1] + // (non-monotonic: 1 < 3). This should be rejected. + // header: basic_type=3, offset_size=1, is_large=false → 0x03 + // num_elements: 2 + // offsets: [0, 3, 1] — non-monotonic + // data: 3 bytes of nulls + uint8_t data[] = { + 0x03, // array header: basic_type=3, offset_size=1, is_large=false + 0x02, // num_elements = 2 + 0x00, + 0x03, + 0x01, // offsets: [0, 3, 1] — non-monotonic! + PrimitiveHeader(PrimitiveType::kNull), + PrimitiveHeader(PrimitiveType::kNull), + PrimitiveHeader(PrimitiveType::kNull), + }; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeAndVisit(empty_metadata_, data, sizeof(data), &visitor)); +} + +// =========================================================================== +// Object field offset out-of-bounds test +// =========================================================================== + +class VariantObjectOffsetBoundsTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"key"}; + } +}; + +TEST_F(VariantObjectOffsetBoundsTest, FieldOffsetExceedsDataSize) { + // Object with 1 field where field_offset[0] = 99 (beyond total_data_size). + // header: basic_type=2, offset_size=1, id_size=1, is_large=false → 0x02 + // num_fields: 1 + // field_ids: [0] + // offsets: [99, 2] — field 0 at offset 99, total=2 + // data: 2 bytes (Null) + uint8_t data[] = { + 0x02, // object header + 0x01, // num_fields = 1 + 0x00, // field_id[0] = 0 + 0x63, + 0x02, // offsets: [99, 2] — 99 > total_data_size(2) + PrimitiveHeader(PrimitiveType::kNull), + PrimitiveHeader(PrimitiveType::kNull), + }; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeAndVisit(metadata_, data, sizeof(data), &visitor)); +} + +// =========================================================================== +// Empty metadata with various offset sizes +// =========================================================================== + +class VariantMetadataOffsetSizeTest : public ::testing::Test {}; + +TEST_F(VariantMetadataOffsetSizeTest, EmptyDictionaryOffsetSize4) { + // Valid metadata with 0 strings but offset_size=4. + auto buf = BuildMetadataBuffer({}, false, 4); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.version, 1); + ASSERT_EQ(metadata.offset_size, 4); + ASSERT_EQ(metadata.strings.size(), 0); +} + +// =========================================================================== +// FindObjectField with binary search (large object >= 32 fields) +// =========================================================================== + +class VariantFindFieldBinarySearchTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + // Backing storage for string_views in metadata (must outlive metadata_). + // Do NOT modify key_storage_ after SetUp(); reallocation invalidates + // the string_views stored in metadata_.strings. + std::vector key_storage_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = true; + metadata_.offset_size = 1; + // 40 keys in sorted order to trigger binary search path + key_storage_.reserve(40); + for (int i = 0; i < 40; ++i) { + std::string key = "k" + std::string(i < 10 ? "0" : "") + std::to_string(i); + key_storage_.emplace_back(key); + } + for (const auto& k : key_storage_) { + metadata_.strings.push_back(k); + } + } +}; + +TEST_F(VariantFindFieldBinarySearchTest, FindMiddleField) { + // Build object with 40 fields, all null values + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + // Search for "k20" (middle of the sorted range) + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "k20", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); + ASSERT_EQ(field_size, 1); // Null = 1 byte +} + +TEST_F(VariantFindFieldBinarySearchTest, FindFirstField) { + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "k00", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); +} + +TEST_F(VariantFindFieldBinarySearchTest, FindLastField) { + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "k39", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); +} + +TEST_F(VariantFindFieldBinarySearchTest, NotFoundInLargeObject) { + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "zzz", &field_offset, &field_size)); + ASSERT_EQ(field_offset, -1); +} + +// =========================================================================== +// GetArrayElement middle index +// =========================================================================== + +class VariantGetArrayElementExtraTest : public ::testing::Test {}; + +TEST_F(VariantGetArrayElementExtraTest, MiddleElement) { + // Array of [Int32(10), Int32(20), Int32(30)] + std::vector e0 = {PrimitiveHeader(PrimitiveType::kInt32), 10, 0, 0, 0}; + std::vector e1 = {PrimitiveHeader(PrimitiveType::kInt32), 20, 0, 0, 0}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto data = BuildArray({e0, e1, e2}); + + int64_t elem_offset = 0, elem_size = 0; + ASSERT_OK(GetArrayElement(data.data(), static_cast(data.size()), 1, + &elem_offset, &elem_size)); + ASSERT_EQ(elem_size, 5); // Int32 = 5 bytes + + // Decode the middle element + VariantMetadata meta; + meta.version = 1; + RecordingVisitor v; + ASSERT_OK(DecodeAndVisit(meta, data.data() + elem_offset, elem_size, &v)); + ASSERT_EQ(v.events[0], "Int32(20)"); +} + +TEST_F(VariantGetArrayElementExtraTest, EmptyArrayOutOfRange) { + auto data = BuildArray({}); + int64_t elem_offset = 0, elem_size = 0; + ASSERT_RAISES(Invalid, GetArrayElement(data.data(), static_cast(data.size()), + 0, &elem_offset, &elem_size)); +} + +// =========================================================================== +// Additional error case tests (missing coverage) +// =========================================================================== + +class VariantErrorCaseTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantErrorCaseTest, MetadataVersionZero) { + // Version 0 is not supported (only version 1 is valid per spec) + uint8_t data[] = {0x00, 0x00, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, GetObjectFieldCountOnArray) { + // Calling GetObjectFieldCount on an array value should produce an error + auto data = BuildArray({}); + ASSERT_RAISES(Invalid, + GetObjectFieldCount(data.data(), static_cast(data.size()))); +} + +TEST_F(VariantErrorCaseTest, GetArrayElementCountOnObject) { + // Calling GetArrayElementCount on an object value should produce an error + auto data = BuildObject({}, {}); + ASSERT_RAISES(Invalid, + GetArrayElementCount(data.data(), static_cast(data.size()))); +} + +TEST_F(VariantErrorCaseTest, GetObjectFieldCountOnPrimitive) { + // Calling GetObjectFieldCount on a primitive should produce an error + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + ASSERT_RAISES(Invalid, GetObjectFieldCount(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, GetArrayElementCountOnPrimitive) { + // Calling GetArrayElementCount on a primitive should produce an error + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + ASSERT_RAISES(Invalid, GetArrayElementCount(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, MetadataStringOffsetExceedsBuffer) { + // Metadata where the last string offset claims more data than the buffer + // contains. This exercises the ValidateOffsets check for offsets.back() > + // data_length. + // Header: version=1, offset_size=1 + // dict_size=1, offsets=[0, 100] — but only 3 bytes of string data + uint8_t data[] = {0x01, // header: version=1, offset_size=1 + 0x01, // dict_size = 1 + 0x00, 0x64, // offsets: [0, 100] — 100 exceeds available string data + 'a', 'b', 'c'}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, GetArrayElementNegativeIndex) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0}); + int64_t elem_offset = 0, elem_size = 0; + ASSERT_RAISES(Invalid, GetArrayElement(data.data(), static_cast(data.size()), + -1, &elem_offset, &elem_size)); +} + +TEST_F(VariantErrorCaseTest, FindObjectFieldOnNonObject) { + // Calling FindObjectField on an array should produce an error + auto data = BuildArray({}); + int64_t field_offset = -1, field_size = 0; + ASSERT_RAISES(Invalid, FindObjectField(empty_metadata_, data.data(), + static_cast(data.size()), "key", + &field_offset, &field_size)); +} + +// TODO: Add fuzz targets for DecodeMetadata and DecodeVariantValue to exercise +// adversarial/malformed input. Fuzz tests in Arrow are typically registered as +// separate executables under cpp/src/arrow/testing/fuzzing/ — see GH-45948. + +// =========================================================================== +// View API tests (new — demonstrates C++ ergonomic approach) +// =========================================================================== + +class VariantViewTest : public ::testing::Test {}; + +TEST_F(VariantViewTest, PrimitiveInt32) { + auto meta_buf = BuildMetadataBuffer({}); + ASSERT_OK_AND_ASSIGN( + auto meta, DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + + // Build an Int32 value: header + 4 bytes LE + std::vector data = {PrimitiveHeader(PrimitiveType::kInt32), 0x2A, 0x00, 0x00, + 0x00}; + + ASSERT_OK_AND_ASSIGN( + auto view, VariantView::Make(meta, data.data(), static_cast(data.size()))); + ASSERT_EQ(view.type(), BasicType::kPrimitive); + ASSERT_FALSE(view.is_null()); + ASSERT_OK_AND_ASSIGN(auto val, view.as_int32()); + ASSERT_EQ(val, 42); + ASSERT_EQ(view.size_bytes(), 5); +} + +TEST_F(VariantViewTest, PrimitiveNull) { + auto meta_buf = BuildMetadataBuffer({}); + ASSERT_OK_AND_ASSIGN( + auto meta, DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + + std::vector data = {PrimitiveHeader(PrimitiveType::kNull)}; + ASSERT_OK_AND_ASSIGN( + auto view, VariantView::Make(meta, data.data(), static_cast(data.size()))); + ASSERT_TRUE(view.is_null()); + ASSERT_EQ(view.size_bytes(), 1); +} + +TEST_F(VariantViewTest, ShortString) { + auto meta_buf = BuildMetadataBuffer({}); + ASSERT_OK_AND_ASSIGN( + auto meta, DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + + auto data = BuildShortString("hello"); + ASSERT_OK_AND_ASSIGN( + auto view, VariantView::Make(meta, data.data(), static_cast(data.size()))); + ASSERT_EQ(view.type(), BasicType::kShortString); + ASSERT_OK_AND_ASSIGN(auto str, view.as_string()); + ASSERT_EQ(str, "hello"); +} + +TEST_F(VariantViewTest, TypeMismatchReturnsError) { + auto meta_buf = BuildMetadataBuffer({}); + ASSERT_OK_AND_ASSIGN( + auto meta, DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + + // Build a boolean true value + std::vector data = {PrimitiveHeader(PrimitiveType::kTrue)}; + ASSERT_OK_AND_ASSIGN( + auto view, VariantView::Make(meta, data.data(), static_cast(data.size()))); + + // Accessing as wrong type should fail + ASSERT_RAISES(Invalid, view.as_int32()); + ASSERT_RAISES(Invalid, view.as_string()); + ASSERT_RAISES(Invalid, view.as_object()); +} + +class VariantObjectViewTest : public ::testing::Test {}; + +TEST_F(VariantObjectViewTest, SimpleObject) { + // Build {"name": "Alice", "age": 42} + auto meta_buf = BuildMetadataBuffer({"age", "name"}, /*sorted=*/true); + ASSERT_OK_AND_ASSIGN( + auto meta, DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + + // Field values + auto val_name = BuildShortString("Alice"); + std::vector val_age = {PrimitiveHeader(PrimitiveType::kInt8), 42}; + + // Build object: field IDs sorted by key name → age=0, name=1 + auto data = BuildObject({0, 1}, {val_age, val_name}); + + ASSERT_OK_AND_ASSIGN( + auto view, VariantView::Make(meta, data.data(), static_cast(data.size()))); + ASSERT_EQ(view.type(), BasicType::kObject); + + ASSERT_OK_AND_ASSIGN(auto obj, view.as_object()); + ASSERT_EQ(obj.num_fields(), 2); + + // Lookup by name + auto age = obj.get("age"); + ASSERT_TRUE(age.has_value()); + ASSERT_OK_AND_ASSIGN(auto age_val, age->as_int8()); + ASSERT_EQ(age_val, 42); + + auto name = obj.get("name"); + ASSERT_TRUE(name.has_value()); + ASSERT_OK_AND_ASSIGN(auto name_val, name->as_string()); + ASSERT_EQ(name_val, "Alice"); + + // Not found + auto missing = obj.get("nonexistent"); + ASSERT_FALSE(missing.has_value()); +} + +TEST_F(VariantObjectViewTest, NestedNavigation) { + // Build: {"addresses": {"postal": {"city": "New York"}}} + // This test addresses reviewer comment #4 (nested field navigation) + auto meta_buf = BuildMetadataBuffer({"addresses", "city", "postal"}, /*sorted=*/true); + ASSERT_OK_AND_ASSIGN( + auto meta, DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + + // innermost: {"city": "New York"} — field_id for "city" = 1 + auto val_city = BuildShortString("New York"); + auto inner_obj = BuildObject({1}, {val_city}); + + // middle: {"postal": } — field_id for "postal" = 2 + auto mid_obj = BuildObject({2}, {inner_obj}); + + // outer: {"addresses": } — field_id for "addresses" = 0 + auto outer_obj = BuildObject({0}, {mid_obj}); + + ASSERT_OK_AND_ASSIGN( + auto root, + VariantView::Make(meta, outer_obj.data(), static_cast(outer_obj.size()))); + + // Navigate: root -> addresses -> postal -> city + ASSERT_OK_AND_ASSIGN(auto root_obj, root.as_object()); + auto addresses = root_obj.get("addresses"); + ASSERT_TRUE(addresses.has_value()); + + ASSERT_OK_AND_ASSIGN(auto addr_obj, addresses->as_object()); + auto postal = addr_obj.get("postal"); + ASSERT_TRUE(postal.has_value()); + + ASSERT_OK_AND_ASSIGN(auto postal_obj, postal->as_object()); + auto city = postal_obj.get("city"); + ASSERT_TRUE(city.has_value()); + + ASSERT_OK_AND_ASSIGN(auto city_val, city->as_string()); + ASSERT_EQ(city_val, "New York"); +} + +TEST_F(VariantObjectViewTest, IterateFields) { + auto meta_buf = BuildMetadataBuffer({"a", "b", "c"}, /*sorted=*/true); + ASSERT_OK_AND_ASSIGN( + auto meta, DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + + std::vector val_a = {PrimitiveHeader(PrimitiveType::kInt8), 1}; + std::vector val_b = {PrimitiveHeader(PrimitiveType::kInt8), 2}; + std::vector val_c = {PrimitiveHeader(PrimitiveType::kInt8), 3}; + auto data = BuildObject({0, 1, 2}, {val_a, val_b, val_c}); + + ASSERT_OK_AND_ASSIGN( + auto obj, + VariantObjectView::Make(meta, data.data(), static_cast(data.size()))); + + std::vector names; + for (auto [name, value] : obj) { + names.push_back(std::string(name)); + } + ASSERT_EQ(names.size(), 3); + ASSERT_EQ(names[0], "a"); + ASSERT_EQ(names[1], "b"); + ASSERT_EQ(names[2], "c"); +} + +class VariantArrayViewTest : public ::testing::Test {}; + +TEST_F(VariantArrayViewTest, SimpleArray) { + auto meta_buf = BuildMetadataBuffer({}); + ASSERT_OK_AND_ASSIGN( + auto meta, DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + + // Build array [42, 100] + std::vector val_a = {PrimitiveHeader(PrimitiveType::kInt8), 42}; + std::vector val_b = {PrimitiveHeader(PrimitiveType::kInt8), 100}; + auto data = BuildArray({val_a, val_b}); + + ASSERT_OK_AND_ASSIGN( + auto view, VariantView::Make(meta, data.data(), static_cast(data.size()))); + ASSERT_EQ(view.type(), BasicType::kArray); + + ASSERT_OK_AND_ASSIGN(auto arr, view.as_array()); + ASSERT_EQ(arr.num_elements(), 2); + + ASSERT_OK_AND_ASSIGN(auto elem0, arr.get(0)); + ASSERT_OK_AND_ASSIGN(auto v0, elem0.as_int8()); + ASSERT_EQ(v0, 42); + + ASSERT_OK_AND_ASSIGN(auto elem1, arr.get(1)); + ASSERT_OK_AND_ASSIGN(auto v1, elem1.as_int8()); + ASSERT_EQ(v1, 100); + + // Out of range + ASSERT_RAISES(Invalid, arr.get(2)); + ASSERT_RAISES(Invalid, arr.get(-1)); +} + +TEST_F(VariantArrayViewTest, IterateElements) { + auto meta_buf = BuildMetadataBuffer({}); + ASSERT_OK_AND_ASSIGN( + auto meta, DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + + std::vector val_a = {PrimitiveHeader(PrimitiveType::kInt8), 10}; + std::vector val_b = {PrimitiveHeader(PrimitiveType::kInt8), 20}; + std::vector val_c = {PrimitiveHeader(PrimitiveType::kInt8), 30}; + auto data = BuildArray({val_a, val_b, val_c}); + + ASSERT_OK_AND_ASSIGN( + auto arr, + VariantArrayView::Make(meta, data.data(), static_cast(data.size()))); + + int count = 0; + for ([[maybe_unused]] auto elem : arr) { + ++count; + } + ASSERT_EQ(count, 3); +} + +} // namespace arrow::extension::variant diff --git a/cpp/src/arrow/meson.build b/cpp/src/arrow/meson.build index 4b8faebecfd7..36ea0f615740 100644 --- a/cpp/src/arrow/meson.build +++ b/cpp/src/arrow/meson.build @@ -142,6 +142,7 @@ arrow_components = { 'extension/bool8.cc', 'extension/json.cc', 'extension/parquet_variant.cc', + 'extension/variant.cc', 'extension/uuid.cc', 'pretty_print.cc', 'record_batch.cc', From f6b8e6609b5cd2b346b162150dab8c84b003e4e1 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Thu, 25 Jun 2026 22:50:10 -0700 Subject: [PATCH 2/2] GH-45947: [C++][Parquet] Variant encoding with RAII builders --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/extension/CMakeLists.txt | 2 +- cpp/src/arrow/extension/meson.build | 2 +- cpp/src/arrow/extension/variant.h | 231 ++++ cpp/src/arrow/extension/variant_builder.cc | 635 +++++++++ .../arrow/extension/variant_builder_test.cc | 1228 +++++++++++++++++ cpp/src/arrow/extension/variant_test.cc | 190 +++ cpp/src/arrow/meson.build | 1 + 8 files changed, 2288 insertions(+), 2 deletions(-) create mode 100644 cpp/src/arrow/extension/variant_builder.cc create mode 100644 cpp/src/arrow/extension/variant_builder_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 149ec9c6ff19..6e4673cafba4 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -392,6 +392,7 @@ set(ARROW_SRCS extension/json.cc extension/parquet_variant.cc extension/variant.cc + extension/variant_builder.cc extension/uuid.cc pretty_print.cc record_batch.cc diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index 283a328a9098..dbd09523ae32 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -16,7 +16,7 @@ # under the License. set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc - variant_test.cc) + variant_test.cc variant_builder_test.cc) if(ARROW_JSON) list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc) diff --git a/cpp/src/arrow/extension/meson.build b/cpp/src/arrow/extension/meson.build index 6d2222698c12..5820fea2cf67 100644 --- a/cpp/src/arrow/extension/meson.build +++ b/cpp/src/arrow/extension/meson.build @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc', 'variant_test.cc'] +canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc', 'variant_test.cc', 'variant_builder_test.cc'] if needs_json canonical_extension_tests += [ diff --git a/cpp/src/arrow/extension/variant.h b/cpp/src/arrow/extension/variant.h index ec9bfab2000b..27a0ed9ebb1d 100644 --- a/cpp/src/arrow/extension/variant.h +++ b/cpp/src/arrow/extension/variant.h @@ -32,9 +32,11 @@ /// - O(log n) field lookup always (no threshold heuristics) #include +#include #include #include #include +#include #include #include "arrow/result.h" @@ -57,6 +59,17 @@ constexpr int32_t kMaxNestingDepth = 128; /// UUID values are always 16 bytes (128-bit, big-endian per RFC 4122). constexpr int32_t kUUIDByteLength = 16; +/// Maximum length for short-string encoding (6 bits → 0..63). +/// Strings longer than this use the long-string (4-byte length prefix) format. +constexpr int32_t kMaxShortStringLength = 63; + +/// Maximum supported decimal scale per the variant encoding spec. +constexpr uint8_t kMaxDecimalScale = 38; + +/// Container element count threshold for large encoding. +/// Objects/arrays with more than 255 elements use 4-byte num_elements fields. +constexpr int32_t kLargeContainerThreshold = 255; + // --------------------------------------------------------------------------- // Enumerations // --------------------------------------------------------------------------- @@ -580,4 +593,222 @@ class ARROW_EXPORT VariantVisitor { /// @} }; +// --------------------------------------------------------------------------- +// VariantBuilder (encoder) +// --------------------------------------------------------------------------- + +class ObjectScope; +class ListScope; + +/// \brief Builder for constructing Variant binary values. +/// +/// Provides both low-level (Offset/NextField/FinishObject) and high-level +/// (StartObject/StartList returning RAII scopes) APIs for encoding. +class ARROW_EXPORT VariantBuilder { + public: + VariantBuilder(); + explicit VariantBuilder(const VariantMetadata& existing_metadata); + ~VariantBuilder() = default; + + VariantBuilder(VariantBuilder&&) noexcept = default; + VariantBuilder& operator=(VariantBuilder&&) noexcept = default; + VariantBuilder(const VariantBuilder&) = delete; + VariantBuilder& operator=(const VariantBuilder&) = delete; + + /// @name Primitive value setters + /// @{ + Status Null(); + Status Bool(bool value); + Status Int(int64_t value); ///< Auto-selects smallest int type + Status Int8(int8_t value); + Status Int16(int16_t value); + Status Int32(int32_t value); + Status Int64(int64_t value); + Status Float(float value); + Status Double(double value); + Status Decimal4(uint8_t scale, const uint8_t* value_bytes); + Status Decimal8(uint8_t scale, const uint8_t* value_bytes); + Status Decimal16(uint8_t scale, const uint8_t* value_bytes); + Status Date(int32_t days_since_epoch); + Status TimestampMicros(int64_t micros); + Status TimestampMicrosNTZ(int64_t micros); + Status TimeNTZ(int64_t micros); + Status TimestampNanos(int64_t nanos); + Status TimestampNanosNTZ(int64_t nanos); + Status String(std::string_view value); ///< Auto short-string for <=63 bytes + Status Binary(std::string_view value); + Status UUID(const uint8_t* bytes); + /// @} + + /// @name Low-level container construction + /// @{ + int64_t Offset() const; + int64_t NextElement(int64_t start) const; + + struct FieldEntry { + std::string key; + uint32_t id; + int64_t offset; + }; + + FieldEntry NextField(int64_t start, std::string_view key); + Status FinishArray(int64_t start, const std::vector& offsets); + Status FinishObject(int64_t start, std::vector& fields); + /// @} + + /// @name RAII container construction + /// @{ + + /// \brief Start building an object. Returns an RAII scope that auto-rolls + /// back if Finish() is not called (e.g., on exception/early return). + [[nodiscard]] ObjectScope StartObject(); + + /// \brief Start building a list/array. Returns an RAII scope. + [[nodiscard]] ListScope StartList(); + /// @} + + /// @name Output + /// @{ + struct EncodedVariant { + std::vector metadata; + std::vector value; + }; + + Result Finish(); + void Reset(); + /// @} + + /// @name Internal (used by scopes and shredding) + /// @name Internal (used by scopes) + void Truncate(int64_t offset); + /// @} + + private: + friend class ObjectScope; + friend class ListScope; + + uint32_t AddKey(std::string_view key); + + std::vector buffer_; + /// \brief Dictionary mapping key names to their IDs. + /// Uses a custom transparent hasher to allow lookups with string_view + /// without constructing a std::string (C++17 heterogeneous lookup). + struct StringHash { + using is_transparent = void; + size_t operator()(std::string_view sv) const noexcept { + return std::hash{}(sv); + } + }; + struct StringEqual { + using is_transparent = void; + bool operator()(std::string_view a, std::string_view b) const noexcept { + return a == b; + } + }; + std::unordered_map dict_; + std::vector dict_keys_; + bool allow_duplicates_ = false; +}; + +// --------------------------------------------------------------------------- +// ObjectScope — RAII scoped object builder +// --------------------------------------------------------------------------- + +/// \brief RAII scope for building an object. Destructor rolls back if not committed. +/// +/// Usage: +/// auto obj = builder.StartObject(); +/// obj.Insert("name", "Alice"); +/// obj.Insert("age", 30); +/// obj.Finish(); // commits +/// // If Finish() not called, destructor truncates buffer to pre-scope state +class ARROW_EXPORT ObjectScope { + public: + ~ObjectScope(); + + ObjectScope(const ObjectScope&) = delete; + ObjectScope& operator=(const ObjectScope&) = delete; + ObjectScope(ObjectScope&& other) noexcept; + ObjectScope& operator=(ObjectScope&&) = delete; + + /// @name Insert fields (delegates to VariantBuilder primitives) + /// @{ + Status Insert(std::string_view key, std::nullptr_t); + Status Insert(std::string_view key, bool value); + Status Insert(std::string_view key, int64_t value); + Status Insert(std::string_view key, double value); + Status Insert(std::string_view key, std::string_view value); + + /// \brief Insert a nested object. Returns an RAII scope for the sub-object. + [[nodiscard]] ObjectScope InsertObject(std::string_view key); + + /// \brief Insert a nested list. Returns an RAII scope for the sub-list. + [[nodiscard]] ListScope InsertList(std::string_view key); + /// @} + + /// \brief Commit the object (sorts fields, writes header). + Status Finish(); + + private: + friend class VariantBuilder; + friend class ListScope; + + explicit ObjectScope(VariantBuilder& parent); + + VariantBuilder* parent_; + int64_t start_offset_; + std::vector fields_; + bool committed_ = false; +}; + +// --------------------------------------------------------------------------- +// ListScope — RAII scoped list/array builder +// --------------------------------------------------------------------------- + +/// \brief RAII scope for building a list. Destructor rolls back if not committed. +/// +/// Usage: +/// auto list = builder.StartList(); +/// list.Append(1); +/// list.Append(2); +/// list.Finish(); +class ARROW_EXPORT ListScope { + public: + ~ListScope(); + + ListScope(const ListScope&) = delete; + ListScope& operator=(const ListScope&) = delete; + ListScope(ListScope&& other) noexcept; + ListScope& operator=(ListScope&&) = delete; + + /// @name Append elements + /// @{ + Status Append(std::nullptr_t); + Status Append(bool value); + Status Append(int64_t value); + Status Append(double value); + Status Append(std::string_view value); + + /// \brief Append a nested object. Returns an RAII scope. + [[nodiscard]] ObjectScope AppendObject(); + + /// \brief Append a nested list. Returns an RAII scope. + [[nodiscard]] ListScope AppendList(); + /// @} + + /// \brief Commit the list (writes header with offsets). + Status Finish(); + + private: + friend class VariantBuilder; + friend class ObjectScope; + + explicit ListScope(VariantBuilder& parent); + + VariantBuilder* parent_; + int64_t start_offset_; + std::vector offsets_; + bool committed_ = false; +}; + } // namespace arrow::extension::variant diff --git a/cpp/src/arrow/extension/variant_builder.cc b/cpp/src/arrow/extension/variant_builder.cc new file mode 100644 index 000000000000..e8465f48eecf --- /dev/null +++ b/cpp/src/arrow/extension/variant_builder.cc @@ -0,0 +1,635 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant.h" + +#include +#include +#include + +#include "arrow/util/endian.h" +#include "arrow/util/logging_internal.h" + +namespace arrow::extension::variant { + +namespace { + +/// \brief Compute the minimum number of bytes needed to represent a value. +/// \param[in] value Must be non-negative and fit in 4 bytes (represents a size or ID). +int32_t IntSize(int64_t value) { + DCHECK_GE(value, 0); + DCHECK_LE(value, static_cast(std::numeric_limits::max())); + if (value <= 0xFF) return 1; + if (value <= 0xFFFF) return 2; + if (value <= 0xFFFFFF) return 3; + return 4; +} + +/// \brief Write an unsigned integer in little-endian using nbytes bytes. +void WriteUnsignedLE(uint8_t* buf, int64_t value, int32_t nbytes) { + for (int32_t i = 0; i < nbytes; ++i) { + buf[i] = static_cast((value >> (i * 8)) & 0xFF); + } +} + +/// \brief Write a little-endian value into a vector at a given position. +void WriteUnsignedLEAt(std::vector& buf, int64_t pos, int64_t value, + int32_t nbytes) { + for (int32_t i = 0; i < nbytes; ++i) { + buf[pos + i] = static_cast((value >> (i * 8)) & 0xFF); + } +} + +/// \brief Construct a primitive header byte. +uint8_t MakePrimitiveHeader(PrimitiveType type) { + return static_cast(BasicType::kPrimitive) | (static_cast(type) << 2); +} + +/// \brief Write a fixed-size numeric primitive into the buffer. +template +void WritePrimitive(std::vector& buf, PrimitiveType type, T value) { + buf.push_back(MakePrimitiveHeader(type)); + value = bit_util::ToLittleEndian(value); + auto ptr = reinterpret_cast(&value); + buf.insert(buf.end(), ptr, ptr + sizeof(T)); +} + +} // namespace + +// --------------------------------------------------------------------------- +// VariantBuilder implementation +// --------------------------------------------------------------------------- + +VariantBuilder::VariantBuilder() = default; + +VariantBuilder::VariantBuilder(const VariantMetadata& existing_metadata) { + for (int32_t i = 0; i < static_cast(existing_metadata.strings.size()); ++i) { + std::string key(existing_metadata.strings[i]); + dict_[key] = static_cast(i); + dict_keys_.push_back(std::move(key)); + } +} + +uint32_t VariantBuilder::AddKey(std::string_view key) { + // Transparent hasher allows direct string_view lookup without constructing + // a std::string. This eliminates per-call allocation/copy for existing keys. + auto it = dict_.find(key); + if (it != dict_.end()) { + return it->second; + } + // Key is new — insert into the dictionary. + auto id = static_cast(dict_keys_.size()); + dict_keys_.emplace_back(key); + dict_.emplace(dict_keys_.back(), id); + return id; +} + +void VariantBuilder::Reset() { + buffer_.clear(); + dict_.clear(); + dict_keys_.clear(); +} + +int64_t VariantBuilder::Offset() const { return static_cast(buffer_.size()); } + +int64_t VariantBuilder::NextElement(int64_t start) const { return Offset() - start; } + +VariantBuilder::FieldEntry VariantBuilder::NextField(int64_t start, + std::string_view key) { + auto id = AddKey(key); + return FieldEntry{std::string(key), id, Offset() - start}; +} + +// --- Primitive setters --- + +Status VariantBuilder::Null() { + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kNull)); + return Status::OK(); +} + +Status VariantBuilder::Bool(bool value) { + buffer_.push_back( + MakePrimitiveHeader(value ? PrimitiveType::kTrue : PrimitiveType::kFalse)); + return Status::OK(); +} + +Status VariantBuilder::Int(int64_t value) { + if (value >= std::numeric_limits::min() && + value <= std::numeric_limits::max()) { + return Int8(static_cast(value)); + } + if (value >= std::numeric_limits::min() && + value <= std::numeric_limits::max()) { + return Int16(static_cast(value)); + } + if (value >= std::numeric_limits::min() && + value <= std::numeric_limits::max()) { + return Int32(static_cast(value)); + } + return Int64(value); +} + +Status VariantBuilder::Int8(int8_t value) { + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kInt8)); + buffer_.push_back(static_cast(value)); + return Status::OK(); +} + +Status VariantBuilder::Int16(int16_t value) { + WritePrimitive(buffer_, PrimitiveType::kInt16, value); + return Status::OK(); +} + +Status VariantBuilder::Int32(int32_t value) { + WritePrimitive(buffer_, PrimitiveType::kInt32, value); + return Status::OK(); +} + +Status VariantBuilder::Int64(int64_t value) { + WritePrimitive(buffer_, PrimitiveType::kInt64, value); + return Status::OK(); +} + +Status VariantBuilder::Float(float value) { + WritePrimitive(buffer_, PrimitiveType::kFloat, value); + return Status::OK(); +} + +Status VariantBuilder::Double(double value) { + WritePrimitive(buffer_, PrimitiveType::kDouble, value); + return Status::OK(); +} + +Status VariantBuilder::Date(int32_t days_since_epoch) { + WritePrimitive(buffer_, PrimitiveType::kDate, days_since_epoch); + return Status::OK(); +} + +Status VariantBuilder::TimestampMicros(int64_t micros) { + WritePrimitive(buffer_, PrimitiveType::kTimestampMicros, micros); + return Status::OK(); +} + +Status VariantBuilder::TimestampMicrosNTZ(int64_t micros) { + WritePrimitive(buffer_, PrimitiveType::kTimestampMicrosNTZ, micros); + return Status::OK(); +} + +Status VariantBuilder::TimeNTZ(int64_t micros) { + WritePrimitive(buffer_, PrimitiveType::kTimeNTZ, micros); + return Status::OK(); +} + +Status VariantBuilder::TimestampNanos(int64_t nanos) { + WritePrimitive(buffer_, PrimitiveType::kTimestampNanos, nanos); + return Status::OK(); +} + +Status VariantBuilder::TimestampNanosNTZ(int64_t nanos) { + WritePrimitive(buffer_, PrimitiveType::kTimestampNanosNTZ, nanos); + return Status::OK(); +} + +Status VariantBuilder::Decimal4(uint8_t scale, const uint8_t* value_bytes) { + if (scale > kMaxDecimalScale) { + return Status::Invalid("Variant decimal scale must be in range [0, ", + static_cast(kMaxDecimalScale), "], got ", + static_cast(scale)); + } + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kDecimal4)); + buffer_.push_back(scale); + buffer_.insert(buffer_.end(), value_bytes, value_bytes + 4); + return Status::OK(); +} + +Status VariantBuilder::Decimal8(uint8_t scale, const uint8_t* value_bytes) { + if (scale > kMaxDecimalScale) { + return Status::Invalid("Variant decimal scale must be in range [0, ", + static_cast(kMaxDecimalScale), "], got ", + static_cast(scale)); + } + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kDecimal8)); + buffer_.push_back(scale); + buffer_.insert(buffer_.end(), value_bytes, value_bytes + 8); + return Status::OK(); +} + +Status VariantBuilder::Decimal16(uint8_t scale, const uint8_t* value_bytes) { + if (scale > kMaxDecimalScale) { + return Status::Invalid("Variant decimal scale must be in range [0, ", + static_cast(kMaxDecimalScale), "], got ", + static_cast(scale)); + } + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kDecimal16)); + buffer_.push_back(scale); + buffer_.insert(buffer_.end(), value_bytes, value_bytes + 16); // 16-byte unscaled value + return Status::OK(); +} + +Status VariantBuilder::String(std::string_view value) { + if (value.size() <= static_cast(kMaxShortStringLength)) { + // Short string: length encoded in header bits 2-7 + uint8_t header = static_cast(BasicType::kShortString) | + (static_cast(value.size()) << 2); + buffer_.push_back(header); + } else { + // Long string: primitive type kString + 4-byte LE length + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kString)); + auto len = static_cast(value.size()); + len = bit_util::ToLittleEndian(len); + auto ptr = reinterpret_cast(&len); + buffer_.insert(buffer_.end(), ptr, ptr + 4); + } + buffer_.insert(buffer_.end(), value.begin(), value.end()); + return Status::OK(); +} + +Status VariantBuilder::Binary(std::string_view value) { + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kBinary)); + auto len = static_cast(value.size()); + len = bit_util::ToLittleEndian(len); + auto ptr = reinterpret_cast(&len); + buffer_.insert(buffer_.end(), ptr, ptr + 4); + buffer_.insert(buffer_.end(), value.begin(), value.end()); + return Status::OK(); +} + +Status VariantBuilder::UUID(const uint8_t* bytes) { + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kUUID)); + buffer_.insert(buffer_.end(), bytes, bytes + kUUIDByteLength); + return Status::OK(); +} + +// --- Container construction --- + +Status VariantBuilder::FinishArray(int64_t start, const std::vector& offsets) { + // Note: offset fields are at most 4 bytes, so individual variant values + // cannot exceed ~4GB. This is not validated here; such values are not + // practically expected (Parquet row group sizes are bounded well below this). + auto data_size = Offset() - start; + if (data_size < 0) { + return Status::Invalid("VariantBuilder::FinishArray: invalid start position"); + } + + auto num_elements = static_cast(offsets.size()); + bool is_large = num_elements > kLargeContainerThreshold; + int32_t size_bytes = is_large ? 4 : 1; + int32_t offset_size = IntSize(data_size); + int64_t header_size = 1 + size_bytes + (num_elements + 1) * offset_size; + + // Validate offsets are non-negative (caller-provided) + for (int64_t i = 0; i < num_elements; ++i) { + if (offsets[i] < 0) { + return Status::Invalid("VariantBuilder::FinishArray: negative offset at index ", i); + } + } + + // Shift existing data to make room for the header + buffer_.resize(buffer_.size() + header_size); + std::memmove(buffer_.data() + start + header_size, buffer_.data() + start, data_size); + + // Write header byte + uint8_t header = static_cast(BasicType::kArray) | + (static_cast(offset_size - 1) << 2); + if (is_large) { + header |= (1 << 4); + } + buffer_[start] = header; + + // Write num_elements + WriteUnsignedLEAt(buffer_, start + 1, num_elements, size_bytes); + + // Write offsets + int64_t offset_pos = start + 1 + size_bytes; + for (int64_t i = 0; i < num_elements; ++i) { + WriteUnsignedLEAt(buffer_, offset_pos + i * offset_size, offsets[i], offset_size); + } + // Last offset = total data size + WriteUnsignedLEAt(buffer_, offset_pos + num_elements * offset_size, data_size, + offset_size); + + return Status::OK(); +} + +Status VariantBuilder::FinishObject(int64_t start, std::vector& fields) { + auto data_size = Offset() - start; + if (data_size < 0) { + return Status::Invalid("VariantBuilder::FinishObject: invalid start position"); + } + + // Sort fields by key name lexicographically (spec requirement). + // Skip the sort if fields are already in order (common for schema-driven insertion). + if (!std::is_sorted( + fields.begin(), fields.end(), + [](const FieldEntry& a, const FieldEntry& b) { return a.key < b.key; })) { + std::sort(fields.begin(), fields.end(), + [](const FieldEntry& a, const FieldEntry& b) { return a.key < b.key; }); + } + + // Handle duplicate keys: reject by default, deduplicate if allowed. + // When allow_duplicates_ is true (used by shredding reconstruction where + // shredded fields and residual fields may overlap on malformed input), + // last-value-wins semantics are applied. After sort, duplicates are adjacent; + // we keep the last entry for each key group. + if (!allow_duplicates_) { + for (size_t i = 1; i < fields.size(); ++i) { + if (fields[i].key == fields[i - 1].key) { + return Status::Invalid("VariantBuilder: duplicate key '", fields[i].key, "'"); + } + } + } else { + // Last-value-wins: for adjacent duplicates after sort, keep the last one. + // Since std::unique keeps the FIRST of each group, we reverse-iterate. + size_t write = 0; + for (size_t i = 0; i < fields.size(); ++i) { + // If next entry has same key, skip this one (keep the later one) + if (i + 1 < fields.size() && fields[i].key == fields[i + 1].key) { + continue; + } + if (write != i) { + fields[write] = std::move(fields[i]); + } + ++write; + } + fields.resize(write); + } + + auto num_fields = static_cast(fields.size()); + bool is_large = num_fields > kLargeContainerThreshold; + int32_t size_bytes = is_large ? 4 : 1; + + // Compute id_size from max dictionary ID + uint32_t max_id = 0; + for (const auto& f : fields) { + max_id = std::max(max_id, f.id); + } + int32_t id_size = IntSize(static_cast(max_id)); + int32_t offset_size = IntSize(data_size); + + int64_t header_size = + 1 + size_bytes + num_fields * id_size + (num_fields + 1) * offset_size; + + // Shift existing data to make room for the header + buffer_.resize(buffer_.size() + header_size); + std::memmove(buffer_.data() + start + header_size, buffer_.data() + start, data_size); + + // Write header byte: basic_type=2, offset_size in bits 2-3, id_size in bits 4-5, + // is_large in bit 6 + uint8_t header = static_cast(BasicType::kObject) | + (static_cast(offset_size - 1) << 2) | + (static_cast(id_size - 1) << 4); + if (is_large) { + header |= (1 << 6); + } + buffer_[start] = header; + + // Write num_fields + WriteUnsignedLEAt(buffer_, start + 1, num_fields, size_bytes); + + // Write field IDs (sorted by key) + int64_t id_pos = start + 1 + size_bytes; + for (int64_t i = 0; i < num_fields; ++i) { + WriteUnsignedLEAt(buffer_, id_pos + i * id_size, fields[i].id, id_size); + } + + // Write field offsets (sorted by key) + int64_t offset_pos = id_pos + num_fields * id_size; + for (int64_t i = 0; i < num_fields; ++i) { + WriteUnsignedLEAt(buffer_, offset_pos + i * offset_size, fields[i].offset, + offset_size); + } + // Last offset = total data size + WriteUnsignedLEAt(buffer_, offset_pos + num_fields * offset_size, data_size, + offset_size); + + return Status::OK(); +} + +Result VariantBuilder::Finish() { + // Build metadata + auto num_keys = static_cast(dict_keys_.size()); + + // Compute total string data size + int64_t total_string_size = 0; + for (const auto& k : dict_keys_) { + total_string_size += static_cast(k.size()); + } + + // Validate sizes fit within the spec's 4-byte offset limit. + // Note: Go implementation enforces a stricter 128MB limit (metadataMaxSizeLimit). + // We only enforce the spec's 4-byte offset maximum (~4GB), which is the correct + // upper bound per the encoding format. + if (total_string_size > static_cast(std::numeric_limits::max())) { + return Status::Invalid("VariantBuilder: total dictionary string data (", + total_string_size, + " bytes) exceeds maximum representable by 4-byte offsets"); + } + + // Compute the offset_size: must accommodate both the largest string offset + // (total_string_size) and the dictionary_size field itself, since both use + // offset_size bytes in the metadata encoding. + int32_t offset_size = + IntSize(std::max(total_string_size, static_cast(num_keys))); + + // Check if dictionary is sorted. + // Uniqueness is guaranteed by dict_ (AddKey prevents duplicates), + // so std::is_sorted with default < is sufficient for the "sorted and unique" + // semantics required by the spec. + // TODO: Cache the sorted state incrementally (check only newly-added keys + // against the previous last key) to avoid O(n) rescan on every Finish() call. + bool is_sorted = std::is_sorted(dict_keys_.begin(), dict_keys_.end()); + + // Build metadata buffer + std::vector metadata; + // Header byte + uint8_t meta_header = kVariantVersion; + if (is_sorted) { + meta_header |= (1 << 4); + } + meta_header |= static_cast((offset_size - 1) << 6); + metadata.push_back(meta_header); + + // Dictionary size + metadata.resize(metadata.size() + offset_size); + WriteUnsignedLE(metadata.data() + 1, num_keys, offset_size); + + // String offsets + int64_t cur_offset = 0; + for (int32_t i = 0; i <= num_keys; ++i) { + size_t pos = metadata.size(); + metadata.resize(pos + offset_size); + WriteUnsignedLE(metadata.data() + pos, cur_offset, offset_size); + if (i < num_keys) { + cur_offset += static_cast(dict_keys_[i].size()); + } + } + + // String data + for (const auto& k : dict_keys_) { + metadata.insert(metadata.end(), k.begin(), k.end()); + } + + EncodedVariant result; + result.metadata = std::move(metadata); + result.value = std::move(buffer_); + + // Note: dict_ and dict_keys_ are intentionally NOT cleared here. + // The dictionary is preserved so the builder can encode multiple values + // sharing the same key schema without re-adding keys. Call Reset() + // explicitly to clear everything. + buffer_.clear(); + + return result; +} + +// --------------------------------------------------------------------------- +// VariantBuilder RAII support +// --------------------------------------------------------------------------- + +void VariantBuilder::Truncate(int64_t offset) { + buffer_.resize(static_cast(offset)); +} + +ObjectScope VariantBuilder::StartObject() { return ObjectScope(*this); } + +ListScope VariantBuilder::StartList() { return ListScope(*this); } + +// --------------------------------------------------------------------------- +// ObjectScope +// --------------------------------------------------------------------------- + +ObjectScope::ObjectScope(VariantBuilder& parent) + : parent_(&parent), start_offset_(parent.Offset()), committed_(false) {} + +ObjectScope::ObjectScope(ObjectScope&& other) noexcept + : parent_(other.parent_), + start_offset_(other.start_offset_), + fields_(std::move(other.fields_)), + committed_(other.committed_) { + other.committed_ = true; // prevent double-rollback +} + +ObjectScope::~ObjectScope() { + if (!committed_ && parent_) { + parent_->Truncate(start_offset_); + } +} + +Status ObjectScope::Insert(std::string_view key, std::nullptr_t) { + fields_.push_back(parent_->NextField(start_offset_, key)); + return parent_->Null(); +} + +Status ObjectScope::Insert(std::string_view key, bool value) { + fields_.push_back(parent_->NextField(start_offset_, key)); + return parent_->Bool(value); +} + +Status ObjectScope::Insert(std::string_view key, int64_t value) { + fields_.push_back(parent_->NextField(start_offset_, key)); + return parent_->Int(value); +} + +Status ObjectScope::Insert(std::string_view key, double value) { + fields_.push_back(parent_->NextField(start_offset_, key)); + return parent_->Double(value); +} + +Status ObjectScope::Insert(std::string_view key, std::string_view value) { + fields_.push_back(parent_->NextField(start_offset_, key)); + return parent_->String(value); +} + +ObjectScope ObjectScope::InsertObject(std::string_view key) { + fields_.push_back(parent_->NextField(start_offset_, key)); + return ObjectScope(*parent_); +} + +ListScope ObjectScope::InsertList(std::string_view key) { + fields_.push_back(parent_->NextField(start_offset_, key)); + return ListScope(*parent_); +} + +Status ObjectScope::Finish() { + ARROW_RETURN_NOT_OK(parent_->FinishObject(start_offset_, fields_)); + committed_ = true; + return Status::OK(); +} + +// --------------------------------------------------------------------------- +// ListScope +// --------------------------------------------------------------------------- + +ListScope::ListScope(VariantBuilder& parent) + : parent_(&parent), start_offset_(parent.Offset()), committed_(false) {} + +ListScope::ListScope(ListScope&& other) noexcept + : parent_(other.parent_), + start_offset_(other.start_offset_), + offsets_(std::move(other.offsets_)), + committed_(other.committed_) { + other.committed_ = true; // prevent double-rollback +} + +ListScope::~ListScope() { + if (!committed_ && parent_) { + parent_->Truncate(start_offset_); + } +} + +Status ListScope::Append(std::nullptr_t) { + offsets_.push_back(parent_->NextElement(start_offset_)); + return parent_->Null(); +} + +Status ListScope::Append(bool value) { + offsets_.push_back(parent_->NextElement(start_offset_)); + return parent_->Bool(value); +} + +Status ListScope::Append(int64_t value) { + offsets_.push_back(parent_->NextElement(start_offset_)); + return parent_->Int(value); +} + +Status ListScope::Append(double value) { + offsets_.push_back(parent_->NextElement(start_offset_)); + return parent_->Double(value); +} + +Status ListScope::Append(std::string_view value) { + offsets_.push_back(parent_->NextElement(start_offset_)); + return parent_->String(value); +} + +ObjectScope ListScope::AppendObject() { + offsets_.push_back(parent_->NextElement(start_offset_)); + return ObjectScope(*parent_); +} + +ListScope ListScope::AppendList() { + offsets_.push_back(parent_->NextElement(start_offset_)); + return ListScope(*parent_); +} + +Status ListScope::Finish() { + ARROW_RETURN_NOT_OK(parent_->FinishArray(start_offset_, offsets_)); + committed_ = true; + return Status::OK(); +} + +} // namespace arrow::extension::variant diff --git a/cpp/src/arrow/extension/variant_builder_test.cc b/cpp/src/arrow/extension/variant_builder_test.cc new file mode 100644 index 000000000000..bd6d0fab34b4 --- /dev/null +++ b/cpp/src/arrow/extension/variant_builder_test.cc @@ -0,0 +1,1228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant.h" +#include "arrow/extension/variant_internal_test_util.h" + +#include +#include +#include +#include + +#include "arrow/testing/gtest_util.h" + +namespace arrow::extension::variant { + +namespace { + +// Test-local helpers +Status DecodeVariantValue(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, VariantVisitor* visitor) { + ARROW_ASSIGN_OR_RAISE(auto view, VariantView::Make(metadata, data, length)); + return view.Visit(visitor); +} + +Status FindObjectField(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, std::string_view field_name, int64_t* field_offset, + int64_t* field_size) { + *field_offset = -1; + *field_size = 0; + ARROW_ASSIGN_OR_RAISE(auto obj, VariantObjectView::Make(metadata, data, length)); + auto result = obj.get(field_name); + if (result.has_value()) { + *field_offset = result->data() - data; + *field_size = result->size_bytes(); + } + return Status::OK(); +} + +Status GetArrayElement(const uint8_t* data, int64_t length, int32_t index, + int64_t* element_offset, int64_t* element_size) { + VariantMetadata empty_meta; + ARROW_ASSIGN_OR_RAISE(auto arr, VariantArrayView::Make(empty_meta, data, length)); + ARROW_ASSIGN_OR_RAISE(auto elem, arr.get(index)); + *element_offset = elem.data() - data; + *element_size = elem.size_bytes(); + return Status::OK(); +} + +Status GetObjectFieldAt(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, int32_t index, std::string_view* field_name, + int64_t* field_offset, int64_t* field_size) { + ARROW_ASSIGN_OR_RAISE(auto obj, VariantObjectView::Make(metadata, data, length)); + ARROW_ASSIGN_OR_RAISE(*field_name, obj.field_name(index)); + ARROW_ASSIGN_OR_RAISE(auto value, obj.field_value(index)); + *field_offset = value.data() - data; + *field_size = value.size_bytes(); + return Status::OK(); +} + +Result GetObjectFieldCount(const uint8_t* data, int64_t length) { + VariantMetadata empty_meta; + ARROW_ASSIGN_OR_RAISE(auto obj, VariantObjectView::Make(empty_meta, data, length)); + return obj.num_fields(); +} + +} // namespace + +// =========================================================================== +// Helper: decode an EncodedVariant and return visitor events +// =========================================================================== + +/// Encode with builder, decode, return events. +/// Note: Uses .ValueOrDie() because ASSERT_OK_AND_ASSIGN cannot be used +/// in a non-void function. Test-only; will crash with a descriptive message +/// on failure rather than producing a clean test failure. +std::vector RoundTrip(VariantBuilder& builder) { + auto result = builder.Finish().ValueOrDie(); + auto metadata = + DecodeMetadata(result.metadata.data(), static_cast(result.metadata.size())) + .ValueOrDie(); + RecordingVisitor visitor; + auto status = DecodeVariantValue(metadata, result.value.data(), + static_cast(result.value.size()), &visitor); + EXPECT_TRUE(status.ok()) << "DecodeVariantValue failed: " << status.ToString(); + return visitor.events; +} + +// =========================================================================== +// Primitive round-trip tests +// =========================================================================== + +class VariantBuilderPrimitiveTest : public ::testing::Test {}; + +TEST_F(VariantBuilderPrimitiveTest, Null) { + VariantBuilder b; + ASSERT_OK(b.Null()); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Null"); +} + +TEST_F(VariantBuilderPrimitiveTest, BoolTrue) { + VariantBuilder b; + ASSERT_OK(b.Bool(true)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Bool(true)"); +} + +TEST_F(VariantBuilderPrimitiveTest, BoolFalse) { + VariantBuilder b; + ASSERT_OK(b.Bool(false)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Bool(false)"); +} + +TEST_F(VariantBuilderPrimitiveTest, IntAutoSizesInt8) { + VariantBuilder b; + ASSERT_OK(b.Int(42)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int8(42)"); +} + +TEST_F(VariantBuilderPrimitiveTest, IntAutoSizesInt16) { + VariantBuilder b; + ASSERT_OK(b.Int(300)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int16(300)"); +} + +TEST_F(VariantBuilderPrimitiveTest, IntAutoSizesInt32) { + VariantBuilder b; + ASSERT_OK(b.Int(100000)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int32(100000)"); +} + +TEST_F(VariantBuilderPrimitiveTest, IntAutoSizesInt64) { + VariantBuilder b; + ASSERT_OK(b.Int(5000000000LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int64(5000000000)"); +} + +TEST_F(VariantBuilderPrimitiveTest, IntNegative) { + VariantBuilder b; + ASSERT_OK(b.Int(-42)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int8(-42)"); +} + +TEST_F(VariantBuilderPrimitiveTest, ShortString) { + VariantBuilder b; + ASSERT_OK(b.String("hello")); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "String(\"hello\")"); +} + +TEST_F(VariantBuilderPrimitiveTest, LongString) { + std::string long_str(100, 'x'); + VariantBuilder b; + ASSERT_OK(b.String(long_str)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "String(\"" + long_str + "\")"); +} + +TEST_F(VariantBuilderPrimitiveTest, ShortStringBoundary63) { + std::string str63(63, 'a'); + VariantBuilder b; + ASSERT_OK(b.String(str63)); + ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); + // Should use short string encoding: 1 byte header + 63 bytes + ASSERT_EQ(result.value.size(), 64); +} + +TEST_F(VariantBuilderPrimitiveTest, LongStringBoundary64) { + std::string str64(64, 'a'); + VariantBuilder b; + ASSERT_OK(b.String(str64)); + ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); + // Should use long string encoding: 1 byte header + 4 byte length + 64 bytes + ASSERT_EQ(result.value.size(), 69); +} + +TEST_F(VariantBuilderPrimitiveTest, Date) { + VariantBuilder b; + ASSERT_OK(b.Date(19000)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Date(19000)"); +} + +TEST_F(VariantBuilderPrimitiveTest, Double) { + VariantBuilder b; + ASSERT_OK(b.Double(3.14)); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Double(") == 0); +} + +// =========================================================================== +// Array round-trip tests +// =========================================================================== + +class VariantBuilderArrayTest : public ::testing::Test {}; + +TEST_F(VariantBuilderArrayTest, EmptyArray) { + VariantBuilder b; + auto start = b.Offset(); + std::vector offsets; + ASSERT_OK(b.FinishArray(start, offsets)); + auto events = RoundTrip(b); + ASSERT_EQ(events.size(), 2); + ASSERT_EQ(events[0], "StartArray(0)"); + ASSERT_EQ(events[1], "EndArray"); +} + +TEST_F(VariantBuilderArrayTest, SimpleArray) { + VariantBuilder b; + auto start = b.Offset(); + std::vector offsets; + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(1)); + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(2)); + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(3)); + ASSERT_OK(b.FinishArray(start, offsets)); + + auto events = RoundTrip(b); + ASSERT_EQ(events.size(), 5); + ASSERT_EQ(events[0], "StartArray(3)"); + ASSERT_EQ(events[1], "Int8(1)"); + ASSERT_EQ(events[2], "Int8(2)"); + ASSERT_EQ(events[3], "Int8(3)"); + ASSERT_EQ(events[4], "EndArray"); +} + +TEST_F(VariantBuilderArrayTest, NestedArray) { + VariantBuilder b; + auto start = b.Offset(); + std::vector offsets; + + // First element: nested array [10, 20] + offsets.push_back(b.NextElement(start)); + auto inner_start = b.Offset(); + std::vector inner_offsets; + inner_offsets.push_back(b.NextElement(inner_start)); + ASSERT_OK(b.Int(10)); + inner_offsets.push_back(b.NextElement(inner_start)); + ASSERT_OK(b.Int(20)); + ASSERT_OK(b.FinishArray(inner_start, inner_offsets)); + + // Second element: 30 + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(30)); + + ASSERT_OK(b.FinishArray(start, offsets)); + + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "StartArray(2)"); + ASSERT_EQ(events[1], "StartArray(2)"); + ASSERT_EQ(events[2], "Int8(10)"); + ASSERT_EQ(events[3], "Int8(20)"); + ASSERT_EQ(events[4], "EndArray"); + ASSERT_EQ(events[5], "Int8(30)"); + ASSERT_EQ(events[6], "EndArray"); +} + +// =========================================================================== +// Object round-trip tests +// =========================================================================== + +class VariantBuilderObjectTest : public ::testing::Test {}; + +TEST_F(VariantBuilderObjectTest, EmptyObject) { + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + ASSERT_OK(b.FinishObject(start, fields)); + auto events = RoundTrip(b); + ASSERT_EQ(events.size(), 2); + ASSERT_EQ(events[0], "StartObject(0)"); + ASSERT_EQ(events[1], "EndObject"); +} + +TEST_F(VariantBuilderObjectTest, SimpleObject) { + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "name")); + ASSERT_OK(b.String("Alice")); + fields.push_back(b.NextField(start, "age")); + ASSERT_OK(b.Int(30)); + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + // Fields sorted by key: "age" before "name" + ASSERT_EQ(events[0], "StartObject(2)"); + ASSERT_EQ(events[1], "FieldName(\"age\")"); + ASSERT_EQ(events[2], "Int8(30)"); + ASSERT_EQ(events[3], "FieldName(\"name\")"); + ASSERT_EQ(events[4], "String(\"Alice\")"); + ASSERT_EQ(events[5], "EndObject"); +} + +TEST_F(VariantBuilderObjectTest, NestedObject) { + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "inner")); + { + auto inner_start = b.Offset(); + std::vector inner_fields; + inner_fields.push_back(b.NextField(inner_start, "key")); + ASSERT_OK(b.String("value")); + ASSERT_OK(b.FinishObject(inner_start, inner_fields)); + } + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "StartObject(1)"); + ASSERT_EQ(events[1], "FieldName(\"inner\")"); + ASSERT_EQ(events[2], "StartObject(1)"); + ASSERT_EQ(events[3], "FieldName(\"key\")"); + ASSERT_EQ(events[4], "String(\"value\")"); + ASSERT_EQ(events[5], "EndObject"); + ASSERT_EQ(events[6], "EndObject"); +} + +TEST_F(VariantBuilderObjectTest, DuplicateKeyError) { + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "key")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "key")); + ASSERT_OK(b.Int(2)); + ASSERT_RAISES(Invalid, b.FinishObject(start, fields)); +} + +TEST_F(VariantBuilderObjectTest, FieldsSortedByKey) { + // Insert fields in reverse order; verify they come out sorted + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "z_last")); + ASSERT_OK(b.Int(3)); + fields.push_back(b.NextField(start, "a_first")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "m_middle")); + ASSERT_OK(b.Int(2)); + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + ASSERT_EQ(events[1], "FieldName(\"a_first\")"); + ASSERT_EQ(events[2], "Int8(1)"); + ASSERT_EQ(events[3], "FieldName(\"m_middle\")"); + ASSERT_EQ(events[4], "Int8(2)"); + ASSERT_EQ(events[5], "FieldName(\"z_last\")"); + ASSERT_EQ(events[6], "Int8(3)"); +} + +// =========================================================================== +// Builder features +// =========================================================================== + +class VariantBuilderFeatureTest : public ::testing::Test {}; + +TEST_F(VariantBuilderFeatureTest, Reset) { + VariantBuilder b; + ASSERT_OK(b.Int(42)); + auto events1 = RoundTrip(b); + ASSERT_EQ(events1[0], "Int8(42)"); + + b.Reset(); + ASSERT_OK(b.String("hello")); + auto events2 = RoundTrip(b); + ASSERT_EQ(events2[0], "String(\"hello\")"); +} + +TEST_F(VariantBuilderFeatureTest, BuilderFromExistingMetadata) { + // First, build a variant to get metadata + VariantBuilder b1; + auto start = b1.Offset(); + std::vector fields; + fields.push_back(b1.NextField(start, "name")); + ASSERT_OK(b1.String("Alice")); + ASSERT_OK(b1.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded1, b1.Finish()); + + // Decode the metadata + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded1.metadata.data(), + static_cast(encoded1.metadata.size()))); + + // Build a new variant reusing the same metadata + VariantBuilder b2(meta); + auto start2 = b2.Offset(); + std::vector fields2; + fields2.push_back(b2.NextField(start2, "name")); + ASSERT_OK(b2.String("Bob")); + ASSERT_OK(b2.FinishObject(start2, fields2)); + + auto events = RoundTrip(b2); + ASSERT_EQ(events[1], "FieldName(\"name\")"); + ASSERT_EQ(events[2], "String(\"Bob\")"); +} + +TEST_F(VariantBuilderFeatureTest, MetadataSortedFlag) { + // If keys are inserted in sorted order, metadata should have sorted flag + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "alpha")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "beta")); + ASSERT_OK(b.Int(2)); + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + ASSERT_TRUE(meta.is_sorted); +} + +TEST_F(VariantBuilderFeatureTest, MetadataUnsortedFlag) { + // If keys are inserted out of order, sorted flag should be false + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "beta")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "alpha")); + ASSERT_OK(b.Int(2)); + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + ASSERT_FALSE(meta.is_sorted); +} + +// =========================================================================== +// Integration: full round-trip of complex structure +// =========================================================================== + +class VariantBuilderIntegrationTest : public ::testing::Test {}; + +TEST_F(VariantBuilderIntegrationTest, ComplexObject) { + // {"name": "Alice", "scores": [95, 87, 92], "active": true} + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + + fields.push_back(b.NextField(start, "name")); + ASSERT_OK(b.String("Alice")); + + fields.push_back(b.NextField(start, "scores")); + { + auto arr_start = b.Offset(); + std::vector arr_offsets; + arr_offsets.push_back(b.NextElement(arr_start)); + ASSERT_OK(b.Int(95)); + arr_offsets.push_back(b.NextElement(arr_start)); + ASSERT_OK(b.Int(87)); + arr_offsets.push_back(b.NextElement(arr_start)); + ASSERT_OK(b.Int(92)); + ASSERT_OK(b.FinishArray(arr_start, arr_offsets)); + } + + fields.push_back(b.NextField(start, "active")); + ASSERT_OK(b.Bool(true)); + + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + // Fields sorted: "active", "name", "scores" + ASSERT_EQ(events[0], "StartObject(3)"); + ASSERT_EQ(events[1], "FieldName(\"active\")"); + ASSERT_EQ(events[2], "Bool(true)"); + ASSERT_EQ(events[3], "FieldName(\"name\")"); + ASSERT_EQ(events[4], "String(\"Alice\")"); + ASSERT_EQ(events[5], "FieldName(\"scores\")"); + ASSERT_EQ(events[6], "StartArray(3)"); + ASSERT_EQ(events[7], "Int8(95)"); + ASSERT_EQ(events[8], "Int8(87)"); + ASSERT_EQ(events[9], "Int8(92)"); + ASSERT_EQ(events[10], "EndArray"); + ASSERT_EQ(events[11], "EndObject"); +} + +TEST_F(VariantBuilderIntegrationTest, LargeMetadataOffsetSize) { + // Build an object with enough unique keys to trigger 2-byte metadata offsets. + // 300 keys of ~4 chars each = ~1200 bytes total string data > 255. + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + for (int i = 0; i < 300; ++i) { + std::string key = "k" + std::to_string(i); + fields.push_back(b.NextField(start, key)); + ASSERT_OK(b.Int(i)); + } + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + // Verify metadata can be decoded + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + ASSERT_EQ(static_cast(meta.strings.size()), 300); + // offset_size should be >= 2 (total string data > 255 bytes) + ASSERT_GE(meta.offset_size, 2); + + // Verify value can be decoded + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(meta, encoded.value.data(), + static_cast(encoded.value.size()), &visitor)); + // StartObject(300) + 300*(FieldName + Int8) + EndObject = 602 events + ASSERT_EQ(visitor.events.size(), 602); + ASSERT_EQ(visitor.events[0], "StartObject(300)"); + ASSERT_EQ(visitor.events[601], "EndObject"); +} + +TEST_F(VariantBuilderIntegrationTest, MetadataOffsetSizeFromKeyCount) { + // Verify that offset_size is computed from max(total_string_size, num_keys). + // Use 260 single-character keys: total_string_size=260 (>255, needs 2 bytes) + // but num_keys=260 also exceeds 255. This ensures the formula handles both. + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + // Generate 260 unique 1-char keys using characters + numeric suffixes + for (int i = 0; i < 260; ++i) { + // Use 2-char keys to guarantee uniqueness: "a0" through "z9", then "A0"... + char c1 = (i < 260) ? static_cast('a' + (i / 10) % 26) : 'A'; + char c2 = static_cast('0' + (i % 10)); + std::string key = {c1, c2}; + fields.push_back(b.NextField(start, key)); + ASSERT_OK(b.Null()); + } + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + // 260 keys of 2 chars = 520 bytes total string data > 255, needs 2-byte offsets + ASSERT_GE(meta.offset_size, 2); + // Also verify num_keys is correctly stored + ASSERT_EQ(static_cast(meta.strings.size()), 260); + + // Verify round-trip + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(meta, encoded.value.data(), + static_cast(encoded.value.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "StartObject(260)"); +} + +TEST_F(VariantBuilderIntegrationTest, InvalidStartPosition) { + VariantBuilder b; + ASSERT_OK(b.Int(42)); + // start=999 is beyond the buffer — should fail + std::vector offsets; + ASSERT_RAISES(Invalid, b.FinishArray(999, offsets)); + + std::vector fields; + ASSERT_RAISES(Invalid, b.FinishObject(999, fields)); +} + +TEST_F(VariantBuilderIntegrationTest, NegativeArrayOffsetRejected) { + VariantBuilder b; + auto start = b.Offset(); + ASSERT_OK(b.Int(1)); + std::vector offsets = {-1}; + ASSERT_RAISES(Invalid, b.FinishArray(start, offsets)); +} + +// =========================================================================== +// Additional primitive round-trip tests (coverage gaps) +// =========================================================================== + +class VariantBuilderPrimitiveExtraTest : public ::testing::Test {}; + +TEST_F(VariantBuilderPrimitiveExtraTest, FloatRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.Float(2.5f)); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Float(") == 0); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, BinaryRoundTrip) { + std::string_view bin_data("\x00\x01\x02\x03", 4); + VariantBuilder b; + ASSERT_OK(b.Binary(bin_data)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Binary(len=4)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, EmptyBinaryRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.Binary("")); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Binary(len=0)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, UUIDRoundTrip) { + uint8_t uuid_bytes[16]; + for (int i = 0; i < 16; ++i) uuid_bytes[i] = static_cast(i + 1); + VariantBuilder b; + ASSERT_OK(b.UUID(uuid_bytes)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "UUID"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, TimestampMicrosRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.TimestampMicros(1654041600000000LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "TimestampMicros(1654041600000000)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, TimestampMicrosNTZRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.TimestampMicrosNTZ(1654041600000000LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "TimestampMicrosNTZ(1654041600000000)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, TimestampNanosRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.TimestampNanos(1654041600000000000LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "TimestampNanos(1654041600000000000)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, TimestampNanosNTZRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.TimestampNanosNTZ(1654041600000000000LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "TimestampNanosNTZ(1654041600000000000)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, TimeNTZRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.TimeNTZ(43200000000LL)); // 12:00:00 in microseconds + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "TimeNTZ(43200000000)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, Decimal4RoundTrip) { + int32_t val = 12345; + uint8_t bytes[4]; + std::memcpy(bytes, &val, 4); + VariantBuilder b; + ASSERT_OK(b.Decimal4(2, bytes)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Decimal4(scale=2)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, Decimal8RoundTrip) { + int64_t val = 123456789012345LL; + uint8_t bytes[8]; + std::memcpy(bytes, &val, 8); + VariantBuilder b; + ASSERT_OK(b.Decimal8(5, bytes)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Decimal8(scale=5)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, Decimal16RoundTrip) { + uint8_t bytes[16] = {}; + bytes[0] = 0x01; // value = 1 in low byte + VariantBuilder b; + ASSERT_OK(b.Decimal16(10, bytes)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Decimal16(scale=10)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, DecimalScaleValidation) { + uint8_t bytes[16] = {}; + VariantBuilder b; + // Scale 39 exceeds spec maximum of 38 + ASSERT_RAISES(Invalid, b.Decimal4(39, bytes)); + ASSERT_RAISES(Invalid, b.Decimal8(39, bytes)); + ASSERT_RAISES(Invalid, b.Decimal16(39, bytes)); + // Scale 38 is valid + ASSERT_OK(b.Decimal4(38, bytes)); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, EmptyString) { + VariantBuilder b; + ASSERT_OK(b.String("")); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "String(\"\")"); +} + +// =========================================================================== +// Special float/double values: NaN, ±Inf +// =========================================================================== + +class VariantBuilderSpecialFloatTest : public ::testing::Test {}; + +TEST_F(VariantBuilderSpecialFloatTest, FloatNaN) { + VariantBuilder b; + ASSERT_OK(b.Float(std::numeric_limits::quiet_NaN())); + ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); + // Verify it round-trips (NaN != NaN, so just check we get a Float event) + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(result.metadata.data(), + static_cast(result.metadata.size()))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, result.value.data(), + static_cast(result.value.size()), &visitor)); + ASSERT_TRUE(visitor.events[0].find("Float(") == 0); +} + +TEST_F(VariantBuilderSpecialFloatTest, FloatPositiveInf) { + VariantBuilder b; + ASSERT_OK(b.Float(std::numeric_limits::infinity())); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Float(") == 0); +} + +TEST_F(VariantBuilderSpecialFloatTest, FloatNegativeInf) { + VariantBuilder b; + ASSERT_OK(b.Float(-std::numeric_limits::infinity())); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Float(") == 0); +} + +TEST_F(VariantBuilderSpecialFloatTest, DoubleNaN) { + VariantBuilder b; + ASSERT_OK(b.Double(std::numeric_limits::quiet_NaN())); + ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(result.metadata.data(), + static_cast(result.metadata.size()))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, result.value.data(), + static_cast(result.value.size()), &visitor)); + ASSERT_TRUE(visitor.events[0].find("Double(") == 0); +} + +TEST_F(VariantBuilderSpecialFloatTest, DoublePositiveInf) { + VariantBuilder b; + ASSERT_OK(b.Double(std::numeric_limits::infinity())); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Double(") == 0); +} + +TEST_F(VariantBuilderSpecialFloatTest, DoubleNegativeInf) { + VariantBuilder b; + ASSERT_OK(b.Double(-std::numeric_limits::infinity())); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Double(") == 0); +} + +// =========================================================================== +// Int auto-sizing boundary tests +// =========================================================================== + +class VariantBuilderIntBoundaryTest : public ::testing::Test {}; + +TEST_F(VariantBuilderIntBoundaryTest, Int8MaxBecomesInt8) { + VariantBuilder b; + ASSERT_OK(b.Int(127)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int8(127)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int8MaxPlusOneBecomesInt16) { + VariantBuilder b; + ASSERT_OK(b.Int(128)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int16(128)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int8MinBecomesInt8) { + VariantBuilder b; + ASSERT_OK(b.Int(-128)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int8(-128)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int8MinMinusOneBecomesInt16) { + VariantBuilder b; + ASSERT_OK(b.Int(-129)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int16(-129)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int16MaxBecomesInt16) { + VariantBuilder b; + ASSERT_OK(b.Int(32767)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int16(32767)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int16MaxPlusOneBecomesInt32) { + VariantBuilder b; + ASSERT_OK(b.Int(32768)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int32(32768)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int32MaxBecomesInt32) { + VariantBuilder b; + ASSERT_OK(b.Int(2147483647LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int32(2147483647)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int32MaxPlusOneBecomesInt64) { + VariantBuilder b; + ASSERT_OK(b.Int(2147483648LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int64(2147483648)"); +} + +// =========================================================================== +// Large array round-trip (is_large flag) +// =========================================================================== + +class VariantBuilderLargeContainerTest : public ::testing::Test {}; + +TEST_F(VariantBuilderLargeContainerTest, LargeArrayIsLarge) { + // Build an array with 300 elements (>255) to trigger is_large=true. + // This exercises the same code path as the Go bug (apache/arrow-go#839). + VariantBuilder b; + auto start = b.Offset(); + std::vector offsets; + for (int i = 0; i < 300; ++i) { + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Null()); + } + ASSERT_OK(b.FinishArray(start, offsets)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + // Verify the header byte has is_large set correctly + ASSERT_FALSE(encoded.value.empty()); + uint8_t header = encoded.value[0]; + ASSERT_EQ(GetBasicType(header), BasicType::kArray); + // is_large at bit 4 of full byte + ASSERT_TRUE(((header >> 4) & 0x01) != 0); + + // Verify round-trip: decode and check element count + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(meta, encoded.value.data(), + static_cast(encoded.value.size()), &visitor)); + // StartArray(300) + 300 Nulls + EndArray = 302 events + ASSERT_EQ(visitor.events.size(), 302); + ASSERT_EQ(visitor.events[0], "StartArray(300)"); + ASSERT_EQ(visitor.events[301], "EndArray"); + + // Also verify ValueSize works correctly on this large array + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(encoded.value.data(), + static_cast(encoded.value.size()))); + ASSERT_EQ(size, static_cast(encoded.value.size())); +} + +TEST_F(VariantBuilderLargeContainerTest, LargeObjectIsLarge) { + // Build an object with 300 fields (>255) to trigger is_large=true. + // Verifies that the encoder correctly sets is_large at bit 6 of the + // full header byte (bit 4 of the 6-bit type_info / value_header). + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + for (int i = 0; i < 300; ++i) { + std::string key = "field_" + std::to_string(i); + fields.push_back(b.NextField(start, key)); + ASSERT_OK(b.Null()); + } + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + // Verify the header byte has is_large set correctly at bit 6 + ASSERT_FALSE(encoded.value.empty()); + uint8_t header = encoded.value[0]; + ASSERT_EQ(GetBasicType(header), BasicType::kObject); + // Object is_large at bit 6 of full byte (bit 4 of type_info) + ASSERT_TRUE(((header >> 6) & 0x01) != 0); + + // Verify round-trip: decode and check field count + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + ASSERT_OK_AND_ASSIGN(auto field_count, + GetObjectFieldCount(encoded.value.data(), + static_cast(encoded.value.size()))); + ASSERT_EQ(field_count, 300); + + // Verify full decode + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(meta, encoded.value.data(), + static_cast(encoded.value.size()), &visitor)); + // StartObject(300) + 300*(FieldName + Null) + EndObject = 602 events + ASSERT_EQ(visitor.events.size(), 602); + ASSERT_EQ(visitor.events[0], "StartObject(300)"); + ASSERT_EQ(visitor.events[601], "EndObject"); + + // Verify ValueSize matches buffer size + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(encoded.value.data(), + static_cast(encoded.value.size()))); + ASSERT_EQ(size, static_cast(encoded.value.size())); +} + +// =========================================================================== +// Decoder utility round-trips through builder output +// =========================================================================== + +class VariantBuilderDecoderUtilTest : public ::testing::Test {}; + +TEST_F(VariantBuilderDecoderUtilTest, FindObjectFieldOnBuilderOutput) { + // Build {alpha: 1, beta: "two", gamma: true} and verify FindObjectField works + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "alpha")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "beta")); + ASSERT_OK(b.String("two")); + fields.push_back(b.NextField(start, "gamma")); + ASSERT_OK(b.Bool(true)); + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + + // Find "beta" + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(meta, encoded.value.data(), + static_cast(encoded.value.size()), "beta", + &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); + ASSERT_GT(field_size, 0); + + // Decode the field value + RecordingVisitor v; + ASSERT_OK( + DecodeVariantValue(meta, encoded.value.data() + field_offset, field_size, &v)); + ASSERT_EQ(v.events[0], "String(\"two\")"); + + // Find non-existent key + int64_t nf_offset = -1, nf_size = 0; + ASSERT_OK(FindObjectField(meta, encoded.value.data(), + static_cast(encoded.value.size()), "missing", + &nf_offset, &nf_size)); + ASSERT_EQ(nf_offset, -1); +} + +TEST_F(VariantBuilderDecoderUtilTest, GetArrayElementOnBuilderOutput) { + // Build [10, 20, 30] and verify GetArrayElement works + VariantBuilder b; + auto start = b.Offset(); + std::vector offsets; + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(10)); + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(20)); + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(30)); + ASSERT_OK(b.FinishArray(start, offsets)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + + // Access element at index 2 + int64_t elem_offset = 0, elem_size = 0; + ASSERT_OK(GetArrayElement(encoded.value.data(), + static_cast(encoded.value.size()), 2, &elem_offset, + &elem_size)); + ASSERT_GT(elem_offset, 0); + ASSERT_EQ(elem_size, 2); // Int8(30) = 2 bytes + + RecordingVisitor v; + ASSERT_OK(DecodeVariantValue(meta, encoded.value.data() + elem_offset, elem_size, &v)); + ASSERT_EQ(v.events[0], "Int8(30)"); +} + +TEST_F(VariantBuilderDecoderUtilTest, GetObjectFieldAtOnBuilderOutput) { + // Build {x: 100, y: 200} and access by positional index + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "x")); + ASSERT_OK(b.Int(100)); + fields.push_back(b.NextField(start, "y")); + ASSERT_OK(b.Int(200)); + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + + // Fields are sorted by key: "x" at index 0, "y" at index 1 + std::string_view field_name; + int64_t field_offset = 0, field_size = 0; + ASSERT_OK(GetObjectFieldAt(meta, encoded.value.data(), + static_cast(encoded.value.size()), 0, &field_name, + &field_offset, &field_size)); + ASSERT_EQ(field_name, "x"); + + RecordingVisitor v; + ASSERT_OK( + DecodeVariantValue(meta, encoded.value.data() + field_offset, field_size, &v)); + ASSERT_EQ(v.events[0], "Int8(100)"); +} + +TEST_F(VariantBuilderDecoderUtilTest, ValueSizeOnBuilderOutput) { + // Build a nested structure and verify ValueSize matches buffer size + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "data")); + { + auto arr_start = b.Offset(); + std::vector arr_offsets; + arr_offsets.push_back(b.NextElement(arr_start)); + ASSERT_OK(b.String("hello")); + arr_offsets.push_back(b.NextElement(arr_start)); + ASSERT_OK(b.Int(42)); + ASSERT_OK(b.FinishArray(arr_start, arr_offsets)); + } + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + // ValueSize of the top-level value should equal the total buffer size + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(encoded.value.data(), + static_cast(encoded.value.size()))); + ASSERT_EQ(size, static_cast(encoded.value.size())); +} + +// =========================================================================== +// Direct integer type method tests (verify explicit types not auto-sized) +// =========================================================================== + +class VariantBuilderDirectIntTest : public ::testing::Test {}; + +TEST_F(VariantBuilderDirectIntTest, ExplicitInt8) { + VariantBuilder b; + ASSERT_OK(b.Int8(42)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int8(42)"); +} + +TEST_F(VariantBuilderDirectIntTest, ExplicitInt16) { + VariantBuilder b; + ASSERT_OK(b.Int16(42)); // Would be Int8 if auto-sized + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int16(42)"); +} + +TEST_F(VariantBuilderDirectIntTest, ExplicitInt32) { + VariantBuilder b; + ASSERT_OK(b.Int32(42)); // Would be Int8 if auto-sized + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int32(42)"); +} + +TEST_F(VariantBuilderDirectIntTest, ExplicitInt64) { + VariantBuilder b; + ASSERT_OK(b.Int64(42)); // Would be Int8 if auto-sized + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int64(42)"); +} + +// =========================================================================== +// Builder reuse: multiple Finish() calls with preserved dictionary +// =========================================================================== + +class VariantBuilderReuseTest : public ::testing::Test {}; + +TEST_F(VariantBuilderReuseTest, MultipleFinishPreservesDictionary) { + VariantBuilder b; + + // Build first value: {name: "Alice"} + auto start1 = b.Offset(); + std::vector fields1; + fields1.push_back(b.NextField(start1, "name")); + ASSERT_OK(b.String("Alice")); + ASSERT_OK(b.FinishObject(start1, fields1)); + ASSERT_OK_AND_ASSIGN(auto encoded1, b.Finish()); + + // Build second value: {name: "Bob"} — reuses dictionary from first build + auto start2 = b.Offset(); + std::vector fields2; + fields2.push_back(b.NextField(start2, "name")); + ASSERT_OK(b.String("Bob")); + ASSERT_OK(b.FinishObject(start2, fields2)); + ASSERT_OK_AND_ASSIGN(auto encoded2, b.Finish()); + + // Verify first value decodes correctly + ASSERT_OK_AND_ASSIGN(auto meta1, + DecodeMetadata(encoded1.metadata.data(), + static_cast(encoded1.metadata.size()))); + RecordingVisitor v1; + ASSERT_OK(DecodeVariantValue(meta1, encoded1.value.data(), + static_cast(encoded1.value.size()), &v1)); + ASSERT_EQ(v1.events[1], "FieldName(\"name\")"); + ASSERT_EQ(v1.events[2], "String(\"Alice\")"); + + // Verify second value decodes correctly + ASSERT_OK_AND_ASSIGN(auto meta2, + DecodeMetadata(encoded2.metadata.data(), + static_cast(encoded2.metadata.size()))); + RecordingVisitor v2; + ASSERT_OK(DecodeVariantValue(meta2, encoded2.value.data(), + static_cast(encoded2.value.size()), &v2)); + ASSERT_EQ(v2.events[1], "FieldName(\"name\")"); + ASSERT_EQ(v2.events[2], "String(\"Bob\")"); + + // Both should have the same dictionary content (same metadata structure) + ASSERT_EQ(meta1.strings.size(), meta2.strings.size()); + ASSERT_EQ(meta1.strings[0], "name"); + ASSERT_EQ(meta2.strings[0], "name"); +} + +TEST_F(VariantBuilderReuseTest, DictionaryGrowsAcrossFinishCalls) { + VariantBuilder b; + + // Build first value with key "x" + auto start1 = b.Offset(); + std::vector fields1; + fields1.push_back(b.NextField(start1, "x")); + ASSERT_OK(b.Int(1)); + ASSERT_OK(b.FinishObject(start1, fields1)); + ASSERT_OK_AND_ASSIGN(auto encoded1, b.Finish()); + + // Build second value with keys "x" and "y" — dictionary should grow + auto start2 = b.Offset(); + std::vector fields2; + fields2.push_back(b.NextField(start2, "x")); + ASSERT_OK(b.Int(2)); + fields2.push_back(b.NextField(start2, "y")); + ASSERT_OK(b.Int(3)); + ASSERT_OK(b.FinishObject(start2, fields2)); + ASSERT_OK_AND_ASSIGN(auto encoded2, b.Finish()); + + // First metadata has 1 key + ASSERT_OK_AND_ASSIGN(auto meta1, + DecodeMetadata(encoded1.metadata.data(), + static_cast(encoded1.metadata.size()))); + ASSERT_EQ(meta1.strings.size(), 1); + + // Second metadata has 2 keys (dictionary grew) + ASSERT_OK_AND_ASSIGN(auto meta2, + DecodeMetadata(encoded2.metadata.data(), + static_cast(encoded2.metadata.size()))); + ASSERT_EQ(meta2.strings.size(), 2); + + // Verify second value decodes correctly + RecordingVisitor v2; + ASSERT_OK(DecodeVariantValue(meta2, encoded2.value.data(), + static_cast(encoded2.value.size()), &v2)); + ASSERT_EQ(v2.events[0], "StartObject(2)"); + // Fields sorted: "x" before "y" + ASSERT_EQ(v2.events[1], "FieldName(\"x\")"); + ASSERT_EQ(v2.events[2], "Int8(2)"); + ASSERT_EQ(v2.events[3], "FieldName(\"y\")"); + ASSERT_EQ(v2.events[4], "Int8(3)"); +} + +// =========================================================================== +// Edge case: FinishObject/FinishArray with pre-existing buffer content +// =========================================================================== + +class VariantBuilderPreExistingBufferTest : public ::testing::Test {}; + +TEST_F(VariantBuilderPreExistingBufferTest, ObjectAfterPrimitive) { + // Write a primitive value first, then build an object. This exercises + // the case where start > 0 (data_size = buffer.size() - start). + // The builder is designed for single top-level values, but this tests + // the internal arithmetic correctness. + VariantBuilder b; + // Write a "prefix" value that occupies buffer space before our object + ASSERT_OK(b.Int(99)); + int64_t prefix_size = b.Offset(); // should be 2 (Int8 header + 1 byte) + ASSERT_EQ(prefix_size, 2); + + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "key")); + ASSERT_OK(b.String("val")); + ASSERT_OK(b.FinishObject(start, fields)); + + // The buffer now contains [Int8(99)] + [Object{key: "val"}]. + // We can't call Finish() meaningfully for a two-value buffer, + // but verify no crash or corruption occurred and the object portion + // is correctly sized. + ASSERT_GT(b.Offset(), prefix_size); +} + +TEST_F(VariantBuilderPreExistingBufferTest, ArrayAfterPrimitive) { + // Same as above but for arrays. + VariantBuilder b; + ASSERT_OK(b.Int(99)); + int64_t prefix_size = b.Offset(); + + auto start = b.Offset(); + std::vector offsets; + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Null()); + ASSERT_OK(b.FinishArray(start, offsets)); + + ASSERT_GT(b.Offset(), prefix_size); +} + +} // namespace arrow::extension::variant diff --git a/cpp/src/arrow/extension/variant_test.cc b/cpp/src/arrow/extension/variant_test.cc index 4a3c1187c45d..f4e791ee2099 100644 --- a/cpp/src/arrow/extension/variant_test.cc +++ b/cpp/src/arrow/extension/variant_test.cc @@ -2408,5 +2408,195 @@ TEST_F(VariantArrayViewTest, IterateElements) { } ASSERT_EQ(count, 3); } +// =========================================================================== +// Widening numeric accessor tests +// =========================================================================== + +class VariantCoercionTest : public ::testing::Test {}; + +TEST_F(VariantCoercionTest, Int8CoercesToInt64) { + VariantBuilder builder; + ASSERT_OK(builder.Int8(42)); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK_AND_ASSIGN( + auto view, + VariantView::Make(meta, enc.value.data(), static_cast(enc.value.size()))); + ASSERT_OK_AND_ASSIGN(auto val, view.as_int64_coerced()); + ASSERT_EQ(val, 42); +} + +TEST_F(VariantCoercionTest, Int16CoercesToInt64) { + VariantBuilder builder; + ASSERT_OK(builder.Int16(1000)); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK_AND_ASSIGN( + auto view, + VariantView::Make(meta, enc.value.data(), static_cast(enc.value.size()))); + ASSERT_OK_AND_ASSIGN(auto val, view.as_int64_coerced()); + ASSERT_EQ(val, 1000); +} + +TEST_F(VariantCoercionTest, Int32CoercesToInt64) { + VariantBuilder builder; + ASSERT_OK(builder.Int32(100000)); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK_AND_ASSIGN( + auto view, + VariantView::Make(meta, enc.value.data(), static_cast(enc.value.size()))); + ASSERT_OK_AND_ASSIGN(auto val, view.as_int64_coerced()); + ASSERT_EQ(val, 100000); +} + +TEST_F(VariantCoercionTest, Int64CoercesToInt64Identity) { + VariantBuilder builder; + ASSERT_OK(builder.Int64(9876543210LL)); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK_AND_ASSIGN( + auto view, + VariantView::Make(meta, enc.value.data(), static_cast(enc.value.size()))); + ASSERT_OK_AND_ASSIGN(auto val, view.as_int64_coerced()); + ASSERT_EQ(val, 9876543210LL); +} + +TEST_F(VariantCoercionTest, NegativeInt8CoercesToInt64) { + VariantBuilder builder; + ASSERT_OK(builder.Int8(-42)); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK_AND_ASSIGN( + auto view, + VariantView::Make(meta, enc.value.data(), static_cast(enc.value.size()))); + ASSERT_OK_AND_ASSIGN(auto val, view.as_int64_coerced()); + ASSERT_EQ(val, -42); +} + +TEST_F(VariantCoercionTest, StringDoesNotCoerceToInt64) { + VariantBuilder builder; + ASSERT_OK(builder.String("hello")); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK_AND_ASSIGN( + auto view, + VariantView::Make(meta, enc.value.data(), static_cast(enc.value.size()))); + ASSERT_NOT_OK(view.as_int64_coerced()); +} + +TEST_F(VariantCoercionTest, Int32CoercedRejectsInt64) { + VariantBuilder builder; + ASSERT_OK(builder.Int64(9876543210LL)); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK_AND_ASSIGN( + auto view, + VariantView::Make(meta, enc.value.data(), static_cast(enc.value.size()))); + ASSERT_NOT_OK(view.as_int32_coerced()); +} + +TEST_F(VariantCoercionTest, DoubleCoercedFromFloat) { + VariantBuilder builder; + ASSERT_OK(builder.Float(3.14f)); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK_AND_ASSIGN( + auto view, + VariantView::Make(meta, enc.value.data(), static_cast(enc.value.size()))); + ASSERT_OK_AND_ASSIGN(auto val, view.as_double_coerced()); + ASSERT_NEAR(val, 3.14, 0.001); +} + +TEST_F(VariantCoercionTest, DoubleCoercedFromInt32) { + VariantBuilder builder; + ASSERT_OK(builder.Int32(42)); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK_AND_ASSIGN( + auto view, + VariantView::Make(meta, enc.value.data(), static_cast(enc.value.size()))); + ASSERT_OK_AND_ASSIGN(auto val, view.as_double_coerced()); + ASSERT_EQ(val, 42.0); +} + +// =========================================================================== +// ValidateVariant tests +// =========================================================================== + +class VariantValidationTest : public ::testing::Test {}; + +TEST_F(VariantValidationTest, ValidPrimitive) { + VariantBuilder builder; + ASSERT_OK(builder.Int(42)); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK( + ValidateVariant(meta, enc.value.data(), static_cast(enc.value.size()))); +} + +TEST_F(VariantValidationTest, ValidNestedObject) { + VariantBuilder builder; + auto obj = builder.StartObject(); + ASSERT_OK(obj.Insert("name", std::string_view("Alice"))); + ASSERT_OK(obj.Insert("age", static_cast(30))); + auto inner = obj.InsertObject("address"); + ASSERT_OK(inner.Insert("city", std::string_view("NYC"))); + ASSERT_OK(inner.Finish()); + ASSERT_OK(obj.Finish()); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK( + ValidateVariant(meta, enc.value.data(), static_cast(enc.value.size()))); +} + +TEST_F(VariantValidationTest, ValidArray) { + VariantBuilder builder; + auto list = builder.StartList(); + ASSERT_OK(list.Append(static_cast(1))); + ASSERT_OK(list.Append(static_cast(2))); + ASSERT_OK(list.Append(static_cast(3))); + ASSERT_OK(list.Finish()); + ASSERT_OK_AND_ASSIGN(auto enc, builder.Finish()); + ASSERT_OK_AND_ASSIGN( + auto meta, + DecodeMetadata(enc.metadata.data(), static_cast(enc.metadata.size()))); + ASSERT_OK( + ValidateVariant(meta, enc.value.data(), static_cast(enc.value.size()))); +} + +TEST_F(VariantValidationTest, NullBuffer) { + ASSERT_NOT_OK(ValidateVariant(VariantMetadata{}, nullptr, 0)); +} + +TEST_F(VariantValidationTest, TruncatedPrimitive) { + // A valid Int32 header (type=5 << 2 | 0 = 20 = 0x14) but no payload + uint8_t data[] = {0x14}; + VariantMetadata meta; + meta.version = 1; + ASSERT_NOT_OK(ValidateVariant(meta, data, 1)); +} } // namespace arrow::extension::variant diff --git a/cpp/src/arrow/meson.build b/cpp/src/arrow/meson.build index 36ea0f615740..3e623d05403b 100644 --- a/cpp/src/arrow/meson.build +++ b/cpp/src/arrow/meson.build @@ -143,6 +143,7 @@ arrow_components = { 'extension/json.cc', 'extension/parquet_variant.cc', 'extension/variant.cc', + 'extension/variant_builder.cc', 'extension/uuid.cc', 'pretty_print.cc', 'record_batch.cc',