diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 321fa40ec3ae..32a582af04de 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -40,7 +40,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: 3.8 - name: Audit licenses @@ -51,7 +51,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 - - uses: actions/setup-node@v4 + - uses: actions/setup-node@v5 with: node-version: "14" - name: Prettier check diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 76ecd7d29a90..4d81716395b3 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -44,7 +44,7 @@ jobs: github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v5.0.0 + uses: actions/labeler@v6.0.1 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 624910a10e23..4eaf62d95de2 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -56,7 +56,7 @@ jobs: echo "::warning title=Invalid file permissions automatically fixed::$line" done - name: Upload artifacts - uses: actions/upload-pages-artifact@v3 + uses: actions/upload-pages-artifact@v4 with: name: crate-docs path: target/doc diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 4118c43db093..923da88eb580 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -63,6 +63,7 @@ jobs: ARROW_INTEGRATION_CPP: ON ARROW_INTEGRATION_CSHARP: ON ARCHERY_INTEGRATION_TARGET_IMPLEMENTATIONS: "rust" + ARCHERY_INTEGRATION_WITH_DOTNET: "1" ARCHERY_INTEGRATION_WITH_GO: "1" ARCHERY_INTEGRATION_WITH_JAVA: "1" ARCHERY_INTEGRATION_WITH_JS: "1" @@ -98,6 +99,11 @@ jobs: with: path: rust fetch-depth: 0 + - name: Checkout Arrow .NET + uses: actions/checkout@v5 + with: + repository: apache/arrow-dotnet + path: dotnet - name: Checkout Arrow Go uses: actions/checkout@v5 with: @@ -152,7 +158,7 @@ jobs: path: /home/runner/target # this key is not equal because maturin uses different compilation flags. key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: '3.8' - name: Upgrade pip and setuptools diff --git a/.github/workflows/miri.yaml b/.github/workflows/miri.yaml index dc398f5a8a32..92c432dc893b 100644 --- a/.github/workflows/miri.yaml +++ b/.github/workflows/miri.yaml @@ -52,12 +52,8 @@ jobs: submodules: true - name: Setup Rust toolchain run: | - # Temp pin to nightly-2025-08-18 until https://github.com/rust-lang/rust/issues/145652 is resolved - # See https://github.com/apache/arrow-rs/issues/8181 for more details - rustup toolchain install nightly-2025-08-18 --component miri - rustup override set nightly-2025-08-18 - # rustup toolchain install nightly --component miri - # rustup override set nightly + rustup toolchain install nightly --component miri + rustup override set nightly cargo miri setup - name: Run Miri Checks env: diff --git a/.github/workflows/parquet.yml b/.github/workflows/parquet.yml index 8a2301acd90c..126e4aa3a614 100644 --- a/.github/workflows/parquet.yml +++ b/.github/workflows/parquet.yml @@ -153,7 +153,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" cache: "pip" diff --git a/.github/workflows/take.yml b/.github/workflows/take.yml index dd21c794960e..94a95f6e31a2 100644 --- a/.github/workflows/take.yml +++ b/.github/workflows/take.yml @@ -28,7 +28,7 @@ jobs: if: (!github.event.issue.pull_request) && github.event.comment.body == 'take' runs-on: ubuntu-latest steps: - - uses: actions/github-script@v7 + - uses: actions/github-script@v8 with: script: | github.rest.issues.addAssignees({ diff --git a/arrow-array/Cargo.toml b/arrow-array/Cargo.toml index 8ebe21c70772..9fffe3b6bbe2 100644 --- a/arrow-array/Cargo.toml +++ b/arrow-array/Cargo.toml @@ -46,7 +46,7 @@ chrono = { workspace = true } chrono-tz = { version = "0.10", optional = true } num = { version = "0.4.1", default-features = false, features = ["std"] } half = { version = "2.1", default-features = false, features = ["num-traits"] } -hashbrown = { version = "0.15.1", default-features = false } +hashbrown = { version = "0.16.0", default-features = false } [package.metadata.docs.rs] all-features = true diff --git a/arrow-array/src/timezone.rs b/arrow-array/src/timezone.rs index b4df77deb4f5..bcf582152146 100644 --- a/arrow-array/src/timezone.rs +++ b/arrow-array/src/timezone.rs @@ -53,6 +53,7 @@ mod private { use super::*; use chrono::offset::TimeZone; use chrono::{LocalResult, NaiveDate, NaiveDateTime, Offset}; + use std::fmt::Display; use std::str::FromStr; /// An [`Offset`] for [`Tz`] @@ -97,6 +98,15 @@ mod private { } } + impl Display for Tz { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + TzInner::Timezone(tz) => tz.fmt(f), + TzInner::Offset(offset) => offset.fmt(f), + } + } + } + macro_rules! tz { ($s:ident, $tz:ident, $b:block) => { match $s.0 { @@ -228,6 +238,15 @@ mod private { sydney_offset_with_dst ); } + + #[test] + fn test_timezone_display() { + let test_cases = ["UTC", "America/Los_Angeles", "-08:00", "+05:30"]; + for &case in &test_cases { + let tz: Tz = case.parse().unwrap(); + assert_eq!(tz.to_string(), case); + } + } } } diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index 5cdef83a2d45..30c23e1932ae 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -40,6 +40,9 @@ default = ["deflate", "snappy", "zstd", "bzip2", "xz"] deflate = ["flate2"] snappy = ["snap", "crc"] canonical_extension_types = ["arrow-schema/canonical_extension_types"] +md5 = ["dep:md5"] +sha256 = ["dep:sha2"] +small_decimals = [] [dependencies] arrow-schema = { workspace = true } @@ -59,6 +62,8 @@ strum_macros = "0.27" uuid = "1.17" indexmap = "2.10" rand = "0.9" +md5 = { version = "0.8", optional = true } +sha2 = { version = "0.10", optional = true } [dev-dependencies] arrow-data = { workspace = true } @@ -73,7 +78,7 @@ arrow = { workspace = true } futures = "0.3.31" bytes = "1.10.1" async-stream = "0.3.6" -apache-avro = "0.14.0" +apache-avro = "0.20.0" num-bigint = "0.4" once_cell = "1.21.3" @@ -83,4 +88,8 @@ harness = false [[bench]] name = "decoder" +harness = false + +[[bench]] +name = "avro_writer" harness = false \ No newline at end of file diff --git a/arrow-avro/benches/avro_writer.rs b/arrow-avro/benches/avro_writer.rs new file mode 100644 index 000000000000..924cbbdc84bd --- /dev/null +++ b/arrow-avro/benches/avro_writer.rs @@ -0,0 +1,324 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `arrow‑avro` **Writer** (Avro Object Container Files) +//! + +extern crate arrow_avro; +extern crate criterion; +extern crate once_cell; + +use arrow_array::{ + types::{Int32Type, Int64Type, TimestampMicrosecondType}, + ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, PrimitiveArray, RecordBatch, +}; +use arrow_avro::writer::AvroWriter; +use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; +use once_cell::sync::Lazy; +use rand::{ + distr::uniform::{SampleRange, SampleUniform}, + rngs::StdRng, + Rng, SeedableRng, +}; +use std::io::Cursor; +use std::sync::Arc; +use std::time::Duration; +use tempfile::tempfile; + +const SIZES: [usize; 4] = [4_096, 8_192, 100_000, 1_000_000]; +const BASE_SEED: u64 = 0x5EED_1234_ABCD_EF01; +const MIX_CONST_1: u64 = 0x9E37_79B1_85EB_CA87; +const MIX_CONST_2: u64 = 0xC2B2_AE3D_27D4_EB4F; + +#[inline] +fn rng_for(tag: u64, n: usize) -> StdRng { + let seed = BASE_SEED ^ tag.wrapping_mul(MIX_CONST_1) ^ (n as u64).wrapping_mul(MIX_CONST_2); + StdRng::seed_from_u64(seed) +} + +#[inline] +fn sample_in(rng: &mut StdRng, range: Rg) -> T +where + T: SampleUniform, + Rg: SampleRange, +{ + rng.random_range(range) +} + +#[inline] +fn make_bool_array_with_tag(n: usize, tag: u64) -> BooleanArray { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random_bool(0.5)); + BooleanArray::from_iter(values.map(Some)) +} + +#[inline] +fn make_i32_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + PrimitiveArray::::from_iter_values(values) +} + +#[inline] +fn make_i64_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + PrimitiveArray::::from_iter_values(values) +} + +#[inline] +fn make_f32_array_with_tag(n: usize, tag: u64) -> Float32Array { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + Float32Array::from_iter_values(values) +} + +#[inline] +fn make_f64_array_with_tag(n: usize, tag: u64) -> Float64Array { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + Float64Array::from_iter_values(values) +} + +#[inline] +fn make_binary_array_with_tag(n: usize, tag: u64) -> BinaryArray { + let mut rng = rng_for(tag, n); + let mut payloads: Vec<[u8; 16]> = vec![[0; 16]; n]; + for p in payloads.iter_mut() { + rng.fill(&mut p[..]); + } + let views: Vec<&[u8]> = payloads.iter().map(|p| &p[..]).collect(); + BinaryArray::from_vec(views) +} + +#[inline] +fn make_ts_micros_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let base: i64 = 1_600_000_000_000_000; + let year_us: i64 = 31_536_000_000_000; + let values = (0..n).map(|_| base + sample_in::(&mut rng, 0..year_us)); + PrimitiveArray::::from_iter_values(values) +} + +#[inline] +fn make_bool_array(n: usize) -> BooleanArray { + make_bool_array_with_tag(n, 0xB001) +} +#[inline] +fn make_i32_array(n: usize) -> PrimitiveArray { + make_i32_array_with_tag(n, 0x1337_0032) +} +#[inline] +fn make_i64_array(n: usize) -> PrimitiveArray { + make_i64_array_with_tag(n, 0x1337_0064) +} +#[inline] +fn make_f32_array(n: usize) -> Float32Array { + make_f32_array_with_tag(n, 0xF0_0032) +} +#[inline] +fn make_f64_array(n: usize) -> Float64Array { + make_f64_array_with_tag(n, 0xF0_0064) +} +#[inline] +fn make_binary_array(n: usize) -> BinaryArray { + make_binary_array_with_tag(n, 0xB1_0001) +} +#[inline] +fn make_ts_micros_array(n: usize) -> PrimitiveArray { + make_ts_micros_array_with_tag(n, 0x7157_0001) +} + +#[inline] +fn schema_single(name: &str, dt: DataType) -> Arc { + Arc::new(Schema::new(vec![Field::new(name, dt, false)])) +} + +#[inline] +fn schema_mixed() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Int64, false), + Field::new("f3", DataType::Binary, false), + Field::new("f4", DataType::Float64, false), + ])) +} + +static BOOLEAN_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Boolean); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_bool_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static INT32_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Int32); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_i32_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static INT64_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Int64); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_i64_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static FLOAT32_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Float32); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_f32_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static FLOAT64_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Float64); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_f64_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static BINARY_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Binary); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_binary_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static TIMESTAMP_US_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Timestamp(TimeUnit::Microsecond, None)); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_ts_micros_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static MIXED_DATA: Lazy> = Lazy::new(|| { + let schema = schema_mixed(); + SIZES + .iter() + .map(|&n| { + let f1: ArrayRef = Arc::new(make_i32_array_with_tag(n, 0xA1)); + let f2: ArrayRef = Arc::new(make_i64_array_with_tag(n, 0xA2)); + let f3: ArrayRef = Arc::new(make_binary_array_with_tag(n, 0xA3)); + let f4: ArrayRef = Arc::new(make_f64_array_with_tag(n, 0xA4)); + RecordBatch::try_new(schema.clone(), vec![f1, f2, f3, f4]).unwrap() + }) + .collect() +}); + +fn ocf_size_for_batch(batch: &RecordBatch) -> usize { + let schema_owned: Schema = (*batch.schema()).clone(); + let cursor = Cursor::new(Vec::::with_capacity(1024)); + let mut writer = AvroWriter::new(cursor, schema_owned).expect("create writer"); + writer.write(batch).expect("write batch"); + writer.finish().expect("finish writer"); + let inner = writer.into_inner(); + inner.into_inner().len() +} + +fn bench_writer_scenario(c: &mut Criterion, name: &str, data_sets: &[RecordBatch]) { + let mut group = c.benchmark_group(name); + let schema_owned: Schema = (*data_sets[0].schema()).clone(); + for (idx, &rows) in SIZES.iter().enumerate() { + let batch = &data_sets[idx]; + let bytes = ocf_size_for_batch(batch); + group.throughput(Throughput::Bytes(bytes as u64)); + match rows { + 4_096 | 8_192 => { + group + .sample_size(40) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + 100_000 => { + group + .sample_size(20) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + 1_000_000 => { + group + .sample_size(10) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + _ => {} + } + group.bench_function(BenchmarkId::from_parameter(rows), |b| { + b.iter_batched_ref( + || { + let file = tempfile().expect("create temp file"); + AvroWriter::new(file, schema_owned.clone()).expect("create writer") + }, + |writer| { + writer.write(batch).unwrap(); + writer.finish().unwrap(); + }, + BatchSize::SmallInput, + ) + }); + } + group.finish(); +} + +fn criterion_benches(c: &mut Criterion) { + bench_writer_scenario(c, "write-Boolean", &BOOLEAN_DATA); + bench_writer_scenario(c, "write-Int32", &INT32_DATA); + bench_writer_scenario(c, "write-Int64", &INT64_DATA); + bench_writer_scenario(c, "write-Float32", &FLOAT32_DATA); + bench_writer_scenario(c, "write-Float64", &FLOAT64_DATA); + bench_writer_scenario(c, "write-Binary(Bytes)", &BINARY_DATA); + bench_writer_scenario(c, "write-TimestampMicros", &TIMESTAMP_US_DATA); + bench_writer_scenario(c, "write-Mixed", &MIXED_DATA); +} + +criterion_group! { + name = avro_writer; + config = Criterion::default().configure_from_args(); + targets = criterion_benches +} +criterion_main!(avro_writer); diff --git a/arrow-avro/benches/decoder.rs b/arrow-avro/benches/decoder.rs index df802daea154..0ca240d12fc9 100644 --- a/arrow-avro/benches/decoder.rs +++ b/arrow-avro/benches/decoder.rs @@ -27,19 +27,42 @@ extern crate uuid; use apache_avro::types::Value; use apache_avro::{to_avro_datum, Decimal, Schema as ApacheSchema}; -use arrow_avro::schema::{Fingerprint, SINGLE_OBJECT_MAGIC}; +use arrow_avro::schema::{Fingerprint, FingerprintAlgorithm, CONFLUENT_MAGIC, SINGLE_OBJECT_MAGIC}; use arrow_avro::{reader::ReaderBuilder, schema::AvroSchema}; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; use once_cell::sync::Lazy; use std::{hint::black_box, time::Duration}; use uuid::Uuid; -fn make_prefix(fp: Fingerprint) -> [u8; 10] { - let Fingerprint::Rabin(val) = fp; - let mut buf = [0u8; 10]; - buf[..2].copy_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 - buf[2..].copy_from_slice(&val.to_le_bytes()); // little‑endian 64‑bit - buf +fn make_prefix(fp: Fingerprint) -> Vec { + match fp { + Fingerprint::Rabin(val) => { + let mut buf = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + size_of::()); + buf.extend_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 + buf.extend_from_slice(&val.to_le_bytes()); // little-endian + buf + } + Fingerprint::Id(id) => { + let mut buf = Vec::with_capacity(CONFLUENT_MAGIC.len() + size_of::()); + buf.extend_from_slice(&CONFLUENT_MAGIC); // 00 + buf.extend_from_slice(&id.to_be_bytes()); // big-endian + buf + } + #[cfg(feature = "md5")] + Fingerprint::MD5(val) => { + let mut buf = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + size_of_val(&val)); + buf.extend_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 + buf.extend_from_slice(&val); + buf + } + #[cfg(feature = "sha256")] + Fingerprint::SHA256(val) => { + let mut buf = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + size_of_val(&val)); + buf.extend_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 + buf.extend_from_slice(&val); + buf + } + } } fn encode_records_with_prefix( @@ -336,6 +359,27 @@ fn new_decoder( .expect("failed to build decoder") } +fn new_decoder_id( + schema_json: &'static str, + batch_size: usize, + utf8view: bool, + id: u32, +) -> arrow_avro::reader::Decoder { + let schema = AvroSchema::new(schema_json.parse().unwrap()); + let mut store = arrow_avro::schema::SchemaStore::new_with_type(FingerprintAlgorithm::None); + // Register the schema with a provided Confluent-style ID + store + .set(Fingerprint::Id(id), schema.clone()) + .expect("failed to set schema with id"); + ReaderBuilder::new() + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id)) + .with_batch_size(batch_size) + .with_utf8_view(utf8view) + .build_decoder() + .expect("failed to build decoder for id") +} + const SIZES: [usize; 3] = [100, 10_000, 1_000_000]; const INT_SCHEMA: &str = @@ -373,7 +417,7 @@ macro_rules! dataset { static $name: Lazy>> = Lazy::new(|| { let schema = ApacheSchema::parse_str($schema_json).expect("invalid schema for generator"); - let arrow_schema = AvroSchema::new($schema_json.to_string()); + let arrow_schema = AvroSchema::new($schema_json.parse().unwrap()); let fingerprint = arrow_schema.fingerprint().expect("fingerprint failed"); let prefix = make_prefix(fingerprint); SIZES @@ -384,6 +428,24 @@ macro_rules! dataset { }; } +/// Additional helper for Confluent's ID-based wire format (00 + BE u32). +macro_rules! dataset_id { + ($name:ident, $schema_json:expr, $gen_fn:ident, $id:expr) => { + static $name: Lazy>> = Lazy::new(|| { + let schema = + ApacheSchema::parse_str($schema_json).expect("invalid schema for generator"); + let prefix = make_prefix(Fingerprint::Id($id)); + SIZES + .iter() + .map(|&n| $gen_fn(&schema, n, &prefix)) + .collect() + }); + }; +} + +const ID_BENCH_ID: u32 = 7; + +dataset_id!(INT_DATA_ID, INT_SCHEMA, gen_int, ID_BENCH_ID); dataset!(INT_DATA, INT_SCHEMA, gen_int); dataset!(LONG_DATA, LONG_SCHEMA, gen_long); dataset!(FLOAT_DATA, FLOAT_SCHEMA, gen_float); @@ -406,19 +468,20 @@ dataset!(ENUM_DATA, ENUM_SCHEMA, gen_enum); dataset!(MIX_DATA, MIX_SCHEMA, gen_mixed); dataset!(NEST_DATA, NEST_SCHEMA, gen_nested); -fn bench_scenario( +fn bench_with_decoder( c: &mut Criterion, name: &str, - schema_json: &'static str, data_sets: &[Vec], - utf8view: bool, - batch_size: usize, -) { + rows: &[usize], + mut new_decoder: F, +) where + F: FnMut() -> arrow_avro::reader::Decoder, +{ let mut group = c.benchmark_group(name); - for (idx, &rows) in SIZES.iter().enumerate() { + for (idx, &row_count) in rows.iter().enumerate() { let datum = &data_sets[idx]; group.throughput(Throughput::Bytes(datum.len() as u64)); - match rows { + match row_count { 10_000 => { group .sample_size(25) @@ -433,9 +496,9 @@ fn bench_scenario( } _ => {} } - group.bench_function(BenchmarkId::from_parameter(rows), |b| { + group.bench_function(BenchmarkId::from_parameter(row_count), |b| { b.iter_batched_ref( - || new_decoder(schema_json, batch_size, utf8view), + &mut new_decoder, |decoder| { black_box(decoder.decode(datum).unwrap()); black_box(decoder.flush().unwrap().unwrap()); @@ -449,105 +512,75 @@ fn bench_scenario( fn criterion_benches(c: &mut Criterion) { for &batch_size in &[SMALL_BATCH, LARGE_BATCH] { - bench_scenario( - c, - "Interval", - INTERVAL_SCHEMA, - &INTERVAL_DATA, - false, - batch_size, - ); - bench_scenario(c, "Int32", INT_SCHEMA, &INT_DATA, false, batch_size); - bench_scenario(c, "Int64", LONG_SCHEMA, &LONG_DATA, false, batch_size); - bench_scenario(c, "Float32", FLOAT_SCHEMA, &FLOAT_DATA, false, batch_size); - bench_scenario(c, "Boolean", BOOL_SCHEMA, &BOOL_DATA, false, batch_size); - bench_scenario(c, "Float64", DOUBLE_SCHEMA, &DOUBLE_DATA, false, batch_size); - bench_scenario( - c, - "Binary(Bytes)", - BYTES_SCHEMA, - &BYTES_DATA, - false, - batch_size, - ); - bench_scenario(c, "String", STRING_SCHEMA, &STRING_DATA, false, batch_size); - bench_scenario( - c, - "StringView", - STRING_SCHEMA, - &STRING_DATA, - true, - batch_size, - ); - bench_scenario(c, "Date32", DATE_SCHEMA, &DATE_DATA, false, batch_size); - bench_scenario( - c, - "TimeMillis", - TMILLIS_SCHEMA, - &TMILLIS_DATA, - false, - batch_size, - ); - bench_scenario( - c, - "TimeMicros", - TMICROS_SCHEMA, - &TMICROS_DATA, - false, - batch_size, - ); - bench_scenario( - c, - "TimestampMillis", - TSMILLIS_SCHEMA, - &TSMILLIS_DATA, - false, - batch_size, - ); - bench_scenario( - c, - "TimestampMicros", - TSMICROS_SCHEMA, - &TSMICROS_DATA, - false, - batch_size, - ); - bench_scenario(c, "Map", MAP_SCHEMA, &MAP_DATA, false, batch_size); - bench_scenario(c, "Array", ARRAY_SCHEMA, &ARRAY_DATA, false, batch_size); - bench_scenario( - c, - "Decimal128", - DECIMAL_SCHEMA, - &DECIMAL_DATA, - false, - batch_size, - ); - bench_scenario(c, "UUID", UUID_SCHEMA, &UUID_DATA, false, batch_size); - bench_scenario( - c, - "FixedSizeBinary", - FIXED_SCHEMA, - &FIXED_DATA, - false, - batch_size, - ); - bench_scenario( - c, - "Enum(Dictionary)", - ENUM_SCHEMA, - &ENUM_DATA, - false, - batch_size, - ); - bench_scenario(c, "Mixed", MIX_SCHEMA, &MIX_DATA, false, batch_size); - bench_scenario( - c, - "Nested(Struct)", - NEST_SCHEMA, - &NEST_DATA, - false, - batch_size, - ); + bench_with_decoder(c, "Interval", &INTERVAL_DATA, &SIZES, || { + new_decoder(INTERVAL_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Int32", &INT_DATA, &SIZES, || { + new_decoder(INT_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Int32_Id", &INT_DATA_ID, &SIZES, || { + new_decoder_id(INT_SCHEMA, batch_size, false, ID_BENCH_ID) + }); + bench_with_decoder(c, "Int64", &LONG_DATA, &SIZES, || { + new_decoder(LONG_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Float32", &FLOAT_DATA, &SIZES, || { + new_decoder(FLOAT_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Boolean", &BOOL_DATA, &SIZES, || { + new_decoder(BOOL_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Float64", &DOUBLE_DATA, &SIZES, || { + new_decoder(DOUBLE_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Binary(Bytes)", &BYTES_DATA, &SIZES, || { + new_decoder(BYTES_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "String", &STRING_DATA, &SIZES, || { + new_decoder(STRING_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "StringView", &STRING_DATA, &SIZES, || { + new_decoder(STRING_SCHEMA, batch_size, true) + }); + bench_with_decoder(c, "Date32", &DATE_DATA, &SIZES, || { + new_decoder(DATE_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimeMillis", &TMILLIS_DATA, &SIZES, || { + new_decoder(TMILLIS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimeMicros", &TMICROS_DATA, &SIZES, || { + new_decoder(TMICROS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimestampMillis", &TSMILLIS_DATA, &SIZES, || { + new_decoder(TSMILLIS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimestampMicros", &TSMICROS_DATA, &SIZES, || { + new_decoder(TSMICROS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Map", &MAP_DATA, &SIZES, || { + new_decoder(MAP_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Array", &ARRAY_DATA, &SIZES, || { + new_decoder(ARRAY_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Decimal128", &DECIMAL_DATA, &SIZES, || { + new_decoder(DECIMAL_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "UUID", &UUID_DATA, &SIZES, || { + new_decoder(UUID_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "FixedSizeBinary", &FIXED_DATA, &SIZES, || { + new_decoder(FIXED_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Enum(Dictionary)", &ENUM_DATA, &SIZES, || { + new_decoder(ENUM_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Mixed", &MIX_DATA, &SIZES, || { + new_decoder(MIX_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Nested(Struct)", &NEST_DATA, &SIZES, || { + new_decoder(NEST_SCHEMA, batch_size, false) + }); } } diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 89a66ddbaa85..0cac8c578680 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -16,31 +16,18 @@ // under the License. use crate::schema::{ - Attributes, AvroSchema, ComplexType, PrimitiveType, Record, Schema, Type, TypeName, - AVRO_ENUM_SYMBOLS_METADATA_KEY, + Attributes, AvroSchema, ComplexType, Enum, Nullability, PrimitiveType, Record, Schema, Type, + TypeName, AVRO_ENUM_SYMBOLS_METADATA_KEY, }; use arrow_schema::{ ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, - DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, }; -use serde_json::Value; -use std::borrow::Cow; +#[cfg(feature = "small_decimals")] +use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; use std::collections::HashMap; use std::sync::Arc; -/// Avro types are not nullable, with nullability instead encoded as a union -/// where one of the variants is the null type. -/// -/// To accommodate this we special case two-variant unions where one of the -/// variants is the null type, and use this to derive arrow's notion of nullability -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum Nullability { - /// The nulls are encoded as the first union variant - NullFirst, - /// The nulls are encoded as the second union variant - NullSecond, -} - /// Contains information about how to resolve differences between a writer's and a reader's schema. #[derive(Debug, Clone, PartialEq)] pub(crate) enum ResolutionInfo { @@ -48,7 +35,7 @@ pub(crate) enum ResolutionInfo { Promotion(Promotion), /// Indicates that a default value should be used for a field. (Implemented in a Follow-up PR) DefaultValue(AvroLiteral), - /// Provides mapping information for resolving enums. (Implemented in a Follow-up PR) + /// Provides mapping information for resolving enums. EnumMapping(EnumMapping), /// Provides resolution information for record fields. (Implemented in a Follow-up PR) Record(ResolvedRecord), @@ -401,7 +388,7 @@ pub enum Codec { /// Represents Avro fixed type, maps to Arrow's FixedSizeBinary data type /// The i32 parameter indicates the fixed binary size Fixed(i32), - /// Represents Avro decimal type, maps to Arrow's Decimal128 or Decimal256 data types + /// Represents Avro decimal type, maps to Arrow's Decimal32, Decimal64, Decimal128, or Decimal256 data types /// /// The fields are `(precision, scale, fixed_size)`. /// - `precision` (`usize`): Total number of digits. @@ -447,20 +434,28 @@ impl Codec { } Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano), Self::Fixed(size) => DataType::FixedSizeBinary(*size), - Self::Decimal(precision, scale, size) => { + Self::Decimal(precision, scale, _size) => { let p = *precision as u8; let s = scale.unwrap_or(0) as i8; - let too_large_for_128 = match *size { - Some(sz) => sz > 16, - None => { - (p as usize) > DECIMAL128_MAX_PRECISION as usize - || (s as usize) > DECIMAL128_MAX_SCALE as usize + #[cfg(feature = "small_decimals")] + { + if *precision <= DECIMAL32_MAX_PRECISION as usize { + DataType::Decimal32(p, s) + } else if *precision <= DECIMAL64_MAX_PRECISION as usize { + DataType::Decimal64(p, s) + } else if *precision <= DECIMAL128_MAX_PRECISION as usize { + DataType::Decimal128(p, s) + } else { + DataType::Decimal256(p, s) + } + } + #[cfg(not(feature = "small_decimals"))] + { + if *precision <= DECIMAL128_MAX_PRECISION as usize { + DataType::Decimal128(p, s) + } else { + DataType::Decimal256(p, s) } - }; - if too_large_for_128 { - DataType::Decimal256(p, s) - } else { - DataType::Decimal128(p, s) } } Self::Uuid => DataType::FixedSizeBinary(16), @@ -506,6 +501,29 @@ impl From for Codec { } } +/// Compute the exact maximum base‑10 precision that fits in `n` bytes for Avro +/// `fixed` decimals stored as two's‑complement unscaled integers (big‑endian). +/// +/// Per Avro spec (Decimal logical type), for a fixed length `n`: +/// max precision = ⌊log₁₀(2^(8n − 1) − 1)⌋. +/// +/// This function returns `None` if `n` is 0 or greater than 32 (Arrow supports +/// Decimal256, which is 32 bytes and has max precision 76). +const fn max_precision_for_fixed_bytes(n: usize) -> Option { + // Precomputed exact table for n = 1..=32 + // 1:2, 2:4, 3:6, 4:9, 5:11, 6:14, 7:16, 8:18, 9:21, 10:23, 11:26, 12:28, + // 13:31, 14:33, 15:35, 16:38, 17:40, 18:43, 19:45, 20:47, 21:50, 22:52, + // 23:55, 24:57, 25:59, 26:62, 27:64, 28:67, 29:69, 30:71, 31:74, 32:76 + const MAX_P: [usize; 32] = [ + 2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26, 28, 31, 33, 35, 38, 40, 43, 45, 47, 50, 52, 55, 57, + 59, 62, 64, 67, 69, 71, 74, 76, + ]; + match n { + 1..=32 => Some(MAX_P[n - 1]), + _ => None, + } +} + fn parse_decimal_attributes( attributes: &Attributes, fallback_size: Option, @@ -529,6 +547,34 @@ fn parse_decimal_attributes( .and_then(|v| v.as_u64()) .map(|s| s as usize) .or(fallback_size); + if precision == 0 { + return Err(ArrowError::ParseError( + "Decimal requires precision > 0".to_string(), + )); + } + if scale > precision { + return Err(ArrowError::ParseError(format!( + "Decimal has invalid scale > precision: scale={scale}, precision={precision}" + ))); + } + if precision > DECIMAL256_MAX_PRECISION as usize { + return Err(ArrowError::ParseError(format!( + "Decimal precision {precision} exceeds maximum supported by Arrow ({})", + DECIMAL256_MAX_PRECISION + ))); + } + if let Some(sz) = size { + let max_p = max_precision_for_fixed_bytes(sz).ok_or_else(|| { + ArrowError::ParseError(format!( + "Invalid fixed size for decimal: {sz}, must be between 1 and 32 bytes" + )) + })?; + if precision > max_p { + return Err(ArrowError::ParseError(format!( + "Decimal precision {precision} exceeds capacity of fixed size {sz} bytes (max {max_p})" + ))); + } + } Ok((precision, scale, size)) } @@ -587,6 +633,63 @@ impl<'a> Resolver<'a> { } } +fn names_match( + writer_name: &str, + writer_aliases: &[&str], + reader_name: &str, + reader_aliases: &[&str], +) -> bool { + writer_name == reader_name + || reader_aliases.contains(&writer_name) + || writer_aliases.contains(&reader_name) +} + +fn ensure_names_match( + data_type: &str, + writer_name: &str, + writer_aliases: &[&str], + reader_name: &str, + reader_aliases: &[&str], +) -> Result<(), ArrowError> { + if names_match(writer_name, writer_aliases, reader_name, reader_aliases) { + Ok(()) + } else { + Err(ArrowError::ParseError(format!( + "{data_type} name mismatch writer={writer_name}, reader={reader_name}" + ))) + } +} + +fn primitive_of(schema: &Schema) -> Option { + match schema { + Schema::TypeName(TypeName::Primitive(primitive)) => Some(*primitive), + Schema::Type(Type { + r#type: TypeName::Primitive(primitive), + .. + }) => Some(*primitive), + _ => None, + } +} + +fn nullable_union_variants<'x, 'y>( + variant: &'y [Schema<'x>], +) -> Option<(Nullability, &'y Schema<'x>)> { + if variant.len() != 2 { + return None; + } + let is_null = |schema: &Schema<'x>| { + matches!( + schema, + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)) + ) + }; + match (is_null(&variant[0]), is_null(&variant[1])) { + (true, false) => Some((Nullability::NullFirst, &variant[1])), + (false, true) => Some((Nullability::NullSecond, &variant[0])), + _ => None, + } +} + /// Resolves Avro type names to [`AvroDataType`] /// /// See @@ -690,7 +793,7 @@ impl<'a> Maker<'a> { Ok(field) } ComplexType::Array(a) => { - let mut field = self.parse_type(a.items.as_ref(), namespace)?; + let field = self.parse_type(a.items.as_ref(), namespace)?; Ok(AvroDataType { nullability: None, metadata: a.attributes.field_metadata(), @@ -815,41 +918,35 @@ impl<'a> Maker<'a> { reader_schema: &'s Schema<'a>, namespace: Option<&'a str>, ) -> Result { + if let (Some(write_primitive), Some(read_primitive)) = + (primitive_of(writer_schema), primitive_of(reader_schema)) + { + return self.resolve_primitives(write_primitive, read_primitive, reader_schema); + } match (writer_schema, reader_schema) { - ( - Schema::TypeName(TypeName::Primitive(writer_primitive)), - Schema::TypeName(TypeName::Primitive(reader_primitive)), - ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), - ( - Schema::Type(Type { - r#type: TypeName::Primitive(writer_primitive), - .. - }), - Schema::Type(Type { - r#type: TypeName::Primitive(reader_primitive), - .. - }), - ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), - ( - Schema::TypeName(TypeName::Primitive(writer_primitive)), - Schema::Type(Type { - r#type: TypeName::Primitive(reader_primitive), - .. - }), - ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), - ( - Schema::Type(Type { - r#type: TypeName::Primitive(writer_primitive), - .. - }), - Schema::TypeName(TypeName::Primitive(reader_primitive)), - ) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema), ( Schema::Complex(ComplexType::Record(writer_record)), Schema::Complex(ComplexType::Record(reader_record)), ) => self.resolve_records(writer_record, reader_record, namespace), - (Schema::Union(writer_variants), Schema::Union(reader_variants)) => { - self.resolve_nullable_union(writer_variants, reader_variants, namespace) + ( + Schema::Complex(ComplexType::Enum(writer_enum)), + Schema::Complex(ComplexType::Enum(reader_enum)), + ) => self.resolve_enums(writer_enum, reader_enum, reader_schema, namespace), + (Schema::Union(writer_variants), Schema::Union(reader_variants)) => self + .resolve_nullable_union( + writer_variants.as_slice(), + reader_variants.as_slice(), + namespace, + ), + (Schema::TypeName(TypeName::Ref(_)), _) => self.parse_type(reader_schema, namespace), + (_, Schema::TypeName(TypeName::Ref(_))) => self.parse_type(reader_schema, namespace), + // if both sides are the same complex kind (non-record), adopt the reader type. + // This aligns with Avro spec: arrays, maps, and enums resolve recursively; + // for identical shapes we can just parse the reader schema. + (Schema::Complex(ComplexType::Array(_)), Schema::Complex(ComplexType::Array(_))) + | (Schema::Complex(ComplexType::Map(_)), Schema::Complex(ComplexType::Map(_))) + | (Schema::Complex(ComplexType::Fixed(_)), Schema::Complex(ComplexType::Fixed(_))) => { + self.parse_type(reader_schema, namespace) } _ => Err(ArrowError::NotYetImplemented( "Other resolutions not yet implemented".to_string(), @@ -886,64 +983,156 @@ impl<'a> Maker<'a> { Ok(datatype) } - fn resolve_nullable_union( + fn resolve_nullable_union<'s>( &mut self, - writer_variants: &[Schema<'a>], - reader_variants: &[Schema<'a>], + writer_variants: &'s [Schema<'a>], + reader_variants: &'s [Schema<'a>], namespace: Option<&'a str>, ) -> Result { - // Only support unions with exactly two branches, one of which is `null` on both sides - if writer_variants.len() != 2 || reader_variants.len() != 2 { - return Err(ArrowError::NotYetImplemented( - "Only 2-branch unions are supported for schema resolution".to_string(), - )); - } - let is_null = |s: &Schema<'a>| { - matches!( - s, - Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)) - ) - }; - let w_null_pos = writer_variants.iter().position(is_null); - let r_null_pos = reader_variants.iter().position(is_null); - match (w_null_pos, r_null_pos) { - (Some(wp), Some(rp)) => { - // Extract a non-null branch on each side - let w_nonnull = &writer_variants[1 - wp]; - let r_nonnull = &reader_variants[1 - rp]; - // Resolve the non-null branch - let mut dt = self.make_data_type(w_nonnull, Some(r_nonnull), namespace)?; + match ( + nullable_union_variants(writer_variants), + nullable_union_variants(reader_variants), + ) { + (Some((_, write_nonnull)), Some((read_nb, read_nonnull))) => { + let mut dt = self.make_data_type(write_nonnull, Some(read_nonnull), namespace)?; // Adopt reader union null ordering - dt.nullability = Some(match rp { - 0 => Nullability::NullFirst, - 1 => Nullability::NullSecond, - _ => unreachable!(), - }); + dt.nullability = Some(read_nb); Ok(dt) } _ => Err(ArrowError::NotYetImplemented( - "Union resolution requires both writer and reader to be nullable unions" + "Union resolution requires both writer and reader to be 2-branch nullable unions" .to_string(), )), } } + // Resolve writer vs. reader enum schemas according to Avro 1.11.1. + // + // # How enums resolve (writer to reader) + // Per “Schema Resolution”: + // * The two schemas must refer to the same (unqualified) enum name (or match + // via alias rewriting). + // * If the writer’s symbol is not present in the reader’s enum and the reader + // enum has a `default`, that `default` symbol must be used; otherwise, + // error. + // https://avro.apache.org/docs/1.11.1/specification/#schema-resolution + // * Avro “Aliases” are applied from the reader side to rewrite the writer’s + // names during resolution. For robustness across ecosystems, we also accept + // symmetry here (see note below). + // https://avro.apache.org/docs/1.11.1/specification/#aliases + // + // # Rationale for this code path + // 1. Do the work once at schema‑resolution time. Avro serializes an enum as a + // writer‑side position. Mapping positions on the hot decoder path is expensive + // if done with string lookups. This method builds a `writer_index to reader_index` + // vector once, so decoding just does an O(1) table lookup. + // 2. Adopt the reader’s symbol set and order. We return an Arrow + // `Dictionary(Int32, Utf8)` whose dictionary values are the reader enum + // symbols. This makes downstream semantics match the reader schema, including + // Avro’s sort order rule that orders enums by symbol position in the schema. + // https://avro.apache.org/docs/1.11.1/specification/#sort-order + // 3. Honor Avro’s `default` for enums. Avro 1.9+ allows a type‑level default + // on the enum. When the writer emits a symbol unknown to the reader, we map it + // to the reader’s validated `default` symbol if present; otherwise we signal an + // error at decoding time. + // https://avro.apache.org/docs/1.11.1/specification/#enums + // + // # Implementation notes + // * We first check that enum names match or are*alias‑equivalent. The Avro + // spec describes alias rewriting using reader aliases; this implementation + // additionally treats writer aliases as acceptable for name matching to be + // resilient with schemas produced by different tooling. + // * We build `EnumMapping`: + // - `mapping[i]` = reader index of the writer symbol at writer index `i`. + // - If the writer symbol is absent and the reader has a default, we store the + // reader index of that default. + // - Otherwise we store `-1` as a sentinel meaning unresolvable; the decoder + // must treat encountering such a value as an error, per the spec. + // * We persist the reader symbol list in field metadata under + // `AVRO_ENUM_SYMBOLS_METADATA_KEY`, so consumers can inspect the dictionary + // without needing the original Avro schema. + // * The Arrow representation is `Dictionary(Int32, Utf8)`, which aligns with + // Avro’s integer index encoding for enums. + // + // # Examples + // * Writer `["A","B","C"]`, Reader `["A","B"]`, Reader default `"A"` + // `mapping = [0, 1, 0]`, `default_index = 0`. + // * Writer `["A","B"]`, Reader `["B","A"]` (no default) + // `mapping = [1, 0]`, `default_index = -1`. + // * Writer `["A","B","C"]`, Reader `["A","B"]` (no default) + // `mapping = [0, 1, -1]` (decode must error on `"C"`). + fn resolve_enums( + &mut self, + writer_enum: &Enum<'a>, + reader_enum: &Enum<'a>, + reader_schema: &Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + ensure_names_match( + "Enum", + writer_enum.name, + &writer_enum.aliases, + reader_enum.name, + &reader_enum.aliases, + )?; + if writer_enum.symbols == reader_enum.symbols { + return self.parse_type(reader_schema, namespace); + } + let reader_index: HashMap<&str, i32> = reader_enum + .symbols + .iter() + .enumerate() + .map(|(index, &symbol)| (symbol, index as i32)) + .collect(); + let default_index: i32 = match reader_enum.default { + Some(symbol) => *reader_index.get(symbol).ok_or_else(|| { + ArrowError::SchemaError(format!( + "Reader enum '{}' default symbol '{symbol}' not found in symbols list", + reader_enum.name, + )) + })?, + None => -1, + }; + let mapping: Vec = writer_enum + .symbols + .iter() + .map(|&write_symbol| { + reader_index + .get(write_symbol) + .copied() + .unwrap_or(default_index) + }) + .collect(); + if self.strict_mode && mapping.iter().any(|&m| m < 0) { + return Err(ArrowError::SchemaError(format!( + "Reader enum '{}' does not cover all writer symbols and no default is provided", + reader_enum.name + ))); + } + let mut dt = self.parse_type(reader_schema, namespace)?; + dt.resolution = Some(ResolutionInfo::EnumMapping(EnumMapping { + mapping: Arc::from(mapping), + default_index, + })); + let reader_ns = reader_enum.namespace.or(namespace); + self.resolver + .register(reader_enum.name, reader_ns, dt.clone()); + Ok(dt) + } + fn resolve_records( &mut self, writer_record: &Record<'a>, reader_record: &Record<'a>, namespace: Option<&'a str>, ) -> Result { - // Names must match or be aliased - let names_match = writer_record.name == reader_record.name - || reader_record.aliases.contains(&writer_record.name) - || writer_record.aliases.contains(&reader_record.name); - if !names_match { - return Err(ArrowError::ParseError(format!( - "Record name mismatch writer={}, reader={}", - writer_record.name, reader_record.name - ))); - } + ensure_names_match( + "Record", + writer_record.name, + &writer_record.aliases, + reader_record.name, + &reader_record.aliases, + )?; let writer_ns = writer_record.namespace.or(namespace); let reader_ns = reader_record.namespace.or(namespace); // Map writer field name -> index @@ -955,12 +1144,12 @@ impl<'a> Maker<'a> { // Prepare outputs let mut reader_fields: Vec = Vec::with_capacity(reader_record.fields.len()); let mut writer_to_reader: Vec> = vec![None; writer_record.fields.len()]; - //let mut skip_fields: Vec> = vec![None; writer_record.fields.len()]; + let mut skip_fields: Vec> = vec![None; writer_record.fields.len()]; //let mut default_fields: Vec = Vec::new(); // Build reader fields and mapping for (reader_idx, r_field) in reader_record.fields.iter().enumerate() { if let Some(&writer_idx) = writer_index_map.get(r_field.name) { - // Field exists in writer: resolve types (including promotions and union-of-null) + // Field exists in a writer: resolve types (including promotions and union-of-null) let w_schema = &writer_record.fields[writer_idx].r#type; let resolved_dt = self.make_data_type(w_schema, Some(&r_field.r#type), reader_ns)?; @@ -975,6 +1164,14 @@ impl<'a> Maker<'a> { )); } } + // Any writer fields not mapped should be skipped + for (writer_idx, writer_field) in writer_record.fields.iter().enumerate() { + if writer_to_reader[writer_idx].is_none() { + // Parse writer field type to know how to skip data + let writer_dt = self.parse_type(&writer_field.r#type, writer_ns)?; + skip_fields[writer_idx] = Some(writer_dt); + } + } // Implement writer-only fields to skip in Follow-up PR here // Build resolved record AvroDataType let resolved = AvroDataType::new_with_resolution( @@ -984,7 +1181,7 @@ impl<'a> Maker<'a> { Some(ResolutionInfo::Record(ResolvedRecord { writer_to_reader: Arc::from(writer_to_reader), default_fields: Arc::default(), - skip_fields: Arc::default(), + skip_fields: Arc::from(skip_fields), })), ); // Register a resolved record by reader name+namespace for potential named type refs diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 3f2daff0a3b1..13e0f07b4544 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -91,8 +91,8 @@ //! use crate::codec::{AvroField, AvroFieldBuilder}; use crate::schema::{ - compare_schemas, generate_fingerprint, AvroSchema, Fingerprint, FingerprintAlgorithm, Schema, - SchemaStore, SINGLE_OBJECT_MAGIC, + compare_schemas, AvroSchema, Fingerprint, FingerprintAlgorithm, Schema, SchemaStore, + CONFLUENT_MAGIC, SINGLE_OBJECT_MAGIC, }; use arrow_array::{Array, RecordBatch, RecordBatchReader}; use arrow_schema::{ArrowError, SchemaRef}; @@ -185,7 +185,7 @@ impl Decoder { }; } match self.handle_prefix(&data[total_consumed..])? { - Some(0) => break, // insufficient bytes + Some(0) => break, // Insufficient bytes Some(n) => { total_consumed += n; self.apply_pending_schema_if_batch_empty(); @@ -201,31 +201,60 @@ impl Decoder { Ok(total_consumed) } - // Attempt to handle a single‑object‑encoding prefix at the current position. - // + // Attempt to handle a prefix at the current position. // * Ok(None) – buffer does not start with the prefix. // * Ok(Some(0)) – prefix detected, but the buffer is too short; caller should await more bytes. // * Ok(Some(n)) – consumed `n > 0` bytes of a complete prefix (magic and fingerprint). fn handle_prefix(&mut self, buf: &[u8]) -> Result, ArrowError> { - // Need at least the magic bytes to decide (2 bytes). - let Some(magic_bytes) = buf.get(..SINGLE_OBJECT_MAGIC.len()) else { - return Ok(Some(0)); // Get more bytes - }; + match self.fingerprint_algorithm { + FingerprintAlgorithm::Rabin => { + self.handle_prefix_common(buf, &SINGLE_OBJECT_MAGIC, |bytes| { + Fingerprint::Rabin(u64::from_le_bytes(bytes)) + }) + } + FingerprintAlgorithm::None => { + self.handle_prefix_common(buf, &CONFLUENT_MAGIC, |bytes| { + Fingerprint::Id(u32::from_be_bytes(bytes)) + }) + } + #[cfg(feature = "md5")] + FingerprintAlgorithm::MD5 => { + self.handle_prefix_common(buf, &SINGLE_OBJECT_MAGIC, |bytes| { + Fingerprint::MD5(bytes) + }) + } + #[cfg(feature = "sha256")] + FingerprintAlgorithm::SHA256 => { + self.handle_prefix_common(buf, &SINGLE_OBJECT_MAGIC, |bytes| { + Fingerprint::SHA256(bytes) + }) + } + } + } + + /// This method checks for the provided `magic` bytes at the start of `buf` and, if present, + /// attempts to read the following fingerprint of `N` bytes, converting it to a + /// [`Fingerprint`] using `fingerprint_from`. + fn handle_prefix_common( + &mut self, + buf: &[u8], + magic: &[u8; MAGIC_LEN], + fingerprint_from: impl FnOnce([u8; N]) -> Fingerprint, + ) -> Result, ArrowError> { + // Need at least the magic bytes to decide + // 2 bytes for Avro Spec and 1 byte for Confluent Wire Protocol. + if buf.len() < MAGIC_LEN { + return Ok(Some(0)); + } // Bail out early if the magic does not match. - if magic_bytes != SINGLE_OBJECT_MAGIC { - return Ok(None); // Continue to decode the next record + if &buf[..MAGIC_LEN] != magic { + return Ok(None); } // Try to parse the fingerprint that follows the magic. - let fingerprint_size = match self.fingerprint_algorithm { - FingerprintAlgorithm::Rabin => self - .handle_fingerprint(&buf[SINGLE_OBJECT_MAGIC.len()..], |bytes| { - Fingerprint::Rabin(u64::from_le_bytes(bytes)) - })?, - }; + let consumed_fp = self.handle_fingerprint(&buf[MAGIC_LEN..], fingerprint_from)?; // Convert the inner result into a “bytes consumed” count. // NOTE: Incomplete fingerprint consumes no bytes. - let consumed = fingerprint_size.map_or(0, |n| n + SINGLE_OBJECT_MAGIC.len()); - Ok(Some(consumed)) + Ok(Some(consumed_fp.map_or(0, |n| n + MAGIC_LEN))) } // Attempts to read and install a new fingerprint of `N` bytes. @@ -239,7 +268,7 @@ impl Decoder { ) -> Result, ArrowError> { // Need enough bytes to get fingerprint (next N bytes) let Some(fingerprint_bytes) = buf.get(..N) else { - return Ok(None); // Insufficient bytes + return Ok(None); // insufficient bytes }; // SAFETY: length checked above. let new_fingerprint = fingerprint_from(fingerprint_bytes.try_into().unwrap()); @@ -658,7 +687,7 @@ mod test { use crate::reader::{read_header, Decoder, Reader, ReaderBuilder}; use crate::schema::{ AvroSchema, Fingerprint, FingerprintAlgorithm, PrimitiveType, Schema as AvroRaw, - SchemaStore, AVRO_ENUM_SYMBOLS_METADATA_KEY, SINGLE_OBJECT_MAGIC, + SchemaStore, AVRO_ENUM_SYMBOLS_METADATA_KEY, CONFLUENT_MAGIC, SINGLE_OBJECT_MAGIC, }; use crate::test_util::arrow_test_data; use arrow::array::ArrayDataBuilder; @@ -668,7 +697,7 @@ mod test { }; use arrow_array::types::{Int32Type, IntervalMonthDayNanoType}; use arrow_array::*; - use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow_buffer::{i256, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema}; use bytes::{Buf, BufMut, Bytes}; use futures::executor::block_on; @@ -760,6 +789,17 @@ mod test { out.extend_from_slice(&v.to_le_bytes()); out } + Fingerprint::Id(v) => { + panic!("make_prefix expects a Rabin fingerprint, got ({v})"); + } + #[cfg(feature = "md5")] + Fingerprint::MD5(v) => { + panic!("make_prefix expects a Rabin fingerprint, got ({v:?})"); + } + #[cfg(feature = "sha256")] + Fingerprint::SHA256(id) => { + panic!("make_prefix expects a Rabin fingerprint, got ({id:?})"); + } } } @@ -773,6 +813,21 @@ mod test { .expect("decoder") } + fn make_id_prefix(id: u32, additional: usize) -> Vec { + let capacity = CONFLUENT_MAGIC.len() + size_of::() + additional; + let mut out = Vec::with_capacity(capacity); + out.extend_from_slice(&CONFLUENT_MAGIC); + out.extend_from_slice(&id.to_be_bytes()); + out + } + + fn make_message_id(id: u32, value: i64) -> Vec { + let encoded_value = encode_zigzag(value); + let mut msg = make_id_prefix(id, encoded_value.len()); + msg.extend_from_slice(&encoded_value); + msg + } + fn make_value_schema(pt: PrimitiveType) -> AvroSchema { let json_schema = format!( r#"{{"type":"record","name":"S","fields":[{{"name":"v","type":"{}"}}]}}"#, @@ -855,6 +910,53 @@ mod test { AvroSchema::new(root.to_string()) } + fn make_reader_schema_with_enum_remap( + path: &str, + remap: &HashMap<&str, Vec<&str>>, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); + + fn to_symbols_array(symbols: &[&str]) -> Value { + Value::Array(symbols.iter().map(|s| Value::String((*s).into())).collect()) + } + + fn update_enum_symbols(ty: &mut Value, symbols: &Value) { + match ty { + Value::Object(map) => { + if matches!(map.get("type"), Some(Value::String(t)) if t == "enum") { + map.insert("symbols".to_string(), symbols.clone()); + } + } + Value::Array(arr) => { + for b in arr.iter_mut() { + if let Value::Object(map) = b { + if matches!(map.get("type"), Some(Value::String(t)) if t == "enum") { + map.insert("symbols".to_string(), symbols.clone()); + } + } + } + } + _ => {} + } + } + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + if let Some(new_symbols) = remap.get(name) { + let symbols_val = to_symbols_array(new_symbols); + let ty = f.get_mut("type").expect("field has a type"); + update_enum_symbols(ty, &symbols_val); + } + } + AvroSchema::new(root.to_string()) + } + fn read_alltypes_with_reader_schema(path: &str, reader_schema: AvroSchema) -> RecordBatch { let file = File::open(path).unwrap(); let reader = ReaderBuilder::new() @@ -863,12 +965,39 @@ mod test { .with_reader_schema(reader_schema) .build(BufReader::new(file)) .unwrap(); - let schema = reader.schema(); let batches = reader.collect::, _>>().unwrap(); arrow::compute::concat_batches(&schema, &batches).unwrap() } + fn make_reader_schema_with_selected_fields_in_order( + path: &str, + selected: &[&str], + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let writer_fields = root + .get("fields") + .and_then(|f| f.as_array()) + .expect("record has fields"); + let mut field_map: HashMap = HashMap::with_capacity(writer_fields.len()); + for f in writer_fields { + if let Some(name) = f.get("name").and_then(|n| n.as_str()) { + field_map.insert(name.to_string(), f.clone()); + } + } + let mut new_fields = Vec::with_capacity(selected.len()); + for name in selected { + let f = field_map + .get(*name) + .unwrap_or_else(|| panic!("field '{name}' not found in writer schema")) + .clone(); + new_fields.push(f); + } + root["fields"] = Value::Array(new_fields); + AvroSchema::new(root.to_string()) + } + #[test] fn test_alltypes_schema_promotion_mixed() { let files = [ @@ -1207,6 +1336,52 @@ mod test { ); } + #[test] + fn test_simple_enum_with_reader_schema_mapping() { + let file = arrow_test_data("avro/simple_enum.avro"); + let mut remap: HashMap<&str, Vec<&str>> = HashMap::new(); + remap.insert("f1", vec!["d", "c", "b", "a"]); + remap.insert("f2", vec!["h", "g", "f", "e"]); + remap.insert("f3", vec!["k", "i", "j"]); + let reader_schema = make_reader_schema_with_enum_remap(&file, &remap); + let actual = read_alltypes_with_reader_schema(&file, reader_schema); + let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let f1_keys = Int32Array::from(vec![3, 2, 1, 0]); + let f1_vals = StringArray::from(vec!["d", "c", "b", "a"]); + let f1 = DictionaryArray::::try_new(f1_keys, Arc::new(f1_vals)).unwrap(); + let mut md_f1 = HashMap::new(); + md_f1.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["d","c","b","a"]"#.to_string(), + ); + let f1_field = Field::new("f1", dict_type.clone(), false).with_metadata(md_f1); + let f2_keys = Int32Array::from(vec![1, 0, 3, 2]); + let f2_vals = StringArray::from(vec!["h", "g", "f", "e"]); + let f2 = DictionaryArray::::try_new(f2_keys, Arc::new(f2_vals)).unwrap(); + let mut md_f2 = HashMap::new(); + md_f2.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["h","g","f","e"]"#.to_string(), + ); + let f2_field = Field::new("f2", dict_type.clone(), false).with_metadata(md_f2); + let f3_keys = Int32Array::from(vec![Some(2), Some(0), None, Some(1)]); + let f3_vals = StringArray::from(vec!["k", "i", "j"]); + let f3 = DictionaryArray::::try_new(f3_keys, Arc::new(f3_vals)).unwrap(); + let mut md_f3 = HashMap::new(); + md_f3.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["k","i","j"]"#.to_string(), + ); + let f3_field = Field::new("f3", dict_type.clone(), true).with_metadata(md_f3); + let expected_schema = Arc::new(Schema::new(vec![f1_field, f2_field, f3_field])); + let expected = RecordBatch::try_new( + expected_schema, + vec![Arc::new(f1) as ArrayRef, Arc::new(f2), Arc::new(f3)], + ) + .unwrap(); + assert_eq!(actual, expected); + } + #[test] fn test_schema_store_register_lookup() { let schema_int = make_record_schema(PrimitiveType::Int); @@ -1258,6 +1433,11 @@ mod test { let mut decoder = make_decoder(&store, fp_int, &schema_long); let long_bytes = match fp_long { Fingerprint::Rabin(v) => v.to_le_bytes(), + Fingerprint::Id(id) => panic!("expected Rabin fingerprint, got ({id})"), + #[cfg(feature = "md5")] + Fingerprint::MD5(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + #[cfg(feature = "sha256")] + Fingerprint::SHA256(v) => panic!("expected Rabin fingerprint, got ({v:?})"), }; let mut buf = Vec::from(SINGLE_OBJECT_MAGIC); buf.extend_from_slice(&long_bytes[..4]); @@ -1276,8 +1456,14 @@ mod test { RecordDecoder::try_new_with_options(root_long.data_type(), decoder.utf8_view).unwrap(); let _ = decoder.cache.insert(fp_long, long_decoder); let mut buf = Vec::from(SINGLE_OBJECT_MAGIC); - let Fingerprint::Rabin(v) = fp_long; - buf.extend_from_slice(&v.to_le_bytes()); + match fp_long { + Fingerprint::Rabin(v) => buf.extend_from_slice(&v.to_le_bytes()), + Fingerprint::Id(id) => panic!("expected Rabin fingerprint, got ({id})"), + #[cfg(feature = "md5")] + Fingerprint::MD5(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + #[cfg(feature = "sha256")] + Fingerprint::SHA256(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + } let consumed = decoder.handle_prefix(&buf).unwrap().unwrap(); assert_eq!(consumed, buf.len()); assert!(decoder.pending_schema.is_some()); @@ -1355,6 +1541,83 @@ mod test { } #[test] + fn test_two_messages_same_schema_id() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let reader_schema = writer_schema.clone(); + let id = 100u32; + // Set up store with None fingerprint algorithm and register schema by id + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); + let _ = store + .set(Fingerprint::Id(id), writer_schema.clone()) + .expect("set id schema"); + let msg1 = make_message_id(id, 21); + let msg2 = make_message_id(id, 22); + let input = [msg1.clone(), msg2.clone()].concat(); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(reader_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id)) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&input).unwrap(); + let batch = decoder.flush().unwrap().expect("batch"); + assert_eq!(batch.num_rows(), 2); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 21); + assert_eq!(col.value(1), 22); + } + + #[test] + fn test_unknown_id_fingerprint_is_error() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let id_known = 7u32; + let id_unknown = 9u32; + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); + let _ = store + .set(Fingerprint::Id(id_known), writer_schema.clone()) + .expect("set id schema"); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(writer_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id_known)) + .build_decoder() + .unwrap(); + let prefix = make_id_prefix(id_unknown, 0); + let err = decoder.decode(&prefix).expect_err("decode should error"); + let msg = err.to_string(); + assert!( + msg.contains("Unknown fingerprint"), + "unexpected message: {msg}" + ); + } + + #[test] + fn test_handle_prefix_id_incomplete_magic() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let id = 5u32; + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::None); + let _ = store + .set(Fingerprint::Id(id), writer_schema.clone()) + .expect("set id schema"); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(writer_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id)) + .build_decoder() + .unwrap(); + let buf = &crate::schema::CONFLUENT_MAGIC[..0]; // empty incomplete magic + let res = decoder.handle_prefix(buf).unwrap(); + assert_eq!(res, Some(0)); + assert!(decoder.pending_schema.is_none()); + } + fn test_split_message_across_chunks() { let writer_schema = make_value_schema(PrimitiveType::Int); let reader_schema = writer_schema.clone(); @@ -1537,6 +1800,107 @@ mod test { assert!(batch.column(0).as_any().is::()); } + #[test] + fn test_alltypes_skip_writer_fields_keep_double_only() { + let file = arrow_test_data("avro/alltypes_plain.avro"); + let reader_schema = + make_reader_schema_with_selected_fields_in_order(&file, &["double_col"]); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_alltypes_skip_writer_fields_reorder_and_skip_many() { + let file = arrow_test_data("avro/alltypes_plain.avro"); + let reader_schema = + make_reader_schema_with_selected_fields_in_order(&file, &["timestamp_col", "id"]); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_skippable_types_project_each_field_individually() { + let path = "test/data/skippable_types.avro"; + let full = read_file(path, 1024, false); + let schema_full = full.schema(); + let num_rows = full.num_rows(); + let writer_json = load_writer_schema_json(path); + assert_eq!( + writer_json["type"], "record", + "writer schema must be a record" + ); + let fields_json = writer_json + .get("fields") + .and_then(|f| f.as_array()) + .expect("record has fields"); + assert_eq!( + schema_full.fields().len(), + fields_json.len(), + "full read column count vs writer fields" + ); + for (idx, f) in fields_json.iter().enumerate() { + let name = f + .get("name") + .and_then(|n| n.as_str()) + .unwrap_or_else(|| panic!("field at index {idx} has no name")); + let reader_schema = make_reader_schema_with_selected_fields_in_order(path, &[name]); + let projected = read_alltypes_with_reader_schema(path, reader_schema); + assert_eq!( + projected.num_columns(), + 1, + "projected batch should contain exactly the selected column '{name}'" + ); + assert_eq!( + projected.num_rows(), + num_rows, + "row count mismatch for projected column '{name}'" + ); + let field = schema_full.field(idx).clone(); + let col = full.column(idx).clone(); + let expected = + RecordBatch::try_new(Arc::new(Schema::new(vec![field])), vec![col]).unwrap(); + // Equality means: (1) read the right column values; and (2) all other + // writer fields were skipped correctly for this projection (no misalignment). + assert_eq!( + projected, expected, + "projected column '{name}' mismatch vs full read column" + ); + } + } + #[test] fn test_read_zero_byte_avro_file() { let batch = read_file("test/data/zero_byte.avro", 3, false); @@ -1791,18 +2155,18 @@ mod test { let expected = RecordBatch::try_from_iter_with_nullable([( "foo", Arc::new(BinaryArray::from_iter_values(vec![ - b"\x00".as_ref(), - b"\x01".as_ref(), - b"\x02".as_ref(), - b"\x03".as_ref(), - b"\x04".as_ref(), - b"\x05".as_ref(), - b"\x06".as_ref(), - b"\x07".as_ref(), - b"\x08".as_ref(), - b"\t".as_ref(), - b"\n".as_ref(), - b"\x0b".as_ref(), + b"\x00" as &[u8], + b"\x01" as &[u8], + b"\x02" as &[u8], + b"\x03" as &[u8], + b"\x04" as &[u8], + b"\x05" as &[u8], + b"\x06" as &[u8], + b"\x07" as &[u8], + b"\x08" as &[u8], + b"\t" as &[u8], + b"\n" as &[u8], + b"\x0b" as &[u8], ])) as Arc, true, )]) @@ -1812,37 +2176,137 @@ mod test { #[test] fn test_decimal() { - let files = [ - ("avro/fixed_length_decimal.avro", 25, 2), - ("avro/fixed_length_decimal_legacy.avro", 13, 2), - ("avro/int32_decimal.avro", 4, 2), - ("avro/int64_decimal.avro", 10, 2), + // Choose expected Arrow types depending on the `small_decimals` feature flag. + // With `small_decimals` enabled, Decimal32/Decimal64 are used where their + // precision allows; otherwise, those cases resolve to Decimal128. + #[cfg(feature = "small_decimals")] + let files: [(&str, DataType); 8] = [ + ( + "avro/fixed_length_decimal.avro", + DataType::Decimal128(25, 2), + ), + ( + "avro/fixed_length_decimal_legacy.avro", + DataType::Decimal64(13, 2), + ), + ("avro/int32_decimal.avro", DataType::Decimal32(4, 2)), + ("avro/int64_decimal.avro", DataType::Decimal64(10, 2)), + ( + "test/data/int256_decimal.avro", + DataType::Decimal256(76, 10), + ), + ( + "test/data/fixed256_decimal.avro", + DataType::Decimal256(76, 10), + ), + ( + "test/data/fixed_length_decimal_legacy_32.avro", + DataType::Decimal32(9, 2), + ), + ("test/data/int128_decimal.avro", DataType::Decimal128(38, 2)), + ]; + #[cfg(not(feature = "small_decimals"))] + let files: [(&str, DataType); 8] = [ + ( + "avro/fixed_length_decimal.avro", + DataType::Decimal128(25, 2), + ), + ( + "avro/fixed_length_decimal_legacy.avro", + DataType::Decimal128(13, 2), + ), + ("avro/int32_decimal.avro", DataType::Decimal128(4, 2)), + ("avro/int64_decimal.avro", DataType::Decimal128(10, 2)), + ( + "test/data/int256_decimal.avro", + DataType::Decimal256(76, 10), + ), + ( + "test/data/fixed256_decimal.avro", + DataType::Decimal256(76, 10), + ), + ( + "test/data/fixed_length_decimal_legacy_32.avro", + DataType::Decimal128(9, 2), + ), + ("test/data/int128_decimal.avro", DataType::Decimal128(38, 2)), ]; - let decimal_values: Vec = (1..=24).map(|n| n as i128 * 100).collect(); - for (file, precision, scale) in files { - let file_path = arrow_test_data(file); + for (file, expected_dt) in files { + let (precision, scale) = match expected_dt { + DataType::Decimal32(p, s) + | DataType::Decimal64(p, s) + | DataType::Decimal128(p, s) + | DataType::Decimal256(p, s) => (p, s), + _ => unreachable!("Unexpected decimal type in test inputs"), + }; + assert!(scale >= 0, "test data uses non-negative scales only"); + let scale_u32 = scale as u32; + let file_path: String = if file.starts_with("avro/") { + arrow_test_data(file) + } else { + std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(file) + .to_string_lossy() + .into_owned() + }; + let pow10: i128 = 10i128.pow(scale_u32); + let values_i128: Vec = (1..=24).map(|n| (n as i128) * pow10).collect(); + let build_expected = |dt: &DataType, values: &[i128]| -> ArrayRef { + match *dt { + DataType::Decimal32(p, s) => { + let it = values.iter().map(|&v| v as i32); + Arc::new( + Decimal32Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + DataType::Decimal64(p, s) => { + let it = values.iter().map(|&v| v as i64); + Arc::new( + Decimal64Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + DataType::Decimal128(p, s) => { + let it = values.iter().copied(); + Arc::new( + Decimal128Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + DataType::Decimal256(p, s) => { + let it = values.iter().map(|&v| i256::from_i128(v)); + Arc::new( + Decimal256Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + _ => unreachable!("Unexpected decimal type in test"), + } + }; let actual_batch = read_file(&file_path, 8, false); - let expected_array = Decimal128Array::from_iter_values(decimal_values.clone()) - .with_precision_and_scale(precision, scale) - .unwrap(); + let actual_nullable = actual_batch.schema().field(0).is_nullable(); + let expected_array = build_expected(&expected_dt, &values_i128); let mut meta = HashMap::new(); meta.insert("precision".to_string(), precision.to_string()); meta.insert("scale".to_string(), scale.to_string()); - let field_with_meta = Field::new("value", DataType::Decimal128(precision, scale), true) - .with_metadata(meta); - let expected_schema = Arc::new(Schema::new(vec![field_with_meta])); + let field = + Field::new("value", expected_dt.clone(), actual_nullable).with_metadata(meta); + let expected_schema = Arc::new(Schema::new(vec![field])); let expected_batch = - RecordBatch::try_new(expected_schema.clone(), vec![Arc::new(expected_array)]) - .expect("Failed to build expected RecordBatch"); + RecordBatch::try_new(expected_schema.clone(), vec![expected_array]).unwrap(); assert_eq!( actual_batch, expected_batch, - "Decoded RecordBatch does not match the expected Decimal128 data for file {file}" + "Decoded RecordBatch does not match for {file}" ); let actual_batch_small = read_file(&file_path, 3, false); assert_eq!( - actual_batch_small, - expected_batch, - "Decoded RecordBatch does not match the expected Decimal128 data for file {file} with batch size 3" + actual_batch_small, expected_batch, + "Decoded RecordBatch does not match for {file} with batch size 3" ); } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index a51e4c78740f..48eb601024b5 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::codec::{AvroDataType, Codec, Nullability, Promotion, ResolutionInfo}; +use crate::codec::{AvroDataType, Codec, Promotion, ResolutionInfo}; use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; use crate::reader::header::Header; use crate::schema::*; use arrow_array::builder::{ - ArrayBuilder, Decimal128Builder, Decimal256Builder, IntervalMonthDayNanoBuilder, - PrimitiveBuilder, + ArrayBuilder, Decimal128Builder, Decimal256Builder, Decimal32Builder, Decimal64Builder, + IntervalMonthDayNanoBuilder, PrimitiveBuilder, }; use arrow_array::types::*; use arrow_array::*; @@ -31,6 +31,8 @@ use arrow_schema::{ ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; +#[cfg(feature = "small_decimals")] +use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; use std::cmp::Ordering; use std::collections::HashMap; use std::io::Read; @@ -39,6 +41,25 @@ use uuid::Uuid; const DEFAULT_CAPACITY: usize = 1024; +/// Macro to decode a decimal payload for a given width and integer type. +macro_rules! decode_decimal { + ($size:expr, $buf:expr, $builder:expr, $N:expr, $Int:ty) => {{ + let bytes = read_decimal_bytes_be::<{ $N }>($buf, $size)?; + $builder.append_value(<$Int>::from_be_bytes(bytes)); + }}; +} + +/// Macro to finish a decimal builder into an array with precision/scale and nulls. +macro_rules! flush_decimal { + ($builder:expr, $precision:expr, $scale:expr, $nulls:expr, $ArrayTy:ty) => {{ + let (_, vals, _) = $builder.finish().into_parts(); + let dec = <$ArrayTy>::new(vals, $nulls) + .with_precision_and_scale(*$precision as u8, $scale.unwrap_or(0) as i8) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Arc::new(dec) as ArrayRef + }}; +} + #[derive(Debug)] pub(crate) struct RecordDecoderBuilder<'a> { data_type: &'a AvroDataType, @@ -70,6 +91,15 @@ pub(crate) struct RecordDecoder { schema: SchemaRef, fields: Vec, use_utf8view: bool, + resolved: Option, +} + +#[derive(Debug)] +struct ResolvedRuntime { + /// writer field index -> reader field index (or None if writer-only) + writer_to_reader: Arc<[Option]>, + /// per-writer-field skipper (Some only when writer-only) + skip_decoders: Vec>, } impl RecordDecoder { @@ -92,8 +122,6 @@ impl RecordDecoder { /// # Arguments /// * `data_type` - The Avro data type to decode. /// * `use_utf8view` - A flag indicating whether to use `Utf8View` for string types. - /// * `strict_mode` - A flag to enable strict decoding, returning an error if the data - /// does not conform to the schema. /// /// # Errors /// This function will return an error if the provided `data_type` is not a `Record`. @@ -101,14 +129,35 @@ impl RecordDecoder { data_type: &AvroDataType, use_utf8view: bool, ) -> Result { - match Decoder::try_new(data_type)? { - Decoder::Record(fields, encodings) => Ok(Self { - schema: Arc::new(ArrowSchema::new(fields)), - fields: encodings, - use_utf8view, - }), - encoding => Err(ArrowError::ParseError(format!( - "Expected record got {encoding:?}" + match data_type.codec() { + Codec::Struct(reader_fields) => { + // Build Arrow schema fields and per-child decoders + let mut arrow_fields = Vec::with_capacity(reader_fields.len()); + let mut encodings = Vec::with_capacity(reader_fields.len()); + for avro_field in reader_fields.iter() { + arrow_fields.push(avro_field.field()); + encodings.push(Decoder::try_new(avro_field.data_type())?); + } + // If this record carries resolution metadata, prepare top-level runtime helpers + let resolved = match data_type.resolution.as_ref() { + Some(ResolutionInfo::Record(rec)) => { + let skip_decoders = build_skip_decoders(&rec.skip_fields)?; + Some(ResolvedRuntime { + writer_to_reader: rec.writer_to_reader.clone(), + skip_decoders, + }) + } + _ => None, + }; + Ok(Self { + schema: Arc::new(ArrowSchema::new(arrow_fields)), + fields: encodings, + use_utf8view, + resolved, + }) + } + other => Err(ArrowError::ParseError(format!( + "Expected record got {other:?}" ))), } } @@ -121,9 +170,25 @@ impl RecordDecoder { /// Decode `count` records from `buf` pub(crate) fn decode(&mut self, buf: &[u8], count: usize) -> Result { let mut cursor = AvroCursor::new(buf); - for _ in 0..count { - for field in &mut self.fields { - field.decode(&mut cursor)?; + match self.resolved.as_mut() { + Some(runtime) => { + // Top-level resolved record: read writer fields in writer order, + // project into reader fields, and skip writer-only fields + for _ in 0..count { + decode_with_resolution( + &mut cursor, + &mut self.fields, + &runtime.writer_to_reader, + &mut runtime.skip_decoders, + )?; + } + } + None => { + for _ in 0..count { + for field in &mut self.fields { + field.decode(&mut cursor)?; + } + } } } Ok(cursor.position()) @@ -136,11 +201,30 @@ impl RecordDecoder { .iter_mut() .map(|x| x.flush(None)) .collect::, _>>()?; - RecordBatch::try_new(self.schema.clone(), arrays) } } +fn decode_with_resolution( + buf: &mut AvroCursor<'_>, + encodings: &mut [Decoder], + writer_to_reader: &[Option], + skippers: &mut [Option], +) -> Result<(), ArrowError> { + for (w_idx, (target, skipper_opt)) in writer_to_reader.iter().zip(skippers).enumerate() { + match (*target, skipper_opt.as_mut()) { + (Some(r_idx), _) => encodings[r_idx].decode(buf)?, + (None, Some(sk)) => sk.skip(buf)?, + (None, None) => { + return Err(ArrowError::SchemaError(format!( + "No skipper available for writer-only field at index {w_idx}", + ))); + } + } + } + Ok(()) +} + #[derive(Debug)] enum Decoder { Null(usize), @@ -180,9 +264,24 @@ enum Decoder { Enum(Vec, Arc<[String]>), Duration(IntervalMonthDayNanoBuilder), Uuid(Vec), + Decimal32(usize, Option, Option, Decimal32Builder), + Decimal64(usize, Option, Option, Decimal64Builder), Decimal128(usize, Option, Option, Decimal128Builder), Decimal256(usize, Option, Option, Decimal256Builder), Nullable(Nullability, NullBufferBuilder, Box), + EnumResolved { + indices: Vec, + symbols: Arc<[String]>, + mapping: Arc<[i32]>, + default_index: i32, + }, + /// Resolved record that needs writer->reader projection and skipping writer-only fields + RecordResolved { + fields: Fields, + encodings: Vec, + writer_to_reader: Arc<[Option]>, + skip_decoders: Vec>, + }, } impl Decoder { @@ -251,36 +350,43 @@ impl Decoder { (Codec::Decimal(precision, scale, size), _) => { let p = *precision; let s = *scale; - let sz = *size; let prec = p as u8; let scl = s.unwrap_or(0) as i8; - match (sz, p) { - (Some(fixed_size), _) if fixed_size <= 16 => { - let builder = - Decimal128Builder::new().with_precision_and_scale(prec, scl)?; - Self::Decimal128(p, s, sz, builder) - } - (Some(fixed_size), _) if fixed_size <= 32 => { - let builder = - Decimal256Builder::new().with_precision_and_scale(prec, scl)?; - Self::Decimal256(p, s, sz, builder) - } - (Some(fixed_size), _) => { + #[cfg(feature = "small_decimals")] + { + if p <= DECIMAL32_MAX_PRECISION as usize { + let builder = Decimal32Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal32(p, s, *size, builder) + } else if p <= DECIMAL64_MAX_PRECISION as usize { + let builder = Decimal64Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal64(p, s, *size, builder) + } else if p <= DECIMAL128_MAX_PRECISION as usize { + let builder = Decimal128Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal128(p, s, *size, builder) + } else if p <= DECIMAL256_MAX_PRECISION as usize { + let builder = Decimal256Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal256(p, s, *size, builder) + } else { return Err(ArrowError::ParseError(format!( - "Unsupported decimal size: {fixed_size:?}" + "Decimal precision {p} exceeds maximum supported" ))); } - (None, p) if p <= DECIMAL128_MAX_PRECISION as usize => { - let builder = - Decimal128Builder::new().with_precision_and_scale(prec, scl)?; - Self::Decimal128(p, s, sz, builder) - } - (None, p) if p <= DECIMAL256_MAX_PRECISION as usize => { - let builder = - Decimal256Builder::new().with_precision_and_scale(prec, scl)?; - Self::Decimal256(p, s, sz, builder) - } - (None, _) => { + } + #[cfg(not(feature = "small_decimals"))] + { + if p <= DECIMAL128_MAX_PRECISION as usize { + let builder = Decimal128Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal128(p, s, *size, builder) + } else if p <= DECIMAL256_MAX_PRECISION as usize { + let builder = Decimal256Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal256(p, s, *size, builder) + } else { return Err(ArrowError::ParseError(format!( "Decimal precision {p} exceeds maximum supported" ))); @@ -297,7 +403,16 @@ impl Decoder { ) } (Codec::Enum(symbols), _) => { - Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone()) + if let Some(ResolutionInfo::EnumMapping(mapping)) = data_type.resolution.as_ref() { + Self::EnumResolved { + indices: Vec::with_capacity(DEFAULT_CAPACITY), + symbols: symbols.clone(), + mapping: mapping.mapping.clone(), + default_index: mapping.default_index, + } + } else { + Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone()) + } } (Codec::Struct(fields), _) => { let mut arrow_fields = Vec::with_capacity(fields.len()); @@ -307,10 +422,20 @@ impl Decoder { arrow_fields.push(avro_field.field()); encodings.push(encoding); } - Self::Record(arrow_fields.into(), encodings) + if let Some(ResolutionInfo::Record(rec)) = data_type.resolution.as_ref() { + let skip_decoders = build_skip_decoders(&rec.skip_fields)?; + Self::RecordResolved { + fields: arrow_fields.into(), + encodings, + writer_to_reader: rec.writer_to_reader.clone(), + skip_decoders, + } + } else { + Self::Record(arrow_fields.into(), encodings) + } } (Codec::Map(child), _) => { - let val_field = child.field_with_name("value").with_nullable(true); + let val_field = child.field_with_name("value"); let map_field = Arc::new(ArrowField::new( "entries", DataType::Struct(Fields::from(vec![ @@ -376,14 +501,20 @@ impl Decoder { Self::Fixed(sz, accum) => { accum.extend(std::iter::repeat_n(0u8, *sz as usize)); } + Self::Decimal32(_, _, _, builder) => builder.append_value(0), + Self::Decimal64(_, _, _, builder) => builder.append_value(0), Self::Decimal128(_, _, _, builder) => builder.append_value(0), Self::Decimal256(_, _, _, builder) => builder.append_value(i256::ZERO), Self::Enum(indices, _) => indices.push(0), + Self::EnumResolved { indices, .. } => indices.push(0), Self::Duration(builder) => builder.append_null(), Self::Nullable(_, null_buffer, inner) => { null_buffer.append(false); inner.append_null(); } + Self::RecordResolved { encodings, .. } => { + encodings.iter_mut().for_each(|e| e.append_null()); + } } } @@ -447,29 +578,41 @@ impl Decoder { let fx = buf.get_fixed(*sz as usize)?; accum.extend_from_slice(fx); } + Self::Decimal32(_, _, size, builder) => { + decode_decimal!(size, buf, builder, 4, i32); + } + Self::Decimal64(_, _, size, builder) => { + decode_decimal!(size, buf, builder, 8, i64); + } Self::Decimal128(_, _, size, builder) => { - let raw = if let Some(s) = size { - buf.get_fixed(*s)? - } else { - buf.get_bytes()? - }; - let ext = sign_extend_to::<16>(raw)?; - let val = i128::from_be_bytes(ext); - builder.append_value(val); + decode_decimal!(size, buf, builder, 16, i128); } Self::Decimal256(_, _, size, builder) => { - let raw = if let Some(s) = size { - buf.get_fixed(*s)? - } else { - buf.get_bytes()? - }; - let ext = sign_extend_to::<32>(raw)?; - let val = i256::from_be_bytes(ext); - builder.append_value(val); + decode_decimal!(size, buf, builder, 32, i256); } Self::Enum(indices, _) => { indices.push(buf.get_int()?); } + Self::EnumResolved { + indices, + mapping, + default_index, + .. + } => { + let raw = buf.get_int()?; + let resolved = usize::try_from(raw) + .ok() + .and_then(|idx| mapping.get(idx).copied()) + .filter(|&idx| idx >= 0) + .unwrap_or(*default_index); + if resolved >= 0 { + indices.push(resolved); + } else { + return Err(ArrowError::ParseError(format!( + "Enum symbol index {raw} not resolvable and no default provided", + ))); + } + } Self::Duration(builder) => { let b = buf.get_fixed(12)?; let months = u32::from_le_bytes(b[0..4].try_into().unwrap()); @@ -484,12 +627,21 @@ impl Decoder { Nullability::NullFirst => branch != 0, Nullability::NullSecond => branch == 0, }; - nb.append(is_not_null); if is_not_null { + // It is important to decode before appending to null buffer in case of decode error encoding.decode(buf)?; } else { encoding.append_null(); } + nb.append(is_not_null); + } + Self::RecordResolved { + encodings, + writer_to_reader, + skip_decoders, + .. + } => { + decode_with_resolution(buf, encodings, writer_to_reader, skip_decoders)?; } } Ok(()) @@ -588,14 +740,16 @@ impl Decoder { ))); } } - let entries_struct = StructArray::new( - Fields::from(vec![ - Arc::new(ArrowField::new("key", DataType::Utf8, false)), - Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)), - ]), - vec![Arc::new(key_arr), val_arr], - None, - ); + let entries_fields = match map_field.data_type() { + DataType::Struct(fields) => fields.clone(), + other => { + return Err(ArrowError::InvalidArgumentError(format!( + "Map entries field must be a Struct, got {other:?}" + ))) + } + }; + let entries_struct = + StructArray::new(entries_fields, vec![Arc::new(key_arr), val_arr], None); let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false); Arc::new(map_arr) } @@ -610,52 +764,87 @@ impl Decoder { .map_err(|e| ArrowError::ParseError(e.to_string()))?; Arc::new(arr) } + Self::Decimal32(precision, scale, _, builder) => { + flush_decimal!(builder, precision, scale, nulls, Decimal32Array) + } + Self::Decimal64(precision, scale, _, builder) => { + flush_decimal!(builder, precision, scale, nulls, Decimal64Array) + } Self::Decimal128(precision, scale, _, builder) => { - let (_, vals, _) = builder.finish().into_parts(); - let scl = scale.unwrap_or(0); - let dec = Decimal128Array::new(vals, nulls) - .with_precision_and_scale(*precision as u8, scl as i8) - .map_err(|e| ArrowError::ParseError(e.to_string()))?; - Arc::new(dec) + flush_decimal!(builder, precision, scale, nulls, Decimal128Array) } Self::Decimal256(precision, scale, _, builder) => { - let (_, vals, _) = builder.finish().into_parts(); - let scl = scale.unwrap_or(0); - let dec = Decimal256Array::new(vals, nulls) - .with_precision_and_scale(*precision as u8, scl as i8) - .map_err(|e| ArrowError::ParseError(e.to_string()))?; - Arc::new(dec) - } - Self::Enum(indices, symbols) => { - let keys = flush_primitive::(indices, nulls); - let values = Arc::new(StringArray::from( - symbols.iter().map(|s| s.as_str()).collect::>(), - )); - Arc::new(DictionaryArray::try_new(keys, values)?) + flush_decimal!(builder, precision, scale, nulls, Decimal256Array) } + Self::Enum(indices, symbols) => flush_dict(indices, symbols, nulls)?, + Self::EnumResolved { + indices, symbols, .. + } => flush_dict(indices, symbols, nulls)?, Self::Duration(builder) => { let (_, vals, _) = builder.finish().into_parts(); let vals = IntervalMonthDayNanoArray::try_new(vals, nulls) .map_err(|e| ArrowError::ParseError(e.to_string()))?; Arc::new(vals) } + Self::RecordResolved { + fields, encodings, .. + } => { + let arrays = encodings + .iter_mut() + .map(|x| x.flush(None)) + .collect::, _>>()?; + Arc::new(StructArray::new(fields.clone(), arrays, nulls)) + } }) } } +#[derive(Debug, Copy, Clone)] +enum NegativeBlockBehavior { + ProcessItems, + SkipBySize, +} + +#[inline] +fn skip_blocks( + buf: &mut AvroCursor, + mut skip_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + process_blockwise( + buf, + move |c| skip_item(c), + NegativeBlockBehavior::SkipBySize, + ) +} + +#[inline] +fn flush_dict( + indices: &mut Vec, + symbols: &[String], + nulls: Option, +) -> Result { + let keys = flush_primitive::(indices, nulls); + let values = Arc::new(StringArray::from_iter_values( + symbols.iter().map(|s| s.as_str()), + )); + DictionaryArray::try_new(keys, values) + .map_err(|e| ArrowError::ParseError(e.to_string())) + .map(|arr| Arc::new(arr) as ArrayRef) +} + #[inline] fn read_blocks( buf: &mut AvroCursor, decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, ) -> Result { - read_blockwise_items(buf, true, decode_entry) + process_blockwise(buf, decode_entry, NegativeBlockBehavior::ProcessItems) } #[inline] -fn read_blockwise_items( +fn process_blockwise( buf: &mut AvroCursor, - read_size_after_negative: bool, - mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, + mut on_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, + negative_behavior: NegativeBlockBehavior, ) -> Result { let mut total = 0usize; loop { @@ -667,22 +856,27 @@ fn read_blockwise_items( match block_count.cmp(&0) { Ordering::Equal => break, Ordering::Less => { - // If block_count is negative, read the absolute value of count, - // then read the block size as a long and discard let count = (-block_count) as usize; - if read_size_after_negative { - let _size_in_bytes = buf.get_long()?; - } - for _ in 0..count { - decode_fn(buf)?; + // A negative count is followed by a long of the size in bytes + let size_in_bytes = buf.get_long()? as usize; + match negative_behavior { + NegativeBlockBehavior::ProcessItems => { + // Process items one-by-one after reading size + for _ in 0..count { + on_item(buf)?; + } + } + NegativeBlockBehavior::SkipBySize => { + // Skip the entire block payload at once + let _ = buf.get_fixed(size_in_bytes)?; + } } total += count; } Ordering::Greater => { - // If block_count is positive, decode that many items let count = block_count as usize; - for _i in 0..count { - decode_fn(buf)?; + for _ in 0..count { + on_item(buf)?; } total += count; } @@ -709,29 +903,237 @@ fn flush_primitive( PrimitiveArray::new(flush_values(values).into(), nulls) } -/// Sign extends a byte slice to a fixed-size array of N bytes. -/// This is done by filling the leading bytes with 0x00 for positive numbers -/// or 0xFF for negative numbers. #[inline] -fn sign_extend_to(raw: &[u8]) -> Result<[u8; N], ArrowError> { - if raw.len() > N { - return Err(ArrowError::ParseError(format!( - "Cannot extend a slice of length {} to {} bytes.", - raw.len(), - N - ))); - } - let mut arr = [0u8; N]; - let pad_len = N - raw.len(); - // Determine the byte to use for padding based on the sign bit of the raw data. - let extension_byte = if raw.is_empty() || (raw[0] & 0x80 == 0) { - 0x00 - } else { - 0xFF - }; - arr[..pad_len].fill(extension_byte); - arr[pad_len..].copy_from_slice(raw); - Ok(arr) +fn read_decimal_bytes_be( + buf: &mut AvroCursor<'_>, + size: &Option, +) -> Result<[u8; N], ArrowError> { + match size { + Some(n) if *n == N => { + let raw = buf.get_fixed(N)?; + let mut arr = [0u8; N]; + arr.copy_from_slice(raw); + Ok(arr) + } + Some(n) => { + let raw = buf.get_fixed(*n)?; + sign_cast_to::(raw) + } + None => { + let raw = buf.get_bytes()?; + sign_cast_to::(raw) + } + } +} + +/// Sign-extend or (when larger) validate-and-truncate a big-endian two's-complement +/// integer into exactly `N` bytes. This matches Avro's decimal binary encoding: +/// the payload is a big-endian two's-complement integer, and when narrowing it must +/// be representable without changing sign or value. +/// +/// If `raw.len() < N`, the value is sign-extended. +/// If `raw.len() > N`, all truncated leading bytes must match the sign-extension byte +/// and the MSB of the first kept byte must match the sign (to avoid silent overflow). +#[inline] +fn sign_cast_to(raw: &[u8]) -> Result<[u8; N], ArrowError> { + let len = raw.len(); + // Fast path: exact width, just copy + if len == N { + let mut out = [0u8; N]; + out.copy_from_slice(raw); + return Ok(out); + } + // Determine sign byte from MSB of first byte (empty => positive) + let first = raw.first().copied().unwrap_or(0u8); + let sign_byte = if (first & 0x80) == 0 { 0x00 } else { 0xFF }; + // Pre-fill with sign byte to support sign extension + let mut out = [sign_byte; N]; + if len > N { + // Validate truncation: all dropped leading bytes must equal sign_byte, + // and the MSB of the first kept byte must match the sign. + let extra = len - N; + // Any non-sign byte in the truncated prefix indicates overflow + if raw[..extra].iter().any(|&b| b != sign_byte) { + return Err(ArrowError::ParseError(format!( + "Decimal value with {} bytes cannot be represented in {} bytes without overflow", + len, N + ))); + } + if N > 0 { + let first_kept = raw[extra]; + let sign_bit_mismatch = ((first_kept ^ sign_byte) & 0x80) != 0; + if sign_bit_mismatch { + return Err(ArrowError::ParseError(format!( + "Decimal value with {} bytes cannot be represented in {} bytes without overflow", + len, N + ))); + } + } + out.copy_from_slice(&raw[extra..]); + return Ok(out); + } + out[N - len..].copy_from_slice(raw); + Ok(out) +} + +/// Lightweight skipper for non‑projected writer fields +/// (fields present in the writer schema but omitted by the reader/projection); +/// per Avro 1.11.1 schema resolution these fields are ignored. +/// +/// +#[derive(Debug)] +enum Skipper { + Null, + Boolean, + Int32, + Int64, + Float32, + Float64, + Bytes, + String, + Date32, + TimeMillis, + TimeMicros, + TimestampMillis, + TimestampMicros, + Fixed(usize), + Decimal(Option), + UuidString, + Enum, + DurationFixed12, + List(Box), + Map(Box), + Struct(Vec), + Nullable(Nullability, Box), +} + +impl Skipper { + fn from_avro(dt: &AvroDataType) -> Result { + let mut base = match dt.codec() { + Codec::Null => Self::Null, + Codec::Boolean => Self::Boolean, + Codec::Int32 | Codec::Date32 | Codec::TimeMillis => Self::Int32, + Codec::Int64 => Self::Int64, + Codec::TimeMicros => Self::TimeMicros, + Codec::TimestampMillis(_) => Self::TimestampMillis, + Codec::TimestampMicros(_) => Self::TimestampMicros, + Codec::Float32 => Self::Float32, + Codec::Float64 => Self::Float64, + Codec::Binary => Self::Bytes, + Codec::Utf8 | Codec::Utf8View => Self::String, + Codec::Fixed(sz) => Self::Fixed(*sz as usize), + Codec::Decimal(_, _, size) => Self::Decimal(*size), + Codec::Uuid => Self::UuidString, // encoded as string + Codec::Enum(_) => Self::Enum, + Codec::List(item) => Self::List(Box::new(Skipper::from_avro(item)?)), + Codec::Struct(fields) => Self::Struct( + fields + .iter() + .map(|f| Skipper::from_avro(f.data_type())) + .collect::>()?, + ), + Codec::Map(values) => Self::Map(Box::new(Skipper::from_avro(values)?)), + Codec::Interval => Self::DurationFixed12, + _ => { + return Err(ArrowError::NotYetImplemented(format!( + "Skipper not implemented for codec {:?}", + dt.codec() + ))); + } + }; + if let Some(n) = dt.nullability() { + base = Self::Nullable(n, Box::new(base)); + } + Ok(base) + } + + fn skip(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { + match self { + Self::Null => Ok(()), + Self::Boolean => { + buf.get_bool()?; + Ok(()) + } + Self::Int32 | Self::Date32 | Self::TimeMillis => { + buf.get_int()?; + Ok(()) + } + Self::Int64 | Self::TimeMicros | Self::TimestampMillis | Self::TimestampMicros => { + buf.get_long()?; + Ok(()) + } + Self::Float32 => { + buf.get_float()?; + Ok(()) + } + Self::Float64 => { + buf.get_double()?; + Ok(()) + } + Self::Bytes | Self::String | Self::UuidString => { + buf.get_bytes()?; + Ok(()) + } + Self::Fixed(sz) => { + buf.get_fixed(*sz)?; + Ok(()) + } + Self::Decimal(size) => { + if let Some(s) = size { + buf.get_fixed(*s) + } else { + buf.get_bytes() + }?; + Ok(()) + } + Self::Enum => { + buf.get_int()?; + Ok(()) + } + Self::DurationFixed12 => { + buf.get_fixed(12)?; + Ok(()) + } + Self::List(item) => { + skip_blocks(buf, |c| item.skip(c))?; + Ok(()) + } + Self::Map(value) => { + skip_blocks(buf, |c| { + c.get_bytes()?; // key + value.skip(c) + })?; + Ok(()) + } + Self::Struct(fields) => { + for f in fields.iter_mut() { + f.skip(buf)? + } + Ok(()) + } + Self::Nullable(order, inner) => { + let branch = buf.read_vlq()?; + let is_not_null = match *order { + Nullability::NullFirst => branch != 0, + Nullability::NullSecond => branch == 0, + }; + if is_not_null { + inner.skip(buf)?; + } + Ok(()) + } + } + } +} + +#[inline] +fn build_skip_decoders( + skip_fields: &[Option], +) -> Result>, ArrowError> { + skip_fields + .iter() + .map(|opt| opt.as_ref().map(Skipper::from_avro).transpose()) + .collect() } #[cfg(test)] @@ -739,8 +1141,9 @@ mod tests { use super::*; use crate::codec::AvroField; use arrow_array::{ - cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, - IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray, + cast::AsArray, Array, Decimal128Array, Decimal256Array, Decimal32Array, DictionaryArray, + FixedSizeBinaryArray, IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, + StructArray, }; fn encode_avro_int(value: i32) -> Vec { @@ -1187,7 +1590,7 @@ mod tests { #[test] fn test_decimal_decoding_fixed256() { - let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(32))); + let dt = avro_from_codec(Codec::Decimal(50, Some(2), Some(32))); let mut decoder = Decoder::try_new(&dt).unwrap(); let row1 = [ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -1214,7 +1617,7 @@ mod tests { #[test] fn test_decimal_decoding_fixed128() { - let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(16))); + let dt = avro_from_codec(Codec::Decimal(28, Some(2), Some(16))); let mut decoder = Decoder::try_new(&dt).unwrap(); let row1 = [ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -1237,6 +1640,79 @@ mod tests { assert_eq!(dec.value_as_string(1), "-1.23"); } + #[test] + fn test_decimal_decoding_fixed32_from_32byte_fixed_storage() { + let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(32))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x30, 0x39, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0x85, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + #[cfg(feature = "small_decimals")] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + } + + #[test] + fn test_decimal_decoding_fixed32_from_16byte_fixed_storage() { + let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(16))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x30, 0x39, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x85, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + + let arr = decoder.flush(None).unwrap(); + #[cfg(feature = "small_decimals")] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + } + #[test] fn test_decimal_decoding_bytes_with_nulls() { let dt = avro_from_codec(Codec::Decimal(4, Some(1), None)); @@ -1253,21 +1729,34 @@ mod tests { data.extend_from_slice(&encode_avro_int(0)); data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); // row1 - decoder.decode(&mut cursor).unwrap(); // row2 - decoder.decode(&mut cursor).unwrap(); // row3 + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 3); - assert!(dec_arr.is_valid(0)); - assert!(!dec_arr.is_valid(1)); - assert!(dec_arr.is_valid(2)); - assert_eq!(dec_arr.value_as_string(0), "123.4"); - assert_eq!(dec_arr.value_as_string(2), "-123.4"); + #[cfg(feature = "small_decimals")] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); + } } #[test] - fn test_decimal_decoding_bytes_with_nulls_fixed_size() { + fn test_decimal_decoding_bytes_with_nulls_fixed_size_narrow_result() { let dt = avro_from_codec(Codec::Decimal(6, Some(2), Some(16))); let inner = Decoder::try_new(&dt).unwrap(); let mut decoder = Decoder::Nullable( @@ -1294,13 +1783,26 @@ mod tests { decoder.decode(&mut cursor).unwrap(); decoder.decode(&mut cursor).unwrap(); let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 3); - assert!(dec_arr.is_valid(0)); - assert!(!dec_arr.is_valid(1)); - assert!(dec_arr.is_valid(2)); - assert_eq!(dec_arr.value_as_string(0), "1234.56"); - assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + #[cfg(feature = "small_decimals")] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "1234.56"); + assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "1234.56"); + assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + } } #[test] @@ -1321,7 +1823,6 @@ mod tests { .as_any() .downcast_ref::>() .unwrap(); - assert_eq!(dict_array.len(), 3); let values = dict_array .values() @@ -1433,4 +1934,327 @@ mod tests { let array = decoder.flush(None).unwrap(); assert_eq!(array.len(), 0); } + + #[test] + fn test_nullable_decode_error_bitmap_corruption() { + // Nullable Int32 with ['T','null'] encoding (NullSecond) + let avro_type = AvroDataType::new( + Codec::Int32, + Default::default(), + Some(Nullability::NullSecond), + ); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + + // Row 1: union branch 1 (null) + let mut row1 = Vec::new(); + row1.extend_from_slice(&encode_avro_int(1)); + + // Row 2: union branch 0 (non-null) but missing the int payload -> decode error + let mut row2 = Vec::new(); + row2.extend_from_slice(&encode_avro_int(0)); // branch = 0 => non-null + + // Row 3: union branch 0 (non-null) with correct int payload -> should succeed + let mut row3 = Vec::new(); + row3.extend_from_slice(&encode_avro_int(0)); // branch + row3.extend_from_slice(&encode_avro_int(42)); // actual value + + decoder.decode(&mut AvroCursor::new(&row1)).unwrap(); + assert!(decoder.decode(&mut AvroCursor::new(&row2)).is_err()); // decode error + decoder.decode(&mut AvroCursor::new(&row3)).unwrap(); + + let array = decoder.flush(None).unwrap(); + + // Should contain 2 elements: row1 (null) and row3 (42) + assert_eq!(array.len(), 2); + let int_array = array.as_any().downcast_ref::().unwrap(); + assert!(int_array.is_null(0)); // row1 is null + assert_eq!(int_array.value(1), 42); // row3 value is 42 + } + + #[test] + fn test_enum_mapping_reordered_symbols() { + let reader_symbols: Arc<[String]> = + vec!["B".to_string(), "C".to_string(), "A".to_string()].into(); + let mapping: Arc<[i32]> = Arc::from(vec![2, 0, 1]); + let default_index: i32 = -1; + let mut dec = Decoder::EnumResolved { + indices: Vec::with_capacity(DEFAULT_CAPACITY), + symbols: reader_symbols.clone(), + mapping, + default_index, + }; + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(2)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let dict = arr + .as_any() + .downcast_ref::>() + .unwrap(); + let expected_keys = Int32Array::from(vec![2, 0, 1]); + assert_eq!(dict.keys(), &expected_keys); + let values = dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(0), "B"); + assert_eq!(values.value(1), "C"); + assert_eq!(values.value(2), "A"); + } + + #[test] + fn test_enum_mapping_unknown_symbol_and_out_of_range_fall_back_to_default() { + let reader_symbols: Arc<[String]> = vec!["A".to_string(), "B".to_string()].into(); + let default_index: i32 = 1; + let mapping: Arc<[i32]> = Arc::from(vec![0, 1]); + let mut dec = Decoder::EnumResolved { + indices: Vec::with_capacity(DEFAULT_CAPACITY), + symbols: reader_symbols.clone(), + mapping, + default_index, + }; + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(99)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let dict = arr + .as_any() + .downcast_ref::>() + .unwrap(); + let expected_keys = Int32Array::from(vec![0, 1, 1]); + assert_eq!(dict.keys(), &expected_keys); + let values = dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(0), "A"); + assert_eq!(values.value(1), "B"); + } + + #[test] + fn test_enum_mapping_unknown_symbol_without_default_errors() { + let reader_symbols: Arc<[String]> = vec!["A".to_string()].into(); + let default_index: i32 = -1; // indicates no default at type-level + let mapping: Arc<[i32]> = Arc::from(vec![-1]); + let mut dec = Decoder::EnumResolved { + indices: Vec::with_capacity(DEFAULT_CAPACITY), + symbols: reader_symbols, + mapping, + default_index, + }; + let data = encode_avro_int(0); + let mut cur = AvroCursor::new(&data); + let err = dec + .decode(&mut cur) + .expect_err("expected decode error for unresolved enum without default"); + let msg = err.to_string(); + assert!( + msg.contains("not resolvable") && msg.contains("no default"), + "unexpected error message: {msg}" + ); + } + + fn make_record_resolved_decoder( + reader_fields: &[(&str, DataType, bool)], + writer_to_reader: Vec>, + mut skip_decoders: Vec>, + ) -> Decoder { + let mut field_refs: Vec = Vec::with_capacity(reader_fields.len()); + let mut encodings: Vec = Vec::with_capacity(reader_fields.len()); + for (name, dt, nullable) in reader_fields { + field_refs.push(Arc::new(ArrowField::new(*name, dt.clone(), *nullable))); + let enc = match dt { + DataType::Int32 => Decoder::Int32(Vec::new()), + DataType::Int64 => Decoder::Int64(Vec::new()), + DataType::Utf8 => { + Decoder::String(OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::new()) + } + other => panic!("Unsupported test reader field type: {other:?}"), + }; + encodings.push(enc); + } + let fields: Fields = field_refs.into(); + Decoder::RecordResolved { + fields, + encodings, + writer_to_reader: Arc::from(writer_to_reader), + skip_decoders, + } + } + + #[test] + fn test_skip_writer_trailing_field_int32() { + let mut dec = make_record_resolved_decoder( + &[("id", arrow_schema::DataType::Int32, false)], + vec![Some(0), None], + vec![None, Some(super::Skipper::Int32)], + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(7)); + data.extend_from_slice(&encode_avro_int(999)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let struct_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_arr.len(), 1); + let id = struct_arr + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.value(0), 7); + } + + #[test] + fn test_skip_writer_middle_field_string() { + let mut dec = make_record_resolved_decoder( + &[ + ("id", DataType::Int32, false), + ("score", DataType::Int64, false), + ], + vec![Some(0), None, Some(1)], + vec![None, Some(Skipper::String), None], + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(42)); + data.extend_from_slice(&encode_avro_bytes(b"abcdef")); + data.extend_from_slice(&encode_avro_long(1000)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let score = s + .column_by_name("score") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.value(0), 42); + assert_eq!(score.value(0), 1000); + } + + #[test] + fn test_skip_writer_array_with_negative_block_count_fast() { + let mut dec = make_record_resolved_decoder( + &[("id", DataType::Int32, false)], + vec![None, Some(0)], + vec![Some(super::Skipper::List(Box::new(Skipper::Int32))), None], + ); + let mut array_payload = Vec::new(); + array_payload.extend_from_slice(&encode_avro_int(1)); + array_payload.extend_from_slice(&encode_avro_int(2)); + array_payload.extend_from_slice(&encode_avro_int(3)); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(-3)); + data.extend_from_slice(&encode_avro_long(array_payload.len() as i64)); + data.extend_from_slice(&array_payload); + data.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&encode_avro_int(5)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.len(), 1); + assert_eq!(id.value(0), 5); + } + + #[test] + fn test_skip_writer_map_with_negative_block_count_fast() { + let mut dec = make_record_resolved_decoder( + &[("id", DataType::Int32, false)], + vec![None, Some(0)], + vec![Some(Skipper::Map(Box::new(Skipper::Int32))), None], + ); + let mut entries = Vec::new(); + entries.extend_from_slice(&encode_avro_bytes(b"k1")); + entries.extend_from_slice(&encode_avro_int(10)); + entries.extend_from_slice(&encode_avro_bytes(b"k2")); + entries.extend_from_slice(&encode_avro_int(20)); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(-2)); + data.extend_from_slice(&encode_avro_long(entries.len() as i64)); + data.extend_from_slice(&entries); + data.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&encode_avro_int(123)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.len(), 1); + assert_eq!(id.value(0), 123); + } + + #[test] + fn test_skip_writer_nullable_field_union_nullfirst() { + let mut dec = make_record_resolved_decoder( + &[("id", DataType::Int32, false)], + vec![None, Some(0)], + vec![ + Some(super::Skipper::Nullable( + Nullability::NullFirst, + Box::new(super::Skipper::Int32), + )), + None, + ], + ); + let mut row1 = Vec::new(); + row1.extend_from_slice(&encode_avro_long(0)); + row1.extend_from_slice(&encode_avro_int(5)); + let mut row2 = Vec::new(); + row2.extend_from_slice(&encode_avro_long(1)); + row2.extend_from_slice(&encode_avro_int(123)); + row2.extend_from_slice(&encode_avro_int(7)); + let mut cur1 = AvroCursor::new(&row1); + let mut cur2 = AvroCursor::new(&row2); + dec.decode(&mut cur1).unwrap(); + dec.decode(&mut cur2).unwrap(); + assert_eq!(cur1.position(), row1.len()); + assert_eq!(cur2.position(), row2.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.len(), 2); + assert_eq!(id.value(0), 5); + assert_eq!(id.value(1), 7); + } } diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 2f1c0a2bcffc..6e343736c1e9 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -20,6 +20,8 @@ use arrow_schema::{ }; use serde::{Deserialize, Serialize}; use serde_json::{json, Map as JsonMap, Value}; +#[cfg(feature = "sha256")] +use sha2::{Digest, Sha256}; use std::cmp::PartialEq; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; @@ -31,6 +33,9 @@ pub const SCHEMA_METADATA_KEY: &str = "avro.schema"; /// The Avro single‑object encoding “magic” bytes (`0xC3 0x01`) pub const SINGLE_OBJECT_MAGIC: [u8; 2] = [0xC3, 0x01]; +/// The Confluent "magic" byte (`0x00`) +pub const CONFLUENT_MAGIC: [u8; 1] = [0x00]; + /// Metadata key used to represent Avro enum symbols in an Arrow schema. pub const AVRO_ENUM_SYMBOLS_METADATA_KEY: &str = "avro.enum.symbols"; @@ -49,11 +54,25 @@ pub const AVRO_DOC_METADATA_KEY: &str = "avro.doc"; /// Compare two Avro schemas for equality (identical schemas). /// Returns true if the schemas have the same parsing canonical form (i.e., logically identical). pub fn compare_schemas(writer: &Schema, reader: &Schema) -> Result { - let canon_writer = generate_canonical_form(writer)?; - let canon_reader = generate_canonical_form(reader)?; + let canon_writer = AvroSchema::generate_canonical_form(writer)?; + let canon_reader = AvroSchema::generate_canonical_form(reader)?; Ok(canon_writer == canon_reader) } +/// Avro types are not nullable, with nullability instead encoded as a union +/// where one of the variants is the null type. +/// +/// To accommodate this, we specially case two-variant unions where one of the +/// variants is the null type, and use this to derive arrow's notion of nullability +#[derive(Debug, Copy, Clone, PartialEq, Default)] +pub enum Nullability { + /// The nulls are encoded as the first union variant + #[default] + NullFirst, + /// The nulls are encoded as the second union variant + NullSecond, +} + /// Either a [`PrimitiveType`] or a reference to a previously defined named type /// /// @@ -108,7 +127,7 @@ pub struct Attributes<'a> { /// Additional JSON attributes #[serde(flatten)] - pub additional: HashMap<&'a str, serde_json::Value>, + pub additional: HashMap<&'a str, Value>, } impl Attributes<'_> { @@ -215,8 +234,8 @@ pub struct Field<'a> { #[serde(borrow)] pub r#type: Schema<'a>, /// Optional default value for this field - #[serde(borrow, default)] - pub default: Option<&'a str>, + #[serde(default)] + pub default: Option, } /// An enumeration @@ -304,18 +323,131 @@ pub struct AvroSchema { impl TryFrom<&ArrowSchema> for AvroSchema { type Error = ArrowError; + /// Converts an `ArrowSchema` to `AvroSchema`, delegating to + /// `AvroSchema::from_arrow_with_options` with `None` so that the + /// union null ordering is decided by `Nullability::default()`. fn try_from(schema: &ArrowSchema) -> Result { - // Fast‑path: schema already contains Avro JSON + AvroSchema::from_arrow_with_options(schema, None) + } +} + +impl AvroSchema { + /// Creates a new `AvroSchema` from a JSON string. + pub fn new(json_string: String) -> Self { + Self { json_string } + } + + /// Deserializes and returns the `AvroSchema`. + /// + /// The returned schema borrows from `self`. + pub fn schema(&self) -> Result, ArrowError> { + serde_json::from_str(self.json_string.as_str()) + .map_err(|e| ArrowError::ParseError(format!("Invalid Avro schema JSON: {e}"))) + } + + /// Returns the Rabin fingerprint of the schema. + pub fn fingerprint(&self) -> Result { + Self::generate_fingerprint_rabin(&self.schema()?) + } + + /// Generates a fingerprint for the given `Schema` using the specified [`FingerprintAlgorithm`]. + /// + /// The fingerprint is computed over the schema's Parsed Canonical Form + /// as defined by the Avro specification. Depending on `hash_type`, this + /// will return one of the supported [`Fingerprint`] variants: + /// - [`Fingerprint::Rabin`] for [`FingerprintAlgorithm::Rabin`] + /// - [`Fingerprint::MD5`] for [`FingerprintAlgorithm::MD5`] + /// - [`Fingerprint::SHA256`] for [`FingerprintAlgorithm::SHA256`] + /// + /// Note: [`FingerprintAlgorithm::None`] cannot be used to generate a fingerprint + /// and will result in an error. If you intend to use a Schema Registry ID-based + /// wire format, load or set the [`Fingerprint::Id`] directly via [`Fingerprint::load_fingerprint_id`] + /// or [`SchemaStore::set`]. + /// + /// See also: + /// + /// # Errors + /// Returns an error if generating the canonical form of the schema fails, + /// or if `hash_type` is [`FingerprintAlgorithm::None`]. + /// + /// # Examples + /// ```no_run + /// use arrow_avro::schema::{AvroSchema, FingerprintAlgorithm}; + /// + /// let avro = AvroSchema::new("\"string\"".to_string()); + /// let schema = avro.schema().unwrap(); + /// let fp = AvroSchema::generate_fingerprint(&schema, FingerprintAlgorithm::Rabin).unwrap(); + /// ``` + pub fn generate_fingerprint( + schema: &Schema, + hash_type: FingerprintAlgorithm, + ) -> Result { + let canonical = Self::generate_canonical_form(schema).map_err(|e| { + ArrowError::ComputeError(format!("Failed to generate canonical form for schema: {e}")) + })?; + match hash_type { + FingerprintAlgorithm::Rabin => { + Ok(Fingerprint::Rabin(compute_fingerprint_rabin(&canonical))) + } + FingerprintAlgorithm::None => Err(ArrowError::SchemaError( + "FingerprintAlgorithm of None cannot be used to generate a fingerprint; \ + if using Fingerprint::Id, pass the registry ID in instead using the set method." + .to_string(), + )), + #[cfg(feature = "md5")] + FingerprintAlgorithm::MD5 => Ok(Fingerprint::MD5(compute_fingerprint_md5(&canonical))), + #[cfg(feature = "sha256")] + FingerprintAlgorithm::SHA256 => { + Ok(Fingerprint::SHA256(compute_fingerprint_sha256(&canonical))) + } + } + } + + /// Generates the 64-bit Rabin fingerprint for the given `Schema`. + /// + /// The fingerprint is computed from the canonical form of the schema. + /// This is also known as `CRC-64-AVRO`. + /// + /// # Returns + /// A `Fingerprint::Rabin` variant containing the 64-bit fingerprint. + pub fn generate_fingerprint_rabin(schema: &Schema) -> Result { + Self::generate_fingerprint(schema, FingerprintAlgorithm::Rabin) + } + + /// Generates the Parsed Canonical Form for the given [`Schema`]. + /// + /// The canonical form is a standardized JSON representation of the schema, + /// primarily used for generating a schema fingerprint for equality checking. + /// + /// This form strips attributes that do not affect the schema's identity, + /// such as `doc` fields, `aliases`, and any properties not defined in the + /// Avro specification. + /// + /// + pub fn generate_canonical_form(schema: &Schema) -> Result { + build_canonical(schema, None) + } + + /// Build Avro JSON from an Arrow [`ArrowSchema`], applying the given null‑union order. + /// + /// If the input Arrow schema already contains Avro JSON in + /// [`SCHEMA_METADATA_KEY`], that JSON is returned verbatim to preserve + /// the exact header encoding alignment; otherwise, a new JSON is generated + /// honoring `null_union_order` at **all nullable sites**. + pub fn from_arrow_with_options( + schema: &ArrowSchema, + null_order: Option, + ) -> Result { if let Some(json) = schema.metadata.get(SCHEMA_METADATA_KEY) { return Ok(AvroSchema::new(json.clone())); } + let order = null_order.unwrap_or_default(); let mut name_gen = NameGenerator::default(); let fields_json = schema .fields() .iter() - .map(|f| arrow_field_to_avro(f, &mut name_gen)) + .map(|f| arrow_field_to_avro(f, &mut name_gen, order)) .collect::, _>>()?; - // Assemble top‑level record let record_name = schema .metadata .get(AVRO_NAME_METADATA_KEY) @@ -333,52 +465,42 @@ impl TryFrom<&ArrowSchema> for AvroSchema { record.insert("doc".into(), Value::String(doc.clone())); } record.insert("fields".into(), Value::Array(fields_json)); - let schema_prefix = format!("{SCHEMA_METADATA_KEY}."); - for (meta_key, meta_val) in &schema.metadata { - // Skip keys already handled or internal - if meta_key.starts_with("avro.") - || meta_key.starts_with(schema_prefix.as_str()) - || is_internal_arrow_key(meta_key) - { - continue; - } - let json_val = - serde_json::from_str(meta_val).unwrap_or_else(|_| Value::String(meta_val.clone())); - record.insert(meta_key.clone(), json_val); - } + extend_with_passthrough_metadata(&mut record, &schema.metadata); let json_string = serde_json::to_string(&Value::Object(record)) - .map_err(|e| ArrowError::SchemaError(format!("Serialising Avro JSON failed: {e}")))?; + .map_err(|e| ArrowError::SchemaError(format!("Serializing Avro JSON failed: {e}")))?; Ok(AvroSchema::new(json_string)) } } -impl AvroSchema { - /// Creates a new `AvroSchema` from a JSON string. - pub fn new(json_string: String) -> Self { - Self { json_string } - } - - /// Deserializes and returns the `AvroSchema`. - /// - /// The returned schema borrows from `self`. - pub fn schema(&self) -> Result, ArrowError> { - serde_json::from_str(self.json_string.as_str()) - .map_err(|e| ArrowError::ParseError(format!("Invalid Avro schema JSON: {e}"))) - } - - /// Returns the Rabin fingerprint of the schema. - pub fn fingerprint(&self) -> Result { - generate_fingerprint_rabin(&self.schema()?) - } -} - /// Supported fingerprint algorithms for Avro schema identification. -/// Currently only `Rabin` is supported, `SHA256` and `MD5` support will come in a future update +/// For use with Confluent Schema Registry IDs, set to None. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] pub enum FingerprintAlgorithm { /// 64‑bit CRC‑64‑AVRO Rabin fingerprint. #[default] Rabin, + /// Represents a fingerprint not based on a hash algorithm, (e.g., a 32-bit Schema Registry ID.) + None, + #[cfg(feature = "md5")] + /// 128-bit MD5 message digest. + MD5, + #[cfg(feature = "sha256")] + /// 256-bit SHA-256 digest. + SHA256, +} + +/// Allow easy extraction of the algorithm used to create a fingerprint. +impl From<&Fingerprint> for FingerprintAlgorithm { + fn from(fp: &Fingerprint) -> Self { + match fp { + Fingerprint::Rabin(_) => FingerprintAlgorithm::Rabin, + Fingerprint::Id(_) => FingerprintAlgorithm::None, + #[cfg(feature = "md5")] + Fingerprint::MD5(_) => FingerprintAlgorithm::MD5, + #[cfg(feature = "sha256")] + Fingerprint::SHA256(_) => FingerprintAlgorithm::SHA256, + } + } } /// A schema fingerprint in one of the supported formats. @@ -386,64 +508,36 @@ pub enum FingerprintAlgorithm { /// This is used as the key inside `SchemaStore` `HashMap`. Each `SchemaStore` /// instance always stores only one variant, matching its configured /// `FingerprintAlgorithm`, but the enum makes the API uniform. -/// Currently only `Rabin` is supported /// /// +/// #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum Fingerprint { /// A 64-bit Rabin fingerprint. Rabin(u64), + /// A 32-bit Schema Registry ID. + Id(u32), + #[cfg(feature = "md5")] + /// A 128-bit MD5 fingerprint. + MD5([u8; 16]), + #[cfg(feature = "sha256")] + /// A 256-bit SHA-256 fingerprint. + SHA256([u8; 32]), } -/// Allow easy extraction of the algorithm used to create a fingerprint. -impl From<&Fingerprint> for FingerprintAlgorithm { - fn from(fp: &Fingerprint) -> Self { - match fp { - Fingerprint::Rabin(_) => FingerprintAlgorithm::Rabin, - } - } -} - -/// Generates a fingerprint for the given `Schema` using the specified `FingerprintAlgorithm`. -pub(crate) fn generate_fingerprint( - schema: &Schema, - hash_type: FingerprintAlgorithm, -) -> Result { - let canonical = generate_canonical_form(schema).map_err(|e| { - ArrowError::ComputeError(format!("Failed to generate canonical form for schema: {e}")) - })?; - match hash_type { - FingerprintAlgorithm::Rabin => { - Ok(Fingerprint::Rabin(compute_fingerprint_rabin(&canonical))) - } +impl Fingerprint { + /// Loads the 32-bit Schema Registry fingerprint (Confluent Schema Registry ID). + /// + /// The provided `id` is in big-endian wire order; this converts it to host order + /// and returns `Fingerprint::Id`. + /// + /// # Returns + /// A `Fingerprint::Id` variant containing the 32-bit fingerprint. + pub fn load_fingerprint_id(id: u32) -> Self { + Fingerprint::Id(u32::from_be(id)) } } -/// Generates the 64-bit Rabin fingerprint for the given `Schema`. -/// -/// The fingerprint is computed from the canonical form of the schema. -/// This is also known as `CRC-64-AVRO`. -/// -/// # Returns -/// A `Fingerprint::Rabin` variant containing the 64-bit fingerprint. -pub fn generate_fingerprint_rabin(schema: &Schema) -> Result { - generate_fingerprint(schema, FingerprintAlgorithm::Rabin) -} - -/// Generates the Parsed Canonical Form for the given [`Schema`]. -/// -/// The canonical form is a standardized JSON representation of the schema, -/// primarily used for generating a schema fingerprint for equality checking. -/// -/// This form strips attributes that do not affect the schema's identity, -/// such as `doc` fields, `aliases`, and any properties not defined in the -/// Avro specification. -/// -/// -pub fn generate_canonical_form(schema: &Schema) -> Result { - build_canonical(schema, None) -} - /// An in-memory cache of Avro schemas, indexed by their fingerprint. /// /// `SchemaStore` provides a mechanism to store and retrieve Avro schemas efficiently. @@ -478,17 +572,16 @@ pub struct SchemaStore { schemas: HashMap, } -impl TryFrom<&[AvroSchema]> for SchemaStore { +impl TryFrom> for SchemaStore { type Error = ArrowError; - /// Creates a `SchemaStore` from a slice of schemas. - /// Each schema in the slice is registered with the new store. - fn try_from(schemas: &[AvroSchema]) -> Result { - let mut store = SchemaStore::new(); - for schema in schemas { - store.register(schema.clone())?; - } - Ok(store) + /// Creates a `SchemaStore` from a HashMap of schemas. + /// Each schema in the HashMap is registered with the new store. + fn try_from(schemas: HashMap) -> Result { + Ok(Self { + schemas, + ..Self::default() + }) } } @@ -498,23 +591,35 @@ impl SchemaStore { Self::default() } - /// Registers a schema with the store and returns its fingerprint. + /// Creates an empty `SchemaStore` using the default fingerprinting algorithm (64-bit Rabin). + pub fn new_with_type(fingerprint_algorithm: FingerprintAlgorithm) -> Self { + Self { + fingerprint_algorithm, + ..Self::default() + } + } + + /// Registers a schema with the store and the provided fingerprint. + /// Note: Confluent wire format implementations should leverage this method. /// - /// A fingerprint is calculated for the given schema using the store's configured - /// hash type. If a schema with the same fingerprint does not already exist in the - /// store, the new schema is inserted. If the fingerprint already exists, the - /// existing schema is not overwritten. + /// A schema is set in the store, using the provided fingerprint. If a schema + /// with the same fingerprint does not already exist in the store, the new schema + /// is inserted. If the fingerprint already exists, the existing schema is not overwritten. /// /// # Arguments /// + /// * `fingerprint` - A reference to the `Fingerprint` of the schema to register. /// * `schema` - The `AvroSchema` to register. /// /// # Returns /// - /// A `Result` containing the `Fingerprint` of the schema if successful, + /// A `Result` returning the provided `Fingerprint` of the schema if successful, /// or an `ArrowError` on failure. - pub fn register(&mut self, schema: AvroSchema) -> Result { - let fingerprint = generate_fingerprint(&schema.schema()?, self.fingerprint_algorithm)?; + pub fn set( + &mut self, + fingerprint: Fingerprint, + schema: AvroSchema, + ) -> Result { match self.schemas.entry(fingerprint) { Entry::Occupied(entry) => { if entry.get() != &schema { @@ -530,6 +635,37 @@ impl SchemaStore { Ok(fingerprint) } + /// Registers a schema with the store and returns its fingerprint. + /// + /// A fingerprint is calculated for the given schema using the store's configured + /// hash type. If a schema with the same fingerprint does not already exist in the + /// store, the new schema is inserted. If the fingerprint already exists, the + /// existing schema is not overwritten. If FingerprintAlgorithm is set to None, this + /// method will return an error. Confluent wire format implementations should leverage the + /// set method instead. + /// + /// # Arguments + /// + /// * `schema` - The `AvroSchema` to register. + /// + /// # Returns + /// + /// A `Result` containing the `Fingerprint` of the schema if successful, + /// or an `ArrowError` on failure. + pub fn register(&mut self, schema: AvroSchema) -> Result { + if self.fingerprint_algorithm == FingerprintAlgorithm::None { + return Err(ArrowError::SchemaError( + "Invalid FingerprintAlgorithm; unable to generate fingerprint. \ + Use the set method directly instead, providing a valid fingerprint" + .to_string(), + )); + } + let fingerprint = + AvroSchema::generate_fingerprint(&schema.schema()?, self.fingerprint_algorithm)?; + self.set(fingerprint, schema)?; + Ok(fingerprint) + } + /// Looks up a schema by its `Fingerprint`. /// /// # Arguments @@ -715,11 +851,52 @@ pub(crate) fn compute_fingerprint_rabin(canonical_form: &str) -> u64 { fp } +#[cfg(feature = "md5")] +/// Compute the **128‑bit MD5** fingerprint of the canonical form. +/// +/// Returns a 16‑byte array (`[u8; 16]`) containing the full MD5 digest, +/// exactly as required by the Avro specification. +#[inline] +pub(crate) fn compute_fingerprint_md5(canonical_form: &str) -> [u8; 16] { + let digest = md5::compute(canonical_form.as_bytes()); + digest.0 +} + +#[cfg(feature = "sha256")] +/// Compute the **256‑bit SHA‑256** fingerprint of the canonical form. +/// +/// Returns a 32‑byte array (`[u8; 32]`) containing the full SHA‑256 digest. +#[inline] +pub(crate) fn compute_fingerprint_sha256(canonical_form: &str) -> [u8; 32] { + let mut hasher = Sha256::new(); + hasher.update(canonical_form.as_bytes()); + let digest = hasher.finalize(); + digest.into() +} + #[inline] fn is_internal_arrow_key(key: &str) -> bool { key.starts_with("ARROW:") || key == SCHEMA_METADATA_KEY } +/// Copies Arrow schema metadata entries to the provided JSON map, +/// skipping keys that are Avro-reserved, internal Arrow keys, or +/// nested under the `avro.schema.` namespace. Values that parse as +/// JSON are inserted as JSON; otherwise the raw string is preserved. +fn extend_with_passthrough_metadata( + target: &mut JsonMap, + metadata: &HashMap, +) { + for (meta_key, meta_val) in metadata { + if meta_key.starts_with("avro.") || is_internal_arrow_key(meta_key) { + continue; + } + let json_val = + serde_json::from_str(meta_val).unwrap_or_else(|_| Value::String(meta_val.clone())); + target.insert(meta_key.clone(), json_val); + } +} + // Sanitize an arbitrary string so it is a valid Avro field or type name fn sanitise_avro_name(base_name: &str) -> String { if base_name.is_empty() { @@ -790,12 +967,21 @@ fn merge_extras(schema: Value, mut extras: JsonMap) -> Value { } } -// Convert an Arrow `DataType` into an Avro schema `Value`. +fn wrap_nullable(inner: Value, null_order: Nullability) -> Value { + let null = Value::String("null".into()); + let elements = match null_order { + Nullability::NullFirst => vec![null, inner], + Nullability::NullSecond => vec![inner, null], + }; + Value::Array(elements) +} + fn datatype_to_avro( dt: &DataType, field_name: &str, metadata: &HashMap, name_gen: &mut NameGenerator, + null_order: Nullability, ) -> Result<(Value, JsonMap), ArrowError> { let mut extras = JsonMap::new(); let val = match dt { @@ -909,20 +1095,32 @@ fn datatype_to_avro( if matches!(dt, DataType::LargeList(_)) { extras.insert("arrowLargeList".into(), Value::Bool(true)); } - let (items, ie) = - datatype_to_avro(child.data_type(), child.name(), child.metadata(), name_gen)?; + let items_schema = process_datatype( + child.data_type(), + child.name(), + child.metadata(), + name_gen, + null_order, + child.is_nullable(), + )?; json!({ "type": "array", - "items": merge_extras(items, ie) + "items": items_schema }) } DataType::FixedSizeList(child, len) => { extras.insert("arrowFixedSize".into(), json!(len)); - let (items, ie) = - datatype_to_avro(child.data_type(), child.name(), child.metadata(), name_gen)?; + let items_schema = process_datatype( + child.data_type(), + child.name(), + child.metadata(), + name_gen, + null_order, + child.is_nullable(), + )?; json!({ "type": "array", - "items": merge_extras(items, ie) + "items": items_schema }) } DataType::Map(entries, _) => { @@ -934,21 +1132,23 @@ fn datatype_to_avro( )) } }; - let (val_schema, value_entry) = datatype_to_avro( + let values_schema = process_datatype( value_field.data_type(), value_field.name(), value_field.metadata(), name_gen, + null_order, + value_field.is_nullable(), )?; json!({ "type": "map", - "values": merge_extras(val_schema, value_entry) + "values": values_schema }) } DataType::Struct(fields) => { let avro_fields = fields .iter() - .map(|field| arrow_field_to_avro(field, name_gen)) + .map(|field| arrow_field_to_avro(field, name_gen, null_order)) .collect::, _>>()?; json!({ "type": "record", @@ -966,19 +1166,24 @@ fn datatype_to_avro( "symbols": symbols }) } else { - let (inner, ie) = datatype_to_avro(value.as_ref(), field_name, metadata, name_gen)?; - merge_extras(inner, ie) + process_datatype( + value.as_ref(), + field_name, + metadata, + name_gen, + null_order, + false, + )? } } - DataType::RunEndEncoded(_, values) => { - let (inner, ie) = datatype_to_avro( - values.data_type(), - values.name(), - values.metadata(), - name_gen, - )?; - merge_extras(inner, ie) - } + DataType::RunEndEncoded(_, values) => process_datatype( + values.data_type(), + values.name(), + values.metadata(), + name_gen, + null_order, + false, + )?, DataType::Union(_, _) => { return Err(ArrowError::NotYetImplemented( "Arrow Union to Avro Union not yet supported".into(), @@ -993,27 +1198,40 @@ fn datatype_to_avro( Ok((val, extras)) } +fn process_datatype( + dt: &DataType, + field_name: &str, + metadata: &HashMap, + name_gen: &mut NameGenerator, + null_order: Nullability, + is_nullable: bool, +) -> Result { + let (schema, extras) = datatype_to_avro(dt, field_name, metadata, name_gen, null_order)?; + let mut merged = merge_extras(schema, extras); + if is_nullable { + merged = wrap_nullable(merged, null_order) + } + Ok(merged) +} + fn arrow_field_to_avro( field: &ArrowField, name_gen: &mut NameGenerator, + null_order: Nullability, ) -> Result { - // Sanitize field name to ensure Avro validity but store the original in metadata let avro_name = sanitise_avro_name(field.name()); - let (schema, extras) = - datatype_to_avro(field.data_type(), &avro_name, field.metadata(), name_gen)?; - // If nullable, wrap `[ "null", ]`, NOTE: second order nullability to be added in a follow-up - let mut schema = if field.is_nullable() { - Value::Array(vec![ - Value::String("null".into()), - merge_extras(schema, extras), - ]) - } else { - merge_extras(schema, extras) - }; + let schema_value = process_datatype( + field.data_type(), + &avro_name, + field.metadata(), + name_gen, + null_order, + field.is_nullable(), + )?; // Build the field map let mut map = JsonMap::with_capacity(field.metadata().len() + 3); map.insert("name".into(), Value::String(avro_name)); - map.insert("type".into(), schema); + map.insert("type".into(), schema_value); // Transfer selected metadata for (meta_key, meta_val) in field.metadata() { if is_internal_arrow_key(meta_key) { @@ -1393,8 +1611,16 @@ mod tests { fn test_try_from_schemas_rabin() { let int_avro_schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); let record_avro_schema = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); - let schemas = vec![int_avro_schema.clone(), record_avro_schema.clone()]; - let store = SchemaStore::try_from(schemas.as_slice()).unwrap(); + let mut schemas: HashMap = HashMap::new(); + schemas.insert( + int_avro_schema.fingerprint().unwrap(), + int_avro_schema.clone(), + ); + schemas.insert( + record_avro_schema.fingerprint().unwrap(), + record_avro_schema.clone(), + ); + let store = SchemaStore::try_from(schemas).unwrap(); let int_fp = int_avro_schema.fingerprint().unwrap(); assert_eq!(store.lookup(&int_fp).cloned(), Some(int_avro_schema)); let rec_fp = record_avro_schema.fingerprint().unwrap(); @@ -1405,12 +1631,21 @@ mod tests { fn test_try_from_with_duplicates() { let int_avro_schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); let record_avro_schema = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); - let schemas = vec![ + let mut schemas: HashMap = HashMap::new(); + schemas.insert( + int_avro_schema.fingerprint().unwrap(), int_avro_schema.clone(), - record_avro_schema, + ); + schemas.insert( + record_avro_schema.fingerprint().unwrap(), + record_avro_schema.clone(), + ); + // Insert duplicate of int schema + schemas.insert( + int_avro_schema.fingerprint().unwrap(), int_avro_schema.clone(), - ]; - let store = SchemaStore::try_from(schemas.as_slice()).unwrap(); + ); + let store = SchemaStore::try_from(schemas).unwrap(); assert_eq!(store.schemas.len(), 2); let int_fp = int_avro_schema.fingerprint().unwrap(); assert_eq!(store.lookup(&int_fp).cloned(), Some(int_avro_schema)); @@ -1421,14 +1656,40 @@ mod tests { let mut store = SchemaStore::new(); let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); let fp_enum = store.register(schema.clone()).unwrap(); - let Fingerprint::Rabin(fp_val) = fp_enum; - assert_eq!( - store.lookup(&Fingerprint::Rabin(fp_val)).cloned(), - Some(schema.clone()) - ); - assert!(store - .lookup(&Fingerprint::Rabin(fp_val.wrapping_add(1))) - .is_none()); + match fp_enum { + Fingerprint::Rabin(fp_val) => { + assert_eq!( + store.lookup(&Fingerprint::Rabin(fp_val)).cloned(), + Some(schema.clone()) + ); + assert!(store + .lookup(&Fingerprint::Rabin(fp_val.wrapping_add(1))) + .is_none()); + } + Fingerprint::Id(id) => { + unreachable!("This test should only generate Rabin fingerprints") + } + #[cfg(feature = "md5")] + Fingerprint::MD5(id) => { + unreachable!("This test should only generate Rabin fingerprints") + } + #[cfg(feature = "sha256")] + Fingerprint::SHA256(id) => { + unreachable!("This test should only generate Rabin fingerprints") + } + } + } + + #[test] + fn test_set_and_lookup_id() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let id = 42u32; + let fp = Fingerprint::Id(id); + let out_fp = store.set(fp, schema.clone()).unwrap(); + assert_eq!(out_fp, fp); + assert_eq!(store.lookup(&fp).cloned(), Some(schema.clone())); + assert!(store.lookup(&Fingerprint::Id(id.wrapping_add(1))).is_none()); } #[test] @@ -1442,10 +1703,43 @@ mod tests { assert_eq!(store.schemas.len(), 1); } + #[test] + fn test_set_and_lookup_with_provided_fingerprint() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let fp = schema.fingerprint().unwrap(); + let out_fp = store.set(fp, schema.clone()).unwrap(); + assert_eq!(out_fp, fp); + assert_eq!(store.lookup(&fp).cloned(), Some(schema)); + } + + #[test] + fn test_set_duplicate_same_schema_ok() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let fp = schema.fingerprint().unwrap(); + let _ = store.set(fp, schema.clone()).unwrap(); + let _ = store.set(fp, schema.clone()).unwrap(); + assert_eq!(store.schemas.len(), 1); + } + + #[test] + fn test_set_duplicate_different_schema_collision_error() { + let mut store = SchemaStore::new(); + let schema1 = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let schema2 = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); + // Use the same Fingerprint::Id to simulate a collision across different schemas + let fp = Fingerprint::Id(123); + let _ = store.set(fp, schema1).unwrap(); + let err = store.set(fp, schema2).unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("Schema fingerprint collision")); + } + #[test] fn test_canonical_form_generation_primitive() { let schema = int_schema(); - let canonical_form = generate_canonical_form(&schema).unwrap(); + let canonical_form = AvroSchema::generate_canonical_form(&schema).unwrap(); assert_eq!(canonical_form, r#""int""#); } @@ -1453,7 +1747,7 @@ mod tests { fn test_canonical_form_generation_record() { let schema = record_schema(); let expected_canonical_form = r#"{"name":"test.namespace.record1","type":"record","fields":[{"name":"field1","type":"int"},{"name":"field2","type":"string"}]}"#; - let canonical_form = generate_canonical_form(&schema).unwrap(); + let canonical_form = AvroSchema::generate_canonical_form(&schema).unwrap(); assert_eq!(canonical_form, expected_canonical_form); } @@ -1510,7 +1804,7 @@ mod tests { r#type: Schema::Type(Type { r#type: TypeName::Primitive(PrimitiveType::Bytes), attributes: Attributes { - logical_type: Some("decimal"), + logical_type: None, additional: HashMap::from([("precision", json!(4))]), }, }), @@ -1522,7 +1816,7 @@ mod tests { }, })); let expected_canonical_form = r#"{"name":"record_with_attrs","type":"record","fields":[{"name":"f1","type":"bytes"}]}"#; - let canonical_form = generate_canonical_form(&schema_with_attrs).unwrap(); + let canonical_form = AvroSchema::generate_canonical_form(&schema_with_attrs).unwrap(); assert_eq!(canonical_form, expected_canonical_form); } @@ -1767,4 +2061,83 @@ mod tests { let avro = AvroSchema::try_from(&schema).unwrap(); assert_json_contains(&avro.json_string, "\"arrowDurationUnit\":\"second\""); } + + #[test] + fn test_schema_with_non_string_defaults_decodes_successfully() { + let schema_json = r#"{ + "type": "record", + "name": "R", + "fields": [ + {"name": "a", "type": "int", "default": 0}, + {"name": "b", "type": {"type": "array", "items": "long"}, "default": [1, 2, 3]}, + {"name": "c", "type": {"type": "map", "values": "double"}, "default": {"x": 1.5, "y": 2.5}}, + {"name": "inner", "type": {"type": "record", "name": "Inner", "fields": [ + {"name": "flag", "type": "boolean", "default": true}, + {"name": "name", "type": "string", "default": "hi"} + ]}, "default": {"flag": false, "name": "d"}}, + {"name": "u", "type": ["int", "null"], "default": 42} + ] + }"#; + + let schema: Schema = serde_json::from_str(schema_json).expect("schema should parse"); + match &schema { + Schema::Complex(ComplexType::Record(_)) => {} + other => panic!("expected record schema, got: {:?}", other), + } + // Avro to Arrow conversion + let field = crate::codec::AvroField::try_from(&schema) + .expect("Avro->Arrow conversion should succeed"); + let arrow_field = field.field(); + + // Build expected Arrow field + let expected_list_item = ArrowField::new( + arrow_schema::Field::LIST_FIELD_DEFAULT_NAME, + DataType::Int64, + false, + ); + let expected_b = ArrowField::new("b", DataType::List(Arc::new(expected_list_item)), false); + + let expected_map_value = ArrowField::new("value", DataType::Float64, false); + let expected_entries = ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + expected_map_value, + ])), + false, + ); + let expected_c = + ArrowField::new("c", DataType::Map(Arc::new(expected_entries), false), false); + + let expected_inner = ArrowField::new( + "inner", + DataType::Struct(Fields::from(vec![ + ArrowField::new("flag", DataType::Boolean, false), + ArrowField::new("name", DataType::Utf8, false), + ])), + false, + ); + + let expected = ArrowField::new( + "R", + DataType::Struct(Fields::from(vec![ + ArrowField::new("a", DataType::Int32, false), + expected_b, + expected_c, + expected_inner, + ArrowField::new("u", DataType::Int32, true), + ])), + false, + ); + + assert_eq!(arrow_field, expected); + } + + #[test] + fn default_order_is_consistent() { + let arrow_schema = ArrowSchema::new(vec![ArrowField::new("s", DataType::Utf8, true)]); + let a = AvroSchema::try_from(&arrow_schema).unwrap().json_string; + let b = AvroSchema::from_arrow_with_options(&arrow_schema, None); + assert_eq!(a, b.unwrap().json_string); + } } diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs index c45aa6cfcf9e..ccf80fd8d1ac 100644 --- a/arrow-avro/src/writer/encoder.rs +++ b/arrow-avro/src/writer/encoder.rs @@ -17,31 +17,25 @@ //! Avro Encoder for Arrow types. +use crate::codec::{AvroDataType, AvroField, Codec}; +use crate::schema::Nullability; use arrow_array::cast::AsArray; use arrow_array::types::{ ArrowPrimitiveType, Float32Type, Float64Type, Int32Type, Int64Type, TimestampMicrosecondType, }; -use arrow_array::OffsetSizeTrait; -use arrow_array::{Array, GenericBinaryArray, PrimitiveArray, RecordBatch}; +use arrow_array::{ + Array, GenericBinaryArray, GenericListArray, GenericStringArray, LargeListArray, ListArray, + OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, +}; use arrow_buffer::NullBuffer; -use arrow_schema::{ArrowError, DataType, FieldRef, TimeUnit}; +use arrow_schema::{ArrowError, DataType, Field, Schema as ArrowSchema, TimeUnit}; use std::io::Write; -/// Behavior knobs for the Avro encoder. -/// -/// When `impala_mode` is `true`, optional/nullable values are encoded -/// as Avro unions with **null second** (`[T, "null"]`). When `false` -/// (default), we use **null first** (`["null", T]`). -#[derive(Debug, Clone, Copy, Default)] -pub struct EncoderOptions { - impala_mode: bool, // Will be fully implemented in a follow-up PR -} - /// Encode a single Avro-`long` using ZigZag + variable length, buffered. /// /// Spec: #[inline] -pub fn write_long(writer: &mut W, value: i64) -> Result<(), ArrowError> { +pub fn write_long(out: &mut W, value: i64) -> Result<(), ArrowError> { let mut zz = ((value << 1) ^ (value >> 63)) as u64; // At most 10 bytes for 64-bit varint let mut buf = [0u8; 10]; @@ -53,28 +47,25 @@ pub fn write_long(writer: &mut W, value: i64) -> Result<(), A } buf[i] = (zz & 0x7F) as u8; i += 1; - writer - .write_all(&buf[..i]) + out.write_all(&buf[..i]) .map_err(|e| ArrowError::IoError(format!("write long: {e}"), e)) } #[inline] -fn write_int(writer: &mut W, value: i32) -> Result<(), ArrowError> { - write_long(writer, value as i64) +fn write_int(out: &mut W, value: i32) -> Result<(), ArrowError> { + write_long(out, value as i64) } #[inline] -fn write_len_prefixed(writer: &mut W, bytes: &[u8]) -> Result<(), ArrowError> { - write_long(writer, bytes.len() as i64)?; - writer - .write_all(bytes) +fn write_len_prefixed(out: &mut W, bytes: &[u8]) -> Result<(), ArrowError> { + write_long(out, bytes.len() as i64)?; + out.write_all(bytes) .map_err(|e| ArrowError::IoError(format!("write bytes: {e}"), e)) } #[inline] -fn write_bool(writer: &mut W, v: bool) -> Result<(), ArrowError> { - writer - .write_all(&[if v { 1 } else { 0 }]) +fn write_bool(out: &mut W, v: bool) -> Result<(), ArrowError> { + out.write_all(&[if v { 1 } else { 0 }]) .map_err(|e| ArrowError::IoError(format!("write bool: {e}"), e)) } @@ -83,146 +74,385 @@ fn write_bool(writer: &mut W, v: bool) -> Result<(), ArrowErr /// Branch index is 0-based per Avro unions: /// - Null-first (default): null => 0, value => 1 /// - Null-second (Impala): value => 0, null => 1 -#[inline] -fn write_optional_branch( - writer: &mut W, +fn write_optional_index( + out: &mut W, is_null: bool, - impala_mode: bool, + null_order: Nullability, ) -> Result<(), ArrowError> { - let branch = if impala_mode == is_null { 1 } else { 0 }; - write_int(writer, branch) + let byte = union_value_branch_byte(null_order, is_null); + out.write_all(&[byte]) + .map_err(|e| ArrowError::IoError(format!("write union branch: {e}"), e)) } -/// Encode a `RecordBatch` in Avro binary format using **default options**. -pub fn encode_record_batch(batch: &RecordBatch, out: &mut W) -> Result<(), ArrowError> { - encode_record_batch_with_options(batch, out, &EncoderOptions::default()) +#[derive(Debug, Clone)] +enum NullState { + NonNullable, + NullableNoNulls { + union_value_byte: u8, + }, + Nullable { + nulls: NullBuffer, + null_order: Nullability, + }, } -/// Encode a `RecordBatch` with explicit `EncoderOptions`. -pub fn encode_record_batch_with_options( - batch: &RecordBatch, - out: &mut W, - opts: &EncoderOptions, -) -> Result<(), ArrowError> { - let mut encoders = batch - .schema() - .fields() - .iter() - .zip(batch.columns()) - .map(|(field, array)| Ok((field.is_nullable(), make_encoder(array.as_ref())?))) - .collect::, ArrowError>>()?; - (0..batch.num_rows()).try_for_each(|row| { - encoders.iter_mut().try_for_each(|(is_nullable, enc)| { - if *is_nullable { - let is_null = enc.is_null(row); - write_optional_branch(out, is_null, opts.impala_mode)?; - if is_null { - return Ok(()); +/// Arrow to Avro FieldEncoder: +/// - Holds the inner `Encoder` (by value) +/// - Carries the per-site nullability **state** as a single enum that enforces invariants +pub struct FieldEncoder<'a> { + encoder: Encoder<'a>, + null_state: NullState, +} + +impl<'a> FieldEncoder<'a> { + fn make_encoder( + array: &'a dyn Array, + field: &Field, + plan: &FieldPlan, + nullability: Option, + ) -> Result { + let encoder = match plan { + FieldPlan::Struct { encoders } => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected StructArray".into()))?; + Encoder::Struct(Box::new(StructEncoder::try_new(arr, encoders)?)) + } + FieldPlan::List { + items_nullability, + item_plan, + } => match array.data_type() { + DataType::List(_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected ListArray".into()))?; + Encoder::List(Box::new(ListEncoder32::try_new( + arr, + *items_nullability, + item_plan.as_ref(), + )?)) + } + DataType::LargeList(_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected LargeListArray".into()))?; + Encoder::LargeList(Box::new(ListEncoder64::try_new( + arr, + *items_nullability, + item_plan.as_ref(), + )?)) + } + other => { + return Err(ArrowError::SchemaError(format!( + "Avro array site requires Arrow List/LargeList, found: {other:?}" + ))) + } + }, + FieldPlan::Scalar => match array.data_type() { + DataType::Boolean => Encoder::Boolean(BooleanEncoder(array.as_boolean())), + DataType::Utf8 => { + Encoder::Utf8(Utf8GenericEncoder::(array.as_string::())) + } + DataType::LargeUtf8 => { + Encoder::Utf8Large(Utf8GenericEncoder::(array.as_string::())) + } + DataType::Int32 => Encoder::Int(IntEncoder(array.as_primitive::())), + DataType::Int64 => Encoder::Long(LongEncoder(array.as_primitive::())), + DataType::Float32 => { + Encoder::Float32(F32Encoder(array.as_primitive::())) + } + DataType::Float64 => { + Encoder::Float64(F64Encoder(array.as_primitive::())) + } + DataType::Binary => Encoder::Binary(BinaryEncoder(array.as_binary::())), + DataType::LargeBinary => { + Encoder::LargeBinary(BinaryEncoder(array.as_binary::())) } + DataType::Timestamp(TimeUnit::Microsecond, _) => Encoder::Timestamp(LongEncoder( + array.as_primitive::(), + )), + other => { + return Err(ArrowError::NotYetImplemented(format!( + "Avro scalar type not yet supported: {other:?}" + ))); + } + }, + other => { + return Err(ArrowError::NotYetImplemented(format!( + "Avro writer: {other:?} not yet supported", + ))); + } + }; + // Compute the effective null state from writer-declared nullability and data nulls. + let null_state = match (nullability, array.null_count() > 0) { + (None, false) => NullState::NonNullable, + (None, true) => { + return Err(ArrowError::InvalidArgumentError(format!( + "Avro site '{}' is non-nullable, but array contains nulls", + field.name() + ))); + } + (Some(order), false) => { + // Optimization: drop any bitmap; emit a constant "value" branch byte. + NullState::NullableNoNulls { + union_value_byte: union_value_branch_byte(order, false), + } + } + (Some(null_order), true) => { + let Some(nulls) = array.nulls().cloned() else { + return Err(ArrowError::InvalidArgumentError(format!( + "Array for Avro site '{}' reports nulls but has no null buffer", + field.name() + ))); + }; + NullState::Nullable { nulls, null_order } } - enc.encode(row, out) + }; + Ok(Self { + encoder, + null_state, }) - }) -} + } -/// Enum for static dispatch of concrete encoders. -enum Encoder<'a> { - Boolean(BooleanEncoder<'a>), - Int(IntEncoder<'a, Int32Type>), - Long(LongEncoder<'a, Int64Type>), - Timestamp(LongEncoder<'a, TimestampMicrosecondType>), - Float32(F32Encoder<'a>), - Float64(F64Encoder<'a>), - Binary(BinaryEncoder<'a, i32>), + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + match &self.null_state { + NullState::NonNullable => {} + NullState::NullableNoNulls { union_value_byte } => out + .write_all(&[*union_value_byte]) + .map_err(|e| ArrowError::IoError(format!("write union value branch: {e}"), e))?, + NullState::Nullable { nulls, null_order } if nulls.is_null(idx) => { + return write_optional_index(out, true, *null_order); // no value to write + } + NullState::Nullable { null_order, .. } => { + write_optional_index(out, false, *null_order)?; + } + } + self.encoder.encode(out, idx) + } } -impl<'a> Encoder<'a> { - /// Encode the value at `idx`. - #[inline] - fn encode(&mut self, idx: usize, out: &mut W) -> Result<(), ArrowError> { - match self { - Encoder::Boolean(e) => e.encode(idx, out), - Encoder::Int(e) => e.encode(idx, out), - Encoder::Long(e) => e.encode(idx, out), - Encoder::Timestamp(e) => e.encode(idx, out), - Encoder::Float32(e) => e.encode(idx, out), - Encoder::Float64(e) => e.encode(idx, out), - Encoder::Binary(e) => e.encode(idx, out), - } +fn union_value_branch_byte(null_order: Nullability, is_null: bool) -> u8 { + let nulls_first = null_order == Nullability::default(); + if nulls_first == is_null { + 0x00 + } else { + 0x02 } } -/// An encoder + a null buffer for nullable fields. -pub struct NullableEncoder<'a> { - encoder: Encoder<'a>, - nulls: Option, +/// Per‑site encoder plan for a field. This mirrors the Avro structure, so nested +/// optional branch order can be honored exactly as declared by the schema. +#[derive(Debug, Clone)] +enum FieldPlan { + /// Non-nested scalar/logical type + Scalar, + /// Record/Struct with Avro‑ordered children + Struct { encoders: Vec }, + /// Array with item‑site nullability and nested plan + List { + items_nullability: Option, + item_plan: Box, + }, } -impl<'a> NullableEncoder<'a> { - /// Create a new nullable encoder, wrapping a non-null encoder and a null buffer. - #[inline] - fn new(encoder: Encoder<'a>, nulls: Option) -> Self { - Self { encoder, nulls } - } +#[derive(Debug, Clone)] +struct FieldBinding { + /// Index of the Arrow field/column associated with this Avro field site + arrow_index: usize, + /// Nullability/order for this site (None for required fields) + nullability: Option, + /// Nested plan for this site + plan: FieldPlan, +} - /// Encode the value at `idx`, assuming it's not-null. - #[inline] - fn encode(&mut self, idx: usize, out: &mut W) -> Result<(), ArrowError> { - self.encoder.encode(idx, out) +/// Builder for `RecordEncoder` write plan +#[derive(Debug)] +pub struct RecordEncoderBuilder<'a> { + avro_root: &'a AvroField, + arrow_schema: &'a ArrowSchema, +} + +impl<'a> RecordEncoderBuilder<'a> { + /// Create a new builder from the Avro root and Arrow schema. + pub fn new(avro_root: &'a AvroField, arrow_schema: &'a ArrowSchema) -> Self { + Self { + avro_root, + arrow_schema, + } } - /// Check if the value at `idx` is null. - #[inline] - fn is_null(&self, idx: usize) -> bool { - self.nulls.as_ref().is_some_and(|nulls| nulls.is_null(idx)) + /// Build the `RecordEncoder` by walking the Avro **record** root in Avro order, + /// resolving each field to an Arrow index by name. + pub fn build(self) -> Result { + let avro_root_dt = self.avro_root.data_type(); + let Codec::Struct(root_fields) = avro_root_dt.codec() else { + return Err(ArrowError::SchemaError( + "Top-level Avro schema must be a record/struct".into(), + )); + }; + let mut columns = Vec::with_capacity(root_fields.len()); + for root_field in root_fields.as_ref() { + let name = root_field.name(); + let arrow_index = self.arrow_schema.index_of(name).map_err(|e| { + ArrowError::SchemaError(format!("Schema mismatch for field '{name}': {e}")) + })?; + columns.push(FieldBinding { + arrow_index, + nullability: root_field.data_type().nullability(), + plan: FieldPlan::build( + root_field.data_type(), + self.arrow_schema.field(arrow_index), + )?, + }); + } + Ok(RecordEncoder { columns }) } } -/// Creates an Avro encoder for the given `array`. -pub fn make_encoder<'a>(array: &'a dyn Array) -> Result, ArrowError> { - let nulls = array.nulls().cloned(); - let enc = match array.data_type() { - DataType::Boolean => { - let arr = array.as_boolean(); - NullableEncoder::new(Encoder::Boolean(BooleanEncoder(arr)), nulls) - } - DataType::Int32 => { - let arr = array.as_primitive::(); - NullableEncoder::new(Encoder::Int(IntEncoder(arr)), nulls) - } - DataType::Int64 => { - let arr = array.as_primitive::(); - NullableEncoder::new(Encoder::Long(LongEncoder(arr)), nulls) - } - DataType::Float32 => { - let arr = array.as_primitive::(); - NullableEncoder::new(Encoder::Float32(F32Encoder(arr)), nulls) - } - DataType::Float64 => { - let arr = array.as_primitive::(); - NullableEncoder::new(Encoder::Float64(F64Encoder(arr)), nulls) +/// A pre-computed plan for encoding a `RecordBatch` to Avro. +/// +/// Derived from an Avro schema and an Arrow schema. It maps +/// top-level Avro fields to Arrow columns and contains a nested encoding plan +/// for each column. +#[derive(Debug, Clone)] +pub struct RecordEncoder { + columns: Vec, +} + +impl RecordEncoder { + fn prepare_for_batch<'a>( + &'a self, + batch: &'a RecordBatch, + ) -> Result>, ArrowError> { + let schema_binding = batch.schema(); + let fields = schema_binding.fields(); + let arrays = batch.columns(); + let mut out = Vec::with_capacity(self.columns.len()); + for col_plan in self.columns.iter() { + let arrow_index = col_plan.arrow_index; + let array = arrays.get(arrow_index).ok_or_else(|| { + ArrowError::SchemaError(format!("Column index {arrow_index} out of range")) + })?; + let field = fields[arrow_index].as_ref(); + let encoder = prepare_value_site_encoder( + array.as_ref(), + field, + col_plan.nullability, + &col_plan.plan, + )?; + out.push(encoder); } - DataType::Binary => { - let arr = array.as_binary::(); - NullableEncoder::new(Encoder::Binary(BinaryEncoder(arr)), nulls) + Ok(out) + } + + /// Encode a `RecordBatch` using this encoder plan. + /// + /// Tip: Wrap `out` in a `std::io::BufWriter` to reduce the overhead of many small writes. + pub fn encode(&self, out: &mut W, batch: &RecordBatch) -> Result<(), ArrowError> { + let mut column_encoders = self.prepare_for_batch(batch)?; + for row in 0..batch.num_rows() { + for encoder in column_encoders.iter_mut() { + encoder.encode(out, row)?; + } } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let arr = array.as_primitive::(); - NullableEncoder::new(Encoder::Timestamp(LongEncoder(arr)), nulls) + Ok(()) + } +} + +fn find_struct_child_index(fields: &arrow_schema::Fields, name: &str) -> Option { + fields.iter().position(|f| f.name() == name) +} + +impl FieldPlan { + fn build(avro_dt: &AvroDataType, arrow_field: &Field) -> Result { + match avro_dt.codec() { + Codec::Struct(avro_fields) => { + let fields = match arrow_field.data_type() { + DataType::Struct(struct_fields) => struct_fields, + other => { + return Err(ArrowError::SchemaError(format!( + "Avro struct maps to Arrow Struct, found: {other:?}" + ))) + } + }; + let mut encoders = Vec::with_capacity(avro_fields.len()); + for avro_field in avro_fields.iter() { + let name = avro_field.name().to_string(); + let idx = find_struct_child_index(fields, &name).ok_or_else(|| { + ArrowError::SchemaError(format!( + "Struct field '{name}' not present in Arrow field '{}'", + arrow_field.name() + )) + })?; + encoders.push(FieldBinding { + arrow_index: idx, + nullability: avro_field.data_type().nullability(), + plan: FieldPlan::build(avro_field.data_type(), fields[idx].as_ref())?, + }); + } + Ok(FieldPlan::Struct { encoders }) + } + Codec::List(items_dt) => match arrow_field.data_type() { + DataType::List(field_ref) => Ok(FieldPlan::List { + items_nullability: items_dt.nullability(), + item_plan: Box::new(FieldPlan::build(items_dt.as_ref(), field_ref.as_ref())?), + }), + DataType::LargeList(field_ref) => Ok(FieldPlan::List { + items_nullability: items_dt.nullability(), + item_plan: Box::new(FieldPlan::build(items_dt.as_ref(), field_ref.as_ref())?), + }), + other => Err(ArrowError::SchemaError(format!( + "Avro array maps to Arrow List/LargeList, found: {other:?}" + ))), + }, + _ => Ok(FieldPlan::Scalar), } - other => { - return Err(ArrowError::NotYetImplemented(format!( - "Unsupported data type for Avro encoding in slim build: {other:?}" - ))) + } +} + +enum Encoder<'a> { + Boolean(BooleanEncoder<'a>), + Int(IntEncoder<'a, Int32Type>), + Long(LongEncoder<'a, Int64Type>), + Timestamp(LongEncoder<'a, TimestampMicrosecondType>), + Float32(F32Encoder<'a>), + Float64(F64Encoder<'a>), + Binary(BinaryEncoder<'a, i32>), + LargeBinary(BinaryEncoder<'a, i64>), + Utf8(Utf8Encoder<'a>), + Utf8Large(Utf8LargeEncoder<'a>), + List(Box>), + LargeList(Box>), + Struct(Box>), +} + +impl<'a> Encoder<'a> { + /// Encode the value at `idx`. + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + match self { + Encoder::Boolean(e) => e.encode(out, idx), + Encoder::Int(e) => e.encode(out, idx), + Encoder::Long(e) => e.encode(out, idx), + Encoder::Timestamp(e) => e.encode(out, idx), + Encoder::Float32(e) => e.encode(out, idx), + Encoder::Float64(e) => e.encode(out, idx), + Encoder::Binary(e) => e.encode(out, idx), + Encoder::LargeBinary(e) => e.encode(out, idx), + Encoder::Utf8(e) => e.encode(out, idx), + Encoder::Utf8Large(e) => e.encode(out, idx), + Encoder::List(e) => e.encode(out, idx), + Encoder::LargeList(e) => e.encode(out, idx), + Encoder::Struct(e) => e.encode(out, idx), } - }; - Ok(enc) + } } struct BooleanEncoder<'a>(&'a arrow_array::BooleanArray); impl BooleanEncoder<'_> { - #[inline] - fn encode(&mut self, idx: usize, out: &mut W) -> Result<(), ArrowError> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { write_bool(out, self.0.value(idx)) } } @@ -230,8 +460,7 @@ impl BooleanEncoder<'_> { /// Generic Avro `int` encoder for primitive arrays with `i32` native values. struct IntEncoder<'a, P: ArrowPrimitiveType>(&'a PrimitiveArray

); impl<'a, P: ArrowPrimitiveType> IntEncoder<'a, P> { - #[inline] - fn encode(&mut self, idx: usize, out: &mut W) -> Result<(), ArrowError> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { write_int(out, self.0.value(idx)) } } @@ -239,8 +468,7 @@ impl<'a, P: ArrowPrimitiveType> IntEncoder<'a, P> { /// Generic Avro `long` encoder for primitive arrays with `i64` native values. struct LongEncoder<'a, P: ArrowPrimitiveType>(&'a PrimitiveArray

); impl<'a, P: ArrowPrimitiveType> LongEncoder<'a, P> { - #[inline] - fn encode(&mut self, idx: usize, out: &mut W) -> Result<(), ArrowError> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { write_long(out, self.0.value(idx)) } } @@ -248,16 +476,14 @@ impl<'a, P: ArrowPrimitiveType> LongEncoder<'a, P> { /// Unified binary encoder generic over offset size (i32/i64). struct BinaryEncoder<'a, O: OffsetSizeTrait>(&'a GenericBinaryArray); impl<'a, O: OffsetSizeTrait> BinaryEncoder<'a, O> { - #[inline] - fn encode(&mut self, idx: usize, out: &mut W) -> Result<(), ArrowError> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { write_len_prefixed(out, self.0.value(idx)) } } struct F32Encoder<'a>(&'a arrow_array::Float32Array); impl F32Encoder<'_> { - #[inline] - fn encode(&mut self, idx: usize, out: &mut W) -> Result<(), ArrowError> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { // Avro float: 4 bytes, IEEE-754 little-endian let bits = self.0.value(idx).to_bits(); out.write_all(&bits.to_le_bytes()) @@ -267,11 +493,274 @@ impl F32Encoder<'_> { struct F64Encoder<'a>(&'a arrow_array::Float64Array); impl F64Encoder<'_> { - #[inline] - fn encode(&mut self, idx: usize, out: &mut W) -> Result<(), ArrowError> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { // Avro double: 8 bytes, IEEE-754 little-endian let bits = self.0.value(idx).to_bits(); out.write_all(&bits.to_le_bytes()) .map_err(|e| ArrowError::IoError(format!("write f64: {e}"), e)) } } + +struct Utf8GenericEncoder<'a, O: OffsetSizeTrait>(&'a GenericStringArray); + +impl<'a, O: OffsetSizeTrait> Utf8GenericEncoder<'a, O> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_len_prefixed(out, self.0.value(idx).as_bytes()) + } +} + +type Utf8Encoder<'a> = Utf8GenericEncoder<'a, i32>; +type Utf8LargeEncoder<'a> = Utf8GenericEncoder<'a, i64>; + +struct StructEncoder<'a> { + encoders: Vec>, +} + +impl<'a> StructEncoder<'a> { + fn try_new( + array: &'a StructArray, + field_bindings: &[FieldBinding], + ) -> Result { + let DataType::Struct(fields) = array.data_type() else { + return Err(ArrowError::SchemaError("Expected Struct".into())); + }; + let mut encoders = Vec::with_capacity(field_bindings.len()); + for field_binding in field_bindings { + let idx = field_binding.arrow_index; + let column = array.columns().get(idx).ok_or_else(|| { + ArrowError::SchemaError(format!("Struct child index {idx} out of range")) + })?; + let field = fields.get(idx).ok_or_else(|| { + ArrowError::SchemaError(format!("Struct child index {idx} out of range")) + })?; + let encoder = prepare_value_site_encoder( + column.as_ref(), + field, + field_binding.nullability, + &field_binding.plan, + )?; + encoders.push(encoder); + } + Ok(Self { encoders }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + for encoder in self.encoders.iter_mut() { + encoder.encode(out, idx)?; + } + Ok(()) + } +} + +/// Encode a blocked range of items with Avro array block framing. +/// +/// `write_item` must take `(out, index)` to maintain the "out-first" convention. +fn encode_blocked_range( + out: &mut W, + start: usize, + end: usize, + mut write_item: F, +) -> Result<(), ArrowError> +where + F: FnMut(&mut W, usize) -> Result<(), ArrowError>, +{ + let len = end.saturating_sub(start); + if len == 0 { + // Zero-length terminator per Avro spec. + write_long(out, 0)?; + return Ok(()); + } + // Emit a single positive block for performance, then the end marker. + write_long(out, len as i64)?; + for row in start..end { + write_item(out, row)?; + } + write_long(out, 0)?; + Ok(()) +} + +struct ListEncoder<'a, O: OffsetSizeTrait> { + list: &'a GenericListArray, + values: FieldEncoder<'a>, + values_offset: usize, +} + +type ListEncoder32<'a> = ListEncoder<'a, i32>; +type ListEncoder64<'a> = ListEncoder<'a, i64>; + +impl<'a, O: OffsetSizeTrait> ListEncoder<'a, O> { + fn try_new( + list: &'a GenericListArray, + items_nullability: Option, + item_plan: &FieldPlan, + ) -> Result { + let child_field = match list.data_type() { + DataType::List(field) => field.as_ref(), + DataType::LargeList(field) => field.as_ref(), + _ => { + return Err(ArrowError::SchemaError( + "Expected List or LargeList for ListEncoder".into(), + )) + } + }; + let values_enc = prepare_value_site_encoder( + list.values().as_ref(), + child_field, + items_nullability, + item_plan, + )?; + Ok(Self { + list, + values: values_enc, + values_offset: list.values().offset(), + }) + } + + fn encode_list_range( + &mut self, + out: &mut W, + start: usize, + end: usize, + ) -> Result<(), ArrowError> { + encode_blocked_range(out, start, end, |out, row| { + self.values + .encode(out, row.saturating_sub(self.values_offset)) + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let offsets = self.list.offsets(); + let start = offsets[idx].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Error converting offset[{idx}] to usize")) + })?; + let end = offsets[idx + 1].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting offset[{}] to usize", + idx + 1 + )) + })?; + self.encode_list_range(out, start, end) + } +} + +fn prepare_value_site_encoder<'a>( + values_array: &'a dyn Array, + value_field: &Field, + nullability: Option, + plan: &FieldPlan, +) -> Result, ArrowError> { + // Effective nullability is computed here from the writer-declared site nullability and data. + FieldEncoder::make_encoder(values_array, value_field, plan, nullability) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::types::Int32Type; + use arrow_array::{ + Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, + Int64Array, LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, StringArray, + TimestampMicrosecondArray, + }; + use arrow_schema::{DataType, Field, Fields}; + + fn zigzag_i64(v: i64) -> u64 { + ((v << 1) ^ (v >> 63)) as u64 + } + + fn varint(mut x: u64) -> Vec { + let mut out = Vec::new(); + while (x & !0x7f) != 0 { + out.push(((x & 0x7f) as u8) | 0x80); + x >>= 7; + } + out.push((x & 0x7f) as u8); + out + } + + fn avro_long_bytes(v: i64) -> Vec { + varint(zigzag_i64(v)) + } + + fn avro_len_prefixed_bytes(payload: &[u8]) -> Vec { + let mut out = avro_long_bytes(payload.len() as i64); + out.extend_from_slice(payload); + out + } + + fn encode_all( + array: &dyn Array, + plan: &FieldPlan, + nullability: Option, + ) -> Vec { + let field = Field::new("f", array.data_type().clone(), true); + let mut enc = FieldEncoder::make_encoder(array, &field, plan, nullability).unwrap(); + let mut out = Vec::new(); + for i in 0..array.len() { + enc.encode(&mut out, i).unwrap(); + } + out + } + + fn assert_bytes_eq(actual: &[u8], expected: &[u8]) { + if actual != expected { + let to_hex = |b: &[u8]| { + b.iter() + .map(|x| format!("{:02X}", x)) + .collect::>() + .join(" ") + }; + panic!( + "mismatch\n expected: [{}]\n actual: [{}]", + to_hex(expected), + to_hex(actual) + ); + } + } + + #[test] + fn binary_encoder() { + let values: Vec<&[u8]> = vec![b"", b"ab", b"\x00\xFF"]; + let arr = BinaryArray::from_vec(values); + let mut expected = Vec::new(); + for payload in [b"" as &[u8], b"ab", b"\x00\xFF"] { + expected.extend(avro_len_prefixed_bytes(payload)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn large_binary_encoder() { + let values: Vec<&[u8]> = vec![b"xyz", b""]; + let arr = LargeBinaryArray::from_vec(values); + let mut expected = Vec::new(); + for payload in [b"xyz" as &[u8], b""] { + expected.extend(avro_len_prefixed_bytes(payload)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn utf8_encoder() { + let arr = StringArray::from(vec!["", "A", "BC"]); + let mut expected = Vec::new(); + for s in ["", "A", "BC"] { + expected.extend(avro_len_prefixed_bytes(s.as_bytes())); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn large_utf8_encoder() { + let arr = LargeStringArray::from(vec!["hello", ""]); + let mut expected = Vec::new(); + for s in ["hello", ""] { + expected.extend(avro_len_prefixed_bytes(s.as_bytes())); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } +} diff --git a/arrow-avro/src/writer/format.rs b/arrow-avro/src/writer/format.rs index 0ebc7a64b422..6fac9e8286a2 100644 --- a/arrow-avro/src/writer/format.rs +++ b/arrow-avro/src/writer/format.rs @@ -17,17 +17,15 @@ use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; use crate::schema::{AvroSchema, SCHEMA_METADATA_KEY}; -use crate::writer::encoder::{write_long, EncoderOptions}; +use crate::writer::encoder::write_long; use arrow_schema::{ArrowError, Schema}; use rand::RngCore; -use serde_json::{Map as JsonMap, Value as JsonValue}; use std::fmt::Debug; use std::io::Write; /// Format abstraction implemented by each container‐level writer. pub trait AvroFormat: Debug + Default { /// Write any bytes required at the very beginning of the output stream - /// (file header, etc.). /// Implementations **must not** write any record data. fn start_stream( &mut self, @@ -44,24 +42,6 @@ pub trait AvroFormat: Debug + Default { #[derive(Debug, Default)] pub struct AvroOcfFormat { sync_marker: [u8; 16], - /// Optional encoder behavior hints to keep file header schema ordering - /// consistent with value encoding (e.g. Impala null-second). - encoder_options: EncoderOptions, -} - -impl AvroOcfFormat { - /// Optional helper to attach encoder options (i.e., Impala null-second) to the format. - #[allow(dead_code)] - pub fn with_encoder_options(mut self, opts: EncoderOptions) -> Self { - self.encoder_options = opts; - self - } - - /// Access the options used by this format. - #[allow(dead_code)] - pub fn encoder_options(&self) -> &EncoderOptions { - &self.encoder_options - } } impl AvroFormat for AvroOcfFormat { diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs index 4c46289b52c5..a5b2691bb816 100644 --- a/arrow-avro/src/writer/mod.rs +++ b/arrow-avro/src/writer/mod.rs @@ -32,13 +32,14 @@ pub mod encoder; /// Logic for different Avro container file formats. pub mod format; +use crate::codec::AvroFieldBuilder; use crate::compression::CompressionCodec; -use crate::schema::AvroSchema; -use crate::writer::encoder::{encode_record_batch, write_long}; +use crate::schema::{AvroSchema, SCHEMA_METADATA_KEY}; +use crate::writer::encoder::{write_long, RecordEncoder, RecordEncoderBuilder}; use crate::writer::format::{AvroBinaryFormat, AvroFormat, AvroOcfFormat}; use arrow_array::RecordBatch; use arrow_schema::{ArrowError, Schema}; -use std::io::{self, Write}; +use std::io::Write; use std::sync::Arc; /// Builder to configure and create a `Writer`. @@ -46,6 +47,7 @@ use std::sync::Arc; pub struct WriterBuilder { schema: Schema, codec: Option, + capacity: usize, } impl WriterBuilder { @@ -54,6 +56,7 @@ impl WriterBuilder { Self { schema, codec: None, + capacity: 1024, } } @@ -63,19 +66,41 @@ impl WriterBuilder { self } + /// Sets the capacity for the given object and returns the modified instance. + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.capacity = capacity; + self + } + /// Create a new `Writer` with specified `AvroFormat` and builder options. - pub fn build(self, writer: W) -> Writer + /// Performs one‑time startup (header/stream init, encoder plan). + pub fn build(self, mut writer: W) -> Result, ArrowError> where W: Write, F: AvroFormat, { - Writer { + let mut format = F::default(); + let avro_schema = match self.schema.metadata.get(SCHEMA_METADATA_KEY) { + Some(json) => AvroSchema::new(json.clone()), + None => AvroSchema::try_from(&self.schema)?, + }; + let mut md = self.schema.metadata().clone(); + md.insert( + SCHEMA_METADATA_KEY.to_string(), + avro_schema.clone().json_string, + ); + let schema = Arc::new(Schema::new_with_metadata(self.schema.fields().clone(), md)); + format.start_stream(&mut writer, &schema, self.codec)?; + let avro_root = AvroFieldBuilder::new(&avro_schema.schema()?).build()?; + let encoder = RecordEncoderBuilder::new(&avro_root, schema.as_ref()).build()?; + Ok(Writer { writer, - schema: Arc::from(self.schema), - format: F::default(), + schema, + format, compression: self.codec, - started: false, - } + capacity: self.capacity, + encoder, + }) } } @@ -86,7 +111,8 @@ pub struct Writer { schema: Arc, format: F, compression: Option, - started: bool, + capacity: usize, + encoder: RecordEncoder, } /// Alias for an Avro **Object Container File** writer. @@ -95,15 +121,9 @@ pub type AvroWriter = Writer; pub type AvroStreamWriter = Writer; impl Writer { - /// Convenience constructor – same as + /// Convenience constructor – same as [`WriterBuilder::build`] with `AvroOcfFormat`. pub fn new(writer: W, schema: Schema) -> Result { - Ok(WriterBuilder::new(schema).build::(writer)) - } - - /// Change the compression codec after construction. - pub fn with_compression(mut self, codec: Option) -> Self { - self.compression = codec; - self + WriterBuilder::new(schema).build::(writer) } /// Return a reference to the 16‑byte sync marker generated for this file. @@ -115,19 +135,14 @@ impl Writer { impl Writer { /// Convenience constructor to create a new [`AvroStreamWriter`]. pub fn new(writer: W, schema: Schema) -> Result { - Ok(WriterBuilder::new(schema).build::(writer)) + WriterBuilder::new(schema).build::(writer) } } impl Writer { /// Serialize one [`RecordBatch`] to the output. pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { - if !self.started { - self.format - .start_stream(&mut self.writer, &self.schema, self.compression)?; - self.started = true; - } - if batch.schema() != self.schema { + if batch.schema().fields() != self.schema.fields() { return Err(ArrowError::SchemaError( "Schema of RecordBatch differs from Writer schema".to_string(), )); @@ -150,11 +165,6 @@ impl Writer { /// Flush remaining buffered data and (for OCF) ensure the header is present. pub fn finish(&mut self) -> Result<(), ArrowError> { - if !self.started { - self.format - .start_stream(&mut self.writer, &self.schema, self.compression)?; - self.started = true; - } self.writer .flush() .map_err(|e| ArrowError::IoError(format!("Error flushing writer: {e}"), e)) @@ -167,7 +177,7 @@ impl Writer { fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), ArrowError> { let mut buf = Vec::::with_capacity(1024); - encode_record_batch(batch, &mut buf)?; + self.encoder.encode(&mut buf, batch)?; let encoded = match self.compression { Some(codec) => codec.compress(&buf)?, None => buf, @@ -184,19 +194,22 @@ impl Writer { } fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { - encode_record_batch(batch, &mut self.writer) + self.encoder.encode(&mut self.writer, batch) } } #[cfg(test)] mod tests { use super::*; + use crate::compression::CompressionCodec; use crate::reader::ReaderBuilder; + use crate::schema::{AvroSchema, SchemaStore}; use crate::test_util::arrow_test_data; - use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch, StringArray}; - use arrow_schema::{DataType, Field, Schema}; + use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, IntervalUnit, Schema}; use std::fs::File; - use std::io::BufReader; + use std::io::{BufReader, Cursor}; + use std::path::PathBuf; use std::sync::Arc; use tempfile::NamedTempFile; @@ -217,10 +230,6 @@ mod tests { .expect("failed to build test RecordBatch") } - fn contains_ascii(haystack: &[u8], needle: &[u8]) -> bool { - haystack.windows(needle.len()).any(|w| w == needle) - } - #[test] fn test_ocf_writer_generates_header_and_sync() -> Result<(), ArrowError> { let batch = make_batch(); @@ -230,12 +239,8 @@ mod tests { writer.finish()?; let out = writer.into_inner(); assert_eq!(&out[..4], b"Obj\x01", "OCF magic bytes missing/incorrect"); - let sync = AvroWriter::new(Vec::new(), make_schema())? - .sync_marker() - .cloned(); let trailer = &out[out.len() - 16..]; assert_eq!(trailer.len(), 16, "expected 16‑byte sync marker"); - let _ = sync; Ok(()) } @@ -309,16 +314,20 @@ mod tests { let tmp = NamedTempFile::new().expect("create temp file"); let out_path = tmp.into_temp_path(); let out_file = File::create(&out_path).expect("create temp avro"); - let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?; - if rel.contains(".snappy.") { - writer = writer.with_compression(Some(CompressionCodec::Snappy)); + let codec = if rel.contains(".snappy.") { + Some(CompressionCodec::Snappy) } else if rel.contains(".zstandard.") { - writer = writer.with_compression(Some(CompressionCodec::ZStandard)); + Some(CompressionCodec::ZStandard) } else if rel.contains(".bzip2.") { - writer = writer.with_compression(Some(CompressionCodec::Bzip2)); + Some(CompressionCodec::Bzip2) } else if rel.contains(".xz.") { - writer = writer.with_compression(Some(CompressionCodec::Xz)); - } + Some(CompressionCodec::Xz) + } else { + None + }; + let mut writer = WriterBuilder::new(original.schema().as_ref().clone()) + .with_compression(codec) + .build::<_, AvroOcfFormat>(out_file)?; writer.write(&original)?; writer.finish()?; drop(writer); @@ -338,4 +347,72 @@ mod tests { } Ok(()) } + + #[test] + fn test_roundtrip_nested_records_writer() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/nested_records.avro"); + let rdr_file = File::open(&path).expect("open nested_records.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for nested_records.avro"); + let schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + { + let out_file = File::create(&out_path).expect("create output avro"); + let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + } + let rt_file = File::open(&out_path).expect("open round_trip avro"); + let mut rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!( + round_trip, original, + "Round-trip batch mismatch for nested_records.avro" + ); + Ok(()) + } + + #[test] + fn test_roundtrip_nested_lists_writer() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/nested_lists.snappy.avro"); + let rdr_file = File::open(&path).expect("open nested_lists.snappy.avro"); + let mut reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for nested_lists.snappy.avro"); + let schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + { + let out_file = File::create(&out_path).expect("create output avro"); + let mut writer = WriterBuilder::new(original.schema().as_ref().clone()) + .with_compression(Some(CompressionCodec::Snappy)) + .build::<_, AvroOcfFormat>(out_file)?; + writer.write(&original)?; + writer.finish()?; + } + let rt_file = File::open(&out_path).expect("open round_trip avro"); + let mut rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!( + round_trip, original, + "Round-trip batch mismatch for nested_lists.snappy.avro" + ); + Ok(()) + } } diff --git a/arrow-avro/test/data/README.md b/arrow-avro/test/data/README.md new file mode 100644 index 000000000000..51416c8416d4 --- /dev/null +++ b/arrow-avro/test/data/README.md @@ -0,0 +1,147 @@ + + +# Avro test files for `arrow-avro` + +This directory contains small Avro Object Container Files (OCF) used by +`arrow-avro` tests to validate the `Reader` implementation. These files are generated from +a set of python scripts and will gradually be removed as they are merged into `arrow-testing`. + +## Decimal Files + +This directory contains OCF files used to exercise decoding of Avro’s `decimal` logical type +across both `bytes` and `fixed` encodings, and to cover Arrow decimal widths ranging +from `Decimal32` up through `Decimal256`. The files were generated from a +script (see **How these files were created** below). + +> **Avro decimal recap.** Avro’s `decimal` logical type annotates either a +> `bytes` or `fixed` primitive and stores the **two’s‑complement big‑endian +> representation of the unscaled integer** (value × 10^scale). Implementations +> should reject invalid combinations such as `scale > precision`. + +> **Arrow decimal recap.** Arrow defines `Decimal32`, `Decimal64`, `Decimal128`, +> and `Decimal256` types with maximum precisions of 9, 18, 38, and 76 digits, +> respectively. Tests here validate that the Avro reader selects compatible +> Arrow decimal widths given the Avro decimal’s precision and storage. + +--- + +All files are one‑column Avro OCFs with a field named `value`. Each contains 24 +rows with the sequence `1 … 24` rendered at the file’s declared `scale` +(i.e., at scale 10: `1.0000000000`, `2.0000000000`). + +| File | Avro storage | Decimal (precision, scale) | Intended Arrow width | +|---|---|---|---| +| `int256_decimal.avro` | `bytes` + `logicalType: decimal` | (76, 10) | `Decimal256` | +| `fixed256_decimal.avro` | `fixed[32]` + `logicalType: decimal` | (76, 10) | `Decimal256` | +| `fixed_length_decimal_legacy_32.avro` | `fixed[4]` + `logicalType: decimal` | (9, 2) | `Decimal32` (legacy fixed‑width path) | +| `int128_decimal.avro` | `bytes` + `logicalType: decimal` | (38, 2) | `Decimal128` | + +### Schemas (for reference) + +#### int256_decimal.avro + +```json +{ + "type": "record", + "name": "OneColDecimal256Bytes", + "fields": [{ + "name": "value", + "type": { "type": "bytes", "logicalType": "decimal", "precision": 76, "scale": 10 } + }] +} +``` + +#### fixed256_decimal.avro + +```json +{ + "type": "record", + "name": "OneColDecimal256Fixed", + "fields": [{ + "name": "value", + "type": { + "type": "fixed", "name": "Decimal256Fixed", "size": 32, + "logicalType": "decimal", "precision": 76, "scale": 10 + } + }] +} +``` + +#### fixed_length_decimal_legacy_32.avro + +```json +{ + "type": "record", + "name": "OneColDecimal32FixedLegacy", + "fields": [{ + "name": "value", + "type": { + "type": "fixed", "name": "Decimal32FixedLegacy", "size": 4, + "logicalType": "decimal", "precision": 9, "scale": 2 + } + }] +} +``` + +#### int128_decimal.avro + +```json +{ + "type": "record", + "name": "OneColDecimal128Bytes", + "fields": [{ + "name": "value", + "type": { "type": "bytes", "logicalType": "decimal", "precision": 38, "scale": 2 } + }] +} +``` + +### How these files were created + +All four files were generated by the Python script +`create_avro_decimal_files.py` authored for this purpose. The script uses +`fastavro` to write OCFs and encodes decimal values as required by the Avro +spec (two’s‑complement big‑endian of the unscaled integer). + +#### Re‑generation + +From the repository root (defaults write into arrow-avro/test/data): + +```bash +# 1) Ensure Python 3 is available, then install fastavro +python -m pip install --upgrade fastavro + +# 2) Fetch the script +curl -L -o create_avro_decimal_files.py \ +https://gist.githubusercontent.com/jecsand838/3890349bdb33082a3e8fdcae3257eef7/raw/create_avro_decimal_files.py + +# 3) Generate the files (prints a verification dump by default) +python create_avro_decimal_files.py -o arrow-avro/test/data +``` + +Options: +* --num-rows (default 24) — number of rows to emit per file +* --scale (default 10) — the decimal scale used for the 256 files +* --no-verify — skip reading the files back for printed verification + +## Other Files + +This directory contains other small OCF files used by `arrow-avro` tests. Details on these will be added in +follow-up PRs. \ No newline at end of file diff --git a/arrow-avro/test/data/fixed256_decimal.avro b/arrow-avro/test/data/fixed256_decimal.avro new file mode 100644 index 000000000000..d1fc97dd8c83 Binary files /dev/null and b/arrow-avro/test/data/fixed256_decimal.avro differ diff --git a/arrow-avro/test/data/fixed_length_decimal_legacy_32.avro b/arrow-avro/test/data/fixed_length_decimal_legacy_32.avro new file mode 100644 index 000000000000..b746df9619b5 Binary files /dev/null and b/arrow-avro/test/data/fixed_length_decimal_legacy_32.avro differ diff --git a/arrow-avro/test/data/int128_decimal.avro b/arrow-avro/test/data/int128_decimal.avro new file mode 100644 index 000000000000..bd54d20ba487 Binary files /dev/null and b/arrow-avro/test/data/int128_decimal.avro differ diff --git a/arrow-avro/test/data/int256_decimal.avro b/arrow-avro/test/data/int256_decimal.avro new file mode 100644 index 000000000000..62ad7ea4df08 Binary files /dev/null and b/arrow-avro/test/data/int256_decimal.avro differ diff --git a/arrow-avro/test/data/skippable_types.avro b/arrow-avro/test/data/skippable_types.avro new file mode 100644 index 000000000000..b0518e0056b5 Binary files /dev/null and b/arrow-avro/test/data/skippable_types.avro differ diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml index 49145cf987f9..32bbd35e811d 100644 --- a/arrow-cast/Cargo.toml +++ b/arrow-cast/Cargo.toml @@ -50,7 +50,8 @@ half = { version = "2.1", default-features = false } num = { version = "0.4", default-features = false, features = ["std"] } lexical-core = { version = "1.0", default-features = false, features = ["write-integers", "write-floats", "parse-integers", "parse-floats"] } atoi = "2.0.0" -comfy-table = { version = "7.0", optional = true, default-features = false } +# unpin after MSRV bump to 1.85 +comfy-table = { version = "=7.1.2", optional = true, default-features = false } base64 = "0.22" ryu = "1.0.16" diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs index c2cb38a226b6..e10943a6a91c 100644 --- a/arrow-csv/src/writer.rs +++ b/arrow-csv/src/writer.rs @@ -102,7 +102,7 @@ impl Writer { WriterBuilder::new().with_delimiter(delimiter).build(writer) } - /// Write a vector of record batches to a writable object + /// Write a RecordBatch to a writable object pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { let num_columns = batch.num_columns(); if self.beginning { diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index ca0d1c5e4b3d..854a149473d1 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -70,7 +70,7 @@ tls-ring = ["tonic/tls-ring"] tls-webpki-roots = ["tonic/tls-webpki-roots"] # Enable CLI tools -cli = ["arrow-array/chrono-tz", "arrow-cast/prettyprint", "tonic/tls-webpki-roots", "dep:anyhow", "dep:clap", "dep:tracing-log", "dep:tracing-subscriber", "dep:tokio"] +cli = ["arrow-array/chrono-tz", "arrow-cast/prettyprint", "tonic/tls-webpki-roots", "tonic/gzip", "tonic/deflate", "tonic/zstd", "dep:anyhow", "dep:clap", "dep:tracing-log", "dep:tracing-subscriber", "dep:tokio"] [dev-dependencies] arrow-cast = { workspace = true, features = ["prettyprint"] } diff --git a/arrow-flight/src/bin/flight_sql_client.rs b/arrow-flight/src/bin/flight_sql_client.rs index 7b9e34898ac8..154b59f5d379 100644 --- a/arrow-flight/src/bin/flight_sql_client.rs +++ b/arrow-flight/src/bin/flight_sql_client.rs @@ -21,11 +21,12 @@ use anyhow::{bail, Context, Result}; use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray}; use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions}; use arrow_flight::{ + flight_service_client::FlightServiceClient, sql::{client::FlightSqlServiceClient, CommandGetDbSchemas, CommandGetTables}, FlightInfo, }; use arrow_schema::Schema; -use clap::{Parser, Subcommand}; +use clap::{Parser, Subcommand, ValueEnum}; use core::str; use futures::TryStreamExt; use tonic::{ @@ -53,6 +54,24 @@ pub struct LoggingArgs { log_verbose_count: u8, } +/// gRPC/HTTP compression algorithms. +#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)] +pub enum CompressionEncoding { + Gzip, + Deflate, + Zstd, +} + +impl From for tonic::codec::CompressionEncoding { + fn from(encoding: CompressionEncoding) -> Self { + match encoding { + CompressionEncoding::Gzip => Self::Gzip, + CompressionEncoding::Deflate => Self::Deflate, + CompressionEncoding::Zstd => Self::Zstd, + } + } +} + #[derive(Debug, Parser)] struct ClientArgs { /// Additional headers. @@ -85,6 +104,14 @@ struct ClientArgs { #[clap(long)] tls: bool, + /// Dump TLS key log. + /// + /// The target file is specified by the `SSLKEYLOGFILE` environment variable. + /// + /// Requires `--tls`. + #[clap(long, requires = "tls")] + key_log: bool, + /// Server host. /// /// Required. @@ -96,6 +123,34 @@ struct ClientArgs { /// Defaults to `443` if `tls` is set, otherwise defaults to `80`. #[clap(long)] port: Option, + + /// Compression accepted by the client for responses sent by the server. + /// + /// The client will send this information to the server as part of the request. The server is free to pick an + /// algorithm from that list or use no compression (called "identity" encoding). + /// + /// You may define multiple algorithms by using a comma-separated list. + #[clap(long, value_delimiter = ',')] + accept_compression: Vec, + + /// Compression of requests sent by the client to the server. + /// + /// Since the client needs to decide on the compression before sending the request, there is no client<->server + /// negotiation. If the server does NOT support the chosen compression, it will respond with an error a la: + /// + /// ``` + /// Ipc error: Status { + /// code: Unimplemented, + /// message: "Content is compressed with `zstd` which isn't supported", + /// metadata: MetadataMap { headers: {"grpc-accept-encoding": "identity", ...} }, + /// ... + /// } + /// ``` + /// + /// Based on the algorithms listed in the `grpc-accept-encoding` header, you may make a more educated guess for + /// your next request. Note that `identity` is a synonym for "no compression". + #[clap(long)] + send_compression: Option, } #[derive(Debug, Parser)] @@ -357,7 +412,11 @@ async fn setup_client(args: ClientArgs) -> Result Result {{ - let array = $input.as_primitive::<$t>(); - for i in 0..array.len() { - if array.is_null(i) { - $builder.append_null(); - continue; - } - $builder.append_variant(Variant::from(array.value(i))); - } - }}; -} - -/// Convert the input array to a `VariantArray` row by row, using `method` -/// requiring a generic type to downcast the generic array to a specific -/// array type and `cast_fn` to transform each element to a type compatible with Variant -macro_rules! generic_conversion { - ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{ - let array = $input.$method::<$t>(); - for i in 0..array.len() { - if array.is_null(i) { - $builder.append_null(); - continue; - } - let cast_value = $cast_fn(array.value(i)); - $builder.append_variant(Variant::from(cast_value)); - } - }}; -} - -/// Convert the input array to a `VariantArray` row by row, using `method` -/// not requiring a generic type to downcast the generic array to a specific -/// array type and `cast_fn` to transform each element to a type compatible with Variant -macro_rules! non_generic_conversion { - ($method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{ - let array = $input.$method(); - for i in 0..array.len() { - if array.is_null(i) { - $builder.append_null(); - continue; - } - let cast_value = $cast_fn(array.value(i)); - $builder.append_variant(Variant::from(cast_value)); - } - }}; -} - -fn convert_timestamp( - time_unit: &TimeUnit, - time_zone: &Option>, - input: &dyn Array, - builder: &mut VariantArrayBuilder, -) { - let native_datetimes: Vec> = match time_unit { - arrow_schema::TimeUnit::Second => { - let ts_array = input - .as_any() - .downcast_ref::() - .expect("Array is not TimestampSecondArray"); - - ts_array - .iter() - .map(|x| x.map(|y| timestamp_s_to_datetime(y).unwrap())) - .collect() - } - arrow_schema::TimeUnit::Millisecond => { - let ts_array = input - .as_any() - .downcast_ref::() - .expect("Array is not TimestampMillisecondArray"); - - ts_array - .iter() - .map(|x| x.map(|y| timestamp_ms_to_datetime(y).unwrap())) - .collect() - } - arrow_schema::TimeUnit::Microsecond => { - let ts_array = input - .as_any() - .downcast_ref::() - .expect("Array is not TimestampMicrosecondArray"); - ts_array - .iter() - .map(|x| x.map(|y| timestamp_us_to_datetime(y).unwrap())) - .collect() - } - arrow_schema::TimeUnit::Nanosecond => { - let ts_array = input - .as_any() - .downcast_ref::() - .expect("Array is not TimestampNanosecondArray"); - ts_array - .iter() - .map(|x| x.map(|y| timestamp_ns_to_datetime(y).unwrap())) - .collect() - } - }; - - for x in native_datetimes { - match x { - Some(ndt) => { - if time_zone.is_none() { - builder.append_variant(ndt.into()); - } else { - let utc_dt: DateTime = Utc.from_utc_datetime(&ndt); - builder.append_variant(utc_dt.into()); - } - } - None => { - builder.append_null(); - } - } - } -} - -/// Convert a decimal value to a `VariantDecimal` -macro_rules! decimal_to_variant_decimal { - ($v:ident, $scale:expr, $value_type:ty, $variant_type:ty) => { - if *$scale < 0 { - // For negative scale, we need to multiply the value by 10^|scale| - // For example: 123 with scale -2 becomes 12300 - let multiplier = (10 as $value_type).pow((-*$scale) as u32); - // Check for overflow - if $v > 0 && $v > <$value_type>::MAX / multiplier { - return Variant::Null; - } - if $v < 0 && $v < <$value_type>::MIN / multiplier { - return Variant::Null; - } - <$variant_type>::try_new($v * multiplier, 0) - .map(|v| v.into()) - .unwrap_or(Variant::Null) - } else { - <$variant_type>::try_new($v, *$scale as u8) - .map(|v| v.into()) - .unwrap_or(Variant::Null) - } - }; -} - -/// Convert arrays that don't need generic type parameters -macro_rules! cast_conversion_nongeneric { - ($method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{ - let array = $input.$method(); - for i in 0..array.len() { - if array.is_null(i) { - $builder.append_null(); - continue; - } - let cast_value = $cast_fn(array.value(i)); - $builder.append_variant(Variant::from(cast_value)); - } - }}; -} - -/// Convert string arrays using the offset size as the type parameter -macro_rules! cast_conversion_string { - ($offset_type:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{ - let array = $input.$method::<$offset_type>(); - for i in 0..array.len() { - if array.is_null(i) { - $builder.append_null(); - continue; - } - let cast_value = $cast_fn(array.value(i)); - $builder.append_variant(Variant::from(cast_value)); - } - }}; -} - /// Casts a typed arrow [`Array`] to a [`VariantArray`]. This is useful when you /// need to convert a specific data type /// @@ -245,62 +79,50 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { let mut builder = VariantArrayBuilder::new(input.len()); let input_type = input.data_type(); - // todo: handle other types like Boolean, Date, Timestamp, etc. match input_type { - DataType::Boolean => { - non_generic_conversion!(as_boolean, |v| v, input, builder); - } - - DataType::Binary => { - generic_conversion!(BinaryType, as_bytes, |v| v, input, builder); - } - DataType::LargeBinary => { - generic_conversion!(LargeBinaryType, as_bytes, |v| v, input, builder); + DataType::Null => { + for _ in 0..input.len() { + builder.append_null(); + } } - DataType::BinaryView => { - generic_conversion!(BinaryViewType, as_byte_view, |v| v, input, builder); + DataType::Boolean => { + non_generic_conversion_array!(input.as_boolean(), |v| v, builder); } DataType::Int8 => { - primitive_conversion!(Int8Type, input, builder); + primitive_conversion_array!(Int8Type, input, builder); } DataType::Int16 => { - primitive_conversion!(Int16Type, input, builder); + primitive_conversion_array!(Int16Type, input, builder); } DataType::Int32 => { - primitive_conversion!(Int32Type, input, builder); + primitive_conversion_array!(Int32Type, input, builder); } DataType::Int64 => { - primitive_conversion!(Int64Type, input, builder); + primitive_conversion_array!(Int64Type, input, builder); } DataType::UInt8 => { - primitive_conversion!(UInt8Type, input, builder); + primitive_conversion_array!(UInt8Type, input, builder); } DataType::UInt16 => { - primitive_conversion!(UInt16Type, input, builder); + primitive_conversion_array!(UInt16Type, input, builder); } DataType::UInt32 => { - primitive_conversion!(UInt32Type, input, builder); + primitive_conversion_array!(UInt32Type, input, builder); } DataType::UInt64 => { - primitive_conversion!(UInt64Type, input, builder); + primitive_conversion_array!(UInt64Type, input, builder); } DataType::Float16 => { - generic_conversion!( - Float16Type, - as_primitive, - |v: f16| -> f32 { v.into() }, - input, - builder - ); + generic_conversion_array!(Float16Type, as_primitive, f32::from, input, builder); } DataType::Float32 => { - primitive_conversion!(Float32Type, input, builder); + primitive_conversion_array!(Float32Type, input, builder); } DataType::Float64 => { - primitive_conversion!(Float64Type, input, builder); + primitive_conversion_array!(Float64Type, input, builder); } DataType::Decimal32(_, scale) => { - generic_conversion!( + generic_conversion_array!( Decimal32Type, as_primitive, |v| decimal_to_variant_decimal!(v, scale, i32, VariantDecimal4), @@ -309,7 +131,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { ); } DataType::Decimal64(_, scale) => { - generic_conversion!( + generic_conversion_array!( Decimal64Type, as_primitive, |v| decimal_to_variant_decimal!(v, scale, i64, VariantDecimal8), @@ -318,7 +140,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { ); } DataType::Decimal128(_, scale) => { - generic_conversion!( + generic_conversion_array!( Decimal128Type, as_primitive, |v| decimal_to_variant_decimal!(v, scale, i128, VariantDecimal16), @@ -327,7 +149,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { ); } DataType::Decimal256(_, scale) => { - generic_conversion!( + generic_conversion_array!( Decimal256Type, as_primitive, |v: i256| { @@ -344,21 +166,31 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { builder ); } - DataType::FixedSizeBinary(_) => { - non_generic_conversion!(as_fixed_size_binary, |v| v, input, builder); - } - DataType::Null => { - for _ in 0..input.len() { - builder.append_null(); - } - } DataType::Timestamp(time_unit, time_zone) => { convert_timestamp(time_unit, time_zone, input, &mut builder); } + DataType::Date32 => { + generic_conversion_array!( + Date32Type, + as_primitive, + |v: i32| -> NaiveDate { Date32Type::to_naive_date(v) }, + input, + builder + ); + } + DataType::Date64 => { + generic_conversion_array!( + Date64Type, + as_primitive, + |v: i64| { Date64Type::to_naive_date_opt(v).unwrap() }, + input, + builder + ); + } DataType::Time32(unit) => { match *unit { TimeUnit::Second => { - generic_conversion!( + generic_conversion_array!( Time32SecondType, as_primitive, // nano second are always 0 @@ -368,7 +200,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { ); } TimeUnit::Millisecond => { - generic_conversion!( + generic_conversion_array!( Time32MillisecondType, as_primitive, |v| NaiveTime::from_num_seconds_from_midnight_opt( @@ -391,7 +223,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { DataType::Time64(unit) => { match *unit { TimeUnit::Microsecond => { - generic_conversion!( + generic_conversion_array!( Time64MicrosecondType, as_primitive, |v| NaiveTime::from_num_seconds_from_midnight_opt( @@ -404,7 +236,7 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { ); } TimeUnit::Nanosecond => { - generic_conversion!( + generic_conversion_array!( Time64NanosecondType, as_primitive, |v| NaiveTime::from_num_seconds_from_midnight_opt( @@ -424,128 +256,339 @@ pub fn cast_to_variant(input: &dyn Array) -> Result { } }; } - DataType::Interval(_) => { + DataType::Duration(_) | DataType::Interval(_) => { return Err(ArrowError::InvalidArgumentError( - "Casting interval types to Variant is not supported. \ - The Variant format does not define interval/duration types." + "Casting duration/interval types to Variant is not supported. \ + The Variant format does not define duration/interval types." .to_string(), )); } + DataType::Binary => { + generic_conversion_array!(BinaryType, as_bytes, |v| v, input, builder); + } + DataType::LargeBinary => { + generic_conversion_array!(LargeBinaryType, as_bytes, |v| v, input, builder); + } + DataType::BinaryView => { + generic_conversion_array!(BinaryViewType, as_byte_view, |v| v, input, builder); + } + DataType::FixedSizeBinary(_) => { + non_generic_conversion_array!(input.as_fixed_size_binary(), |v| v, builder); + } DataType::Utf8 => { - cast_conversion_string!(i32, as_string, |v| v, input, builder); + generic_conversion_array!(i32, as_string, |v| v, input, builder); } DataType::LargeUtf8 => { - cast_conversion_string!(i64, as_string, |v| v, input, builder); + generic_conversion_array!(i64, as_string, |v| v, input, builder); } DataType::Utf8View => { - cast_conversion_nongeneric!(as_string_view, |v| v, input, builder); + non_generic_conversion_array!(input.as_string_view(), |v| v, builder); } - DataType::Struct(_) => { - let struct_array = input.as_struct(); - - // Pre-convert all field arrays once for better performance - // This avoids converting the same field array multiple times - // Alternative approach: Use slicing per row: field_array.slice(i, 1) - // However, pre-conversion is more efficient for typical use cases - let field_variant_arrays: Result, _> = struct_array - .columns() + DataType::List(_) => convert_list::(input, &mut builder)?, + DataType::LargeList(_) => convert_list::(input, &mut builder)?, + DataType::Struct(_) => convert_struct(input, &mut builder)?, + DataType::Map(field, _) => convert_map(field, input, &mut builder)?, + DataType::Union(fields, _) => convert_union(fields, input, &mut builder)?, + DataType::Dictionary(_, _) => convert_dictionary_encoded(input, &mut builder)?, + DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { + DataType::Int16 => convert_run_end_encoded::(input, &mut builder)?, + DataType::Int32 => convert_run_end_encoded::(input, &mut builder)?, + DataType::Int64 => convert_run_end_encoded::(input, &mut builder)?, + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported run ends type: {:?}", + run_ends.data_type() + ))); + } + }, + dt => { + return Err(ArrowError::CastError(format!( + "Unsupported data type for casting to Variant: {dt:?}", + ))); + } + }; + Ok(builder.build()) +} + +// TODO do we need a cast_with_options to allow specifying conversion behavior, +// e.g. how to handle overflows, whether to convert to Variant::Null or return +// an error, etc. ? + +/// Convert timestamp arrays to native datetimes +fn convert_timestamp( + time_unit: &TimeUnit, + time_zone: &Option>, + input: &dyn Array, + builder: &mut VariantArrayBuilder, +) { + let native_datetimes: Vec> = match time_unit { + arrow_schema::TimeUnit::Second => { + let ts_array = input + .as_any() + .downcast_ref::() + .expect("Array is not TimestampSecondArray"); + + ts_array .iter() - .map(|field_array| cast_to_variant(field_array.as_ref())) - .collect(); - let field_variant_arrays = field_variant_arrays?; + .map(|x| x.map(|y| timestamp_s_to_datetime(y).unwrap())) + .collect() + } + arrow_schema::TimeUnit::Millisecond => { + let ts_array = input + .as_any() + .downcast_ref::() + .expect("Array is not TimestampMillisecondArray"); - // Cache column names to avoid repeated calls - let column_names = struct_array.column_names(); + ts_array + .iter() + .map(|x| x.map(|y| timestamp_ms_to_datetime(y).unwrap())) + .collect() + } + arrow_schema::TimeUnit::Microsecond => { + let ts_array = input + .as_any() + .downcast_ref::() + .expect("Array is not TimestampMicrosecondArray"); + ts_array + .iter() + .map(|x| x.map(|y| timestamp_us_to_datetime(y).unwrap())) + .collect() + } + arrow_schema::TimeUnit::Nanosecond => { + let ts_array = input + .as_any() + .downcast_ref::() + .expect("Array is not TimestampNanosecondArray"); + ts_array + .iter() + .map(|x| x.map(|y| timestamp_ns_to_datetime(y).unwrap())) + .collect() + } + }; + + for x in native_datetimes { + match x { + Some(ndt) => { + if time_zone.is_none() { + builder.append_variant(ndt.into()); + } else { + let utc_dt: DateTime = Utc.from_utc_datetime(&ndt); + builder.append_variant(utc_dt.into()); + } + } + None => { + builder.append_null(); + } + } + } +} + +/// Generic function to convert list arrays (both List and LargeList) to variant arrays +fn convert_list( + input: &dyn Array, + builder: &mut VariantArrayBuilder, +) -> Result<(), ArrowError> { + let list_array = input.as_list::(); + let values = list_array.values(); + let offsets = list_array.offsets(); + + let first_offset = *offsets.first().expect("There should be an offset"); + let length = *offsets.last().expect("There should be an offset") - first_offset; + let sliced_values = values.slice(first_offset.as_usize(), length.as_usize()); + + let values_variant_array = cast_to_variant(sliced_values.as_ref())?; + let new_offsets = OffsetBuffer::new(ScalarBuffer::from_iter( + offsets.iter().map(|o| *o - first_offset), + )); + + for i in 0..list_array.len() { + if list_array.is_null(i) { + builder.append_null(); + continue; + } + + let start = new_offsets[i].as_usize(); + let end = new_offsets[i + 1].as_usize(); + + // Start building the inner VariantList + let mut variant_builder = VariantBuilder::new(); + let mut list_builder = variant_builder.new_list(); + + // Add all values from the slice + for j in start..end { + list_builder.append_value(values_variant_array.value(j)); + } + + list_builder.finish(); + + let (metadata, value) = variant_builder.finish(); + let variant = Variant::new(&metadata, &value); + builder.append_variant(variant) + } + + Ok(()) +} + +fn convert_struct(input: &dyn Array, builder: &mut VariantArrayBuilder) -> Result<(), ArrowError> { + let struct_array = input.as_struct(); + + // Pre-convert all field arrays once for better performance + // This avoids converting the same field array multiple times + // Alternative approach: Use slicing per row: field_array.slice(i, 1) + // However, pre-conversion is more efficient for typical use cases + let field_variant_arrays: Result, _> = struct_array + .columns() + .iter() + .map(|field_array| cast_to_variant(field_array.as_ref())) + .collect(); + let field_variant_arrays = field_variant_arrays?; + + // Cache column names to avoid repeated calls + let column_names = struct_array.column_names(); + + for i in 0..struct_array.len() { + if struct_array.is_null(i) { + builder.append_null(); + continue; + } - for i in 0..struct_array.len() { - if struct_array.is_null(i) { + // Create a VariantBuilder for this struct instance + let mut variant_builder = VariantBuilder::new(); + let mut object_builder = variant_builder.new_object(); + + // Iterate through all fields in the struct + for (field_idx, field_name) in column_names.iter().enumerate() { + // Use pre-converted field variant arrays for better performance + // Check nulls directly from the pre-converted arrays instead of accessing column again + if !field_variant_arrays[field_idx].is_null(i) { + let field_variant = field_variant_arrays[field_idx].value(i); + object_builder.insert(field_name, field_variant); + } + // Note: we skip null fields rather than inserting Variant::Null + // to match Arrow struct semantics where null fields are omitted + } + + object_builder.finish(); + let (metadata, value) = variant_builder.finish(); + let variant = Variant::try_new(&metadata, &value)?; + builder.append_variant(variant); + } + + Ok(()) +} + +fn convert_map( + field: &FieldRef, + input: &dyn Array, + builder: &mut VariantArrayBuilder, +) -> Result<(), ArrowError> { + match field.data_type() { + DataType::Struct(_) => { + let map_array = input.as_map(); + let keys = cast(map_array.keys(), &DataType::Utf8)?; + let key_strings = keys.as_string::(); + let values = cast_to_variant(map_array.values())?; + let offsets = map_array.offsets(); + + let mut start_offset = offsets[0]; + for end_offset in offsets.iter().skip(1) { + if start_offset >= *end_offset { builder.append_null(); continue; } - // Create a VariantBuilder for this struct instance + let length = (end_offset - start_offset) as usize; + let mut variant_builder = VariantBuilder::new(); let mut object_builder = variant_builder.new_object(); - // Iterate through all fields in the struct - for (field_idx, field_name) in column_names.iter().enumerate() { - // Use pre-converted field variant arrays for better performance - // Check nulls directly from the pre-converted arrays instead of accessing column again - if !field_variant_arrays[field_idx].is_null(i) { - let field_variant = field_variant_arrays[field_idx].value(i); - object_builder.insert(field_name, field_variant); - } - // Note: we skip null fields rather than inserting Variant::Null - // to match Arrow struct semantics where null fields are omitted + for i in start_offset..*end_offset { + let value = values.value(i as usize); + object_builder.insert(key_strings.value(i as usize), value); } - - object_builder.finish()?; + object_builder.finish(); let (metadata, value) = variant_builder.finish(); let variant = Variant::try_new(&metadata, &value)?; + builder.append_variant(variant); + + start_offset += length as i32; } } - DataType::Date32 => { - generic_conversion!( - Date32Type, - as_primitive, - |v: i32| -> NaiveDate { Date32Type::to_naive_date(v) }, - input, - builder - ); - } - DataType::Date64 => { - generic_conversion!( - Date64Type, - as_primitive, - |v: i64| { Date64Type::to_naive_date_opt(v).unwrap() }, - input, - builder - ); + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported map field type for casting to Variant: {field:?}", + ))); } - DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { - DataType::Int16 => process_run_end_encoded::(input, &mut builder)?, - DataType::Int32 => process_run_end_encoded::(input, &mut builder)?, - DataType::Int64 => process_run_end_encoded::(input, &mut builder)?, - _ => { - return Err(ArrowError::CastError(format!( - "Unsupported run ends type: {:?}", - run_ends.data_type() - ))); - } - }, - DataType::Dictionary(_, _) => { - let dict_array = input.as_any_dictionary(); - let values_variant_array = cast_to_variant(dict_array.values().as_ref())?; - let normalized_keys = dict_array.normalized_keys(); - let keys = dict_array.keys(); - - for (i, key_idx) in normalized_keys.iter().enumerate() { - if keys.is_null(i) { - builder.append_null(); - continue; - } + } - if values_variant_array.is_null(*key_idx) { - builder.append_null(); - continue; - } + Ok(()) +} + +fn convert_union( + fields: &UnionFields, + input: &dyn Array, + builder: &mut VariantArrayBuilder, +) -> Result<(), ArrowError> { + let union_array = input.as_union(); + + // Convert each child array to variant arrays + let mut child_variant_arrays = HashMap::new(); + for (type_id, _) in fields.iter() { + let child_array = union_array.child(type_id); + let child_variant_array = cast_to_variant(child_array.as_ref())?; + child_variant_arrays.insert(type_id, child_variant_array); + } - let value = values_variant_array.value(*key_idx); + // Process each element in the union array + for i in 0..union_array.len() { + let type_id = union_array.type_id(i); + let value_offset = union_array.value_offset(i); + + if let Some(child_variant_array) = child_variant_arrays.get(&type_id) { + if child_variant_array.is_null(value_offset) { + builder.append_null(); + } else { + let value = child_variant_array.value(value_offset); builder.append_variant(value); } + } else { + // This should not happen in a valid union, but handle gracefully + builder.append_null(); } - dt => { - return Err(ArrowError::CastError(format!( - "Unsupported data type for casting to Variant: {dt:?}", - ))); + } + + Ok(()) +} + +fn convert_dictionary_encoded( + input: &dyn Array, + builder: &mut VariantArrayBuilder, +) -> Result<(), ArrowError> { + let dict_array = input.as_any_dictionary(); + let values_variant_array = cast_to_variant(dict_array.values().as_ref())?; + let normalized_keys = dict_array.normalized_keys(); + let keys = dict_array.keys(); + + for (i, key_idx) in normalized_keys.iter().enumerate() { + if keys.is_null(i) { + builder.append_null(); + continue; } - }; - Ok(builder.build()) + + if values_variant_array.is_null(*key_idx) { + builder.append_null(); + continue; + } + + let value = values_variant_array.value(*key_idx); + builder.append_variant(value); + } + + Ok(()) } -/// Generic function to process run-end encoded arrays -fn process_run_end_encoded( +fn convert_run_end_encoded( input: &dyn Array, builder: &mut VariantArrayBuilder, ) -> Result<(), ArrowError> { @@ -579,27 +622,28 @@ fn process_run_end_encoded( Ok(()) } -// TODO do we need a cast_with_options to allow specifying conversion behavior, -// e.g. how to handle overflows, whether to convert to Variant::Null or return -// an error, etc. ? - #[cfg(test)] mod tests { use super::*; use arrow::array::{ ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, - Decimal256Array, Decimal32Array, Decimal64Array, DictionaryArray, FixedSizeBinaryBuilder, - Float16Array, Float32Array, Float64Array, GenericByteBuilder, GenericByteViewBuilder, - Int16Array, Int32Array, Int64Array, Int8Array, IntervalYearMonthArray, LargeStringArray, - NullArray, StringArray, StringRunBuilder, StringViewArray, StructArray, - Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Decimal256Array, Decimal32Array, Decimal64Array, DictionaryArray, DurationMicrosecondArray, + DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, + FixedSizeBinaryBuilder, Float16Array, Float32Array, Float64Array, GenericByteBuilder, + GenericByteViewBuilder, Int16Array, Int32Array, Int64Array, Int8Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeListArray, + LargeStringArray, ListArray, MapArray, NullArray, StringArray, StringRunBuilder, + StringViewArray, StructArray, Time32MillisecondArray, Time32SecondArray, + Time64MicrosecondArray, Time64NanosecondArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, UnionArray, }; - use arrow::buffer::NullBuffer; - use arrow_schema::{Field, Fields}; + use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano}; + use arrow_schema::{DataType, Field, Fields, UnionFields}; use arrow_schema::{ DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; + use half::f16; use parquet_variant::{Variant, VariantDecimal16}; use std::{sync::Arc, vec}; @@ -611,145 +655,13 @@ mod tests { (u64::pow(10, $precision as u32) - 1) as i64 }; (128, $precision:expr) => { - (u128::pow(10, $precision as u32) - 1) as i128 - }; - } - - #[test] - fn test_cast_to_variant_timestamp() { - let run_array_tests = - |microseconds: i64, array_ntz: Arc, array_tz: Arc| { - let timestamp = DateTime::from_timestamp_nanos(microseconds * 1000); - run_test( - array_tz, - vec![Some(Variant::TimestampMicros(timestamp)), None], - ); - run_test( - array_ntz, - vec![ - Some(Variant::TimestampNtzMicros(timestamp.naive_utc())), - None, - ], - ); - }; - - let nanosecond = 1234567890; - let microsecond = 1234567; - let millisecond = 1234; - let second = 1; - - let second_array = TimestampSecondArray::from(vec![Some(second), None]); - run_array_tests( - second * 1000 * 1000, - Arc::new(second_array.clone()), - Arc::new(second_array.with_timezone("+01:00".to_string())), - ); - - let millisecond_array = TimestampMillisecondArray::from(vec![Some(millisecond), None]); - run_array_tests( - millisecond * 1000, - Arc::new(millisecond_array.clone()), - Arc::new(millisecond_array.with_timezone("+01:00".to_string())), - ); - - let microsecond_array = TimestampMicrosecondArray::from(vec![Some(microsecond), None]); - run_array_tests( - microsecond, - Arc::new(microsecond_array.clone()), - Arc::new(microsecond_array.with_timezone("+01:00".to_string())), - ); - - let timestamp = DateTime::from_timestamp_nanos(nanosecond); - let nanosecond_array = TimestampNanosecondArray::from(vec![Some(nanosecond), None]); - run_test( - Arc::new(nanosecond_array.clone()), - vec![ - Some(Variant::TimestampNtzNanos(timestamp.naive_utc())), - None, - ], - ); - run_test( - Arc::new(nanosecond_array.with_timezone("+01:00".to_string())), - vec![Some(Variant::TimestampNanos(timestamp)), None], - ); - } - - #[test] - fn test_cast_to_variant_fixed_size_binary() { - let v1 = vec![1, 2]; - let v2 = vec![3, 4]; - let v3 = vec![5, 6]; - - let mut builder = FixedSizeBinaryBuilder::new(2); - builder.append_value(&v1).unwrap(); - builder.append_value(&v2).unwrap(); - builder.append_null(); - builder.append_value(&v3).unwrap(); - let array = builder.finish(); - - run_test( - Arc::new(array), - vec![ - Some(Variant::Binary(&v1)), - Some(Variant::Binary(&v2)), - None, - Some(Variant::Binary(&v3)), - ], - ); + (u128::pow(10, $precision as u32) - 1) as i128 + }; } #[test] - fn test_cast_to_variant_binary() { - // BinaryType - let mut builder = GenericByteBuilder::::new(); - builder.append_value(b"hello"); - builder.append_value(b""); - builder.append_null(); - builder.append_value(b"world"); - let binary_array = builder.finish(); - run_test( - Arc::new(binary_array), - vec![ - Some(Variant::Binary(b"hello")), - Some(Variant::Binary(b"")), - None, - Some(Variant::Binary(b"world")), - ], - ); - - // LargeBinaryType - let mut builder = GenericByteBuilder::::new(); - builder.append_value(b"hello"); - builder.append_value(b""); - builder.append_null(); - builder.append_value(b"world"); - let large_binary_array = builder.finish(); - run_test( - Arc::new(large_binary_array), - vec![ - Some(Variant::Binary(b"hello")), - Some(Variant::Binary(b"")), - None, - Some(Variant::Binary(b"world")), - ], - ); - - // BinaryViewType - let mut builder = GenericByteViewBuilder::::new(); - builder.append_value(b"hello"); - builder.append_value(b""); - builder.append_null(); - builder.append_value(b"world"); - let byte_view_array = builder.finish(); - run_test( - Arc::new(byte_view_array), - vec![ - Some(Variant::Binary(b"hello")), - Some(Variant::Binary(b"")), - None, - Some(Variant::Binary(b"world")), - ], - ); + fn test_cast_to_variant_null() { + run_test(Arc::new(NullArray::new(2)), vec![None, None]) } #[test] @@ -993,26 +905,6 @@ mod tests { ) } - #[test] - fn test_cast_to_variant_interval_error() { - let array = IntervalYearMonthArray::from(vec![Some(12), None, Some(-6)]); - let result = cast_to_variant(&array); - - assert!(result.is_err()); - match result.unwrap_err() { - ArrowError::InvalidArgumentError(msg) => { - assert!(msg.contains("Casting interval types to Variant is not supported")); - assert!(msg.contains("The Variant format does not define interval/duration types")); - } - _ => panic!("Expected InvalidArgumentError"), - } - } - - #[test] - fn test_cast_to_variant_null() { - run_test(Arc::new(NullArray::new(2)), vec![None, None]) - } - #[test] fn test_cast_to_variant_decimal32() { run_test( @@ -1406,7 +1298,105 @@ mod tests { } #[test] - fn test_cast_time32_second_to_variant_time() { + fn test_cast_to_variant_timestamp() { + let run_array_tests = + |microseconds: i64, array_ntz: Arc, array_tz: Arc| { + let timestamp = DateTime::from_timestamp_nanos(microseconds * 1000); + run_test( + array_tz, + vec![Some(Variant::TimestampMicros(timestamp)), None], + ); + run_test( + array_ntz, + vec![ + Some(Variant::TimestampNtzMicros(timestamp.naive_utc())), + None, + ], + ); + }; + + let nanosecond = 1234567890; + let microsecond = 1234567; + let millisecond = 1234; + let second = 1; + + let second_array = TimestampSecondArray::from(vec![Some(second), None]); + run_array_tests( + second * 1000 * 1000, + Arc::new(second_array.clone()), + Arc::new(second_array.with_timezone("+01:00".to_string())), + ); + + let millisecond_array = TimestampMillisecondArray::from(vec![Some(millisecond), None]); + run_array_tests( + millisecond * 1000, + Arc::new(millisecond_array.clone()), + Arc::new(millisecond_array.with_timezone("+01:00".to_string())), + ); + + let microsecond_array = TimestampMicrosecondArray::from(vec![Some(microsecond), None]); + run_array_tests( + microsecond, + Arc::new(microsecond_array.clone()), + Arc::new(microsecond_array.with_timezone("+01:00".to_string())), + ); + + let timestamp = DateTime::from_timestamp_nanos(nanosecond); + let nanosecond_array = TimestampNanosecondArray::from(vec![Some(nanosecond), None]); + run_test( + Arc::new(nanosecond_array.clone()), + vec![ + Some(Variant::TimestampNtzNanos(timestamp.naive_utc())), + None, + ], + ); + run_test( + Arc::new(nanosecond_array.with_timezone("+01:00".to_string())), + vec![Some(Variant::TimestampNanos(timestamp)), None], + ); + } + + #[test] + fn test_cast_to_variant_date() { + // Date32Array + run_test( + Arc::new(Date32Array::from(vec![ + Some(Date32Type::from_naive_date(NaiveDate::MIN)), + None, + Some(Date32Type::from_naive_date( + NaiveDate::from_ymd_opt(2025, 8, 1).unwrap(), + )), + Some(Date32Type::from_naive_date(NaiveDate::MAX)), + ])), + vec![ + Some(Variant::Date(NaiveDate::MIN)), + None, + Some(Variant::Date(NaiveDate::from_ymd_opt(2025, 8, 1).unwrap())), + Some(Variant::Date(NaiveDate::MAX)), + ], + ); + + // Date64Array + run_test( + Arc::new(Date64Array::from(vec![ + Some(Date64Type::from_naive_date(NaiveDate::MIN)), + None, + Some(Date64Type::from_naive_date( + NaiveDate::from_ymd_opt(2025, 8, 1).unwrap(), + )), + Some(Date64Type::from_naive_date(NaiveDate::MAX)), + ])), + vec![ + Some(Variant::Date(NaiveDate::MIN)), + None, + Some(Variant::Date(NaiveDate::from_ymd_opt(2025, 8, 1).unwrap())), + Some(Variant::Date(NaiveDate::MAX)), + ], + ); + } + + #[test] + fn test_cast_to_variant_time32_second() { let array: Time32SecondArray = vec![Some(1), Some(86_399), None].into(); let values = Arc::new(array); run_test( @@ -1420,65 +1410,194 @@ mod tests { )), None, ], - ) - } + ) + } + + #[test] + fn test_cast_to_variant_time32_millisecond() { + let array: Time32MillisecondArray = vec![Some(123_456), Some(456_000), None].into(); + let values = Arc::new(array); + run_test( + values, + vec![ + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(123, 456_000_000).unwrap(), + )), + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(456, 0).unwrap(), + )), + None, + ], + ) + } + + #[test] + fn test_cast_to_variant_time64_micro() { + let array: Time64MicrosecondArray = vec![Some(1), Some(123_456_789), None].into(); + let values = Arc::new(array); + run_test( + values, + vec![ + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(0, 1_000).unwrap(), + )), + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(123, 456_789_000).unwrap(), + )), + None, + ], + ) + } + + #[test] + fn test_cast_to_variant_time64_nano() { + let array: Time64NanosecondArray = + vec![Some(1), Some(1001), Some(123_456_789_012), None].into(); + run_test( + Arc::new(array), + // as we can only present with micro second, so the nano second will round donw to 0 + vec![ + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(0, 0).unwrap(), + )), + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(0, 1_000).unwrap(), + )), + Some(Variant::Time( + NaiveTime::from_num_seconds_from_midnight_opt(123, 456_789_000).unwrap(), + )), + None, + ], + ) + } + + #[test] + fn test_cast_to_variant_duration_or_interval_errors() { + let arrays: Vec> = vec![ + // Duration types + Box::new(DurationSecondArray::from(vec![Some(10), None, Some(-5)])), + Box::new(DurationMillisecondArray::from(vec![ + Some(10), + None, + Some(-5), + ])), + Box::new(DurationMicrosecondArray::from(vec![ + Some(10), + None, + Some(-5), + ])), + Box::new(DurationNanosecondArray::from(vec![ + Some(10), + None, + Some(-5), + ])), + // Interval types + Box::new(IntervalYearMonthArray::from(vec![Some(12), None, Some(-6)])), + Box::new(IntervalDayTimeArray::from(vec![ + Some(IntervalDayTime::new(12, 0)), + None, + Some(IntervalDayTime::new(-6, 0)), + ])), + Box::new(IntervalMonthDayNanoArray::from(vec![ + Some(IntervalMonthDayNano::new(12, 0, 0)), + None, + Some(IntervalMonthDayNano::new(-6, 0, 0)), + ])), + ]; + + for array in arrays { + let result = cast_to_variant(array.as_ref()); + assert!(result.is_err()); + match result.unwrap_err() { + ArrowError::InvalidArgumentError(msg) => { + assert!( + msg.contains("Casting duration/interval types to Variant is not supported") + ); + assert!( + msg.contains("The Variant format does not define duration/interval types") + ); + } + _ => panic!("Expected InvalidArgumentError"), + } + } + } + + #[test] + fn test_cast_to_variant_binary() { + // BinaryType + let mut builder = GenericByteBuilder::::new(); + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"world"); + let binary_array = builder.finish(); + run_test( + Arc::new(binary_array), + vec![ + Some(Variant::Binary(b"hello")), + Some(Variant::Binary(b"")), + None, + Some(Variant::Binary(b"world")), + ], + ); - #[test] - fn test_cast_time32_millisecond_to_variant_time() { - let array: Time32MillisecondArray = vec![Some(123_456), Some(456_000), None].into(); - let values = Arc::new(array); + // LargeBinaryType + let mut builder = GenericByteBuilder::::new(); + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"world"); + let large_binary_array = builder.finish(); run_test( - values, + Arc::new(large_binary_array), vec![ - Some(Variant::Time( - NaiveTime::from_num_seconds_from_midnight_opt(123, 456_000_000).unwrap(), - )), - Some(Variant::Time( - NaiveTime::from_num_seconds_from_midnight_opt(456, 0).unwrap(), - )), + Some(Variant::Binary(b"hello")), + Some(Variant::Binary(b"")), None, + Some(Variant::Binary(b"world")), ], - ) - } + ); - #[test] - fn test_cast_time64_micro_to_variant_time() { - let array: Time64MicrosecondArray = vec![Some(1), Some(123_456_789), None].into(); - let values = Arc::new(array); + // BinaryViewType + let mut builder = GenericByteViewBuilder::::new(); + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"world"); + let byte_view_array = builder.finish(); run_test( - values, + Arc::new(byte_view_array), vec![ - Some(Variant::Time( - NaiveTime::from_num_seconds_from_midnight_opt(0, 1_000).unwrap(), - )), - Some(Variant::Time( - NaiveTime::from_num_seconds_from_midnight_opt(123, 456_789_000).unwrap(), - )), + Some(Variant::Binary(b"hello")), + Some(Variant::Binary(b"")), None, + Some(Variant::Binary(b"world")), ], - ) + ); } #[test] - fn test_cast_time64_nano_to_variant_time() { - let array: Time64NanosecondArray = - vec![Some(1), Some(1001), Some(123_456_789_012), None].into(); + fn test_cast_to_variant_fixed_size_binary() { + let v1 = vec![1, 2]; + let v2 = vec![3, 4]; + let v3 = vec![5, 6]; + + let mut builder = FixedSizeBinaryBuilder::new(2); + builder.append_value(&v1).unwrap(); + builder.append_value(&v2).unwrap(); + builder.append_null(); + builder.append_value(&v3).unwrap(); + let array = builder.finish(); + run_test( Arc::new(array), - // as we can only present with micro second, so the nano second will round donw to 0 vec![ - Some(Variant::Time( - NaiveTime::from_num_seconds_from_midnight_opt(0, 0).unwrap(), - )), - Some(Variant::Time( - NaiveTime::from_num_seconds_from_midnight_opt(0, 1_000).unwrap(), - )), - Some(Variant::Time( - NaiveTime::from_num_seconds_from_midnight_opt(123, 456_789_000).unwrap(), - )), + Some(Variant::Binary(&v1)), + Some(Variant::Binary(&v2)), None, + Some(Variant::Binary(&v3)), ], - ) + ); } #[test] @@ -1575,6 +1694,101 @@ mod tests { ); } + #[test] + fn test_cast_to_variant_list() { + // List Array + let data = vec![Some(vec![Some(0), Some(1), Some(2)]), None]; + let list_array = ListArray::from_iter_primitive::(data); + + // Expected value + let (metadata, value) = { + let mut builder = VariantBuilder::new(); + let mut list = builder.new_list(); + list.append_value(0); + list.append_value(1); + list.append_value(2); + list.finish(); + builder.finish() + }; + let variant = Variant::new(&metadata, &value); + + run_test(Arc::new(list_array), vec![Some(variant), None]); + } + + #[test] + fn test_cast_to_variant_sliced_list() { + // List Array + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![Some(3), Some(4), Some(5)]), + None, + ]; + let list_array = ListArray::from_iter_primitive::(data); + + // Expected value + let (metadata, value) = { + let mut builder = VariantBuilder::new(); + let mut list = builder.new_list(); + list.append_value(3); + list.append_value(4); + list.append_value(5); + list.finish(); + builder.finish() + }; + let variant = Variant::new(&metadata, &value); + + run_test(Arc::new(list_array.slice(1, 2)), vec![Some(variant), None]); + } + + #[test] + fn test_cast_to_variant_large_list() { + // Large List Array + let data = vec![Some(vec![Some(0), Some(1), Some(2)]), None]; + let large_list_array = LargeListArray::from_iter_primitive::(data); + + // Expected value + let (metadata, value) = { + let mut builder = VariantBuilder::new(); + let mut list = builder.new_list(); + list.append_value(0i64); + list.append_value(1i64); + list.append_value(2i64); + list.finish(); + builder.finish() + }; + let variant = Variant::new(&metadata, &value); + + run_test(Arc::new(large_list_array), vec![Some(variant), None]); + } + + #[test] + fn test_cast_to_variant_sliced_large_list() { + // List Array + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![Some(3), Some(4), Some(5)]), + None, + ]; + let large_list_array = ListArray::from_iter_primitive::(data); + + // Expected value + let (metadata, value) = { + let mut builder = VariantBuilder::new(); + let mut list = builder.new_list(); + list.append_value(3i64); + list.append_value(4i64); + list.append_value(5i64); + list.finish(); + builder.finish() + }; + let variant = Variant::new(&metadata, &value); + + run_test( + Arc::new(large_list_array.slice(1, 2)), + vec![Some(variant), None], + ); + } + #[test] fn test_cast_to_variant_struct() { // Test a simple struct with two fields: id (int64) and age (int32) @@ -1856,40 +2070,226 @@ mod tests { } #[test] - fn test_cast_to_variant_date() { - // Date32Array + fn test_cast_to_variant_map() { + let keys = vec!["key1", "key2", "key3"]; + let values_data = Int32Array::from(vec![1, 2, 3]); + let entry_offsets = vec![0, 1, 3]; + let map_array = + MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) + .unwrap(); + + let result = cast_to_variant(&map_array).unwrap(); + // [{"key1":1}] + let variant1 = result.value(0); + assert_eq!( + variant1.as_object().unwrap().get("key1").unwrap(), + Variant::from(1) + ); + + // [{"key2":2},{"key3":3}] + let variant2 = result.value(1); + assert_eq!( + variant2.as_object().unwrap().get("key2").unwrap(), + Variant::from(2) + ); + assert_eq!( + variant2.as_object().unwrap().get("key3").unwrap(), + Variant::from(3) + ); + } + + #[test] + fn test_cast_to_variant_map_with_nulls() { + let keys = vec!["key1", "key2", "key3"]; + let values_data = Int32Array::from(vec![1, 2, 3]); + let entry_offsets = vec![0, 1, 1, 3]; + let map_array = + MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) + .unwrap(); + + let result = cast_to_variant(&map_array).unwrap(); + // [{"key1":1}] + let variant1 = result.value(0); + assert_eq!( + variant1.as_object().unwrap().get("key1").unwrap(), + Variant::from(1) + ); + + // None + assert!(result.is_null(1)); + + // [{"key2":2},{"key3":3}] + let variant2 = result.value(2); + assert_eq!( + variant2.as_object().unwrap().get("key2").unwrap(), + Variant::from(2) + ); + assert_eq!( + variant2.as_object().unwrap().get("key3").unwrap(), + Variant::from(3) + ); + } + + #[test] + fn test_cast_to_variant_map_with_non_string_keys() { + let offsets = OffsetBuffer::new(vec![0, 1, 3].into()); + let fields = Fields::from(vec![ + Field::new("key", DataType::Int32, false), + Field::new("values", DataType::Int32, false), + ]); + let columns = vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as _, + Arc::new(Int32Array::from(vec![1, 2, 3])) as _, + ]; + + let entries = StructArray::new(fields.clone(), columns, None); + let field = Arc::new(Field::new("entries", DataType::Struct(fields), false)); + + let map_array = MapArray::new(field.clone(), offsets.clone(), entries.clone(), None, false); + + let result = cast_to_variant(&map_array).unwrap(); + + let variant1 = result.value(0); + assert_eq!( + variant1.as_object().unwrap().get("1").unwrap(), + Variant::from(1) + ); + + let variant2 = result.value(1); + assert_eq!( + variant2.as_object().unwrap().get("2").unwrap(), + Variant::from(2) + ); + assert_eq!( + variant2.as_object().unwrap().get("3").unwrap(), + Variant::from(3) + ); + } + + #[test] + fn test_cast_to_variant_union_sparse() { + // Create a sparse union array with mixed types (int, float, string) + let int_array = Int32Array::from(vec![Some(1), None, None, None, Some(34), None]); + let float_array = Float64Array::from(vec![None, Some(3.2), None, Some(32.5), None, None]); + let string_array = StringArray::from(vec![None, None, Some("hello"), None, None, None]); + let type_ids = [0, 1, 2, 1, 0, 0].into_iter().collect::>(); + + let union_fields = UnionFields::new( + vec![0, 1, 2], + vec![ + Field::new("int_field", DataType::Int32, false), + Field::new("float_field", DataType::Float64, false), + Field::new("string_field", DataType::Utf8, false), + ], + ); + + let children: Vec> = vec![ + Arc::new(int_array), + Arc::new(float_array), + Arc::new(string_array), + ]; + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + None, // Sparse union + children, + ) + .unwrap(); + run_test( - Arc::new(Date32Array::from(vec![ - Some(Date32Type::from_naive_date(NaiveDate::MIN)), - None, - Some(Date32Type::from_naive_date( - NaiveDate::from_ymd_opt(2025, 8, 1).unwrap(), - )), - Some(Date32Type::from_naive_date(NaiveDate::MAX)), - ])), + Arc::new(union_array), vec![ - Some(Variant::Date(NaiveDate::MIN)), + Some(Variant::Int32(1)), + Some(Variant::Double(3.2)), + Some(Variant::from("hello")), + Some(Variant::Double(32.5)), + Some(Variant::Int32(34)), None, - Some(Variant::Date(NaiveDate::from_ymd_opt(2025, 8, 1).unwrap())), - Some(Variant::Date(NaiveDate::MAX)), ], ); + } + + #[test] + fn test_cast_to_variant_union_dense() { + // Create a dense union array with mixed types (int, float, string) + let int_array = Int32Array::from(vec![Some(1), Some(34), None]); + let float_array = Float64Array::from(vec![3.2, 32.5]); + let string_array = StringArray::from(vec!["hello"]); + let type_ids = [0, 1, 2, 1, 0, 0].into_iter().collect::>(); + let offsets = [0, 0, 0, 1, 1, 2] + .into_iter() + .collect::>(); + + let union_fields = UnionFields::new( + vec![0, 1, 2], + vec![ + Field::new("int_field", DataType::Int32, false), + Field::new("float_field", DataType::Float64, false), + Field::new("string_field", DataType::Utf8, false), + ], + ); + + let children: Vec> = vec![ + Arc::new(int_array), + Arc::new(float_array), + Arc::new(string_array), + ]; + + let union_array = UnionArray::try_new( + union_fields, + type_ids, + Some(offsets), // Dense union + children, + ) + .unwrap(); - // Date64Array run_test( - Arc::new(Date64Array::from(vec![ - Some(Date64Type::from_naive_date(NaiveDate::MIN)), + Arc::new(union_array), + vec![ + Some(Variant::Int32(1)), + Some(Variant::Double(3.2)), + Some(Variant::from("hello")), + Some(Variant::Double(32.5)), + Some(Variant::Int32(34)), None, - Some(Date64Type::from_naive_date( - NaiveDate::from_ymd_opt(2025, 8, 1).unwrap(), - )), - Some(Date64Type::from_naive_date(NaiveDate::MAX)), - ])), + ], + ); + } + + #[test] + fn test_cast_to_variant_dictionary() { + let values = StringArray::from(vec!["apple", "banana", "cherry", "date"]); + let keys = Int32Array::from(vec![Some(0), Some(1), None, Some(2), Some(0), Some(3)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + run_test( + Arc::new(dict_array), vec![ - Some(Variant::Date(NaiveDate::MIN)), + Some(Variant::from("apple")), + Some(Variant::from("banana")), None, - Some(Variant::Date(NaiveDate::from_ymd_opt(2025, 8, 1).unwrap())), - Some(Variant::Date(NaiveDate::MAX)), + Some(Variant::from("cherry")), + Some(Variant::from("apple")), + Some(Variant::from("date")), + ], + ); + } + + #[test] + fn test_cast_to_variant_dictionary_with_nulls() { + // Test dictionary with null values in the values array + let values = StringArray::from(vec![Some("a"), None, Some("c")]); + let keys = Int8Array::from(vec![Some(0), Some(1), Some(2), Some(0)]); + let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + run_test( + Arc::new(dict_array), + vec![ + Some(Variant::from("a")), + None, // key 1 points to null value + Some(Variant::from("c")), + Some(Variant::from("a")), ], ); } @@ -1946,43 +2346,6 @@ mod tests { ); } - #[test] - fn test_cast_to_variant_dictionary() { - let values = StringArray::from(vec!["apple", "banana", "cherry", "date"]); - let keys = Int32Array::from(vec![Some(0), Some(1), None, Some(2), Some(0), Some(3)]); - let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); - - run_test( - Arc::new(dict_array), - vec![ - Some(Variant::from("apple")), - Some(Variant::from("banana")), - None, - Some(Variant::from("cherry")), - Some(Variant::from("apple")), - Some(Variant::from("date")), - ], - ); - } - - #[test] - fn test_cast_to_variant_dictionary_with_nulls() { - // Test dictionary with null values in the values array - let values = StringArray::from(vec![Some("a"), None, Some("c")]); - let keys = Int8Array::from(vec![Some(0), Some(1), Some(2), Some(0)]); - let dict_array = DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); - - run_test( - Arc::new(dict_array), - vec![ - Some(Variant::from("a")), - None, // key 1 points to null value - Some(Variant::from("c")), - Some(Variant::from("a")), - ], - ); - } - /// Converts the given `Array` to a `VariantArray` and tests the conversion /// against the expected values. It also tests the handling of nulls by /// setting one element to null and verifying the output. diff --git a/parquet-variant-compute/src/from_json.rs b/parquet-variant-compute/src/from_json.rs index 8512620f4631..fb5fe320733f 100644 --- a/parquet-variant-compute/src/from_json.rs +++ b/parquet-variant-compute/src/from_json.rs @@ -102,7 +102,7 @@ mod test { let mut vb = VariantBuilder::new(); let mut ob = vb.new_object(); ob.insert("a", Variant::Int8(32)); - ob.finish()?; + ob.finish(); let (object_metadata, object_value) = vb.finish(); let expected = Variant::new(&object_metadata, &object_value); assert_eq!(variant_array.value(2), expected); @@ -151,7 +151,7 @@ mod test { let mut vb = VariantBuilder::new(); let mut ob = vb.new_object(); ob.insert("a", Variant::Int8(32)); - ob.finish()?; + ob.finish(); let (object_metadata, object_value) = vb.finish(); let expected = Variant::new(&object_metadata, &object_value); assert_eq!(variant_array.value(2), expected); @@ -200,7 +200,7 @@ mod test { let mut vb = VariantBuilder::new(); let mut ob = vb.new_object(); ob.insert("a", Variant::Int8(32)); - ob.finish()?; + ob.finish(); let (object_metadata, object_value) = vb.finish(); let expected = Variant::new(&object_metadata, &object_value); assert_eq!(variant_array.value(2), expected); diff --git a/parquet-variant-compute/src/lib.rs b/parquet-variant-compute/src/lib.rs index 245e344488ce..ef674d9614b5 100644 --- a/parquet-variant-compute/src/lib.rs +++ b/parquet-variant-compute/src/lib.rs @@ -38,6 +38,7 @@ pub mod cast_to_variant; mod from_json; mod to_json; +mod type_conversion; mod variant_array; mod variant_array_builder; pub mod variant_get; diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs new file mode 100644 index 000000000000..647d2c705ff0 --- /dev/null +++ b/parquet-variant-compute/src/type_conversion.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module for transforming a typed arrow `Array` to `VariantArray`. + +/// Convert the input array to a `VariantArray` row by row, using `method` +/// not requiring a generic type to downcast the generic array to a specific +/// array type and `cast_fn` to transform each element to a type compatible with Variant +macro_rules! non_generic_conversion_array { + ($array:expr, $cast_fn:expr, $builder:expr) => {{ + let array = $array; + for i in 0..array.len() { + if array.is_null(i) { + $builder.append_null(); + continue; + } + let cast_value = $cast_fn(array.value(i)); + $builder.append_variant(Variant::from(cast_value)); + } + }}; +} +pub(crate) use non_generic_conversion_array; + +/// Convert the value at a specific index in the given array into a `Variant`. +macro_rules! non_generic_conversion_single_value { + ($array:expr, $cast_fn:expr, $index:expr) => {{ + let array = $array; + if array.is_null($index) { + Variant::Null + } else { + let cast_value = $cast_fn(array.value($index)); + Variant::from(cast_value) + } + }}; +} +pub(crate) use non_generic_conversion_single_value; + +/// Convert the input array to a `VariantArray` row by row, using `method` +/// requiring a generic type to downcast the generic array to a specific +/// array type and `cast_fn` to transform each element to a type compatible with Variant +macro_rules! generic_conversion_array { + ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $builder:expr) => {{ + $crate::type_conversion::non_generic_conversion_array!( + $input.$method::<$t>(), + $cast_fn, + $builder + ) + }}; +} +pub(crate) use generic_conversion_array; + +/// Convert the value at a specific index in the given array into a `Variant`, +/// using `method` requiring a generic type to downcast the generic array +/// to a specific array type and `cast_fn` to transform the element. +macro_rules! generic_conversion_single_value { + ($t:ty, $method:ident, $cast_fn:expr, $input:expr, $index:expr) => {{ + $crate::type_conversion::non_generic_conversion_single_value!( + $input.$method::<$t>(), + $cast_fn, + $index + ) + }}; +} +pub(crate) use generic_conversion_single_value; + +/// Convert the input array of a specific primitive type to a `VariantArray` +/// row by row +macro_rules! primitive_conversion_array { + ($t:ty, $input:expr, $builder:expr) => {{ + $crate::type_conversion::generic_conversion_array!( + $t, + as_primitive, + |v| v, + $input, + $builder + ) + }}; +} +pub(crate) use primitive_conversion_array; + +/// Convert the value at a specific index in the given array into a `Variant`. +macro_rules! primitive_conversion_single_value { + ($t:ty, $input:expr, $index:expr) => {{ + $crate::type_conversion::generic_conversion_single_value!( + $t, + as_primitive, + |v| v, + $input, + $index + ) + }}; +} +pub(crate) use primitive_conversion_single_value; + +/// Convert a decimal value to a `VariantDecimal` +macro_rules! decimal_to_variant_decimal { + ($v:ident, $scale:expr, $value_type:ty, $variant_type:ty) => {{ + let (v, scale) = if *$scale < 0 { + // For negative scale, we need to multiply the value by 10^|scale| + // For example: 123 with scale -2 becomes 12300 with scale 0 + let multiplier = <$value_type>::pow(10, (-*$scale) as u32); + (<$value_type>::checked_mul($v, multiplier), 0u8) + } else { + (Some($v), *$scale as u8) + }; + + v.and_then(|v| <$variant_type>::try_new(v, scale).ok()) + .map_or(Variant::Null, Variant::from) + }}; +} +pub(crate) use decimal_to_variant_decimal; diff --git a/parquet-variant-compute/src/variant_array.rs b/parquet-variant-compute/src/variant_array.rs index 2facaaafa59b..17b0adbdd086 100644 --- a/parquet-variant-compute/src/variant_array.rs +++ b/parquet-variant-compute/src/variant_array.rs @@ -19,12 +19,17 @@ use arrow::array::{Array, ArrayData, ArrayRef, AsArray, BinaryViewArray, StructArray}; use arrow::buffer::NullBuffer; -use arrow::datatypes::Int32Type; +use arrow::datatypes::{ + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields}; use parquet_variant::Variant; use std::any::Any; use std::sync::Arc; +use crate::type_conversion::primitive_conversion_single_value; + /// An array of Parquet [`Variant`] values /// /// A [`VariantArray`] wraps an Arrow [`StructArray`] that stores the underlying @@ -128,7 +133,6 @@ impl VariantArray { }) } - #[allow(unused)] pub(crate) fn from_parts( metadata: BinaryViewArray, value: Option, @@ -150,7 +154,8 @@ impl VariantArray { // This would be a lot simpler if ShreddingState were just a pair of Option... we already // have everything we need. let inner = builder.build(); - let shredding_state = ShreddingState::try_new(metadata.clone(), value, typed_value).unwrap(); // valid by construction + let shredding_state = + ShreddingState::try_new(metadata.clone(), value, typed_value).unwrap(); // valid by construction Self { inner, metadata, @@ -207,7 +212,9 @@ impl VariantArray { typed_value_to_variant(typed_value, index) } } - ShreddingState::PartiallyShredded { value, typed_value, .. } => { + ShreddingState::PartiallyShredded { + value, typed_value, .. + } => { // PartiallyShredded case (formerly ImperfectlyShredded) if typed_value.is_null(index) { Variant::new(self.metadata.value(index), value.value(index)) @@ -313,9 +320,11 @@ impl ShreddedVariantFieldArray { }; // Extract value and typed_value fields (metadata is not expected in ShreddedVariantFieldArray) - let value = inner_struct.column_by_name("value").and_then(|col| col.as_binary_view_opt().cloned()); + let value = inner_struct + .column_by_name("value") + .and_then(|col| col.as_binary_view_opt().cloned()); let typed_value = inner_struct.column_by_name("typed_value").cloned(); - + // Use a dummy metadata for the constructor (ShreddedVariantFieldArray doesn't have metadata) let dummy_metadata = arrow::array::BinaryViewArray::new_null(inner_struct.len()); @@ -387,8 +396,8 @@ impl Array for ShreddedVariantFieldArray { } fn nulls(&self) -> Option<&NullBuffer> { - // According to the shredding spec, ShreddedVariantFieldArray should be - // physically non-nullable - SQL NULL is inferred by both value and + // According to the shredding spec, ShreddedVariantFieldArray should be + // physically non-nullable - SQL NULL is inferred by both value and // typed_value being physically NULL None } @@ -423,13 +432,13 @@ impl Array for ShreddedVariantFieldArray { #[derive(Debug)] pub enum ShreddingState { /// This variant has no typed_value field - Unshredded { + Unshredded { metadata: BinaryViewArray, value: BinaryViewArray, }, /// This variant has a typed_value field and no value field /// meaning it is the shredded type - Typed { + Typed { metadata: BinaryViewArray, typed_value: ArrayRef, }, @@ -454,9 +463,7 @@ pub enum ShreddingState { /// Note: By strict spec interpretation, this should only be valid for shredded object fields, /// not top-level variants. However, we allow it and treat as Variant::Null for pragmatic /// handling of missing data. - AllNull { - metadata: BinaryViewArray, - }, + AllNull { metadata: BinaryViewArray }, } impl ShreddingState { @@ -583,9 +590,38 @@ impl StructArrayBuilder { /// returns the non-null element at index as a Variant fn typed_value_to_variant(typed_value: &ArrayRef, index: usize) -> Variant<'_, '_> { match typed_value.data_type() { + DataType::Int8 => { + primitive_conversion_single_value!(Int8Type, typed_value, index) + } + DataType::Int16 => { + primitive_conversion_single_value!(Int16Type, typed_value, index) + } DataType::Int32 => { - let typed_value = typed_value.as_primitive::(); - Variant::from(typed_value.value(index)) + primitive_conversion_single_value!(Int32Type, typed_value, index) + } + DataType::Int64 => { + primitive_conversion_single_value!(Int64Type, typed_value, index) + } + DataType::UInt8 => { + primitive_conversion_single_value!(UInt8Type, typed_value, index) + } + DataType::UInt16 => { + primitive_conversion_single_value!(UInt16Type, typed_value, index) + } + DataType::UInt32 => { + primitive_conversion_single_value!(UInt32Type, typed_value, index) + } + DataType::UInt64 => { + primitive_conversion_single_value!(UInt64Type, typed_value, index) + } + DataType::Float16 => { + primitive_conversion_single_value!(Float16Type, typed_value, index) + } + DataType::Float32 => { + primitive_conversion_single_value!(Float32Type, typed_value, index) + } + DataType::Float64 => { + primitive_conversion_single_value!(Float64Type, typed_value, index) } // todo other types here (note this is very similar to cast_to_variant.rs) // so it would be great to figure out how to share this code diff --git a/parquet-variant-compute/src/variant_array_builder.rs b/parquet-variant-compute/src/variant_array_builder.rs index 969dc3776a81..d5f578421ed3 100644 --- a/parquet-variant-compute/src/variant_array_builder.rs +++ b/parquet-variant-compute/src/variant_array_builder.rs @@ -20,7 +20,8 @@ use crate::VariantArray; use arrow::array::{ArrayRef, BinaryViewArray, BinaryViewBuilder, NullBufferBuilder, StructArray}; use arrow_schema::{ArrowError, DataType, Field, Fields}; -use parquet_variant::{ListBuilder, ObjectBuilder, Variant, VariantBuilder, VariantBuilderExt}; +use parquet_variant::{ListBuilder, ObjectBuilder, Variant, VariantBuilderExt}; +use parquet_variant::{ParentState, ValueBuilder, WritableMetadataBuilder}; use std::sync::Arc; /// A builder for [`VariantArray`] @@ -49,8 +50,7 @@ use std::sync::Arc; /// let mut vb = builder.variant_builder(); /// vb.new_object() /// .with_field("foo", "bar") -/// .finish() -/// .unwrap(); +/// .finish(); /// vb.finish(); // must call finish to write the variant to the buffers /// /// // create the final VariantArray @@ -72,12 +72,12 @@ use std::sync::Arc; pub struct VariantArrayBuilder { /// Nulls nulls: NullBufferBuilder, - /// buffer for all the metadata - metadata_buffer: Vec, + /// builder for all the metadata + metadata_builder: WritableMetadataBuilder, /// ending offset for each serialized metadata dictionary in the buffer metadata_offsets: Vec, - /// buffer for values - value_buffer: Vec, + /// builder for values + value_builder: ValueBuilder, /// ending offset for each serialized variant value in the buffer value_offsets: Vec, /// The fields of the final `StructArray` @@ -95,9 +95,9 @@ impl VariantArrayBuilder { Self { nulls: NullBufferBuilder::new(row_capacity), - metadata_buffer: Vec::new(), // todo allocation capacity + metadata_builder: WritableMetadataBuilder::default(), metadata_offsets: Vec::with_capacity(row_capacity), - value_buffer: Vec::new(), + value_builder: ValueBuilder::new(), value_offsets: Vec::with_capacity(row_capacity), fields: Fields::from(vec![metadata_field, value_field]), } @@ -107,15 +107,17 @@ impl VariantArrayBuilder { pub fn build(self) -> VariantArray { let Self { mut nulls, - metadata_buffer, + metadata_builder, metadata_offsets, - value_buffer, + value_builder, value_offsets, fields, } = self; + let metadata_buffer = metadata_builder.into_inner(); let metadata_array = binary_view_array_from_buffers(metadata_buffer, metadata_offsets); + let value_buffer = value_builder.into_inner(); let value_array = binary_view_array_from_buffers(value_buffer, value_offsets); // The build the final struct array @@ -136,14 +138,14 @@ impl VariantArrayBuilder { pub fn append_null(&mut self) { self.nulls.append_null(); // The subfields are expected to be non-nullable according to the parquet variant spec. - self.metadata_offsets.push(self.metadata_buffer.len()); - self.value_offsets.push(self.value_buffer.len()); + self.metadata_offsets.push(self.metadata_builder.offset()); + self.value_offsets.push(self.value_builder.offset()); } /// Append the [`Variant`] to the builder as the next row pub fn append_variant(&mut self, variant: Variant) { let mut direct_builder = self.variant_builder(); - direct_builder.variant_builder.append_value(variant); + direct_builder.append_value(variant); direct_builder.finish() } @@ -169,8 +171,7 @@ impl VariantArrayBuilder { /// variant_builder /// .new_object() /// .with_field("my_field", 42i64) - /// .finish() - /// .unwrap(); + /// .finish(); /// variant_builder.finish(); /// /// // finalize the array @@ -194,32 +195,23 @@ impl VariantArrayBuilder { /// /// See [`VariantArrayBuilder::variant_builder`] for an example pub struct VariantArrayVariantBuilder<'a> { - /// was finish called? - finished: bool, - /// starting offset in the variant_builder's `metadata` buffer - metadata_offset: usize, - /// starting offset in the variant_builder's `value` buffer - value_offset: usize, - /// Parent array builder that this variant builder writes to. Buffers - /// have been moved into the variant builder, and must be returned on - /// drop - array_builder: &'a mut VariantArrayBuilder, - /// Builder for the in progress variant value, temporarily owns the buffers - /// from `array_builder` - variant_builder: VariantBuilder, + parent_state: ParentState<'a>, + metadata_offsets: &'a mut Vec, + value_offsets: &'a mut Vec, + nulls: &'a mut NullBufferBuilder, } impl VariantBuilderExt for VariantArrayVariantBuilder<'_> { fn append_value<'m, 'v>(&mut self, value: impl Into>) { - self.variant_builder.append_value(value); + ValueBuilder::append_variant(self.parent_state(), value.into()); } fn try_new_list(&mut self) -> Result, ArrowError> { - Ok(self.variant_builder.new_list()) + Ok(ListBuilder::new(self.parent_state(), false)) } fn try_new_object(&mut self) -> Result, ArrowError> { - Ok(self.variant_builder.new_object()) + Ok(ObjectBuilder::new(self.parent_state(), false)) } } @@ -228,103 +220,40 @@ impl<'a> VariantArrayVariantBuilder<'a> { /// /// Note this is not public as this is a structure that is logically /// part of the [`VariantArrayBuilder`] and relies on its internal structure - fn new(array_builder: &'a mut VariantArrayBuilder) -> Self { - // append directly into the metadata and value buffers - let metadata_buffer = std::mem::take(&mut array_builder.metadata_buffer); - let value_buffer = std::mem::take(&mut array_builder.value_buffer); - let metadata_offset = metadata_buffer.len(); - let value_offset = value_buffer.len(); + fn new(builder: &'a mut VariantArrayBuilder) -> Self { + let parent_state = + ParentState::variant(&mut builder.value_builder, &mut builder.metadata_builder); VariantArrayVariantBuilder { - finished: false, - metadata_offset, - value_offset, - variant_builder: VariantBuilder::new_with_buffers(metadata_buffer, value_buffer), - array_builder, + parent_state, + metadata_offsets: &mut builder.metadata_offsets, + value_offsets: &mut builder.value_offsets, + nulls: &mut builder.nulls, } } - /// Return a reference to the underlying `VariantBuilder` - pub fn inner(&self) -> &VariantBuilder { - &self.variant_builder - } - - /// Return a mutable reference to the underlying `VariantBuilder` - pub fn inner_mut(&mut self) -> &mut VariantBuilder { - &mut self.variant_builder - } - /// Called to finish the in progress variant and write it to the underlying /// buffers /// /// Note if you do not call finish, on drop any changes made to the /// underlying buffers will be rolled back. pub fn finish(mut self) { - self.finished = true; - - let metadata_offset = self.metadata_offset; - let value_offset = self.value_offset; - // get the buffers back from the variant builder - let (metadata_buffer, value_buffer) = std::mem::take(&mut self.variant_builder).finish(); - - // Sanity Check: if the buffers got smaller, something went wrong (previous data was lost) - assert!( - metadata_offset <= metadata_buffer.len(), - "metadata length decreased unexpectedly" - ); - assert!( - value_offset <= value_buffer.len(), - "value length decreased unexpectedly" - ); - - // commit the changes by putting the - // ending offsets into the parent array builder. - let builder = &mut self.array_builder; - builder.metadata_offsets.push(metadata_buffer.len()); - builder.value_offsets.push(value_buffer.len()); - builder.nulls.append_non_null(); + // Record the ending offsets after finishing metadata and finish the parent state. + let (value_builder, metadata_builder) = self.parent_state.value_and_metadata_builders(); + self.metadata_offsets.push(metadata_builder.finish()); + self.value_offsets.push(value_builder.offset()); + self.nulls.append_non_null(); + self.parent_state.finish(); + } - // put the buffers back into the array builder - builder.metadata_buffer = metadata_buffer; - builder.value_buffer = value_buffer; + fn parent_state(&mut self) -> ParentState<'_> { + let (value_builder, metadata_builder) = self.parent_state.value_and_metadata_builders(); + ParentState::variant(value_builder, metadata_builder) } } +// Empty Drop to help with borrow checking - warns users if they forget to call finish() impl Drop for VariantArrayVariantBuilder<'_> { - /// If the builder was not finished, roll back any changes made to the - /// underlying buffers (by truncating them) - fn drop(&mut self) { - if self.finished { - return; - } - - // if the object was not finished, need to rollback any changes by - // truncating the buffers to the original offsets - let metadata_offset = self.metadata_offset; - let value_offset = self.value_offset; - - // get the buffers back from the variant builder - let (mut metadata_buffer, mut value_buffer) = - std::mem::take(&mut self.variant_builder).into_buffers(); - - // Sanity Check: if the buffers got smaller, something went wrong (previous data was lost) so panic immediately - metadata_buffer - .len() - .checked_sub(metadata_offset) - .expect("metadata length decreased unexpectedly"); - value_buffer - .len() - .checked_sub(value_offset) - .expect("value length decreased unexpectedly"); - - // Note this truncate is fast because truncate doesn't free any memory: - // it just has to drop elements (and u8 doesn't have a destructor) - metadata_buffer.truncate(metadata_offset); - value_buffer.truncate(value_offset); - - // put the buffers back into the array builder - self.array_builder.metadata_buffer = metadata_buffer; - self.array_builder.value_buffer = value_buffer; - } + fn drop(&mut self) {} } fn binary_view_array_from_buffers(buffer: Vec, offsets: Vec) -> BinaryViewArray { @@ -388,11 +317,7 @@ mod test { // let's make a sub-object in the next row let mut sub_builder = builder.variant_builder(); - sub_builder - .new_object() - .with_field("foo", "bar") - .finish() - .unwrap(); + sub_builder.new_object().with_field("foo", "bar").finish(); sub_builder.finish(); // must call finish to write the variant to the buffers // append a new list @@ -426,29 +351,17 @@ mod test { // make a sub-object in the first row let mut sub_builder = builder.variant_builder(); - sub_builder - .new_object() - .with_field("foo", 1i32) - .finish() - .unwrap(); + sub_builder.new_object().with_field("foo", 1i32).finish(); sub_builder.finish(); // must call finish to write the variant to the buffers // start appending an object but don't finish let mut sub_builder = builder.variant_builder(); - sub_builder - .new_object() - .with_field("bar", 2i32) - .finish() - .unwrap(); + sub_builder.new_object().with_field("bar", 2i32).finish(); drop(sub_builder); // drop the sub builder without finishing it // make a third sub-object (this should reset the previous unfinished object) let mut sub_builder = builder.variant_builder(); - sub_builder - .new_object() - .with_field("baz", 3i32) - .finish() - .unwrap(); + sub_builder.new_object().with_field("baz", 3i32).finish(); sub_builder.finish(); // must call finish to write the variant to the buffers let variant_array = builder.build(); @@ -457,12 +370,18 @@ mod test { assert_eq!(variant_array.len(), 2); assert!(!variant_array.is_null(0)); let variant = variant_array.value(0); - let variant = variant.as_object().expect("variant to be an object"); - assert_eq!(variant.get("foo").unwrap(), Variant::from(1i32)); + assert_eq!( + variant.get_object_field("foo"), + Some(Variant::from(1i32)), + "Expected an object with field \"foo\", got: {variant:?}" + ); assert!(!variant_array.is_null(1)); let variant = variant_array.value(1); - let variant = variant.as_object().expect("variant to be an object"); - assert_eq!(variant.get("baz").unwrap(), Variant::from(3i32)); + assert_eq!( + variant.get_object_field("baz"), + Some(Variant::from(3i32)), + "Expected an object with field \"baz\", got: {variant:?}" + ); } } diff --git a/parquet-variant-compute/src/variant_get/mod.rs b/parquet-variant-compute/src/variant_get/mod.rs index 64d6c3980f65..10403b1369a6 100644 --- a/parquet-variant-compute/src/variant_get/mod.rs +++ b/parquet-variant-compute/src/variant_get/mod.rs @@ -74,13 +74,14 @@ pub(crate) fn follow_shredded_path_element<'a>( if !cast_options.safe { return Err(ArrowError::CastError(format!( "Cannot access field '{}' on non-struct type: {}", - name, typed_value.data_type() + name, + typed_value.data_type() ))); } // With safe cast options, return NULL (missing_path_step) return Ok(missing_path_step()); }; - + // Now try to find the column - missing column in a present struct is just missing data let Some(field) = struct_array.column_by_name(name) else { // Missing column in a present struct is just missing, not wrong - return Ok @@ -123,29 +124,29 @@ fn shredded_get_path( ) -> Result { // Helper that creates a new VariantArray from the given nested value and typed_value columns, // properly accounting for accumulated nulls from path traversal - let make_target_variant = |value: Option, typed_value: Option, accumulated_nulls: Option| { - let metadata = input.metadata_field().clone(); - VariantArray::from_parts( - metadata, - value, - typed_value, - accumulated_nulls, - ) - }; + let make_target_variant = + |value: Option, + typed_value: Option, + accumulated_nulls: Option| { + let metadata = input.metadata_field().clone(); + VariantArray::from_parts(metadata, value, typed_value, accumulated_nulls) + }; // Helper that shreds a VariantArray to a specific type. - let shred_basic_variant = |target: VariantArray, path: VariantPath<'_>, as_field: Option<&Field>| { - let as_type = as_field.map(|f| f.data_type()); - let mut builder = output::row_builder::make_shredding_row_builder(path, as_type)?; - for i in 0..target.len() { - if target.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(&target.value(i))?; + let shred_basic_variant = + |target: VariantArray, path: VariantPath<'_>, as_field: Option<&Field>| { + let as_type = as_field.map(|f| f.data_type()); + let mut builder = + output::row_builder::make_shredding_row_builder(path, as_type, cast_options)?; + for i in 0..target.len() { + if target.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(&target.value(i))?; + } } - } - builder.finish() - }; + builder.finish() + }; // Peel away the prefix of path elements that traverses the shredded parts of this variant // column. Shredding will traverse the rest of the path on a per-row basis. @@ -175,20 +176,17 @@ fn shredded_get_path( return Ok(arr); } ShreddedPathStep::NotShredded => { - let target = make_target_variant(shredding_state.value_field().cloned(), None, accumulated_nulls); + let target = make_target_variant( + shredding_state.value_field().cloned(), + None, + accumulated_nulls, + ); return shred_basic_variant(target, path[path_index..].into(), as_field); } }; } // Path exhausted! Create a new `VariantArray` for the location we landed on. - // Also union nulls from the final typed_value field we landed on - if let Some(typed_value) = shredding_state.typed_value_field() { - accumulated_nulls = arrow::buffer::NullBuffer::union( - accumulated_nulls.as_ref(), - typed_value.nulls(), - ); - } let target = make_target_variant( shredding_state.value_field().cloned(), shredding_state.typed_value_field().cloned(), @@ -246,7 +244,11 @@ pub fn variant_get(input: &ArrayRef, options: GetOptions) -> Result { ) })?; - let GetOptions { as_type, path, cast_options } = options; + let GetOptions { + as_type, + path, + cast_options, + } = options; shredded_get_path(variant_array, &path, as_type.as_deref(), &cast_options) } @@ -296,13 +298,18 @@ impl<'a> GetOptions<'a> { mod test { use std::sync::Arc; - use arrow::array::{Array, ArrayRef, BinaryViewArray, Int32Array, StringArray, StructArray}; + use arrow::array::{ + Array, ArrayRef, BinaryViewArray, Float16Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, StringArray, StructArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, + }; use arrow::buffer::NullBuffer; + use arrow::compute::CastOptions; use arrow_schema::{DataType, Field, FieldRef, Fields}; use parquet_variant::{Variant, VariantPath}; use crate::json_to_variant; - use crate::{VariantArray, variant_array::ShreddedVariantFieldArray}; + use crate::{variant_array::ShreddedVariantFieldArray, VariantArray}; use super::{variant_get, GetOptions}; @@ -387,29 +394,91 @@ mod test { ); } - /// Shredding: extract a value as a VariantArray + /// Partial Shredding: extract a value as a VariantArray + macro_rules! numeric_partially_shredded_test { + ($primitive_type:ty, $data_fn:ident) => { + let array = $data_fn(); + let options = GetOptions::new(); + let result = variant_get(&array, options).unwrap(); + + // expect the result is a VariantArray + let result: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!(result.len(), 4); + + // Expect the values are the same as the original values + assert_eq!( + result.value(0), + Variant::from(<$primitive_type>::try_from(34u8).unwrap()) + ); + assert!(!result.is_valid(1)); + assert_eq!(result.value(2), Variant::from("n/a")); + assert_eq!( + result.value(3), + Variant::from(<$primitive_type>::try_from(100u8).unwrap()) + ); + }; + } + #[test] - fn get_variant_shredded_int32_as_variant() { - let array = shredded_int32_variant_array(); - let options = GetOptions::new(); - let result = variant_get(&array, options).unwrap(); + fn get_variant_partially_shredded_int8_as_variant() { + numeric_partially_shredded_test!(i8, partially_shredded_int8_variant_array); + } - // expect the result is a VariantArray - let result: &VariantArray = result.as_any().downcast_ref().unwrap(); - assert_eq!(result.len(), 4); + #[test] + fn get_variant_partially_shredded_int16_as_variant() { + numeric_partially_shredded_test!(i16, partially_shredded_int16_variant_array); + } - // Expect the values are the same as the original values - assert_eq!(result.value(0), Variant::Int32(34)); - assert!(!result.is_valid(1)); - assert_eq!(result.value(2), Variant::from("n/a")); - assert_eq!(result.value(3), Variant::Int32(100)); + #[test] + fn get_variant_partially_shredded_int32_as_variant() { + numeric_partially_shredded_test!(i32, partially_shredded_int32_variant_array); + } + + #[test] + fn get_variant_partially_shredded_int64_as_variant() { + numeric_partially_shredded_test!(i64, partially_shredded_int64_variant_array); + } + + #[test] + fn get_variant_partially_shredded_uint8_as_variant() { + numeric_partially_shredded_test!(u8, partially_shredded_uint8_variant_array); + } + + #[test] + fn get_variant_partially_shredded_uint16_as_variant() { + numeric_partially_shredded_test!(u16, partially_shredded_uint16_variant_array); + } + + #[test] + fn get_variant_partially_shredded_uint32_as_variant() { + numeric_partially_shredded_test!(u32, partially_shredded_uint32_variant_array); + } + + #[test] + fn get_variant_partially_shredded_uint64_as_variant() { + numeric_partially_shredded_test!(u64, partially_shredded_uint64_variant_array); + } + + #[test] + fn get_variant_partially_shredded_float16_as_variant() { + numeric_partially_shredded_test!(half::f16, partially_shredded_float16_variant_array); + } + + #[test] + fn get_variant_partially_shredded_float32_as_variant() { + numeric_partially_shredded_test!(f32, partially_shredded_float32_variant_array); + } + + #[test] + fn get_variant_partially_shredded_float64_as_variant() { + numeric_partially_shredded_test!(f64, partially_shredded_float64_variant_array); } /// Shredding: extract a value as an Int32Array #[test] fn get_variant_shredded_int32_as_int32_safe_cast() { // Extract the typed value as Int32Array - let array = shredded_int32_variant_array(); + let array = partially_shredded_int32_variant_array(); // specify we want the typed value as Int32 let field = Field::new("typed_value", DataType::Int32, true); let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); @@ -423,21 +492,105 @@ mod test { assert_eq!(&result, &expected) } + /// Shredding: extract a value as an Int32Array, unsafe cast (should error on "n/a") + #[test] + fn get_variant_shredded_int32_as_int32_unsafe_cast() { + // Extract the typed value as Int32Array + let array = partially_shredded_int32_variant_array(); + let field = Field::new("typed_value", DataType::Int32, true); + let cast_options = CastOptions { + safe: false, // unsafe cast + ..Default::default() + }; + let options = GetOptions::new() + .with_as_type(Some(FieldRef::from(field))) + .with_cast_options(cast_options); + + let err = variant_get(&array, options).unwrap_err(); + // TODO make this error message nicer (not Debug format) + assert_eq!(err.to_string(), "Cast error: Failed to extract primitive of type Int32 from variant ShortString(ShortString(\"n/a\")) at path VariantPath([])"); + } + /// Perfect Shredding: extract the typed value as a VariantArray + macro_rules! numeric_perfectly_shredded_test { + ($primitive_type:ty, $data_fn:ident) => { + let array = $data_fn(); + let options = GetOptions::new(); + let result = variant_get(&array, options).unwrap(); + + // expect the result is a VariantArray + let result: &VariantArray = result.as_any().downcast_ref().unwrap(); + assert_eq!(result.len(), 3); + + // Expect the values are the same as the original values + assert_eq!( + result.value(0), + Variant::from(<$primitive_type>::try_from(1u8).unwrap()) + ); + assert_eq!( + result.value(1), + Variant::from(<$primitive_type>::try_from(2u8).unwrap()) + ); + assert_eq!( + result.value(2), + Variant::from(<$primitive_type>::try_from(3u8).unwrap()) + ); + }; + } + + #[test] + fn get_variant_perfectly_shredded_int8_as_variant() { + numeric_perfectly_shredded_test!(i8, perfectly_shredded_int8_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_int16_as_variant() { + numeric_perfectly_shredded_test!(i16, perfectly_shredded_int16_variant_array); + } + #[test] fn get_variant_perfectly_shredded_int32_as_variant() { - let array = perfectly_shredded_int32_variant_array(); - let options = GetOptions::new(); - let result = variant_get(&array, options).unwrap(); + numeric_perfectly_shredded_test!(i32, perfectly_shredded_int32_variant_array); + } - // expect the result is a VariantArray - let result: &VariantArray = result.as_any().downcast_ref().unwrap(); - assert_eq!(result.len(), 3); + #[test] + fn get_variant_perfectly_shredded_int64_as_variant() { + numeric_perfectly_shredded_test!(i64, perfectly_shredded_int64_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_uint8_as_variant() { + numeric_perfectly_shredded_test!(u8, perfectly_shredded_uint8_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_uint16_as_variant() { + numeric_perfectly_shredded_test!(u16, perfectly_shredded_uint16_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_uint32_as_variant() { + numeric_perfectly_shredded_test!(u32, perfectly_shredded_uint32_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_uint64_as_variant() { + numeric_perfectly_shredded_test!(u64, perfectly_shredded_uint64_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_float16_as_variant() { + numeric_perfectly_shredded_test!(half::f16, perfectly_shredded_float16_variant_array); + } - // Expect the values are the same as the original values - assert_eq!(result.value(0), Variant::Int32(1)); - assert_eq!(result.value(1), Variant::Int32(2)); - assert_eq!(result.value(2), Variant::Int32(3)); + #[test] + fn get_variant_perfectly_shredded_float32_as_variant() { + numeric_perfectly_shredded_test!(f32, perfectly_shredded_float32_variant_array); + } + + #[test] + fn get_variant_perfectly_shredded_float64_as_variant() { + numeric_perfectly_shredded_test!(f64, perfectly_shredded_float64_variant_array); } /// Shredding: Extract the typed value as Int32Array @@ -487,14 +640,20 @@ mod test { assert_eq!(&result, &expected) } + #[test] + fn get_variant_perfectly_shredded_int16_as_int16() { + // Extract the typed value as Int16Array + let array = perfectly_shredded_int16_variant_array(); + // specify we want the typed value as Int16 + let field = Field::new("typed_value", DataType::Int16, true); + let options = GetOptions::new().with_as_type(Some(FieldRef::from(field))); + let result = variant_get(&array, options).unwrap(); + let expected: ArrayRef = Arc::new(Int16Array::from(vec![Some(1), Some(2), Some(3)])); + assert_eq!(&result, &expected) + } + /// Return a VariantArray that represents a perfectly "shredded" variant - /// for the following example (3 Variant::Int32 values): - /// - /// ```text - /// 1 - /// 2 - /// 3 - /// ``` + /// for the given typed value. /// /// The schema of the corresponding `StructArray` would look like this: /// @@ -504,24 +663,88 @@ mod test { /// typed_value: Int32Array, /// } /// ``` - fn perfectly_shredded_int32_variant_array() -> ArrayRef { - // At the time of writing, the `VariantArrayBuilder` does not support shredding. - // so we must construct the array manually. see https://github.com/apache/arrow-rs/issues/7895 - let (metadata, _value) = { parquet_variant::VariantBuilder::new().finish() }; - - let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 3)); - let typed_value = Int32Array::from(vec![Some(1), Some(2), Some(3)]); - - let struct_array = crate::variant_array::StructArrayBuilder::new() - .with_field("metadata", Arc::new(metadata)) - .with_field("typed_value", Arc::new(typed_value)) - .build(); - - Arc::new( - VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"), - ) + macro_rules! numeric_perfectly_shredded_variant_array_fn { + ($func:ident, $array_type:ident, $primitive_type:ty) => { + fn $func() -> ArrayRef { + // At the time of writing, the `VariantArrayBuilder` does not support shredding. + // so we must construct the array manually. see https://github.com/apache/arrow-rs/issues/7895 + let (metadata, _value) = { parquet_variant::VariantBuilder::new().finish() }; + let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 3)); + let typed_value = $array_type::from(vec![ + Some(<$primitive_type>::try_from(1u8).unwrap()), + Some(<$primitive_type>::try_from(2u8).unwrap()), + Some(<$primitive_type>::try_from(3u8).unwrap()), + ]); + + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata)) + .with_field("typed_value", Arc::new(typed_value)) + .build(); + + Arc::new( + VariantArray::try_new(Arc::new(struct_array)) + .expect("should create variant array"), + ) + } + }; } + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_int8_variant_array, + Int8Array, + i8 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_int16_variant_array, + Int16Array, + i16 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_int32_variant_array, + Int32Array, + i32 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_int64_variant_array, + Int64Array, + i64 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_uint8_variant_array, + UInt8Array, + u8 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_uint16_variant_array, + UInt16Array, + u16 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_uint32_variant_array, + UInt32Array, + u32 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_uint64_variant_array, + UInt64Array, + u64 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_float16_variant_array, + Float16Array, + half::f16 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_float32_variant_array, + Float32Array, + f32 + ); + numeric_perfectly_shredded_variant_array_fn!( + perfectly_shredded_float64_variant_array, + Float64Array, + f64 + ); + /// Return a VariantArray that represents a normal "shredded" variant /// for the following example /// @@ -545,53 +768,114 @@ mod test { /// typed_value: Int32Array, /// } /// ``` - fn shredded_int32_variant_array() -> ArrayRef { - // At the time of writing, the `VariantArrayBuilder` does not support shredding. - // so we must construct the array manually. see https://github.com/apache/arrow-rs/issues/7895 - let (metadata, string_value) = { - let mut builder = parquet_variant::VariantBuilder::new(); - builder.append_value("n/a"); - builder.finish() - }; - - let nulls = NullBuffer::from(vec![ - true, // row 0 non null - false, // row 1 is null - true, // row 2 non null - true, // row 3 non null - ]); - - // metadata is the same for all rows - let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 4)); - - // See https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?disco=AAABml8WQrY - // about why row1 is an empty but non null, value. - let values = BinaryViewArray::from(vec![ - None, // row 0 is shredded, so no value - Some(b"" as &[u8]), // row 1 is null, so empty value (why?) - Some(&string_value), // copy the string value "N/A" - None, // row 3 is shredded, so no value - ]); - - let typed_value = Int32Array::from(vec![ - Some(34), // row 0 is shredded, so it has a value - None, // row 1 is null, so no value - None, // row 2 is a string, so no typed value - Some(100), // row 3 is shredded, so it has a value - ]); - - let struct_array = crate::variant_array::StructArrayBuilder::new() - .with_field("metadata", Arc::new(metadata)) - .with_field("typed_value", Arc::new(typed_value)) - .with_field("value", Arc::new(values)) - .with_nulls(nulls) - .build(); + macro_rules! numeric_partially_shredded_variant_array_fn { + ($func:ident, $array_type:ident, $primitive_type:ty) => { + fn $func() -> ArrayRef { + // At the time of writing, the `VariantArrayBuilder` does not support shredding. + // so we must construct the array manually. see https://github.com/apache/arrow-rs/issues/7895 + let (metadata, string_value) = { + let mut builder = parquet_variant::VariantBuilder::new(); + builder.append_value("n/a"); + builder.finish() + }; - Arc::new( - VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"), - ) + let nulls = NullBuffer::from(vec![ + true, // row 0 non null + false, // row 1 is null + true, // row 2 non null + true, // row 3 non null + ]); + + // metadata is the same for all rows + let metadata = BinaryViewArray::from_iter_values(std::iter::repeat_n(&metadata, 4)); + + // See https://docs.google.com/document/d/1pw0AWoMQY3SjD7R4LgbPvMjG_xSCtXp3rZHkVp9jpZ4/edit?disco=AAABml8WQrY + // about why row1 is an empty but non null, value. + let values = BinaryViewArray::from(vec![ + None, // row 0 is shredded, so no value + Some(b"" as &[u8]), // row 1 is null, so empty value (why?) + Some(&string_value), // copy the string value "N/A" + None, // row 3 is shredded, so no value + ]); + + let typed_value = $array_type::from(vec![ + Some(<$primitive_type>::try_from(34u8).unwrap()), // row 0 is shredded, so it has a value + None, // row 1 is null, so no value + None, // row 2 is a string, so no typed value + Some(<$primitive_type>::try_from(100u8).unwrap()), // row 3 is shredded, so it has a value + ]); + + let struct_array = StructArrayBuilder::new() + .with_field("metadata", Arc::new(metadata)) + .with_field("typed_value", Arc::new(typed_value)) + .with_field("value", Arc::new(values)) + .with_nulls(nulls) + .build(); + + Arc::new( + VariantArray::try_new(Arc::new(struct_array)) + .expect("should create variant array"), + ) + } + }; } + numeric_partially_shredded_variant_array_fn!( + partially_shredded_int8_variant_array, + Int8Array, + i8 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_int16_variant_array, + Int16Array, + i16 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_int32_variant_array, + Int32Array, + i32 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_int64_variant_array, + Int64Array, + i64 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_uint8_variant_array, + UInt8Array, + u8 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_uint16_variant_array, + UInt16Array, + u16 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_uint32_variant_array, + UInt32Array, + u32 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_uint64_variant_array, + UInt64Array, + u64 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_float16_variant_array, + Float16Array, + half::f16 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_float32_variant_array, + Float32Array, + f32 + ); + numeric_partially_shredded_variant_array_fn!( + partially_shredded_float64_variant_array, + Float64Array, + f64 + ); + /// Builds struct arrays from component fields /// /// TODO: move to arrow crate @@ -636,7 +920,7 @@ mod test { /// /// ```text /// null - /// null + /// null /// null /// ``` /// @@ -668,21 +952,21 @@ mod test { VariantArray::try_new(Arc::new(struct_array)).expect("should create variant array"), ) } - /// This test manually constructs a shredded variant array representing objects + /// This test manually constructs a shredded variant array representing objects /// like {"x": 1, "y": "foo"} and {"x": 42} and tests extracting the "x" field /// as VariantArray using variant_get. #[test] fn test_shredded_object_field_access() { let array = shredded_object_with_x_field_variant_array(); - + // Test: Extract the "x" field as VariantArray first let options = GetOptions::new_with_path(VariantPath::from("x")); let result = variant_get(&array, options).unwrap(); - + let result_variant: &VariantArray = result.as_any().downcast_ref().unwrap(); assert_eq!(result_variant.len(), 2); - - // Row 0: expect x=1 + + // Row 0: expect x=1 assert_eq!(result_variant.value(0), Variant::Int32(1)); // Row 1: expect x=42 assert_eq!(result_variant.value(1), Variant::Int32(42)); @@ -692,37 +976,37 @@ mod test { #[test] fn test_shredded_object_field_as_int32() { let array = shredded_object_with_x_field_variant_array(); - + // Test: Extract the "x" field as Int32Array (type conversion) let field = Field::new("x", DataType::Int32, false); let options = GetOptions::new_with_path(VariantPath::from("x")) .with_as_type(Some(FieldRef::from(field))); let result = variant_get(&array, options).unwrap(); - + // Should get Int32Array let expected: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(42)])); assert_eq!(&result, &expected); } - /// Helper function to create a shredded variant array representing objects - /// + /// Helper function to create a shredded variant array representing objects + /// /// This creates an array that represents: /// Row 0: {"x": 1, "y": "foo"} (x is shredded, y is in value field) /// Row 1: {"x": 42} (x is shredded, perfect shredding) /// /// The physical layout follows the shredding spec where: - /// - metadata: contains object metadata + /// - metadata: contains object metadata /// - typed_value: StructArray with field "x" (ShreddedVariantFieldArray) /// - value: contains fallback for unshredded fields like {"y": "foo"} /// - The "x" field has typed_value=Int32Array and value=NULL (perfect shredding) fn shredded_object_with_x_field_variant_array() -> ArrayRef { - // Create the base metadata for objects + // Create the base metadata for objects let (metadata, y_field_value) = { let mut builder = parquet_variant::VariantBuilder::new(); let mut obj = builder.new_object(); obj.insert("x", Variant::Int32(42)); obj.insert("y", Variant::from("foo")); - obj.finish().unwrap(); + obj.finish(); builder.finish() }; @@ -736,89 +1020,103 @@ mod test { let empty_object_value = { let mut builder = parquet_variant::VariantBuilder::new(); let obj = builder.new_object(); - obj.finish().unwrap(); + obj.finish(); let (_, value) = builder.finish(); value }; - + let value_array = BinaryViewArray::from(vec![ - Some(y_field_value.as_slice()), // Row 0 has {"y": "foo"} - Some(empty_object_value.as_slice()), // Row 1 has {} + Some(y_field_value.as_slice()), // Row 0 has {"y": "foo"} + Some(empty_object_value.as_slice()), // Row 1 has {} ]); // Create the "x" field as a ShreddedVariantFieldArray // This represents the shredded Int32 values for the "x" field let x_field_typed_value = Int32Array::from(vec![Some(1), Some(42)]); - + // For perfect shredding of the x field, no "value" column, only typed_value let x_field_struct = crate::variant_array::StructArrayBuilder::new() .with_field("typed_value", Arc::new(x_field_typed_value)) .build(); - + // Wrap the x field struct in a ShreddedVariantFieldArray let x_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(x_field_struct)) .expect("should create ShreddedVariantFieldArray"); // Create the main typed_value as a struct containing the "x" field - let typed_value_fields = Fields::from(vec![ - Field::new("x", x_field_shredded.data_type().clone(), true) - ]); + let typed_value_fields = Fields::from(vec![Field::new( + "x", + x_field_shredded.data_type().clone(), + true, + )]); let typed_value_struct = StructArray::try_new( typed_value_fields, vec![Arc::new(x_field_shredded)], None, // No nulls - both rows have the object structure - ).unwrap(); + ) + .unwrap(); - // Create the main VariantArray + // Create the main VariantArray let main_struct = crate::variant_array::StructArrayBuilder::new() .with_field("metadata", Arc::new(metadata_array)) .with_field("value", Arc::new(value_array)) .with_field("typed_value", Arc::new(typed_value_struct)) .build(); - Arc::new( - VariantArray::try_new(Arc::new(main_struct)).expect("should create variant array"), - ) + Arc::new(VariantArray::try_new(Arc::new(main_struct)).expect("should create variant array")) } /// Simple test to check if nested paths are supported by current implementation - #[test] + #[test] fn test_simple_nested_path_support() { // Check: How does VariantPath parse different strings? println!("Testing path parsing:"); - + let path_x = VariantPath::from("x"); let elements_x: Vec<_> = path_x.iter().collect(); println!(" 'x' -> {} elements: {:?}", elements_x.len(), elements_x); - + let path_ax = VariantPath::from("a.x"); let elements_ax: Vec<_> = path_ax.iter().collect(); - println!(" 'a.x' -> {} elements: {:?}", elements_ax.len(), elements_ax); - + println!( + " 'a.x' -> {} elements: {:?}", + elements_ax.len(), + elements_ax + ); + let path_ax_alt = VariantPath::from("$.a.x"); let elements_ax_alt: Vec<_> = path_ax_alt.iter().collect(); - println!(" '$.a.x' -> {} elements: {:?}", elements_ax_alt.len(), elements_ax_alt); - + println!( + " '$.a.x' -> {} elements: {:?}", + elements_ax_alt.len(), + elements_ax_alt + ); + let path_nested = VariantPath::from("a").join("x"); let elements_nested: Vec<_> = path_nested.iter().collect(); - println!(" VariantPath::from('a').join('x') -> {} elements: {:?}", elements_nested.len(), elements_nested); - + println!( + " VariantPath::from('a').join('x') -> {} elements: {:?}", + elements_nested.len(), + elements_nested + ); + // Use your existing simple test data but try "a.x" instead of "x" let array = shredded_object_with_x_field_variant_array(); - + // Test if variant_get with REAL nested path throws not implemented error let real_nested_path = VariantPath::from("a").join("x"); let options = GetOptions::new_with_path(real_nested_path); let result = variant_get(&array, options); - + match result { Ok(_) => { println!("Nested path 'a.x' works unexpectedly!"); - }, + } Err(e) => { println!("Nested path 'a.x' error: {}", e); - if e.to_string().contains("not yet implemented") - || e.to_string().contains("NotYetImplemented") { + if e.to_string().contains("not yet implemented") + || e.to_string().contains("NotYetImplemented") + { println!("This is expected - nested paths are not implemented"); return; } @@ -834,36 +1132,34 @@ mod test { #[test] fn test_depth_0_int32_conversion() { println!("=== Testing Depth 0: Direct field access ==="); - + // Non-shredded test data: [{"x": 42}, {"x": "foo"}, {"y": 10}] let unshredded_array = create_depth_0_test_data(); - + let field = Field::new("result", DataType::Int32, true); let path = VariantPath::from("x"); - let options = GetOptions::new_with_path(path) - .with_as_type(Some(FieldRef::from(field))); + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); let result = variant_get(&unshredded_array, options).unwrap(); - + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(42), // {"x": 42} -> 42 - None, // {"x": "foo"} -> NULL (type mismatch) - None, // {"y": 10} -> NULL (field missing) + Some(42), // {"x": 42} -> 42 + None, // {"x": "foo"} -> NULL (type mismatch) + None, // {"y": 10} -> NULL (field missing) ])); assert_eq!(&result, &expected); println!("Depth 0 (unshredded) passed"); - + // Shredded test data: using simplified approach based on working pattern let shredded_array = create_depth_0_shredded_test_data_simple(); - + let field = Field::new("result", DataType::Int32, true); let path = VariantPath::from("x"); - let options = GetOptions::new_with_path(path) - .with_as_type(Some(FieldRef::from(field))); + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); let result = variant_get(&shredded_array, options).unwrap(); - + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(42), // {"x": 42} -> 42 (from typed_value) - None, // {"x": "foo"} -> NULL (type mismatch, from value field) + Some(42), // {"x": 42} -> 42 (from typed_value) + None, // {"x": "foo"} -> NULL (type mismatch, from value field) ])); assert_eq!(&result, &expected); println!("Depth 0 (shredded) passed"); @@ -874,35 +1170,33 @@ mod test { #[test] fn test_depth_1_int32_conversion() { println!("=== Testing Depth 1: Single nested field access ==="); - + // Non-shredded test data from the GitHub issue let unshredded_array = create_nested_path_test_data(); - + let field = Field::new("result", DataType::Int32, true); let path = VariantPath::from("a.x"); // Dot notation! - let options = GetOptions::new_with_path(path) - .with_as_type(Some(FieldRef::from(field))); + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); let result = variant_get(&unshredded_array, options).unwrap(); - + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(55), // {"a": {"x": 55}} -> 55 - None, // {"a": {"x": "foo"}} -> NULL (type mismatch) + Some(55), // {"a": {"x": 55}} -> 55 + None, // {"a": {"x": "foo"}} -> NULL (type mismatch) ])); assert_eq!(&result, &expected); println!("Depth 1 (unshredded) passed"); - - // Shredded test data: depth 1 nested shredding + + // Shredded test data: depth 1 nested shredding let shredded_array = create_depth_1_shredded_test_data_working(); - + let field = Field::new("result", DataType::Int32, true); let path = VariantPath::from("a.x"); // Dot notation! - let options = GetOptions::new_with_path(path) - .with_as_type(Some(FieldRef::from(field))); + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); let result = variant_get(&shredded_array, options).unwrap(); - + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(55), // {"a": {"x": 55}} -> 55 (from nested shredded x) - None, // {"a": {"x": "foo"}} -> NULL (type mismatch in nested value) + Some(55), // {"a": {"x": 55}} -> 55 (from nested shredded x) + None, // {"a": {"x": "foo"}} -> NULL (type mismatch in nested value) ])); assert_eq!(&result, &expected); println!("Depth 1 (shredded) passed"); @@ -913,16 +1207,15 @@ mod test { #[test] fn test_depth_2_int32_conversion() { println!("=== Testing Depth 2: Double nested field access ==="); - + // Non-shredded test data: [{"a": {"b": {"x": 100}}}, {"a": {"b": {"x": "bar"}}}, {"a": {"b": {"y": 200}}}] let unshredded_array = create_depth_2_test_data(); - + let field = Field::new("result", DataType::Int32, true); let path = VariantPath::from("a.b.x"); // Double nested dot notation! - let options = GetOptions::new_with_path(path) - .with_as_type(Some(FieldRef::from(field))); + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); let result = variant_get(&unshredded_array, options).unwrap(); - + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ Some(100), // {"a": {"b": {"x": 100}}} -> 100 None, // {"a": {"b": {"x": "bar"}}} -> NULL (type mismatch) @@ -930,16 +1223,15 @@ mod test { ])); assert_eq!(&result, &expected); println!("Depth 2 (unshredded) passed"); - - // Shredded test data: depth 2 nested shredding + + // Shredded test data: depth 2 nested shredding let shredded_array = create_depth_2_shredded_test_data_working(); - + let field = Field::new("result", DataType::Int32, true); let path = VariantPath::from("a.b.x"); // Double nested dot notation! - let options = GetOptions::new_with_path(path) - .with_as_type(Some(FieldRef::from(field))); + let options = GetOptions::new_with_path(path).with_as_type(Some(FieldRef::from(field))); let result = variant_get(&shredded_array, options).unwrap(); - + let expected: ArrayRef = Arc::new(Int32Array::from(vec![ Some(100), // {"a": {"b": {"x": 100}}} -> 100 (from deeply nested shredded x) None, // {"a": {"b": {"x": "bar"}}} -> NULL (type mismatch in deep value) @@ -950,29 +1242,29 @@ mod test { } /// Test that demonstrates what CURRENTLY WORKS - /// + /// /// This shows that nested path functionality does work, but only when the /// test data matches what the current implementation expects #[test] fn test_current_nested_path_functionality() { let array = shredded_object_with_x_field_variant_array(); - + // Test: Extract the "x" field (single level) - this works let single_path = VariantPath::from("x"); let field = Field::new("result", DataType::Int32, true); - let options = GetOptions::new_with_path(single_path) - .with_as_type(Some(FieldRef::from(field))); + let options = + GetOptions::new_with_path(single_path).with_as_type(Some(FieldRef::from(field))); let result = variant_get(&array, options).unwrap(); - + println!("Single path 'x' works - result: {:?}", result); - + // Test: Try nested path "a.x" - this is what we need to implement let nested_path = VariantPath::from("a").join("x"); let field = Field::new("result", DataType::Int32, true); - let options = GetOptions::new_with_path(nested_path) - .with_as_type(Some(FieldRef::from(field))); + let options = + GetOptions::new_with_path(nested_path).with_as_type(Some(FieldRef::from(field))); let result = variant_get(&array, options).unwrap(); - + println!("Nested path 'a.x' result: {:?}", result); } @@ -980,7 +1272,7 @@ mod test { /// [{"x": 42}, {"x": "foo"}, {"y": 10}] fn create_depth_0_test_data() -> ArrayRef { let mut builder = crate::VariantArrayBuilder::new(3); - + // Row 1: {"x": 42} { let json_str = r#"{"x": 42}"#; @@ -991,7 +1283,7 @@ mod test { builder.append_null(); } } - + // Row 2: {"x": "foo"} { let json_str = r#"{"x": "foo"}"#; @@ -1002,7 +1294,7 @@ mod test { builder.append_null(); } } - + // Row 3: {"y": 10} (missing "x" field) { let json_str = r#"{"y": 10}"#; @@ -1013,7 +1305,7 @@ mod test { builder.append_null(); } } - + Arc::new(builder.build()) } @@ -1021,7 +1313,7 @@ mod test { /// This represents the exact scenarios from the GitHub issue: "a.x" fn create_nested_path_test_data() -> ArrayRef { let mut builder = crate::VariantArrayBuilder::new(2); - + // Row 1: {"a": {"x": 55}, "b": 42} { let json_str = r#"{"a": {"x": 55}, "b": 42}"#; @@ -1032,7 +1324,7 @@ mod test { builder.append_null(); } } - + // Row 2: {"a": {"x": "foo"}, "b": 42} { let json_str = r#"{"a": {"x": "foo"}, "b": 42}"#; @@ -1043,7 +1335,7 @@ mod test { builder.append_null(); } } - + Arc::new(builder.build()) } @@ -1051,7 +1343,7 @@ mod test { /// [{"a": {"b": {"x": 100}}}, {"a": {"b": {"x": "bar"}}}, {"a": {"b": {"y": 200}}}] fn create_depth_2_test_data() -> ArrayRef { let mut builder = crate::VariantArrayBuilder::new(3); - + // Row 1: {"a": {"b": {"x": 100}}} { let json_str = r#"{"a": {"b": {"x": 100}}}"#; @@ -1062,7 +1354,7 @@ mod test { builder.append_null(); } } - + // Row 2: {"a": {"b": {"x": "bar"}}} { let json_str = r#"{"a": {"b": {"x": "bar"}}}"#; @@ -1073,7 +1365,7 @@ mod test { builder.append_null(); } } - + // Row 3: {"a": {"b": {"y": 200}}} (missing "x" field) { let json_str = r#"{"a": {"b": {"y": 200}}}"#; @@ -1084,7 +1376,7 @@ mod test { builder.append_null(); } } - + Arc::new(builder.build()) } @@ -1096,7 +1388,7 @@ mod test { let mut builder = parquet_variant::VariantBuilder::new(); let mut obj = builder.new_object(); obj.insert("x", Variant::from("foo")); - obj.finish().unwrap(); + obj.finish(); builder.finish() }; @@ -1109,36 +1401,36 @@ mod test { let empty_object_value = { let mut builder = parquet_variant::VariantBuilder::new(); let obj = builder.new_object(); - obj.finish().unwrap(); + obj.finish(); let (_, value) = builder.finish(); value }; - + let value_array = BinaryViewArray::from(vec![ - Some(empty_object_value.as_slice()), // Row 0: {} (x shredded out) - Some(string_x_value.as_slice()), // Row 1: {"x": "foo"} (fallback) + Some(empty_object_value.as_slice()), // Row 0: {} (x shredded out) + Some(string_x_value.as_slice()), // Row 1: {"x": "foo"} (fallback) ]); - // Create the "x" field as a ShreddedVariantFieldArray + // Create the "x" field as a ShreddedVariantFieldArray let x_field_typed_value = Int32Array::from(vec![Some(42), None]); - + // For the x field, only typed_value (perfect shredding when possible) let x_field_struct = crate::variant_array::StructArrayBuilder::new() .with_field("typed_value", Arc::new(x_field_typed_value)) .build(); - + let x_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(x_field_struct)) .expect("should create ShreddedVariantFieldArray"); // Create the main typed_value as a struct containing the "x" field - let typed_value_fields = Fields::from(vec![ - Field::new("x", x_field_shredded.data_type().clone(), true) - ]); - let typed_value_struct = StructArray::try_new( - typed_value_fields, - vec![Arc::new(x_field_shredded)], - None, - ).unwrap(); + let typed_value_fields = Fields::from(vec![Field::new( + "x", + x_field_shredded.data_type().clone(), + true, + )]); + let typed_value_struct = + StructArray::try_new(typed_value_fields, vec![Arc::new(x_field_shredded)], None) + .unwrap(); // Build final VariantArray let struct_array = crate::variant_array::StructArrayBuilder::new() @@ -1147,9 +1439,7 @@ mod test { .with_field("typed_value", Arc::new(typed_value_struct)) .build(); - Arc::new( - VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray"), - ) + Arc::new(VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray")) } /// Create working depth 1 shredded test data based on the existing working pattern @@ -1160,18 +1450,16 @@ mod test { // Create metadata following the working pattern from shredded_object_with_x_field_variant_array let (metadata, _) = { // Create nested structure: {"a": {"x": 55}, "b": 42} - let a_variant = { - let mut a_builder = parquet_variant::VariantBuilder::new(); - let mut a_obj = a_builder.new_object(); - a_obj.insert("x", Variant::Int32(55)); // "a.x" field (shredded when possible) - a_obj.finish().unwrap() - }; - let mut builder = parquet_variant::VariantBuilder::new(); let mut obj = builder.new_object(); - obj.insert("a", a_variant); + + // Create the nested "a" object + let mut a_obj = obj.new_object("a"); + a_obj.insert("x", Variant::Int32(55)); + a_obj.finish(); + obj.insert("b", Variant::Int32(42)); - obj.finish().unwrap(); + obj.finish(); builder.finish() }; @@ -1182,25 +1470,25 @@ mod test { let empty_object_value = { let mut builder = parquet_variant::VariantBuilder::new(); let obj = builder.new_object(); - obj.finish().unwrap(); + obj.finish(); let (_, value) = builder.finish(); value }; - // Row 1 fallback: use the working pattern from the existing shredded test + // Row 1 fallback: use the working pattern from the existing shredded test // This avoids metadata issues by using the simple fallback approach let row1_fallback = { let mut builder = parquet_variant::VariantBuilder::new(); let mut obj = builder.new_object(); obj.insert("fallback", Variant::from("data")); - obj.finish().unwrap(); + obj.finish(); let (_, value) = builder.finish(); value }; let value_array = BinaryViewArray::from(vec![ - Some(empty_object_value.as_slice()), // Row 0: {} (everything shredded except b in unshredded fields) - Some(row1_fallback.as_slice()), // Row 1: {"a": {"x": "foo"}, "b": 42} (a.x can't be shredded) + Some(empty_object_value.as_slice()), // Row 0: {} (everything shredded except b in unshredded fields) + Some(row1_fallback.as_slice()), // Row 1: {"a": {"x": "foo"}, "b": 42} (a.x can't be shredded) ]); // Create the nested shredded structure @@ -1214,43 +1502,47 @@ mod test { // Level 1: a field containing x field + value field for fallbacks // The "a" field needs both typed_value (for shredded x) and value (for fallback cases) - + // Create the value field for "a" (for cases where a.x can't be shredded) let a_value_data = { let mut builder = parquet_variant::VariantBuilder::new(); let obj = builder.new_object(); - obj.finish().unwrap(); + obj.finish(); let (_, value) = builder.finish(); value }; let a_value_array = BinaryViewArray::from(vec![ - None, // Row 0: x is shredded, so no value fallback needed - Some(a_value_data.as_slice()), // Row 1: fallback for a.x="foo" (but logic will check typed_value first) - ]); - - let a_inner_fields = Fields::from(vec![ - Field::new("x", x_field_shredded.data_type().clone(), true) + None, // Row 0: x is shredded, so no value fallback needed + Some(a_value_data.as_slice()), // Row 1: fallback for a.x="foo" (but logic will check typed_value first) ]); + + let a_inner_fields = Fields::from(vec![Field::new( + "x", + x_field_shredded.data_type().clone(), + true, + )]); let a_inner_struct = crate::variant_array::StructArrayBuilder::new() - .with_field("typed_value", Arc::new(StructArray::try_new( - a_inner_fields, - vec![Arc::new(x_field_shredded)], - None, - ).unwrap())) + .with_field( + "typed_value", + Arc::new( + StructArray::try_new(a_inner_fields, vec![Arc::new(x_field_shredded)], None) + .unwrap(), + ), + ) .with_field("value", Arc::new(a_value_array)) .build(); let a_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(a_inner_struct)) .expect("should create ShreddedVariantFieldArray for a"); // Level 0: main typed_value struct containing a field - let typed_value_fields = Fields::from(vec![ - Field::new("a", a_field_shredded.data_type().clone(), true) - ]); - let typed_value_struct = StructArray::try_new( - typed_value_fields, - vec![Arc::new(a_field_shredded)], - None, - ).unwrap(); + let typed_value_fields = Fields::from(vec![Field::new( + "a", + a_field_shredded.data_type().clone(), + true, + )]); + let typed_value_struct = + StructArray::try_new(typed_value_fields, vec![Arc::new(a_field_shredded)], None) + .unwrap(); // Build final VariantArray let struct_array = crate::variant_array::StructArrayBuilder::new() @@ -1259,38 +1551,29 @@ mod test { .with_field("typed_value", Arc::new(typed_value_struct)) .build(); - Arc::new( - VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray"), - ) + Arc::new(VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray")) } /// Create working depth 2 shredded test data for "a.b.x" paths /// This creates a 3-level nested shredded structure where: /// - Row 0: {"a": {"b": {"x": 100}}} with a.b.x shredded into typed_value - /// - Row 1: {"a": {"b": {"x": "bar"}}} with type mismatch fallback + /// - Row 1: {"a": {"b": {"x": "bar"}}} with type mismatch fallback /// - Row 2: {"a": {"b": {"y": 200}}} with missing field fallback fn create_depth_2_shredded_test_data_working() -> ArrayRef { // Create metadata following the working pattern let (metadata, _) = { // Create deeply nested structure: {"a": {"b": {"x": 100}}} - let b_variant = { - let mut b_builder = parquet_variant::VariantBuilder::new(); - let mut b_obj = b_builder.new_object(); - b_obj.insert("x", Variant::Int32(100)); - b_obj.finish().unwrap() - }; - - let a_variant = { - let mut a_builder = parquet_variant::VariantBuilder::new(); - let mut a_obj = a_builder.new_object(); - a_obj.insert("b", b_variant); - a_obj.finish().unwrap() - }; - let mut builder = parquet_variant::VariantBuilder::new(); let mut obj = builder.new_object(); - obj.insert("a", a_variant); // "a" field containing b - obj.finish().unwrap(); + + // Create the nested "a.b" structure + let mut a_obj = obj.new_object("a"); + let mut b_obj = a_obj.new_object("b"); + b_obj.insert("x", Variant::Int32(100)); + b_obj.finish(); + a_obj.finish(); + + obj.finish(); builder.finish() }; @@ -1300,11 +1583,11 @@ mod test { let empty_object_value = { let mut builder = parquet_variant::VariantBuilder::new(); let obj = builder.new_object(); - obj.finish().unwrap(); + obj.finish(); let (_, value) = builder.finish(); value }; - + // Simple fallback values - avoiding complex nested metadata let value_array = BinaryViewArray::from(vec![ Some(empty_object_value.as_slice()), // Row 0: fully shredded @@ -1313,7 +1596,7 @@ mod test { ]); // Create the deeply nested shredded structure: a.b.x - + // Level 3: x field (deepest level) let x_typed_value = Int32Array::from(vec![Some(100), None, None]); let x_field_struct = crate::variant_array::StructArrayBuilder::new() @@ -1326,25 +1609,29 @@ mod test { let b_value_data = { let mut builder = parquet_variant::VariantBuilder::new(); let obj = builder.new_object(); - obj.finish().unwrap(); + obj.finish(); let (_, value) = builder.finish(); value }; let b_value_array = BinaryViewArray::from(vec![ - None, // Row 0: x is shredded - Some(b_value_data.as_slice()), // Row 1: fallback for b.x="bar" - Some(b_value_data.as_slice()), // Row 2: fallback for b.y=200 - ]); - - let b_inner_fields = Fields::from(vec![ - Field::new("x", x_field_shredded.data_type().clone(), true) + None, // Row 0: x is shredded + Some(b_value_data.as_slice()), // Row 1: fallback for b.x="bar" + Some(b_value_data.as_slice()), // Row 2: fallback for b.y=200 ]); + + let b_inner_fields = Fields::from(vec![Field::new( + "x", + x_field_shredded.data_type().clone(), + true, + )]); let b_inner_struct = crate::variant_array::StructArrayBuilder::new() - .with_field("typed_value", Arc::new(StructArray::try_new( - b_inner_fields, - vec![Arc::new(x_field_shredded)], - None, - ).unwrap())) + .with_field( + "typed_value", + Arc::new( + StructArray::try_new(b_inner_fields, vec![Arc::new(x_field_shredded)], None) + .unwrap(), + ), + ) .with_field("value", Arc::new(b_value_array)) .build(); let b_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(b_inner_struct)) @@ -1354,39 +1641,43 @@ mod test { let a_value_data = { let mut builder = parquet_variant::VariantBuilder::new(); let obj = builder.new_object(); - obj.finish().unwrap(); + obj.finish(); let (_, value) = builder.finish(); value }; let a_value_array = BinaryViewArray::from(vec![ - None, // Row 0: b is shredded - Some(a_value_data.as_slice()), // Row 1: fallback for a.b.* - Some(a_value_data.as_slice()), // Row 2: fallback for a.b.* - ]); - - let a_inner_fields = Fields::from(vec![ - Field::new("b", b_field_shredded.data_type().clone(), true) + None, // Row 0: b is shredded + Some(a_value_data.as_slice()), // Row 1: fallback for a.b.* + Some(a_value_data.as_slice()), // Row 2: fallback for a.b.* ]); + + let a_inner_fields = Fields::from(vec![Field::new( + "b", + b_field_shredded.data_type().clone(), + true, + )]); let a_inner_struct = crate::variant_array::StructArrayBuilder::new() - .with_field("typed_value", Arc::new(StructArray::try_new( - a_inner_fields, - vec![Arc::new(b_field_shredded)], - None, - ).unwrap())) + .with_field( + "typed_value", + Arc::new( + StructArray::try_new(a_inner_fields, vec![Arc::new(b_field_shredded)], None) + .unwrap(), + ), + ) .with_field("value", Arc::new(a_value_array)) .build(); let a_field_shredded = ShreddedVariantFieldArray::try_new(Arc::new(a_inner_struct)) .expect("should create ShreddedVariantFieldArray for a"); // Level 0: main typed_value struct containing a field - let typed_value_fields = Fields::from(vec![ - Field::new("a", a_field_shredded.data_type().clone(), true) - ]); - let typed_value_struct = StructArray::try_new( - typed_value_fields, - vec![Arc::new(a_field_shredded)], - None, - ).unwrap(); + let typed_value_fields = Fields::from(vec![Field::new( + "a", + a_field_shredded.data_type().clone(), + true, + )]); + let typed_value_struct = + StructArray::try_new(typed_value_fields, vec![Arc::new(a_field_shredded)], None) + .unwrap(); // Build final VariantArray let struct_array = crate::variant_array::StructArrayBuilder::new() @@ -1395,29 +1686,27 @@ mod test { .with_field("typed_value", Arc::new(typed_value_struct)) .build(); - Arc::new( - VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray"), - ) + Arc::new(VariantArray::try_new(Arc::new(struct_array)).expect("should create VariantArray")) } #[test] fn test_strict_cast_options_downcast_failure() { - use arrow::error::ArrowError; use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field}; - use std::sync::Arc; + use arrow::error::ArrowError; use parquet_variant::VariantPath; - + use std::sync::Arc; + // Use the existing simple test data that has Int32 as typed_value let variant_array = perfectly_shredded_int32_variant_array(); - + // Try to access a field with safe cast options (should return NULLs) let safe_options = GetOptions { path: VariantPath::from("nonexistent_field"), as_type: Some(Arc::new(Field::new("result", DataType::Int32, true))), cast_options: CastOptions::default(), // safe = true }; - + let variant_array_ref: Arc = variant_array.clone(); let result = variant_get(&variant_array_ref, safe_options); // Should succeed and return NULLs (safe behavior) @@ -1427,63 +1716,74 @@ mod test { assert!(result_array.is_null(0)); assert!(result_array.is_null(1)); assert!(result_array.is_null(2)); - + // Try to access a field with strict cast options (should error) let strict_options = GetOptions { - path: VariantPath::from("nonexistent_field"), + path: VariantPath::from("nonexistent_field"), as_type: Some(Arc::new(Field::new("result", DataType::Int32, true))), - cast_options: CastOptions { safe: false, ..Default::default() }, + cast_options: CastOptions { + safe: false, + ..Default::default() + }, }; - + let result = variant_get(&variant_array_ref, strict_options); // Should fail with a cast error assert!(result.is_err()); let error = result.unwrap_err(); assert!(matches!(error, ArrowError::CastError(_))); - assert!(error.to_string().contains("Cannot access field 'nonexistent_field' on non-struct type")); + assert!(error + .to_string() + .contains("Cannot access field 'nonexistent_field' on non-struct type")); } #[test] fn test_null_buffer_union_for_shredded_paths() { use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field}; - use std::sync::Arc; use parquet_variant::VariantPath; - + use std::sync::Arc; + // Test that null buffers are properly unioned when traversing shredded paths // This test verifies scovich's null buffer union requirement - + // Create a depth-1 shredded variant array where: // - The top-level variant array has some nulls // - The nested typed_value also has some nulls // - The result should be the union of both null buffers - + let variant_array = create_depth_1_shredded_test_data_working(); - + // Get the field "x" which should union nulls from: // 1. The top-level variant array nulls - // 2. The "a" field's typed_value nulls + // 2. The "a" field's typed_value nulls // 3. The "x" field's typed_value nulls let options = GetOptions { path: VariantPath::from("a.x"), as_type: Some(Arc::new(Field::new("result", DataType::Int32, true))), cast_options: CastOptions::default(), }; - + let variant_array_ref: Arc = variant_array.clone(); let result = variant_get(&variant_array_ref, options).unwrap(); - + // Verify the result length matches input assert_eq!(result.len(), variant_array.len()); - + // The null pattern should reflect the union of all ancestor nulls // Row 0: Should have valid data (path exists and is shredded as Int32) // Row 1: Should be null (due to type mismatch - "foo" can't cast to Int32) assert!(!result.is_null(0), "Row 0 should have valid Int32 data"); - assert!(result.is_null(1), "Row 1 should be null due to type casting failure"); - + assert!( + result.is_null(1), + "Row 1 should be null due to type casting failure" + ); + // Verify the actual values - let int32_result = result.as_any().downcast_ref::().unwrap(); + let int32_result = result + .as_any() + .downcast_ref::() + .unwrap(); assert_eq!(int32_result.value(0), 55); // The valid Int32 value } @@ -1491,24 +1791,24 @@ mod test { fn test_struct_null_mask_union_from_children() { use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field, Fields}; - use std::sync::Arc; use parquet_variant::VariantPath; + use std::sync::Arc; use arrow::array::StringArray; - + // Test that struct null masks properly union nulls from children field extractions // This verifies scovich's concern about incomplete null masks in struct construction - + // Create test data where some fields will fail type casting let json_strings = vec![ - r#"{"a": 42, "b": "hello"}"#, // Row 0: a=42 (castable to int), b="hello" (not castable to int) - r#"{"a": "world", "b": 100}"#, // Row 1: a="world" (not castable to int), b=100 (castable to int) - r#"{"a": 55, "b": 77}"#, // Row 2: a=55 (castable to int), b=77 (castable to int) + r#"{"a": 42, "b": "hello"}"#, // Row 0: a=42 (castable to int), b="hello" (not castable to int) + r#"{"a": "world", "b": 100}"#, // Row 1: a="world" (not castable to int), b=100 (castable to int) + r#"{"a": 55, "b": 77}"#, // Row 2: a=55 (castable to int), b=77 (castable to int) ]; - + let string_array: Arc = Arc::new(StringArray::from(json_strings)); let variant_array = json_to_variant(&string_array).unwrap(); - + // Request extraction as a struct with both fields as Int32 // This should create child arrays where some fields are null due to casting failures let struct_fields = Fields::from(vec![ @@ -1516,47 +1816,57 @@ mod test { Field::new("b", DataType::Int32, true), ]); let struct_type = DataType::Struct(struct_fields); - + let options = GetOptions { path: VariantPath::default(), // Extract the whole object as struct as_type: Some(Arc::new(Field::new("result", struct_type, true))), cast_options: CastOptions::default(), }; - + let variant_array_ref: Arc = Arc::new(variant_array); let result = variant_get(&variant_array_ref, options).unwrap(); - + // Verify the result is a StructArray - let struct_result = result.as_any().downcast_ref::().unwrap(); + let struct_result = result + .as_any() + .downcast_ref::() + .unwrap(); assert_eq!(struct_result.len(), 3); - + // Get the individual field arrays - let field_a = struct_result.column(0).as_any().downcast_ref::().unwrap(); - let field_b = struct_result.column(1).as_any().downcast_ref::().unwrap(); - + let field_a = struct_result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let field_b = struct_result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + // Verify field values and nulls // Row 0: a=42 (valid), b=null (casting failure) assert!(!field_a.is_null(0)); assert_eq!(field_a.value(0), 42); assert!(field_b.is_null(0)); // "hello" can't cast to int - + // Row 1: a=null (casting failure), b=100 (valid) assert!(field_a.is_null(1)); // "world" can't cast to int assert!(!field_b.is_null(1)); assert_eq!(field_b.value(1), 100); - + // Row 2: a=55 (valid), b=77 (valid) assert!(!field_a.is_null(2)); assert_eq!(field_a.value(2), 55); assert!(!field_b.is_null(2)); assert_eq!(field_b.value(2), 77); - - + // Verify the struct-level null mask properly unions child nulls // The struct should NOT be null in any row because each row has at least one valid field // (This tests that we're not incorrectly making the entire struct null when children fail) assert!(!struct_result.is_null(0)); // Has valid field 'a' - assert!(!struct_result.is_null(1)); // Has valid field 'b' + assert!(!struct_result.is_null(1)); // Has valid field 'b' assert!(!struct_result.is_null(2)); // Has both valid fields } @@ -1564,28 +1874,28 @@ mod test { fn test_field_nullability_preservation() { use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field}; - use std::sync::Arc; use parquet_variant::VariantPath; + use std::sync::Arc; use arrow::array::StringArray; - + // Test that field nullability from GetOptions.as_type is preserved in the result - + let json_strings = vec![ - r#"{"x": 42}"#, // Row 0: Valid int that should convert to Int32 - r#"{"x": "not_a_number"}"#, // Row 1: String that can't cast to Int32 - r#"{"x": null}"#, // Row 2: Explicit null value - r#"{"x": "hello"}"#, // Row 3: Another string (wrong type) - r#"{"y": 100}"#, // Row 4: Missing "x" field (SQL NULL case) - r#"{"x": 127}"#, // Row 5: Small int (could be Int8, widening cast candidate) - r#"{"x": 32767}"#, // Row 6: Medium int (could be Int16, widening cast candidate) - r#"{"x": 2147483647}"#, // Row 7: Max Int32 value (fits in Int32) - r#"{"x": 9223372036854775807}"#, // Row 8: Large Int64 value (cannot convert to Int32) + r#"{"x": 42}"#, // Row 0: Valid int that should convert to Int32 + r#"{"x": "not_a_number"}"#, // Row 1: String that can't cast to Int32 + r#"{"x": null}"#, // Row 2: Explicit null value + r#"{"x": "hello"}"#, // Row 3: Another string (wrong type) + r#"{"y": 100}"#, // Row 4: Missing "x" field (SQL NULL case) + r#"{"x": 127}"#, // Row 5: Small int (could be Int8, widening cast candidate) + r#"{"x": 32767}"#, // Row 6: Medium int (could be Int16, widening cast candidate) + r#"{"x": 2147483647}"#, // Row 7: Max Int32 value (fits in Int32) + r#"{"x": 9223372036854775807}"#, // Row 8: Large Int64 value (cannot convert to Int32) ]; - + let string_array: Arc = Arc::new(StringArray::from(json_strings)); let variant_array = json_to_variant(&string_array).unwrap(); - + // Test 1: nullable field (should allow nulls from cast failures) let nullable_field = Arc::new(Field::new("result", DataType::Int32, true)); let options_nullable = GetOptions { @@ -1593,49 +1903,52 @@ mod test { as_type: Some(nullable_field.clone()), cast_options: CastOptions::default(), }; - + let variant_array_ref: Arc = Arc::new(variant_array); let result_nullable = variant_get(&variant_array_ref, options_nullable).unwrap(); - - // Verify we get an Int32Array with nulls for cast failures - let int32_result = result_nullable.as_any().downcast_ref::().unwrap(); + + // Verify we get an Int32Array with nulls for cast failures + let int32_result = result_nullable + .as_any() + .downcast_ref::() + .unwrap(); assert_eq!(int32_result.len(), 9); - + // Row 0: 42 converts successfully to Int32 - assert!(!int32_result.is_null(0)); + assert!(!int32_result.is_null(0)); assert_eq!(int32_result.value(0), 42); - + // Row 1: "not_a_number" fails to convert -> NULL - assert!(int32_result.is_null(1)); - + assert!(int32_result.is_null(1)); + // Row 2: explicit null value -> NULL - assert!(int32_result.is_null(2)); - + assert!(int32_result.is_null(2)); + // Row 3: "hello" (wrong type) fails to convert -> NULL - assert!(int32_result.is_null(3)); - + assert!(int32_result.is_null(3)); + // Row 4: missing "x" field (SQL NULL case) -> NULL assert!(int32_result.is_null(4)); - - // Row 5: 127 (small int, potential Int8 -> Int32 widening) + + // Row 5: 127 (small int, potential Int8 -> Int32 widening) // Current behavior: JSON parses to Int8, should convert to Int32 assert!(!int32_result.is_null(5)); assert_eq!(int32_result.value(5), 127); - + // Row 6: 32767 (medium int, potential Int16 -> Int32 widening) - // Current behavior: JSON parses to Int16, should convert to Int32 + // Current behavior: JSON parses to Int16, should convert to Int32 assert!(!int32_result.is_null(6)); assert_eq!(int32_result.value(6), 32767); - + // Row 7: 2147483647 (max Int32, fits exactly) // Current behavior: Should convert successfully assert!(!int32_result.is_null(7)); assert_eq!(int32_result.value(7), 2147483647); - + // Row 8: 9223372036854775807 (large Int64, cannot fit in Int32) // Current behavior: Should fail conversion -> NULL assert!(int32_result.is_null(8)); - + // Test 2: non-nullable field (behavior should be the same with safe casting) let non_nullable_field = Arc::new(Field::new("result", DataType::Int32, false)); let options_non_nullable = GetOptions { @@ -1643,27 +1956,30 @@ mod test { as_type: Some(non_nullable_field.clone()), cast_options: CastOptions::default(), // safe=true by default }; - + // Create variant array again since we moved it let variant_array_2 = json_to_variant(&string_array).unwrap(); let variant_array_ref_2: Arc = Arc::new(variant_array_2); let result_non_nullable = variant_get(&variant_array_ref_2, options_non_nullable).unwrap(); - let int32_result_2 = result_non_nullable.as_any().downcast_ref::().unwrap(); - + let int32_result_2 = result_non_nullable + .as_any() + .downcast_ref::() + .unwrap(); + // Even with a non-nullable field, safe casting should still produce nulls for failures assert_eq!(int32_result_2.len(), 9); - + // Row 0: 42 converts successfully to Int32 assert!(!int32_result_2.is_null(0)); assert_eq!(int32_result_2.value(0), 42); - + // Rows 1-4: All should be null due to safe casting behavior // (non-nullable field specification doesn't override safe casting behavior) - assert!(int32_result_2.is_null(1)); // "not_a_number" + assert!(int32_result_2.is_null(1)); // "not_a_number" assert!(int32_result_2.is_null(2)); // explicit null assert!(int32_result_2.is_null(3)); // "hello" assert!(int32_result_2.is_null(4)); // missing field - + // Rows 5-7: These should also convert successfully (numeric widening/fitting) assert!(!int32_result_2.is_null(5)); // 127 (Int8 -> Int32) assert_eq!(int32_result_2.value(5), 127); @@ -1671,10 +1987,8 @@ mod test { assert_eq!(int32_result_2.value(6), 32767); assert!(!int32_result_2.is_null(7)); // 2147483647 (fits in Int32) assert_eq!(int32_result_2.value(7), 2147483647); - + // Row 8: Large Int64 should fail conversion -> NULL assert!(int32_result_2.is_null(8)); // 9223372036854775807 (too large for Int32) } - - } diff --git a/parquet-variant-compute/src/variant_get/output/mod.rs b/parquet-variant-compute/src/variant_get/output/mod.rs index ca0db0670bdb..c3df183ec8b4 100644 --- a/parquet-variant-compute/src/variant_get/output/mod.rs +++ b/parquet-variant-compute/src/variant_get/output/mod.rs @@ -15,4 +15,4 @@ // specific language governing permissions and limitations // under the License. -pub(crate) mod row_builder; \ No newline at end of file +pub(crate) mod row_builder; diff --git a/parquet-variant-compute/src/variant_get/output/primitive.rs b/parquet-variant-compute/src/variant_get/output/primitive.rs index aabc9827a7b7..ff3e58c3c340 100644 --- a/parquet-variant-compute/src/variant_get/output/primitive.rs +++ b/parquet-variant-compute/src/variant_get/output/primitive.rs @@ -24,7 +24,7 @@ use arrow::array::{ NullBufferBuilder, PrimitiveArray, }; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::Int32Type; +use arrow::datatypes::{Int16Type, Int32Type}; use arrow_schema::{ArrowError, FieldRef}; use parquet_variant::{Variant, VariantPath}; use std::marker::PhantomData; @@ -176,3 +176,9 @@ impl ArrowPrimitiveVariant for Int32Type { variant.as_int32() } } + +impl ArrowPrimitiveVariant for Int16Type { + fn from_variant(variant: &Variant) -> Option { + variant.as_int16() + } +} diff --git a/parquet-variant-compute/src/variant_get/output/row_builder.rs b/parquet-variant-compute/src/variant_get/output/row_builder.rs index 7d8b432b3f1f..787bdd610d81 100644 --- a/parquet-variant-compute/src/variant_get/output/row_builder.rs +++ b/parquet-variant-compute/src/variant_get/output/row_builder.rs @@ -16,6 +16,7 @@ // under the License. use arrow::array::ArrayRef; +use arrow::compute::CastOptions; use arrow::datatypes; use arrow::datatypes::ArrowPrimitiveType; use arrow::error::{ArrowError, Result}; @@ -29,17 +30,62 @@ pub(crate) fn make_shredding_row_builder<'a>( //metadata: &BinaryViewArray, path: VariantPath<'a>, data_type: Option<&'a datatypes::DataType>, + cast_options: &'a CastOptions, ) -> Result> { use arrow::array::PrimitiveBuilder; - use datatypes::Int32Type; - + use datatypes::{ + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + }; + // support non-empty paths (field access) and some empty path cases if path.is_empty() { return match data_type { + Some(datatypes::DataType::Int8) => { + let builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, + }; + Ok(Box::new(builder)) + } + Some(datatypes::DataType::Int16) => { + let builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, + }; + Ok(Box::new(builder)) + } Some(datatypes::DataType::Int32) => { - // Return PrimitiveInt32Builder for type conversion let builder = PrimitiveVariantShreddingRowBuilder { builder: PrimitiveBuilder::::new(), + cast_options, + }; + Ok(Box::new(builder)) + } + Some(datatypes::DataType::Int64) => { + let builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, + }; + Ok(Box::new(builder)) + } + Some(datatypes::DataType::Float16) => { + let builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, + }; + Ok(Box::new(builder)) + } + Some(datatypes::DataType::Float32) => { + let builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, + }; + Ok(Box::new(builder)) + } + Some(datatypes::DataType::Float64) => { + let builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, }; Ok(Box::new(builder)) } @@ -48,13 +94,10 @@ pub(crate) fn make_shredding_row_builder<'a>( let builder = VariantArrayShreddingRowBuilder::new(16); Ok(Box::new(builder)) } - _ => { - // only Int32 supported for empty paths - Err(ArrowError::NotYetImplemented(format!( - "variant_get with empty path and data_type={:?} not yet implemented", - data_type - ))) - } + _ => Err(ArrowError::NotYetImplemented(format!( + "variant_get with empty path and data_type={:?} not yet implemented", + data_type + ))), }; } @@ -70,10 +113,52 @@ pub(crate) fn make_shredding_row_builder<'a>( } match data_type { + Some(datatypes::DataType::Int8) => { + let inner_builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, + }; + wrap_with_path!(inner_builder) + } + Some(datatypes::DataType::Int16) => { + let inner_builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, + }; + wrap_with_path!(inner_builder) + } Some(datatypes::DataType::Int32) => { - // Create a primitive builder and wrap it with path functionality let inner_builder = PrimitiveVariantShreddingRowBuilder { builder: PrimitiveBuilder::::new(), + cast_options, + }; + wrap_with_path!(inner_builder) + } + Some(datatypes::DataType::Int64) => { + let inner_builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, + }; + wrap_with_path!(inner_builder) + } + Some(datatypes::DataType::Float16) => { + let inner_builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, + }; + wrap_with_path!(inner_builder) + } + Some(datatypes::DataType::Float32) => { + let inner_builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, + }; + wrap_with_path!(inner_builder) + } + Some(datatypes::DataType::Float64) => { + let inner_builder = PrimitiveVariantShreddingRowBuilder { + builder: PrimitiveBuilder::::new(), + cast_options, }; wrap_with_path!(inner_builder) } @@ -82,13 +167,10 @@ pub(crate) fn make_shredding_row_builder<'a>( let inner_builder = VariantArrayShreddingRowBuilder::new(16); wrap_with_path!(inner_builder) } - _ => { - // only Int32 and VariantArray supported - Err(ArrowError::NotYetImplemented(format!( - "variant_get with path={:?} and data_type={:?} not yet implemented", - path, data_type - ))) - } + _ => Err(ArrowError::NotYetImplemented(format!( + "variant_get with path={:?} and data_type={:?} not yet implemented", + path, data_type + ))), } } @@ -133,23 +215,68 @@ impl VariantShreddingRowBuilder for VariantPathRo trait VariantAsPrimitive { fn as_primitive(&self) -> Option; } + impl VariantAsPrimitive for Variant<'_, '_> { fn as_primitive(&self) -> Option { self.as_int32() } } +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_int16() + } +} +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_int8() + } +} +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_int64() + } +} +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_f16() + } +} +impl VariantAsPrimitive for Variant<'_, '_> { + fn as_primitive(&self) -> Option { + self.as_f32() + } +} impl VariantAsPrimitive for Variant<'_, '_> { fn as_primitive(&self) -> Option { self.as_f64() } } +/// Helper function to get a user-friendly type name +fn get_type_name() -> &'static str { + match std::any::type_name::() { + "arrow_array::types::Int32Type" => "Int32", + "arrow_array::types::Int16Type" => "Int16", + "arrow_array::types::Int8Type" => "Int8", + "arrow_array::types::Int64Type" => "Int64", + "arrow_array::types::UInt32Type" => "UInt32", + "arrow_array::types::UInt16Type" => "UInt16", + "arrow_array::types::UInt8Type" => "UInt8", + "arrow_array::types::UInt64Type" => "UInt64", + "arrow_array::types::Float32Type" => "Float32", + "arrow_array::types::Float64Type" => "Float64", + "arrow_array::types::Float16Type" => "Float16", + _ => "Unknown", + } +} + /// Builder for shredding variant values to primitive values -struct PrimitiveVariantShreddingRowBuilder { +struct PrimitiveVariantShreddingRowBuilder<'a, T: ArrowPrimitiveType> { builder: arrow::array::PrimitiveBuilder, + cast_options: &'a CastOptions<'a>, } -impl VariantShreddingRowBuilder for PrimitiveVariantShreddingRowBuilder +impl<'a, T> VariantShreddingRowBuilder for PrimitiveVariantShreddingRowBuilder<'a, T> where T: ArrowPrimitiveType, for<'m, 'v> Variant<'m, 'v>: VariantAsPrimitive, @@ -164,9 +291,15 @@ where self.builder.append_value(v); Ok(true) } else { - // append null on conversion failure (safe casting behavior) - // This matches the default CastOptions::safe = true behavior - // TODO: In future steps, respect CastOptions for safe vs unsafe casting + if !self.cast_options.safe { + // Unsafe casting: return error on conversion failure + return Err(ArrowError::CastError(format!( + "Failed to extract primitive of type {} from variant {:?} at path VariantPath([])", + get_type_name::(), + value + ))); + } + // Safe casting: append null on conversion failure self.builder.append_null(); Ok(false) } @@ -207,5 +340,3 @@ impl VariantShreddingRowBuilder for VariantArrayShreddingRowBuilder { Ok(Arc::new(builder.build())) } } - - diff --git a/parquet-variant-compute/src/variant_get/output/variant.rs b/parquet-variant-compute/src/variant_get/output/variant.rs index 7c8b4da2f5c1..8a1fe8335fde 100644 --- a/parquet-variant-compute/src/variant_get/output/variant.rs +++ b/parquet-variant-compute/src/variant_get/output/variant.rs @@ -16,13 +16,38 @@ // under the License. use crate::variant_get::output::OutputBuilder; -use crate::{VariantArray, VariantArrayBuilder}; +use crate::{type_conversion::primitive_conversion_array, VariantArray, VariantArrayBuilder}; use arrow::array::{Array, ArrayRef, AsArray, BinaryViewArray}; -use arrow::datatypes::Int32Type; +use arrow::datatypes::{ + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; use arrow_schema::{ArrowError, DataType}; use parquet_variant::{Variant, VariantPath}; use std::sync::Arc; +macro_rules! cast_partially_shredded_primitive { + ($typed_value:expr, $variant_array:expr, $arrow_type:ty) => {{ + let mut array_builder = VariantArrayBuilder::new($variant_array.len()); + let primitive_array = $typed_value.as_primitive::<$arrow_type>(); + for i in 0..$variant_array.len() { + if $variant_array.is_null(i) { + array_builder.append_null(); + } else if $typed_value.is_null(i) { + // fall back to the value (variant) field + // (TODO could copy the variant bytes directly) + let value = $variant_array.value(i); + array_builder.append_variant(value); + } else { + // otherwise we have a typed value, so we can use it directly + let value = primitive_array.value(i); + array_builder.append_variant(Variant::from(value)); + } + } + Ok(Arc::new(array_builder.build())) + }}; +} + /// Outputs VariantArrays pub(super) struct VariantOutputBuilder<'a> { /// What path to extract @@ -44,40 +69,59 @@ impl OutputBuilder for VariantOutputBuilder<'_> { _value_field: &BinaryViewArray, typed_value: &ArrayRef, ) -> arrow::error::Result { - // in this case dispatch on the typed_value and - // TODO macro'ize this using downcast! to handle all other primitive types // TODO(perf): avoid builders entirely (and write the raw variant directly as we know the metadata is the same) - let mut array_builder = VariantArrayBuilder::new(variant_array.len()); match typed_value.data_type() { + DataType::Int8 => { + cast_partially_shredded_primitive!(typed_value, variant_array, Int8Type) + } + + DataType::Int16 => { + cast_partially_shredded_primitive!(typed_value, variant_array, Int16Type) + } + DataType::Int32 => { - let primitive_array = typed_value.as_primitive::(); - for i in 0..variant_array.len() { - if variant_array.is_null(i) { - array_builder.append_null(); - continue; - } - - if typed_value.is_null(i) { - // fall back to the value (variant) field - // (TODO could copy the variant bytes directly) - let value = variant_array.value(i); - array_builder.append_variant(value); - continue; - } - - // otherwise we have a typed value, so we can use it directly - let int_value = primitive_array.value(i); - array_builder.append_variant(Variant::from(int_value)); - } + cast_partially_shredded_primitive!(typed_value, variant_array, Int32Type) + } + + DataType::Int64 => { + cast_partially_shredded_primitive!(typed_value, variant_array, Int64Type) + } + + DataType::UInt8 => { + cast_partially_shredded_primitive!(typed_value, variant_array, UInt8Type) + } + + DataType::UInt16 => { + cast_partially_shredded_primitive!(typed_value, variant_array, UInt16Type) + } + + DataType::UInt32 => { + cast_partially_shredded_primitive!(typed_value, variant_array, UInt32Type) } + + DataType::UInt64 => { + cast_partially_shredded_primitive!(typed_value, variant_array, UInt64Type) + } + + DataType::Float16 => { + cast_partially_shredded_primitive!(typed_value, variant_array, Float16Type) + } + + DataType::Float32 => { + cast_partially_shredded_primitive!(typed_value, variant_array, Float32Type) + } + + DataType::Float64 => { + cast_partially_shredded_primitive!(typed_value, variant_array, Float64Type) + } + dt => { // https://github.com/apache/arrow-rs/issues/8086 - return Err(ArrowError::NotYetImplemented(format!( - "variant_get fully_shredded with typed_value={dt} is not implemented yet", - ))); + Err(ArrowError::NotYetImplemented(format!( + "variant_get partially shredded with typed_value={dt} is not implemented yet", + ))) } - }; - Ok(Arc::new(array_builder.build())) + } } fn typed( @@ -87,30 +131,33 @@ impl OutputBuilder for VariantOutputBuilder<'_> { _metadata: &BinaryViewArray, typed_value: &ArrayRef, ) -> arrow::error::Result { - // in this case dispatch on the typed_value and - // TODO macro'ize this using downcast! to handle all other primitive types // TODO(perf): avoid builders entirely (and write the raw variant directly as we know the metadata is the same) let mut array_builder = VariantArrayBuilder::new(variant_array.len()); match typed_value.data_type() { - DataType::Int32 => { - let primitive_array = typed_value.as_primitive::(); - for i in 0..variant_array.len() { - if primitive_array.is_null(i) { - array_builder.append_null(); - continue; - } - - let int_value = primitive_array.value(i); - array_builder.append_variant(Variant::from(int_value)); - } + DataType::Int8 => primitive_conversion_array!(Int8Type, typed_value, array_builder), + DataType::Int16 => primitive_conversion_array!(Int16Type, typed_value, array_builder), + DataType::Int32 => primitive_conversion_array!(Int32Type, typed_value, array_builder), + DataType::Int64 => primitive_conversion_array!(Int64Type, typed_value, array_builder), + DataType::UInt8 => primitive_conversion_array!(UInt8Type, typed_value, array_builder), + DataType::UInt16 => primitive_conversion_array!(UInt16Type, typed_value, array_builder), + DataType::UInt32 => primitive_conversion_array!(UInt32Type, typed_value, array_builder), + DataType::UInt64 => primitive_conversion_array!(UInt64Type, typed_value, array_builder), + DataType::Float16 => { + primitive_conversion_array!(Float16Type, typed_value, array_builder) + } + DataType::Float32 => { + primitive_conversion_array!(Float32Type, typed_value, array_builder) + } + DataType::Float64 => { + primitive_conversion_array!(Float64Type, typed_value, array_builder) } dt => { // https://github.com/apache/arrow-rs/issues/8087 return Err(ArrowError::NotYetImplemented(format!( - "variant_get fully_shredded with typed_value={dt} is not implemented yet", + "variant_get perfectly shredded with typed_value={dt} is not implemented yet", ))); } - }; + } Ok(Arc::new(array_builder.build())) } diff --git a/parquet-variant-json/src/from_json.rs b/parquet-variant-json/src/from_json.rs index 164d9b5facaf..90b26f7d307b 100644 --- a/parquet-variant-json/src/from_json.rs +++ b/parquet-variant-json/src/from_json.rs @@ -126,7 +126,7 @@ fn append_json(json: &Value, builder: &mut impl VariantBuilderExt) -> Result<(), }; append_json(value, &mut field_builder)?; } - obj_builder.finish()?; + obj_builder.finish(); } }; Ok(()) @@ -489,7 +489,7 @@ mod test { let mut list_builder = variant_builder.new_list(); let mut object_builder_inner = list_builder.new_object(); object_builder_inner.insert("age", Variant::Int8(32)); - object_builder_inner.finish().unwrap(); + object_builder_inner.finish(); list_builder.append_value(Variant::Int16(128)); list_builder.append_value(Variant::BooleanFalse); list_builder.finish(); @@ -553,7 +553,7 @@ mod test { let mut object_builder = variant_builder.new_object(); object_builder.insert("a", Variant::Int8(3)); object_builder.insert("b", Variant::Int8(2)); - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = variant_builder.finish(); let variant = Variant::try_new(&metadata, &value)?; JsonToVariantTest { @@ -577,7 +577,7 @@ mod test { inner_list_builder.append_value(Variant::Double(-3e0)); inner_list_builder.append_value(Variant::Double(1001e-3)); inner_list_builder.finish(); - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = variant_builder.finish(); let variant = Variant::try_new(&metadata, &value)?; JsonToVariantTest { @@ -643,9 +643,9 @@ mod test { } list_builder.finish(); }); - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); }); - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = variant_builder.finish(); let variant = Variant::try_new(&metadata, &value)?; @@ -669,7 +669,7 @@ mod test { let mut object_builder = variant_builder.new_object(); object_builder.insert("a", Variant::Int8(1)); object_builder.insert("爱", Variant::ShortString(ShortString::try_new("अ")?)); - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = variant_builder.finish(); let variant = Variant::try_new(&metadata, &value)?; diff --git a/parquet-variant-json/src/to_json.rs b/parquet-variant-json/src/to_json.rs index b1894a64f837..b9f5364cf5b6 100644 --- a/parquet-variant-json/src/to_json.rs +++ b/parquet-variant-json/src/to_json.rs @@ -966,8 +966,7 @@ mod tests { .with_field("age", 30i32) .with_field("active", true) .with_field("score", 95.5f64) - .finish() - .unwrap(); + .finish(); let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; @@ -997,7 +996,7 @@ mod tests { { let obj = builder.new_object(); - obj.finish().unwrap(); + obj.finish(); } let (metadata, value) = builder.finish(); @@ -1022,8 +1021,7 @@ mod tests { .with_field("message", "Hello \"World\"\nWith\tTabs") .with_field("path", "C:\\Users\\Alice\\Documents") .with_field("unicode", "😀 Smiley") - .finish() - .unwrap(); + .finish(); let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value)?; @@ -1135,7 +1133,7 @@ mod tests { obj.insert("zebra", "last"); obj.insert("alpha", "first"); obj.insert("beta", "second"); - obj.finish().unwrap(); + obj.finish(); } let (metadata, value) = builder.finish(); @@ -1202,7 +1200,7 @@ mod tests { obj.insert("float_field", 2.71f64); obj.insert("null_field", ()); obj.insert("long_field", 999i64); - obj.finish().unwrap(); + obj.finish(); } let (metadata, value) = builder.finish(); diff --git a/parquet-variant/Cargo.toml b/parquet-variant/Cargo.toml index a4d4792e09f5..6e88bff6bd3a 100644 --- a/parquet-variant/Cargo.toml +++ b/parquet-variant/Cargo.toml @@ -33,6 +33,7 @@ rust-version = { workspace = true } [dependencies] arrow-schema = { workspace = true } chrono = { workspace = true } +half = { version = "2.1", default-features = false } indexmap = "2.10.0" uuid = { version = "1.18.0", features = ["v4"]} diff --git a/parquet-variant/benches/variant_builder.rs b/parquet-variant/benches/variant_builder.rs index a42327fe1335..5d00cc054e55 100644 --- a/parquet-variant/benches/variant_builder.rs +++ b/parquet-variant/benches/variant_builder.rs @@ -77,7 +77,7 @@ fn bench_object_field_names_reverse_order(c: &mut Criterion) { object_builder.insert(format!("{}", 1000 - i).as_str(), string_table.next()); } - object_builder.finish().unwrap(); + object_builder.finish(); hint::black_box(variant.finish()); }) }); @@ -113,7 +113,7 @@ fn bench_object_same_schema(c: &mut Criterion) { inner_list_builder.append_value(string_table.next()); inner_list_builder.finish(); - object_builder.finish().unwrap(); + object_builder.finish(); hint::black_box(variant.finish()); } @@ -154,7 +154,7 @@ fn bench_object_list_same_schema(c: &mut Criterion) { list_builder.append_value(string_table.next()); list_builder.finish(); - object_builder.finish().unwrap(); + object_builder.finish(); } list_builder.finish(); @@ -189,7 +189,7 @@ fn bench_object_unknown_schema(c: &mut Criterion) { let key = string_table.next(); inner_object_builder.insert(key, key); } - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); continue; } @@ -202,7 +202,7 @@ fn bench_object_unknown_schema(c: &mut Criterion) { inner_list_builder.finish(); } - object_builder.finish().unwrap(); + object_builder.finish(); hint::black_box(variant.finish()); } }) @@ -241,7 +241,7 @@ fn bench_object_list_unknown_schema(c: &mut Criterion) { let key = string_table.next(); inner_object_builder.insert(key, key); } - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); continue; } @@ -254,7 +254,7 @@ fn bench_object_list_unknown_schema(c: &mut Criterion) { inner_list_builder.finish(); } - object_builder.finish().unwrap(); + object_builder.finish(); } list_builder.finish(); @@ -314,10 +314,10 @@ fn bench_object_partially_same_schema(c: &mut Criterion) { let key = string_table.next(); inner_object_builder.insert(key, key); } - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); } - object_builder.finish().unwrap(); + object_builder.finish(); hint::black_box(variant.finish()); } }) @@ -376,10 +376,10 @@ fn bench_object_list_partially_same_schema(c: &mut Criterion) { let key = string_table.next(); inner_object_builder.insert(key, key); } - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); } - object_builder.finish().unwrap(); + object_builder.finish(); } list_builder.finish(); @@ -408,7 +408,7 @@ fn bench_validation_validated_vs_unvalidated(c: &mut Criterion) { } list.finish(); - obj.finish().unwrap(); + obj.finish(); test_data.push(builder.finish()); } @@ -462,7 +462,7 @@ fn bench_iteration_performance(c: &mut Criterion) { let mut obj = list.new_object(); obj.insert(&format!("field_{i}"), rng.random::()); obj.insert("nested_data", format!("data_{i}").as_str()); - obj.finish().unwrap(); + obj.finish(); } list.finish(); diff --git a/parquet-variant/benches/variant_validation.rs b/parquet-variant/benches/variant_validation.rs index 0ccc10117898..dcf7681a76ed 100644 --- a/parquet-variant/benches/variant_validation.rs +++ b/parquet-variant/benches/variant_validation.rs @@ -40,9 +40,9 @@ fn generate_large_object() -> (Vec, Vec) { } list_builder.finish(); } - inner_object.finish().unwrap(); + inner_object.finish(); } - outer_object.finish().unwrap(); + outer_object.finish(); variant_builder.finish() } @@ -72,9 +72,9 @@ fn generate_complex_object() -> (Vec, Vec) { let key = format!("{}", 1024 - i); inner_object_builder.insert(&key, i); } - inner_object_builder.finish().unwrap(); + inner_object_builder.finish(); - object_builder.finish().unwrap(); + object_builder.finish(); variant_builder.finish() } diff --git a/parquet-variant/src/builder.rs b/parquet-variant/src/builder.rs index aa202460a44e..2fa8d0981c5b 100644 --- a/parquet-variant/src/builder.rs +++ b/parquet-variant/src/builder.rs @@ -24,6 +24,8 @@ use chrono::Timelike; use indexmap::{IndexMap, IndexSet}; use uuid::Uuid; +use std::collections::HashMap; + const BASIC_TYPE_BITS: u8 = 2; const UNIX_EPOCH_DATE: chrono::NaiveDate = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); @@ -86,25 +88,46 @@ fn append_packed_u32(dest: &mut Vec, value: u32, value_size: usize) { /// /// You can reuse an existing `Vec` by using the `from` impl #[derive(Debug, Default)] -struct ValueBuilder(Vec); +pub struct ValueBuilder(Vec); impl ValueBuilder { /// Construct a ValueBuffer that will write to a new underlying `Vec` - fn new() -> Self { + pub fn new() -> Self { Default::default() } } -impl From> for ValueBuilder { - fn from(value: Vec) -> Self { - Self(value) - } -} - -impl From for Vec { - fn from(value_buffer: ValueBuilder) -> Self { - value_buffer.0 - } +/// Macro to generate the match statement for each append_variant, try_append_variant, and +/// append_variant_bytes -- they each have slightly different handling for object and list handling. +macro_rules! variant_append_value { + ($builder:expr, $value:expr, $object_pat:pat => $object_arm:expr, $list_pat:pat => $list_arm:expr) => { + match $value { + Variant::Null => $builder.append_null(), + Variant::BooleanTrue => $builder.append_bool(true), + Variant::BooleanFalse => $builder.append_bool(false), + Variant::Int8(v) => $builder.append_int8(v), + Variant::Int16(v) => $builder.append_int16(v), + Variant::Int32(v) => $builder.append_int32(v), + Variant::Int64(v) => $builder.append_int64(v), + Variant::Date(v) => $builder.append_date(v), + Variant::Time(v) => $builder.append_time_micros(v), + Variant::TimestampMicros(v) => $builder.append_timestamp_micros(v), + Variant::TimestampNtzMicros(v) => $builder.append_timestamp_ntz_micros(v), + Variant::TimestampNanos(v) => $builder.append_timestamp_nanos(v), + Variant::TimestampNtzNanos(v) => $builder.append_timestamp_ntz_nanos(v), + Variant::Decimal4(decimal4) => $builder.append_decimal4(decimal4), + Variant::Decimal8(decimal8) => $builder.append_decimal8(decimal8), + Variant::Decimal16(decimal16) => $builder.append_decimal16(decimal16), + Variant::Float(v) => $builder.append_float(v), + Variant::Double(v) => $builder.append_double(v), + Variant::Binary(v) => $builder.append_binary(v), + Variant::String(s) => $builder.append_string(s), + Variant::ShortString(s) => $builder.append_short_string(s), + Variant::Uuid(v) => $builder.append_uuid(v), + $object_pat => $object_arm, + $list_pat => $list_arm, + } + }; } impl ValueBuilder { @@ -120,8 +143,9 @@ impl ValueBuilder { self.0.push(primitive_header(primitive_type)); } - fn into_inner(self) -> Vec { - self.into() + /// Returns the underlying buffer, consuming self + pub fn into_inner(self) -> Vec { + self.0 } fn inner_mut(&mut self) -> &mut Vec { @@ -258,7 +282,7 @@ impl ValueBuilder { object_builder.insert(field_name, value); } - object_builder.finish().unwrap(); + object_builder.finish(); } fn try_append_object(state: ParentState<'_>, obj: VariantObject) -> Result<(), ArrowError> { @@ -269,7 +293,8 @@ impl ValueBuilder { object_builder.try_insert(field_name, value)?; } - object_builder.finish() + object_builder.finish(); + Ok(()) } fn append_list(state: ParentState<'_>, list: VariantList) { @@ -292,7 +317,8 @@ impl ValueBuilder { Ok(()) } - fn offset(&self) -> usize { + /// Returns the current size of the underlying buffer + pub fn offset(&self) -> usize { self.0.len() } @@ -302,34 +328,14 @@ impl ValueBuilder { /// /// This method will panic if the variant contains duplicate field names in objects /// when validation is enabled. For a fallible version, use [`ValueBuilder::try_append_variant`] - fn append_variant(mut state: ParentState<'_>, variant: Variant<'_, '_>) { + pub fn append_variant(mut state: ParentState<'_>, variant: Variant<'_, '_>) { let builder = state.value_builder(); - match variant { - Variant::Null => builder.append_null(), - Variant::BooleanTrue => builder.append_bool(true), - Variant::BooleanFalse => builder.append_bool(false), - Variant::Int8(v) => builder.append_int8(v), - Variant::Int16(v) => builder.append_int16(v), - Variant::Int32(v) => builder.append_int32(v), - Variant::Int64(v) => builder.append_int64(v), - Variant::Date(v) => builder.append_date(v), - Variant::Time(v) => builder.append_time_micros(v), - Variant::TimestampMicros(v) => builder.append_timestamp_micros(v), - Variant::TimestampNtzMicros(v) => builder.append_timestamp_ntz_micros(v), - Variant::TimestampNanos(v) => builder.append_timestamp_nanos(v), - Variant::TimestampNtzNanos(v) => builder.append_timestamp_ntz_nanos(v), - Variant::Decimal4(decimal4) => builder.append_decimal4(decimal4), - Variant::Decimal8(decimal8) => builder.append_decimal8(decimal8), - Variant::Decimal16(decimal16) => builder.append_decimal16(decimal16), - Variant::Float(v) => builder.append_float(v), - Variant::Double(v) => builder.append_double(v), - Variant::Binary(v) => builder.append_binary(v), - Variant::String(s) => builder.append_string(s), - Variant::ShortString(s) => builder.append_short_string(s), - Variant::Uuid(v) => builder.append_uuid(v), + variant_append_value!( + builder, + variant, Variant::Object(obj) => return Self::append_object(state, obj), - Variant::List(list) => return Self::append_list(state, list), - } + Variant::List(list) => return Self::append_list(state, list) + ); state.finish(); } @@ -342,37 +348,35 @@ impl ValueBuilder { variant: Variant<'_, '_>, ) -> Result<(), ArrowError> { let builder = state.value_builder(); - match variant { - Variant::Null => builder.append_null(), - Variant::BooleanTrue => builder.append_bool(true), - Variant::BooleanFalse => builder.append_bool(false), - Variant::Int8(v) => builder.append_int8(v), - Variant::Int16(v) => builder.append_int16(v), - Variant::Int32(v) => builder.append_int32(v), - Variant::Int64(v) => builder.append_int64(v), - Variant::Date(v) => builder.append_date(v), - Variant::Time(v) => builder.append_time_micros(v), - Variant::TimestampMicros(v) => builder.append_timestamp_micros(v), - Variant::TimestampNtzMicros(v) => builder.append_timestamp_ntz_micros(v), - Variant::TimestampNanos(v) => builder.append_timestamp_nanos(v), - Variant::TimestampNtzNanos(v) => builder.append_timestamp_ntz_nanos(v), - Variant::Decimal4(decimal4) => builder.append_decimal4(decimal4), - Variant::Decimal8(decimal8) => builder.append_decimal8(decimal8), - Variant::Decimal16(decimal16) => builder.append_decimal16(decimal16), - Variant::Float(v) => builder.append_float(v), - Variant::Double(v) => builder.append_double(v), - Variant::Binary(v) => builder.append_binary(v), - Variant::String(s) => builder.append_string(s), - Variant::ShortString(s) => builder.append_short_string(s), - Variant::Uuid(v) => builder.append_uuid(v), + variant_append_value!( + builder, + variant, Variant::Object(obj) => return Self::try_append_object(state, obj), - Variant::List(list) => return Self::try_append_list(state, list), - } - + Variant::List(list) => return Self::try_append_list(state, list) + ); state.finish(); Ok(()) } + /// Appends a variant to the buffer by copying raw bytes when possible. + /// + /// For objects and lists, this directly copies their underlying byte representation instead of + /// performing a logical copy and without touching the metadata builder. For other variant + /// types, this falls back to the standard append behavior. + /// + /// The caller must ensure that the metadata dictionary is already built and correct for + /// any objects or lists being appended. + pub fn append_variant_bytes(mut state: ParentState<'_>, variant: Variant<'_, '_>) { + let builder = state.value_builder(); + variant_append_value!( + builder, + variant, + Variant::Object(obj) => builder.append_slice(obj.value), + Variant::List(list) => builder.append_slice(list.value) + ); + state.finish(); + } + /// Writes out the header byte for a variant object or list, from the starting position /// of the builder, will return the position after this write fn append_header_start_from_buf_pos( @@ -431,13 +435,111 @@ impl ValueBuilder { } } +/// A trait for building variant metadata dictionaries, to be used in conjunction with a +/// [`ValueBuilder`]. The trait provides methods for managing field names and their IDs, as well as +/// rolling back a failed builder operation that might have created new field ids. +pub trait MetadataBuilder: std::fmt::Debug { + /// Attempts to register a field name, returning the corresponding (possibly newly-created) + /// field id on success. Attempting to register the same field name twice will _generally_ + /// produce the same field id both times, but the variant spec does not actually require it. + fn try_upsert_field_name(&mut self, field_name: &str) -> Result; + + /// Retrieves the field name for a given field id, which must be less than + /// [`Self::num_field_names`]. Panics if the field id is out of bounds. + fn field_name(&self, field_id: usize) -> &str; + + /// Returns the number of field names stored in this metadata builder. Any number less than this + /// is a valid field id. The builder can be reverted back to this size later on (discarding any + /// newer/higher field ids) by calling [`Self::truncate_field_names`]. + fn num_field_names(&self) -> usize; + + /// Reverts the field names to a previous size, discarding any newly out of bounds field ids. + fn truncate_field_names(&mut self, new_size: usize); + + /// Finishes the current metadata dictionary, returning the new size of the underlying buffer. + fn finish(&mut self) -> usize; +} + +impl MetadataBuilder for WritableMetadataBuilder { + fn try_upsert_field_name(&mut self, field_name: &str) -> Result { + Ok(self.upsert_field_name(field_name)) + } + fn field_name(&self, field_id: usize) -> &str { + self.field_name(field_id) + } + fn num_field_names(&self) -> usize { + self.num_field_names() + } + fn truncate_field_names(&mut self, new_size: usize) { + self.field_names.truncate(new_size) + } + fn finish(&mut self) -> usize { + self.finish() + } +} + +/// A metadata builder that cannot register new field names, and merely returns the field id +/// associated with a known field name. This is useful for variant unshredding operations, where the +/// metadata column is fixed and -- per variant shredding spec -- already contains all field names +/// from the typed_value column. It is also useful when projecting a subset of fields from a variant +/// object value, since the bytes can be copied across directly without re-encoding their field ids. +/// +/// NOTE: [`Self::finish`] is a no-op. If the intent is to make a copy of the underlying bytes each +/// time `finish` is called, a different trait impl will be needed. +#[derive(Debug)] +pub struct ReadOnlyMetadataBuilder<'m> { + metadata: VariantMetadata<'m>, + // A cache that tracks field names this builder has already seen, because finding the field id + // for a given field name is expensive -- O(n) for a large and unsorted metadata dictionary. + known_field_names: HashMap<&'m str, u32>, +} + +impl<'m> ReadOnlyMetadataBuilder<'m> { + /// Creates a new read-only metadata builder from the given metadata dictionary. + pub fn new(metadata: VariantMetadata<'m>) -> Self { + Self { + metadata, + known_field_names: HashMap::new(), + } + } +} + +impl MetadataBuilder for ReadOnlyMetadataBuilder<'_> { + fn try_upsert_field_name(&mut self, field_name: &str) -> Result { + if let Some(field_id) = self.known_field_names.get(field_name) { + return Ok(*field_id); + } + + let Some((field_id, field_name)) = self.metadata.get_entry(field_name) else { + return Err(ArrowError::InvalidArgumentError(format!( + "Field name '{field_name}' not found in metadata dictionary" + ))); + }; + + self.known_field_names.insert(field_name, field_id); + Ok(field_id) + } + fn field_name(&self, field_id: usize) -> &str { + &self.metadata[field_id] + } + fn num_field_names(&self) -> usize { + self.metadata.len() + } + fn truncate_field_names(&mut self, new_size: usize) { + debug_assert_eq!(self.metadata.len(), new_size); + } + fn finish(&mut self) -> usize { + self.metadata.bytes.len() + } +} + /// Builder for constructing metadata for [`Variant`] values. /// /// This is used internally by the [`VariantBuilder`] to construct the metadata /// /// You can use an existing `Vec` as the metadata buffer by using the `from` impl. #[derive(Default, Debug)] -struct MetadataBuilder { +pub struct WritableMetadataBuilder { // Field names -- field_ids are assigned in insert order field_names: IndexSet, @@ -448,17 +550,7 @@ struct MetadataBuilder { metadata_buffer: Vec, } -/// Create a new MetadataBuilder that will write to the specified metadata buffer -impl From> for MetadataBuilder { - fn from(metadata_buffer: Vec) -> Self { - Self { - metadata_buffer, - ..Default::default() - } - } -} - -impl MetadataBuilder { +impl WritableMetadataBuilder { /// Upsert field name to dictionary, return its ID fn upsert_field_name(&mut self, field_name: &str) -> u32 { let (id, new_entry) = self.field_names.insert_full(field_name.to_string()); @@ -477,6 +569,11 @@ impl MetadataBuilder { id as u32 } + /// The current length of the underlying metadata buffer + pub fn offset(&self) -> usize { + self.metadata_buffer.len() + } + /// Returns the number of field names stored in the metadata builder. /// Note: this method should be the only place to call `self.field_names.len()` /// @@ -498,17 +595,18 @@ impl MetadataBuilder { self.field_names.iter().map(|k| k.len()).sum() } - fn finish(self) -> Vec { + /// Finalizes the metadata dictionary and appends its serialized bytes to the underlying buffer, + /// returning the resulting [`Self::offset`]. The builder state is reset and ready to start + /// building a new metadata dictionary. + pub fn finish(&mut self) -> usize { let nkeys = self.num_field_names(); // Calculate metadata size let total_dict_size: usize = self.metadata_size(); - let Self { - field_names, - is_sorted, - mut metadata_buffer, - } = self; + let metadata_buffer = &mut self.metadata_buffer; + let is_sorted = std::mem::take(&mut self.is_sorted); + let field_names = std::mem::take(&mut self.field_names); // Determine appropriate offset size based on the larger of dict size or total string size let max_offset = std::cmp::max(total_dict_size, nkeys); @@ -524,32 +622,32 @@ impl MetadataBuilder { metadata_buffer.push(0x01 | (is_sorted as u8) << 4 | ((offset_size - 1) << 6)); // Write dictionary size - write_offset(&mut metadata_buffer, nkeys, offset_size); + write_offset(metadata_buffer, nkeys, offset_size); // Write offsets let mut cur_offset = 0; for key in field_names.iter() { - write_offset(&mut metadata_buffer, cur_offset, offset_size); + write_offset(metadata_buffer, cur_offset, offset_size); cur_offset += key.len(); } // Write final offset - write_offset(&mut metadata_buffer, cur_offset, offset_size); + write_offset(metadata_buffer, cur_offset, offset_size); // Write string data for key in field_names { metadata_buffer.extend_from_slice(key.as_bytes()); } - metadata_buffer + metadata_buffer.len() } - /// Return the inner buffer, without finalizing any in progress metadata. - pub(crate) fn into_inner(self) -> Vec { + /// Returns the inner buffer, consuming self without finalizing any in progress metadata. + pub fn into_inner(self) -> Vec { self.metadata_buffer } } -impl> FromIterator for MetadataBuilder { +impl> FromIterator for WritableMetadataBuilder { fn from_iter>(iter: T) -> Self { let mut this = Self::default(); this.extend(iter); @@ -558,7 +656,7 @@ impl> FromIterator for MetadataBuilder { } } -impl> Extend for MetadataBuilder { +impl> Extend for WritableMetadataBuilder { fn extend>(&mut self, iter: T) { let iter = iter.into_iter(); let (min, _) = iter.size_hint(); @@ -585,18 +683,18 @@ impl> Extend for MetadataBuilder { /// treat the variants as a union, so that accessing a `value_builder` or `metadata_builder` is /// branch-free. #[derive(Debug)] -enum ParentState<'a> { +pub enum ParentState<'a> { Variant { value_builder: &'a mut ValueBuilder, saved_value_builder_offset: usize, - metadata_builder: &'a mut MetadataBuilder, + metadata_builder: &'a mut dyn MetadataBuilder, saved_metadata_builder_dict_size: usize, finished: bool, }, List { value_builder: &'a mut ValueBuilder, saved_value_builder_offset: usize, - metadata_builder: &'a mut MetadataBuilder, + metadata_builder: &'a mut dyn MetadataBuilder, saved_metadata_builder_dict_size: usize, offsets: &'a mut Vec, saved_offsets_size: usize, @@ -605,7 +703,7 @@ enum ParentState<'a> { Object { value_builder: &'a mut ValueBuilder, saved_value_builder_offset: usize, - metadata_builder: &'a mut MetadataBuilder, + metadata_builder: &'a mut dyn MetadataBuilder, saved_metadata_builder_dict_size: usize, fields: &'a mut IndexMap, saved_fields_size: usize, @@ -614,9 +712,12 @@ enum ParentState<'a> { } impl<'a> ParentState<'a> { - fn variant( + /// Creates a new instance suitable for a top-level variant builder + /// (e.g. [`VariantBuilder`]). The value and metadata builder state is checkpointed and will + /// roll back on drop, unless [`Self::finish`] is called. + pub fn variant( value_builder: &'a mut ValueBuilder, - metadata_builder: &'a mut MetadataBuilder, + metadata_builder: &'a mut dyn MetadataBuilder, ) -> Self { ParentState::Variant { saved_value_builder_offset: value_builder.offset(), @@ -627,9 +728,12 @@ impl<'a> ParentState<'a> { } } - fn list( + /// Creates a new instance suitable for a [`ListBuilder`]. The value and metadata builder state + /// is checkpointed and will roll back on drop, unless [`Self::finish`] is called. The new + /// element's offset is also captured eagerly and will also roll back if not finished. + pub fn list( value_builder: &'a mut ValueBuilder, - metadata_builder: &'a mut MetadataBuilder, + metadata_builder: &'a mut dyn MetadataBuilder, offsets: &'a mut Vec, saved_parent_value_builder_offset: usize, ) -> Self { @@ -651,9 +755,14 @@ impl<'a> ParentState<'a> { } } - fn try_object( + /// Creates a new instance suitable for an [`ObjectBuilder`]. The value and metadata builder state + /// is checkpointed and will roll back on drop, unless [`Self::finish`] is called. The new + /// field's name and offset are also captured eagerly and will also roll back if not finished. + /// + /// The call fails if the field name is invalid (e.g. because it duplicates an existing field). + pub fn try_object( value_builder: &'a mut ValueBuilder, - metadata_builder: &'a mut MetadataBuilder, + metadata_builder: &'a mut dyn MetadataBuilder, fields: &'a mut IndexMap, saved_parent_value_builder_offset: usize, field_name: &str, @@ -665,7 +774,7 @@ impl<'a> ParentState<'a> { let saved_value_builder_offset = value_builder.offset(); let saved_fields_size = fields.len(); let saved_metadata_builder_dict_size = metadata_builder.num_field_names(); - let field_id = metadata_builder.upsert_field_name(field_name); + let field_id = metadata_builder.try_upsert_field_name(field_name)?; let field_start = saved_value_builder_offset - saved_parent_value_builder_offset; if fields.insert(field_id, field_start).is_some() && validate_unique_fields { return Err(ArrowError::InvalidArgumentError(format!( @@ -688,7 +797,7 @@ impl<'a> ParentState<'a> { self.value_and_metadata_builders().0 } - fn metadata_builder(&mut self) -> &mut MetadataBuilder { + fn metadata_builder(&mut self) -> &mut dyn MetadataBuilder { self.value_and_metadata_builders().1 } @@ -717,8 +826,8 @@ impl<'a> ParentState<'a> { } } - // Mark the insertion as having succeeded. - fn finish(&mut self) { + /// Mark the insertion as having succeeded. Internal state will no longer roll back on drop. + pub fn finish(&mut self) { *self.is_finished() = true } @@ -754,9 +863,7 @@ impl<'a> ParentState<'a> { value_builder .inner_mut() .truncate(*saved_value_builder_offset); - metadata_builder - .field_names - .truncate(*saved_metadata_builder_dict_size); + metadata_builder.truncate_field_names(*saved_metadata_builder_dict_size); } }; @@ -778,7 +885,7 @@ impl<'a> ParentState<'a> { /// Return mutable references to the value and metadata builders that this /// parent state is using. - fn value_and_metadata_builders(&mut self) -> (&mut ValueBuilder, &mut MetadataBuilder) { + pub fn value_and_metadata_builders(&mut self) -> (&mut ValueBuilder, &mut dyn MetadataBuilder) { match self { ParentState::Variant { value_builder, @@ -986,41 +1093,6 @@ impl Drop for ParentState<'_> { /// ); /// /// ``` -/// # Example: Reusing Buffers -/// -/// You can use the [`VariantBuilder`] to write into existing buffers (for -/// example to write multiple variants back to back in the same buffer) -/// -/// ``` -/// // we will write two variants back to back -/// use parquet_variant::{Variant, VariantBuilder}; -/// // Append 12345 -/// let mut builder = VariantBuilder::new(); -/// builder.append_value(12345); -/// let (metadata, value) = builder.finish(); -/// // remember where the first variant ends -/// let (first_meta_offset, first_meta_len) = (0, metadata.len()); -/// let (first_value_offset, first_value_len) = (0, value.len()); -/// -/// // now, append a second variant to the same buffers -/// let mut builder = VariantBuilder::new_with_buffers(metadata, value); -/// builder.append_value("Foo"); -/// let (metadata, value) = builder.finish(); -/// -/// // The variants can be referenced in their appropriate location -/// let variant1 = Variant::new( -/// &metadata[first_meta_offset..first_meta_len], -/// &value[first_value_offset..first_value_len] -/// ); -/// assert_eq!(variant1, Variant::Int32(12345)); -/// -/// let variant2 = Variant::new( -/// &metadata[first_meta_len..], -/// &value[first_value_len..] -/// ); -/// assert_eq!(variant2, Variant::from("Foo")); -/// ``` -/// /// # Example: Unique Field Validation /// /// This example shows how enabling unique field validation will cause an error @@ -1053,7 +1125,7 @@ impl Drop for ParentState<'_> { /// obj.insert("name", "Alice"); /// obj.insert("age", 30); /// obj.insert("score", 95.5); -/// obj.finish().unwrap(); +/// obj.finish(); /// /// let (metadata, value) = builder.finish(); /// let variant = Variant::try_new(&metadata, &value).unwrap(); @@ -1071,7 +1143,7 @@ impl Drop for ParentState<'_> { /// obj.insert("name", "Bob"); // field id = 3 /// obj.insert("age", 25); /// obj.insert("score", 88.0); -/// obj.finish().unwrap(); +/// obj.finish(); /// /// let (metadata, value) = builder.finish(); /// let variant = Variant::try_new(&metadata, &value).unwrap(); @@ -1079,7 +1151,7 @@ impl Drop for ParentState<'_> { #[derive(Default, Debug)] pub struct VariantBuilder { value_builder: ValueBuilder, - metadata_builder: MetadataBuilder, + metadata_builder: WritableMetadataBuilder, validate_unique_fields: bool, } @@ -1088,7 +1160,7 @@ impl VariantBuilder { pub fn new() -> Self { Self { value_builder: ValueBuilder::new(), - metadata_builder: MetadataBuilder::default(), + metadata_builder: WritableMetadataBuilder::default(), validate_unique_fields: false, } } @@ -1100,16 +1172,6 @@ impl VariantBuilder { self } - /// Create a new VariantBuilder that will write the metadata and values to - /// the specified buffers. - pub fn new_with_buffers(metadata_buffer: Vec, value_buffer: Vec) -> Self { - Self { - value_builder: ValueBuilder::from(value_buffer), - metadata_builder: MetadataBuilder::from(metadata_buffer), - validate_unique_fields: false, - } - } - /// Enables validation of unique field keys in nested objects. /// /// This setting is propagated to all [`ObjectBuilder`]s created through this [`VariantBuilder`] @@ -1126,7 +1188,7 @@ impl VariantBuilder { /// You can use this to pre-populate a [`VariantBuilder`] with a sorted dictionary if you /// know the field names beforehand. Sorted dictionaries can accelerate field access when /// reading [`Variant`]s. - pub fn with_field_names<'a>(mut self, field_names: impl Iterator) -> Self { + pub fn with_field_names<'a>(mut self, field_names: impl IntoIterator) -> Self { self.metadata_builder.extend(field_names); self @@ -1214,20 +1276,22 @@ impl VariantBuilder { ValueBuilder::try_append_variant(state, value.into()) } - /// Finish the builder and return the metadata and value buffers. - pub fn finish(self) -> (Vec, Vec) { - ( - self.metadata_builder.finish(), - self.value_builder.into_inner(), - ) + /// Appends a variant value to the builder by copying raw bytes when possible. + /// + /// For objects and lists, this directly copies their underlying byte representation instead of + /// performing a logical copy and without touching the metadata builder. For other variant + /// types, this falls back to the standard append behavior. + /// + /// The caller must ensure that the metadata dictionary entries are already built and correct for + /// any objects or lists being appended. + pub fn append_value_bytes<'m, 'd>(&mut self, value: impl Into>) { + let state = ParentState::variant(&mut self.value_builder, &mut self.metadata_builder); + ValueBuilder::append_variant_bytes(state, value.into()); } - /// Return the inner metadata buffers and value buffer. - /// - /// This can be used to get the underlying buffers provided via - /// [`VariantBuilder::new_with_buffers`] without finalizing the metadata or - /// values (for rolling back changes). - pub fn into_buffers(self) -> (Vec, Vec) { + /// Finish the builder and return the metadata and value buffers. + pub fn finish(mut self) -> (Vec, Vec) { + self.metadata_builder.finish(); ( self.metadata_builder.into_inner(), self.value_builder.into_inner(), @@ -1246,7 +1310,8 @@ pub struct ListBuilder<'a> { } impl<'a> ListBuilder<'a> { - fn new(parent_state: ParentState<'a>, validate_unique_fields: bool) -> Self { + /// Creates a new list builder, nested on top of the given parent state. + pub fn new(parent_state: ParentState<'a>, validate_unique_fields: bool) -> Self { Self { parent_state, offsets: vec![], @@ -1312,6 +1377,19 @@ impl<'a> ListBuilder<'a> { ValueBuilder::try_append_variant(state, value.into()) } + /// Appends a variant value to this list by copying raw bytes when possible. + /// + /// For objects and lists, this directly copies their underlying byte representation instead of + /// performing a logical copy. For other variant types, this falls back to the standard append + /// behavior. + /// + /// The caller must ensure that the metadata dictionary is already built and correct for + /// any objects or lists being appended. + pub fn append_value_bytes<'m, 'd>(&mut self, value: impl Into>) { + let (state, _) = self.parent_state(); + ValueBuilder::append_variant_bytes(state, value.into()) + } + /// Builder-style API for appending a value to the list and returning self to enable method chaining. /// /// # Panics @@ -1388,7 +1466,8 @@ pub struct ObjectBuilder<'a> { } impl<'a> ObjectBuilder<'a> { - fn new(parent_state: ParentState<'a>, validate_unique_fields: bool) -> Self { + /// Creates a new object builder, nested on top of the given parent state. + pub fn new(parent_state: ParentState<'a>, validate_unique_fields: bool) -> Self { Self { parent_state, fields: IndexMap::new(), @@ -1417,7 +1496,8 @@ impl<'a> ObjectBuilder<'a> { /// - [`ObjectBuilder::insert`] for an infallible version that panics /// - [`ObjectBuilder::try_with_field`] for a builder-style API. /// - /// # Note Attempting to insert a duplicate field name produces an error if unique field + /// # Note + /// Attempting to insert a duplicate field name produces an error if unique field /// validation is enabled. Otherwise, the new value overwrites the previous field mapping /// without erasing the old value, resulting in a larger variant pub fn try_insert<'m, 'd, T: Into>>( @@ -1429,6 +1509,45 @@ impl<'a> ObjectBuilder<'a> { ValueBuilder::try_append_variant(state, value.into()) } + /// Add a field with key and value to the object by copying raw bytes when possible. + /// + /// For objects and lists, this directly copies their underlying byte representation instead of + /// performing a logical copy, and without touching the metadata builder. For other variant + /// types, this falls back to the standard append behavior. + /// + /// The caller must ensure that the metadata dictionary is already built and correct for + /// any objects or lists being appended, but the value's new field name is handled normally. + /// + /// # Panics + /// + /// This method will panic if the variant contains duplicate field names in objects + /// when validation is enabled. For a fallible version, use [`ObjectBuilder::try_insert_bytes`] + pub fn insert_bytes<'m, 'd>(&mut self, key: &str, value: impl Into>) { + self.try_insert_bytes(key, value).unwrap() + } + + /// Add a field with key and value to the object by copying raw bytes when possible. + /// + /// For objects and lists, this directly copies their underlying byte representation instead of + /// performing a logical copy, and without touching the metadata builder. For other variant + /// types, this falls back to the standard append behavior. + /// + /// The caller must ensure that the metadata dictionary is already built and correct for + /// any objects or lists being appended, but the value's new field name is handled normally. + /// + /// # Note + /// When inserting duplicate keys, the new value overwrites the previous mapping, + /// but the old value remains in the buffer, resulting in a larger variant + pub fn try_insert_bytes<'m, 'd>( + &mut self, + key: &str, + value: impl Into>, + ) -> Result<(), ArrowError> { + let (state, _) = self.parent_state(key)?; + ValueBuilder::append_variant_bytes(state, value.into()); + Ok(()) + } + /// Builder style API for adding a field with key and value to the object /// /// Same as [`ObjectBuilder::insert`], but returns `self` for chaining. @@ -1516,7 +1635,7 @@ impl<'a> ObjectBuilder<'a> { } /// Finalizes this object and appends it to its parent, which otherwise remains unmodified. - pub fn finish(mut self) -> Result<(), ArrowError> { + pub fn finish(mut self) { let metadata_builder = self.parent_state.metadata_builder(); self.fields.sort_by(|&field_a_id, _, &field_b_id, _| { @@ -1579,8 +1698,6 @@ impl<'a> ObjectBuilder<'a> { offset_size, ); self.parent_state.finish(); - - Ok(()) } } @@ -1589,18 +1706,27 @@ impl<'a> ObjectBuilder<'a> { /// Allows users to append values to a [`VariantBuilder`], [`ListBuilder`] or /// [`ObjectBuilder`]. using the same interface. pub trait VariantBuilderExt { + /// Appends a new variant value to this builder. See e.g. [`VariantBuilder::append_value`]. fn append_value<'m, 'v>(&mut self, value: impl Into>); + /// Creates a nested list builder. See e.g. [`VariantBuilder::new_list`]. Panics if the nested + /// builder cannot be created, see e.g. [`ObjectBuilder::new_list`]. fn new_list(&mut self) -> ListBuilder<'_> { self.try_new_list().unwrap() } + /// Creates a nested object builder. See e.g. [`VariantBuilder::new_object`]. Panics if the + /// nested builder cannot be created, see e.g. [`ObjectBuilder::new_object`]. fn new_object(&mut self) -> ObjectBuilder<'_> { self.try_new_object().unwrap() } + /// Creates a nested list builder. See e.g. [`VariantBuilder::new_list`]. Returns an error if + /// the nested builder cannot be created, see e.g. [`ObjectBuilder::try_new_list`]. fn try_new_list(&mut self) -> Result, ArrowError>; + /// Creates a nested object builder. See e.g. [`VariantBuilder::new_object`]. Returns an error + /// if the nested builder cannot be created, see e.g. [`ObjectBuilder::try_new_object`]. fn try_new_object(&mut self) -> Result, ArrowError>; } @@ -1779,8 +1905,7 @@ mod tests { .new_object() .with_field("name", "John") .with_field("age", 42i8) - .finish() - .unwrap(); + .finish(); let (metadata, value) = builder.finish(); assert!(!metadata.is_empty()); @@ -1796,8 +1921,7 @@ mod tests { .with_field("zebra", "stripes") .with_field("apple", "red") .with_field("banana", "yellow") - .finish() - .unwrap(); + .finish(); let (_, value) = builder.finish(); @@ -1821,8 +1945,7 @@ mod tests { .new_object() .with_field("name", "Ron Artest") .with_field("name", "Metta World Peace") // Duplicate field - .finish() - .unwrap(); + .finish(); let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value).unwrap(); @@ -1941,15 +2064,13 @@ mod tests { .new_object() .with_field("id", 1) .with_field("type", "Cauliflower") - .finish() - .unwrap(); + .finish(); list_builder .new_object() .with_field("id", 2) .with_field("type", "Beets") - .finish() - .unwrap(); + .finish(); list_builder.finish(); @@ -1986,17 +2107,9 @@ mod tests { let mut list_builder = builder.new_list(); - list_builder - .new_object() - .with_field("a", 1) - .finish() - .unwrap(); + list_builder.new_object().with_field("a", 1).finish(); - list_builder - .new_object() - .with_field("b", 2) - .finish() - .unwrap(); + list_builder.new_object().with_field("b", 2).finish(); list_builder.finish(); @@ -2042,7 +2155,7 @@ mod tests { { let mut object_builder = list_builder.new_object(); object_builder.insert("a", 1); - let _ = object_builder.finish(); + object_builder.finish(); } list_builder.append_value(2); @@ -2050,7 +2163,7 @@ mod tests { { let mut object_builder = list_builder.new_object(); object_builder.insert("b", 2); - let _ = object_builder.finish(); + object_builder.finish(); } list_builder.append_value(3); @@ -2100,10 +2213,10 @@ mod tests { { let mut inner_object_builder = outer_object_builder.new_object("c"); inner_object_builder.insert("b", "a"); - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } - let _ = outer_object_builder.finish(); + outer_object_builder.finish(); } let (metadata, value) = builder.finish(); @@ -2142,11 +2255,11 @@ mod tests { inner_object_builder.insert("b", false); inner_object_builder.insert("c", "a"); - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } outer_object_builder.insert("b", false); - let _ = outer_object_builder.finish(); + outer_object_builder.finish(); } let (metadata, value) = builder.finish(); @@ -2190,10 +2303,10 @@ mod tests { .with_value(false) .finish(); - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } - let _ = outer_object_builder.finish(); + outer_object_builder.finish(); } let (metadata, value) = builder.finish(); @@ -2253,15 +2366,15 @@ mod tests { { let mut inner_inner_object_builder = inner_object_builder.new_object("c"); inner_inner_object_builder.insert("aa", "bb"); - let _ = inner_inner_object_builder.finish(); + inner_inner_object_builder.finish(); } { let mut inner_inner_object_builder = inner_object_builder.new_object("d"); inner_inner_object_builder.insert("cc", "dd"); - let _ = inner_inner_object_builder.finish(); + inner_inner_object_builder.finish(); } - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } outer_object_builder.insert("b", true); @@ -2285,10 +2398,10 @@ mod tests { inner_list_builder.finish(); } - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } - let _ = outer_object_builder.finish(); + outer_object_builder.finish(); } let (metadata, value) = builder.finish(); @@ -2388,7 +2501,7 @@ mod tests { let mut inner_object_builder = inner_list_builder.new_object(); inner_object_builder.insert("a", "b"); inner_object_builder.insert("b", "c"); - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } { @@ -2397,7 +2510,7 @@ mod tests { let mut inner_object_builder = inner_list_builder.new_object(); inner_object_builder.insert("c", "d"); inner_object_builder.insert("d", "e"); - let _ = inner_object_builder.finish(); + inner_object_builder.finish(); } inner_list_builder.finish(); @@ -2483,7 +2596,7 @@ mod tests { let mut obj = builder.new_object(); obj.insert("a", 1); obj.insert("a", 2); - assert!(obj.finish().is_ok()); + obj.finish(); // Deeply nested list structure with duplicates let mut builder = VariantBuilder::new(); @@ -2493,12 +2606,8 @@ mod tests { nested_obj.insert("x", 1); nested_obj.insert("x", 2); nested_obj.new_list("x").with_value(3).finish(); - nested_obj - .new_object("x") - .with_field("y", 4) - .finish() - .unwrap(); - assert!(nested_obj.finish().is_ok()); + nested_obj.new_object("x").with_field("y", 4).finish(); + nested_obj.finish(); inner_list.finish(); outer_list.finish(); @@ -2558,14 +2667,14 @@ mod tests { valid_obj.insert("m", 1); valid_obj.insert("n", 2); - let valid_result = valid_obj.finish(); - assert!(valid_result.is_ok()); + valid_obj.finish(); + list.finish(); } #[test] fn test_sorted_dictionary() { // check if variant metadatabuilders are equivalent from different ways of constructing them - let mut variant1 = VariantBuilder::new().with_field_names(["b", "c", "d"].into_iter()); + let mut variant1 = VariantBuilder::new().with_field_names(["b", "c", "d"]); let mut variant2 = { let mut builder = VariantBuilder::new(); @@ -2615,7 +2724,7 @@ mod tests { #[test] fn test_object_sorted_dictionary() { // predefine the list of field names - let mut variant1 = VariantBuilder::new().with_field_names(["a", "b", "c"].into_iter()); + let mut variant1 = VariantBuilder::new().with_field_names(["a", "b", "c"]); let mut obj = variant1.new_object(); obj.insert("c", true); @@ -2628,7 +2737,7 @@ mod tests { // add a field name that wasn't pre-defined but doesn't break the sort order obj.insert("d", 2); - obj.finish().unwrap(); + obj.finish(); let (metadata, value) = variant1.finish(); let variant = Variant::try_new(&metadata, &value).unwrap(); @@ -2649,7 +2758,7 @@ mod tests { #[test] fn test_object_not_sorted_dictionary() { // predefine the list of field names - let mut variant1 = VariantBuilder::new().with_field_names(["b", "c", "d"].into_iter()); + let mut variant1 = VariantBuilder::new().with_field_names(["b", "c", "d"]); let mut obj = variant1.new_object(); obj.insert("c", true); @@ -2662,7 +2771,7 @@ mod tests { // add a field name that wasn't pre-defined but breaks the sort order obj.insert("a", 2); - obj.finish().unwrap(); + obj.finish(); let (metadata, value) = variant1.finish(); let variant = Variant::try_new(&metadata, &value).unwrap(); @@ -2691,40 +2800,40 @@ mod tests { assert!(builder.metadata_builder.is_sorted); assert_eq!(builder.metadata_builder.num_field_names(), 1); - let builder = builder.with_field_names(["b", "c", "d"].into_iter()); + let builder = builder.with_field_names(["b", "c", "d"]); assert!(builder.metadata_builder.is_sorted); assert_eq!(builder.metadata_builder.num_field_names(), 4); - let builder = builder.with_field_names(["z", "y"].into_iter()); + let builder = builder.with_field_names(["z", "y"]); assert!(!builder.metadata_builder.is_sorted); assert_eq!(builder.metadata_builder.num_field_names(), 6); } #[test] fn test_metadata_builder_from_iter() { - let metadata = MetadataBuilder::from_iter(vec!["apple", "banana", "cherry"]); + let metadata = WritableMetadataBuilder::from_iter(vec!["apple", "banana", "cherry"]); assert_eq!(metadata.num_field_names(), 3); assert_eq!(metadata.field_name(0), "apple"); assert_eq!(metadata.field_name(1), "banana"); assert_eq!(metadata.field_name(2), "cherry"); assert!(metadata.is_sorted); - let metadata = MetadataBuilder::from_iter(["zebra", "apple", "banana"]); + let metadata = WritableMetadataBuilder::from_iter(["zebra", "apple", "banana"]); assert_eq!(metadata.num_field_names(), 3); assert_eq!(metadata.field_name(0), "zebra"); assert_eq!(metadata.field_name(1), "apple"); assert_eq!(metadata.field_name(2), "banana"); assert!(!metadata.is_sorted); - let metadata = MetadataBuilder::from_iter(Vec::<&str>::new()); + let metadata = WritableMetadataBuilder::from_iter(Vec::<&str>::new()); assert_eq!(metadata.num_field_names(), 0); assert!(!metadata.is_sorted); } #[test] fn test_metadata_builder_extend() { - let mut metadata = MetadataBuilder::default(); + let mut metadata = WritableMetadataBuilder::default(); assert_eq!(metadata.num_field_names(), 0); assert!(!metadata.is_sorted); @@ -2749,7 +2858,7 @@ mod tests { #[test] fn test_metadata_builder_extend_sort_order() { - let mut metadata = MetadataBuilder::default(); + let mut metadata = WritableMetadataBuilder::default(); metadata.extend(["middle"]); assert!(metadata.is_sorted); @@ -2765,95 +2874,23 @@ mod tests { #[test] fn test_metadata_builder_from_iter_with_string_types() { // &str - let metadata = MetadataBuilder::from_iter(["a", "b", "c"]); + let metadata = WritableMetadataBuilder::from_iter(["a", "b", "c"]); assert_eq!(metadata.num_field_names(), 3); // string - let metadata = - MetadataBuilder::from_iter(vec!["a".to_string(), "b".to_string(), "c".to_string()]); + let metadata = WritableMetadataBuilder::from_iter(vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + ]); assert_eq!(metadata.num_field_names(), 3); // mixed types (anything that implements AsRef) let field_names: Vec> = vec!["a".into(), "b".into(), "c".into()]; - let metadata = MetadataBuilder::from_iter(field_names); + let metadata = WritableMetadataBuilder::from_iter(field_names); assert_eq!(metadata.num_field_names(), 3); } - /// Test reusing buffers with nested objects - #[test] - fn test_with_existing_buffers_nested() { - let mut builder = VariantBuilder::new(); - append_test_list(&mut builder); - let (m1, v1) = builder.finish(); - let variant1 = Variant::new(&m1, &v1); - - let mut builder = VariantBuilder::new(); - append_test_object(&mut builder); - let (m2, v2) = builder.finish(); - let variant2 = Variant::new(&m2, &v2); - - let mut builder = VariantBuilder::new(); - builder.append_value("This is a string"); - let (m3, v3) = builder.finish(); - let variant3 = Variant::new(&m3, &v3); - - // Now, append those three variants to the a new buffer that is reused - let mut builder = VariantBuilder::new(); - append_test_list(&mut builder); - let (metadata, value) = builder.finish(); - let (meta1_offset, meta1_end) = (0, metadata.len()); - let (value1_offset, value1_end) = (0, value.len()); - - // reuse same buffer - let mut builder = VariantBuilder::new_with_buffers(metadata, value); - append_test_object(&mut builder); - let (metadata, value) = builder.finish(); - let (meta2_offset, meta2_end) = (meta1_end, metadata.len()); - let (value2_offset, value2_end) = (value1_end, value.len()); - - // Append a string - let mut builder = VariantBuilder::new_with_buffers(metadata, value); - builder.append_value("This is a string"); - let (metadata, value) = builder.finish(); - let (meta3_offset, meta3_end) = (meta2_end, metadata.len()); - let (value3_offset, value3_end) = (value2_end, value.len()); - - // verify we can read the variants back correctly - let roundtrip1 = Variant::new( - &metadata[meta1_offset..meta1_end], - &value[value1_offset..value1_end], - ); - assert_eq!(roundtrip1, variant1,); - - let roundtrip2 = Variant::new( - &metadata[meta2_offset..meta2_end], - &value[value2_offset..value2_end], - ); - assert_eq!(roundtrip2, variant2,); - - let roundtrip3 = Variant::new( - &metadata[meta3_offset..meta3_end], - &value[value3_offset..value3_end], - ); - assert_eq!(roundtrip3, variant3); - } - - /// append a simple List variant - fn append_test_list(builder: &mut VariantBuilder) { - builder - .new_list() - .with_value(1234) - .with_value("a string value") - .finish(); - } - - /// append an object variant - fn append_test_object(builder: &mut VariantBuilder) { - let mut obj = builder.new_object(); - obj.insert("a", true); - obj.finish().unwrap(); - } - #[test] fn test_variant_builder_to_list_builder_no_finish() { // Create a list builder but never finish it @@ -2978,7 +3015,7 @@ mod tests { // Create a nested object builder and finish it let mut nested_object_builder = list_builder.new_object(); nested_object_builder.insert("name", "unknown"); - nested_object_builder.finish().unwrap(); + nested_object_builder.finish(); // Drop the outer list builder without finishing it drop(list_builder); @@ -3008,7 +3045,7 @@ mod tests { object_builder.insert("second", 2i8); // The parent object should only contain the original fields - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = builder.finish(); let metadata = VariantMetadata::try_new(&metadata).unwrap(); @@ -3062,7 +3099,7 @@ mod tests { object_builder.insert("second", 2i8); // The parent object should only contain the original fields - object_builder.finish().unwrap(); + object_builder.finish(); let (metadata, value) = builder.finish(); let metadata = VariantMetadata::try_new(&metadata).unwrap(); @@ -3086,7 +3123,7 @@ mod tests { // Create a nested object builder and finish it let mut nested_object_builder = object_builder.new_object("nested"); nested_object_builder.insert("name", "unknown"); - nested_object_builder.finish().unwrap(); + nested_object_builder.finish(); // Drop the outer object builder without finishing it drop(object_builder); @@ -3124,7 +3161,7 @@ mod tests { obj.insert("b", true); obj.insert("a", false); - obj.finish().unwrap(); + obj.finish(); builder.finish() } @@ -3153,10 +3190,10 @@ mod tests { { let mut inner_obj = outer_obj.new_object("b"); inner_obj.insert("a", "inner_value"); - inner_obj.finish().unwrap(); + inner_obj.finish(); } - outer_obj.finish().unwrap(); + outer_obj.finish(); } builder.finish() @@ -3234,7 +3271,7 @@ mod tests { } } if i % skip != 0 { - object.finish().unwrap(); + object.finish(); } } if i % skip != 0 { @@ -3242,7 +3279,7 @@ mod tests { } } if i % skip != 0 { - object.finish().unwrap(); + object.finish(); } } list.finish(); @@ -3255,4 +3292,411 @@ mod tests { assert_eq!(format!("{v1:?}"), format!("{v2:?}")); } + + #[test] + fn test_read_only_metadata_builder() { + // First create some metadata with a few field names + let mut default_builder = VariantBuilder::new(); + default_builder.add_field_name("name"); + default_builder.add_field_name("age"); + default_builder.add_field_name("active"); + let (metadata_bytes, _) = default_builder.finish(); + + // Use the metadata to build new variant values + let metadata = VariantMetadata::try_new(&metadata_bytes).unwrap(); + let mut metadata_builder = ReadOnlyMetadataBuilder::new(metadata); + let mut value_builder = ValueBuilder::new(); + + { + let state = ParentState::variant(&mut value_builder, &mut metadata_builder); + let mut obj = ObjectBuilder::new(state, false); + + // These should succeed because the fields exist in the metadata + obj.insert("name", "Alice"); + obj.insert("age", 30i8); + obj.insert("active", true); + obj.finish(); + } + + let value = value_builder.into_inner(); + + // Verify the variant was built correctly + let variant = Variant::try_new(&metadata_bytes, &value).unwrap(); + let obj = variant.as_object().unwrap(); + assert_eq!(obj.get("name"), Some(Variant::from("Alice"))); + assert_eq!(obj.get("age"), Some(Variant::Int8(30))); + assert_eq!(obj.get("active"), Some(Variant::from(true))); + } + + #[test] + fn test_read_only_metadata_builder_fails_on_unknown_field() { + // Create metadata with only one field + let mut default_builder = VariantBuilder::new(); + default_builder.add_field_name("known_field"); + let (metadata_bytes, _) = default_builder.finish(); + + // Use the metadata to build new variant values + let metadata = VariantMetadata::try_new(&metadata_bytes).unwrap(); + let mut metadata_builder = ReadOnlyMetadataBuilder::new(metadata); + let mut value_builder = ValueBuilder::new(); + + { + let state = ParentState::variant(&mut value_builder, &mut metadata_builder); + let mut obj = ObjectBuilder::new(state, false); + + // This should succeed + obj.insert("known_field", "value"); + + // This should fail because "unknown_field" is not in the metadata + let result = obj.try_insert("unknown_field", "value"); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Field name 'unknown_field' not found")); + } + } + + #[test] + fn test_append_variant_bytes_round_trip() { + // Create a complex variant with the normal builder + let mut builder = VariantBuilder::new(); + { + let mut obj = builder.new_object(); + obj.insert("name", "Alice"); + obj.insert("age", 30i32); + { + let mut scores_list = obj.new_list("scores"); + scores_list.append_value(95i32); + scores_list.append_value(87i32); + scores_list.append_value(92i32); + scores_list.finish(); + } + { + let mut address = obj.new_object("address"); + address.insert("street", "123 Main St"); + address.insert("city", "Anytown"); + address.finish(); + } + obj.finish(); + } + let (metadata, value1) = builder.finish(); + let variant1 = Variant::try_new(&metadata, &value1).unwrap(); + + // Copy using the new bytes API + let metadata = VariantMetadata::new(&metadata); + let mut metadata = ReadOnlyMetadataBuilder::new(metadata); + let mut builder2 = ValueBuilder::new(); + let state = ParentState::variant(&mut builder2, &mut metadata); + ValueBuilder::append_variant_bytes(state, variant1.clone()); + let value2 = builder2.into_inner(); + + // The bytes should be identical, we merely copied them across. + assert_eq!(value1, value2); + } + + #[test] + fn test_object_insert_bytes_subset() { + // Create an original object, making sure to inject the field names we'll add later. + let mut builder = VariantBuilder::new().with_field_names(["new_field", "another_field"]); + { + let mut obj = builder.new_object(); + obj.insert("field1", "value1"); + obj.insert("field2", 42i32); + obj.insert("field3", true); + obj.insert("field4", "value4"); + obj.finish(); + } + let (metadata1, value1) = builder.finish(); + let original_variant = Variant::try_new(&metadata1, &value1).unwrap(); + let original_obj = original_variant.as_object().unwrap(); + + // Create a new object copying subset of fields interleaved with new ones + let metadata2 = VariantMetadata::new(&metadata1); + let mut metadata2 = ReadOnlyMetadataBuilder::new(metadata2); + let mut builder2 = ValueBuilder::new(); + let state = ParentState::variant(&mut builder2, &mut metadata2); + { + let mut obj = ObjectBuilder::new(state, true); + + // Copy field1 using bytes API + obj.insert_bytes("field1", original_obj.get("field1").unwrap()); + + // Add new field + obj.insert("new_field", "new_value"); + + // Copy field3 using bytes API + obj.insert_bytes("field3", original_obj.get("field3").unwrap()); + + // Add another new field + obj.insert("another_field", 99i32); + + // Copy field2 using bytes API + obj.insert_bytes("field2", original_obj.get("field2").unwrap()); + + obj.finish(); + } + let value2 = builder2.into_inner(); + let result_variant = Variant::try_new(&metadata1, &value2).unwrap(); + let result_obj = result_variant.as_object().unwrap(); + + // Verify the object contains expected fields + assert_eq!(result_obj.len(), 5); + assert_eq!( + result_obj.get("field1").unwrap().as_string().unwrap(), + "value1" + ); + assert_eq!(result_obj.get("field2").unwrap().as_int32().unwrap(), 42); + assert!(result_obj.get("field3").unwrap().as_boolean().unwrap()); + assert_eq!( + result_obj.get("new_field").unwrap().as_string().unwrap(), + "new_value" + ); + assert_eq!( + result_obj.get("another_field").unwrap().as_int32().unwrap(), + 99 + ); + } + + #[test] + fn test_list_append_bytes_subset() { + // Create an original list + let mut builder = VariantBuilder::new(); + { + let mut list = builder.new_list(); + list.append_value("item1"); + list.append_value(42i32); + list.append_value(true); + list.append_value("item4"); + list.append_value(1.234f64); + list.finish(); + } + let (metadata1, value1) = builder.finish(); + let original_variant = Variant::try_new(&metadata1, &value1).unwrap(); + let original_list = original_variant.as_list().unwrap(); + + // Create a new list copying subset of elements interleaved with new ones + let metadata2 = VariantMetadata::new(&metadata1); + let mut metadata2 = ReadOnlyMetadataBuilder::new(metadata2); + let mut builder2 = ValueBuilder::new(); + let state = ParentState::variant(&mut builder2, &mut metadata2); + { + let mut list = ListBuilder::new(state, true); + + // Copy first element using bytes API + list.append_value_bytes(original_list.get(0).unwrap()); + + // Add new element + list.append_value("new_item"); + + // Copy third element using bytes API + list.append_value_bytes(original_list.get(2).unwrap()); + + // Add another new element + list.append_value(99i32); + + // Copy last element using bytes API + list.append_value_bytes(original_list.get(4).unwrap()); + + list.finish(); + } + let value2 = builder2.into_inner(); + let result_variant = Variant::try_new(&metadata1, &value2).unwrap(); + let result_list = result_variant.as_list().unwrap(); + + // Verify the list contains expected elements + assert_eq!(result_list.len(), 5); + assert_eq!(result_list.get(0).unwrap().as_string().unwrap(), "item1"); + assert_eq!(result_list.get(1).unwrap().as_string().unwrap(), "new_item"); + assert!(result_list.get(2).unwrap().as_boolean().unwrap()); + assert_eq!(result_list.get(3).unwrap().as_int32().unwrap(), 99); + assert_eq!(result_list.get(4).unwrap().as_f64().unwrap(), 1.234); + } + + #[test] + fn test_complex_nested_filtering_injection() { + // Create a complex nested structure: object -> list -> objects. Make sure to pre-register + // the extra field names we'll need later while manipulating variant bytes. + let mut builder = VariantBuilder::new().with_field_names([ + "active_count", + "active_users", + "computed_score", + "processed_at", + "status", + ]); + + { + let mut root_obj = builder.new_object(); + root_obj.insert("metadata", "original"); + + { + let mut users_list = root_obj.new_list("users"); + + // User 1 + { + let mut user1 = users_list.new_object(); + user1.insert("id", 1i32); + user1.insert("name", "Alice"); + user1.insert("active", true); + user1.finish(); + } + + // User 2 + { + let mut user2 = users_list.new_object(); + user2.insert("id", 2i32); + user2.insert("name", "Bob"); + user2.insert("active", false); + user2.finish(); + } + + // User 3 + { + let mut user3 = users_list.new_object(); + user3.insert("id", 3i32); + user3.insert("name", "Charlie"); + user3.insert("active", true); + user3.finish(); + } + + users_list.finish(); + } + + root_obj.insert("total_count", 3i32); + root_obj.finish(); + } + let (metadata1, value1) = builder.finish(); + let original_variant = Variant::try_new(&metadata1, &value1).unwrap(); + let original_obj = original_variant.as_object().unwrap(); + let original_users = original_obj.get("users").unwrap(); + let original_users = original_users.as_list().unwrap(); + + // Create filtered/modified version: only copy active users and inject new data + let metadata2 = VariantMetadata::new(&metadata1); + let mut metadata2 = ReadOnlyMetadataBuilder::new(metadata2); + let mut builder2 = ValueBuilder::new(); + let state = ParentState::variant(&mut builder2, &mut metadata2); + { + let mut root_obj = ObjectBuilder::new(state, true); + + // Copy metadata using bytes API + root_obj.insert_bytes("metadata", original_obj.get("metadata").unwrap()); + + // Add processing timestamp + root_obj.insert("processed_at", "2024-01-01T00:00:00Z"); + + { + let mut filtered_users = root_obj.new_list("active_users"); + + // Copy only active users and inject additional data + for i in 0..original_users.len() { + let user = original_users.get(i).unwrap(); + let user = user.as_object().unwrap(); + if user.get("active").unwrap().as_boolean().unwrap() { + { + let mut new_user = filtered_users.new_object(); + + // Copy existing fields using bytes API + new_user.insert_bytes("id", user.get("id").unwrap()); + new_user.insert_bytes("name", user.get("name").unwrap()); + + // Inject new computed field + let user_id = user.get("id").unwrap().as_int32().unwrap(); + new_user.insert("computed_score", user_id * 10); + + // Add status transformation (don't copy the 'active' field) + new_user.insert("status", "verified"); + + new_user.finish(); + } + } + } + + // Inject a completely new user + { + let mut new_user = filtered_users.new_object(); + new_user.insert("id", 999i32); + new_user.insert("name", "System User"); + new_user.insert("computed_score", 0i32); + new_user.insert("status", "system"); + new_user.finish(); + } + + filtered_users.finish(); + } + + // Update count + root_obj.insert("active_count", 3i32); // 2 active + 1 new + + root_obj.finish(); + } + let value2 = builder2.into_inner(); + let result_variant = Variant::try_new(&metadata1, &value2).unwrap(); + let result_obj = result_variant.as_object().unwrap(); + + // Verify the filtered/modified structure + assert_eq!( + result_obj.get("metadata").unwrap().as_string().unwrap(), + "original" + ); + assert_eq!( + result_obj.get("processed_at").unwrap().as_string().unwrap(), + "2024-01-01T00:00:00Z" + ); + assert_eq!( + result_obj.get("active_count").unwrap().as_int32().unwrap(), + 3 + ); + + let active_users = result_obj.get("active_users").unwrap(); + let active_users = active_users.as_list().unwrap(); + assert_eq!(active_users.len(), 3); + + // Verify Alice (id=1, was active) + let alice = active_users.get(0).unwrap(); + let alice = alice.as_object().unwrap(); + assert_eq!(alice.get("id").unwrap().as_int32().unwrap(), 1); + assert_eq!(alice.get("name").unwrap().as_string().unwrap(), "Alice"); + assert_eq!(alice.get("computed_score").unwrap().as_int32().unwrap(), 10); + assert_eq!( + alice.get("status").unwrap().as_string().unwrap(), + "verified" + ); + assert!(alice.get("active").is_none()); // This field was not copied + + // Verify Charlie (id=3, was active) - Bob (id=2) was not active so not included + let charlie = active_users.get(1).unwrap(); + let charlie = charlie.as_object().unwrap(); + assert_eq!(charlie.get("id").unwrap().as_int32().unwrap(), 3); + assert_eq!(charlie.get("name").unwrap().as_string().unwrap(), "Charlie"); + assert_eq!( + charlie.get("computed_score").unwrap().as_int32().unwrap(), + 30 + ); + assert_eq!( + charlie.get("status").unwrap().as_string().unwrap(), + "verified" + ); + + // Verify injected system user + let system_user = active_users.get(2).unwrap(); + let system_user = system_user.as_object().unwrap(); + assert_eq!(system_user.get("id").unwrap().as_int32().unwrap(), 999); + assert_eq!( + system_user.get("name").unwrap().as_string().unwrap(), + "System User" + ); + assert_eq!( + system_user + .get("computed_score") + .unwrap() + .as_int32() + .unwrap(), + 0 + ); + assert_eq!( + system_user.get("status").unwrap().as_string().unwrap(), + "system" + ); + } } diff --git a/parquet-variant/src/utils.rs b/parquet-variant/src/utils.rs index 8374105e0af8..872e90ad51f9 100644 --- a/parquet-variant/src/utils.rs +++ b/parquet-variant/src/utils.rs @@ -18,6 +18,7 @@ use std::{array::TryFromSliceError, ops::Range, str}; use arrow_schema::ArrowError; +use std::cmp::Ordering; use std::fmt::Debug; use std::slice::SliceIndex; @@ -115,23 +116,20 @@ pub(crate) fn string_from_slice( /// * `Some(Ok(index))` - Element found at the given index /// * `Some(Err(index))` - Element not found, but would be inserted at the given index /// * `None` - Key extraction failed -pub(crate) fn try_binary_search_range_by( +pub(crate) fn try_binary_search_range_by( range: Range, - target: &K, - key_extractor: F, + cmp: F, ) -> Option> where - K: Ord, - F: Fn(usize) -> Option, + F: Fn(usize) -> Option, { let Range { mut start, mut end } = range; while start < end { let mid = start + (end - start) / 2; - let key = key_extractor(mid)?; - match key.cmp(target) { - std::cmp::Ordering::Equal => return Some(Ok(mid)), - std::cmp::Ordering::Greater => end = mid, - std::cmp::Ordering::Less => start = mid + 1, + match cmp(mid)? { + Ordering::Equal => return Some(Ok(mid)), + Ordering::Greater => end = mid, + Ordering::Less => start = mid + 1, } } diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 003d46c122a4..3dae4daa0ff8 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -28,6 +28,7 @@ use std::ops::Deref; use arrow_schema::ArrowError; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; +use half::f16; use uuid::Uuid; mod decimal; @@ -804,6 +805,166 @@ impl<'m, 'v> Variant<'m, 'v> { } } + fn generic_convert_unsigned_primitive(&self) -> Option + where + T: TryFrom + TryFrom + TryFrom + TryFrom + TryFrom, + { + match *self { + Variant::Int8(i) => i.try_into().ok(), + Variant::Int16(i) => i.try_into().ok(), + Variant::Int32(i) => i.try_into().ok(), + Variant::Int64(i) => i.try_into().ok(), + Variant::Decimal4(d) if d.scale() == 0 => d.integer().try_into().ok(), + Variant::Decimal8(d) if d.scale() == 0 => d.integer().try_into().ok(), + Variant::Decimal16(d) if d.scale() == 0 => d.integer().try_into().ok(), + _ => None, + } + } + + /// Converts this variant to a `u8` if possible. + /// + /// Returns `Some(u8)` for integer variants that fit in `u8` + /// `None` for non-integer variants or values that would overflow. + /// + /// # Examples + /// + /// ``` + /// use parquet_variant::{Variant, VariantDecimal4}; + /// + /// // you can read an int64 variant into an u8 + /// let v1 = Variant::from(123i64); + /// assert_eq!(v1.as_u8(), Some(123u8)); + /// + /// // or a Decimal4 with scale 0 into u8 + /// let d = VariantDecimal4::try_new(26, 0).unwrap(); + /// let v2 = Variant::from(d); + /// assert_eq!(v2.as_u8(), Some(26u8)); + /// + /// // but not a variant that can't fit into the range + /// let v3 = Variant::from(-1); + /// assert_eq!(v3.as_u8(), None); + /// + /// // not a variant that decimal with scale not equal to zero + /// let d = VariantDecimal4::try_new(1, 2).unwrap(); + /// let v4 = Variant::from(d); + /// assert_eq!(v4.as_u8(), None); + /// + /// // or not a variant that cannot be cast into an integer + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_u8(), None); + /// ``` + pub fn as_u8(&self) -> Option { + self.generic_convert_unsigned_primitive::() + } + + /// Converts this variant to an `u16` if possible. + /// + /// Returns `Some(u16)` for integer variants that fit in `u16` + /// `None` for non-integer variants or values that would overflow. + /// + /// # Examples + /// + /// ``` + /// use parquet_variant::{Variant, VariantDecimal4}; + /// + /// // you can read an int64 variant into an u16 + /// let v1 = Variant::from(123i64); + /// assert_eq!(v1.as_u16(), Some(123u16)); + /// + /// // or a Decimal4 with scale 0 into u8 + /// let d = VariantDecimal4::try_new(u16::MAX as i32, 0).unwrap(); + /// let v2 = Variant::from(d); + /// assert_eq!(v2.as_u16(), Some(u16::MAX)); + /// + /// // but not a variant that can't fit into the range + /// let v3 = Variant::from(-1); + /// assert_eq!(v3.as_u16(), None); + /// + /// // not a variant that decimal with scale not equal to zero + /// let d = VariantDecimal4::try_new(1, 2).unwrap(); + /// let v4 = Variant::from(d); + /// assert_eq!(v4.as_u16(), None); + /// + /// // or not a variant that cannot be cast into an integer + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_u16(), None); + /// ``` + pub fn as_u16(&self) -> Option { + self.generic_convert_unsigned_primitive::() + } + + /// Converts this variant to an `u32` if possible. + /// + /// Returns `Some(u32)` for integer variants that fit in `u32` + /// `None` for non-integer variants or values that would overflow. + /// + /// # Examples + /// + /// ``` + /// use parquet_variant::{Variant, VariantDecimal8}; + /// + /// // you can read an int64 variant into an u32 + /// let v1 = Variant::from(123i64); + /// assert_eq!(v1.as_u32(), Some(123u32)); + /// + /// // or a Decimal4 with scale 0 into u8 + /// let d = VariantDecimal8::try_new(u32::MAX as i64, 0).unwrap(); + /// let v2 = Variant::from(d); + /// assert_eq!(v2.as_u32(), Some(u32::MAX)); + /// + /// // but not a variant that can't fit into the range + /// let v3 = Variant::from(-1); + /// assert_eq!(v3.as_u32(), None); + /// + /// // not a variant that decimal with scale not equal to zero + /// let d = VariantDecimal8::try_new(1, 2).unwrap(); + /// let v4 = Variant::from(d); + /// assert_eq!(v4.as_u32(), None); + /// + /// // or not a variant that cannot be cast into an integer + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_u32(), None); + /// ``` + pub fn as_u32(&self) -> Option { + self.generic_convert_unsigned_primitive::() + } + + /// Converts this variant to an `u64` if possible. + /// + /// Returns `Some(u64)` for integer variants that fit in `u64` + /// `None` for non-integer variants or values that would overflow. + /// + /// # Examples + /// + /// ``` + /// use parquet_variant::{Variant, VariantDecimal16}; + /// + /// // you can read an int64 variant into an u64 + /// let v1 = Variant::from(123i64); + /// assert_eq!(v1.as_u64(), Some(123u64)); + /// + /// // or a Decimal16 with scale 0 into u8 + /// let d = VariantDecimal16::try_new(u64::MAX as i128, 0).unwrap(); + /// let v2 = Variant::from(d); + /// assert_eq!(v2.as_u64(), Some(u64::MAX)); + /// + /// // but not a variant that can't fit into the range + /// let v3 = Variant::from(-1); + /// assert_eq!(v3.as_u64(), None); + /// + /// // not a variant that decimal with scale not equal to zero + /// let d = VariantDecimal16::try_new(1, 2).unwrap(); + /// let v4 = Variant::from(d); + /// assert_eq!(v4.as_u64(), None); + /// + /// // or not a variant that cannot be cast into an integer + /// let v5 = Variant::from("hello!"); + /// assert_eq!(v5.as_u64(), None); + /// ``` + pub fn as_u64(&self) -> Option { + self.generic_convert_unsigned_primitive::() + } + /// Converts this variant to tuple with a 4-byte unscaled value if possible. /// /// Returns `Some((i32, u8))` for decimal variants where the unscaled value @@ -915,6 +1076,37 @@ impl<'m, 'v> Variant<'m, 'v> { _ => None, } } + + /// Converts this variant to an `f16` if possible. + /// + /// Returns `Some(f16)` for float and double variants, + /// `None` for non-floating-point variants. + /// + /// # Example + /// + /// ``` + /// use parquet_variant::Variant; + /// use half::f16; + /// + /// // you can extract an f16 from a float variant + /// let v1 = Variant::from(std::f32::consts::PI); + /// assert_eq!(v1.as_f16(), Some(f16::from_f32(std::f32::consts::PI))); + /// + /// // and from a double variant (with loss of precision to nearest f16) + /// let v2 = Variant::from(std::f64::consts::PI); + /// assert_eq!(v2.as_f16(), Some(f16::from_f64(std::f64::consts::PI))); + /// + /// // but not from other variants + /// let v3 = Variant::from("hello!"); + /// assert_eq!(v3.as_f16(), None); + pub fn as_f16(&self) -> Option { + match *self { + Variant::Float(i) => Some(f16::from_f32(i)), + Variant::Double(i) => Some(f16::from_f64(i)), + _ => None, + } + } + /// Converts this variant to an `f32` if possible. /// /// Returns `Some(f32)` for float and double variants, @@ -1149,7 +1341,7 @@ impl<'m, 'v> Variant<'m, 'v> { /// # list.append_value("bar"); /// # list.append_value("baz"); /// # list.finish(); - /// # obj.finish().unwrap(); + /// # obj.finish(); /// # let (metadata, value) = builder.finish(); /// // given a variant like `{"foo": ["bar", "baz"]}` /// let variant = Variant::new(&metadata, &value); @@ -1278,6 +1470,12 @@ impl From for Variant<'_, '_> { } } +impl From for Variant<'_, '_> { + fn from(value: half::f16) -> Self { + Variant::Float(value.into()) + } +} + impl From for Variant<'_, '_> { fn from(value: f32) -> Self { Variant::Float(value) @@ -1578,7 +1776,7 @@ mod tests { let mut nested_obj = root_obj.new_object("nested_object"); nested_obj.insert("inner_key1", "inner_value1"); nested_obj.insert("inner_key2", 999i32); - nested_obj.finish().unwrap(); + nested_obj.finish(); // Add list with mixed types let mut mixed_list = root_obj.new_list("mixed_list"); @@ -1596,7 +1794,7 @@ mod tests { mixed_list.finish(); - root_obj.finish().unwrap(); + root_obj.finish(); let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value).unwrap(); diff --git a/parquet-variant/src/variant/list.rs b/parquet-variant/src/variant/list.rs index e3053ce9100e..438faddffe15 100644 --- a/parquet-variant/src/variant/list.rs +++ b/parquet-variant/src/variant/list.rs @@ -697,7 +697,7 @@ mod tests { // list3 (10..20) let (metadata3, value3) = make_listi32(10i32..20i32); object_builder.insert("list3", Variant::new(&metadata3, &value3)); - object_builder.finish().unwrap(); + object_builder.finish(); builder.finish() }; diff --git a/parquet-variant/src/variant/metadata.rs b/parquet-variant/src/variant/metadata.rs index 0e356e34c41e..1c9da6bcc0cf 100644 --- a/parquet-variant/src/variant/metadata.rs +++ b/parquet-variant/src/variant/metadata.rs @@ -16,7 +16,10 @@ // under the License. use crate::decoder::{map_bytes_to_offsets, OffsetSizeBytes}; -use crate::utils::{first_byte_from_slice, overflow_error, slice_from_slice, string_from_slice}; +use crate::utils::{ + first_byte_from_slice, overflow_error, slice_from_slice, string_from_slice, + try_binary_search_range_by, +}; use arrow_schema::ArrowError; @@ -315,6 +318,32 @@ impl<'m> VariantMetadata<'m> { string_from_slice(self.bytes, self.first_value_byte as _, byte_range) } + // Helper method used by our `impl Index` and also by `get_entry`. Panics if the underlying + // bytes are invalid. Needed because the `Index` trait forces the returned result to have the + // lifetime of `self` instead of the string's own (longer) lifetime `'m`. + fn get_impl(&self, i: usize) -> &'m str { + self.get(i).expect("Invalid metadata dictionary entry") + } + + /// Attempts to retrieve a dictionary entry and its field id, returning None if the requested field + /// name is not present. The search cost is logarithmic if [`Self::is_sorted`] and linear + /// otherwise. + /// + /// WARNING: This method panics if the underlying bytes are [invalid]. + /// + /// [invalid]: Self#Validation + pub fn get_entry(&self, field_name: &str) -> Option<(u32, &'m str)> { + let field_id = if self.is_sorted() && self.len() > 10 { + // Binary search is faster for a not-tiny sorted metadata dictionary + let cmp = |i| Some(self.get_impl(i).cmp(field_name)); + try_binary_search_range_by(0..self.len(), cmp)?.ok()? + } else { + // Fall back to Linear search for tiny or unsorted dictionary + (0..self.len()).find(|i| self.get_impl(*i) == field_name)? + }; + Some((field_id as u32, self.get_impl(field_id))) + } + /// Returns an iterator that attempts to visit all dictionary entries, producing `Err` if the /// iterator encounters [invalid] data. /// @@ -341,7 +370,7 @@ impl std::ops::Index for VariantMetadata<'_> { type Output = str; fn index(&self, i: usize) -> &str { - self.get(i).expect("Invalid metadata dictionary entry") + self.get_impl(i) } } @@ -544,7 +573,7 @@ mod tests { o.insert("a", false); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, _) = b.finish(); @@ -579,7 +608,7 @@ mod tests { o.insert("a", false); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, _) = b.finish(); diff --git a/parquet-variant/src/variant/object.rs b/parquet-variant/src/variant/object.rs index b809fe278cb4..df1857846302 100644 --- a/parquet-variant/src/variant/object.rs +++ b/parquet-variant/src/variant/object.rs @@ -397,8 +397,8 @@ impl<'m, 'v> VariantObject<'m, 'v> { // NOTE: This does not require a sorted metadata dictionary, because the variant spec // requires object field ids to be lexically sorted by their corresponding string values, // and probing the dictionary for a field id is always O(1) work. - let i = try_binary_search_range_by(0..self.len(), &name, |i| self.field_name(i))?.ok()?; - + let cmp = |i| Some(self.field_name(i)?.cmp(name)); + let i = try_binary_search_range_by(0..self.len(), cmp)?.ok()?; self.field(i) } } @@ -550,7 +550,7 @@ mod tests { #[test] fn test_variant_object_empty_fields() { let mut builder = VariantBuilder::new(); - builder.new_object().with_field("", 42).finish().unwrap(); + builder.new_object().with_field("", 42).finish(); let (metadata, value) = builder.finish(); // Resulting object is valid and has a single empty field @@ -676,7 +676,7 @@ mod tests { obj.insert(&field_names[i as usize], i); } - obj.finish().unwrap(); + obj.finish(); let (metadata, value) = builder.finish(); let variant = Variant::new(&metadata, &value); @@ -737,7 +737,7 @@ mod tests { obj.insert(&key, str_val.as_str()); } - obj.finish().unwrap(); + obj.finish(); let (metadata, value) = builder.finish(); let variant = Variant::new(&metadata, &value); @@ -783,7 +783,7 @@ mod tests { o.insert("c", ()); o.insert("a", ()); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); @@ -801,7 +801,7 @@ mod tests { o.insert("a", ()); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); let v1 = Variant::try_new(&m, &v).unwrap(); @@ -812,7 +812,7 @@ mod tests { o.insert("a", ()); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); let v2 = Variant::try_new(&m, &v).unwrap(); @@ -828,7 +828,7 @@ mod tests { o.insert("a", ()); o.insert("b", 4.3); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); @@ -841,8 +841,8 @@ mod tests { o.insert("a", ()); let mut inner_o = o.new_object("b"); inner_o.insert("a", 3.3); - inner_o.finish().unwrap(); - o.finish().unwrap(); + inner_o.finish(); + o.finish(); let (m, v) = b.finish(); @@ -866,7 +866,7 @@ mod tests { o.insert("a", ()); o.insert("b", 4.3); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); @@ -879,7 +879,7 @@ mod tests { o.insert("aardvark", ()); o.insert("barracuda", 3.3); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); let v2 = Variant::try_new(&m, &v).unwrap(); @@ -895,7 +895,7 @@ mod tests { o.insert("b", false); o.insert("a", ()); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); @@ -904,13 +904,13 @@ mod tests { // create another object pre-filled with field names, b and a // but insert the fields in the order of a, b - let mut b = VariantBuilder::new().with_field_names(["b", "a"].into_iter()); + let mut b = VariantBuilder::new().with_field_names(["b", "a"]); let mut o = b.new_object(); o.insert("a", ()); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); @@ -930,7 +930,7 @@ mod tests { o.insert("a", ()); o.insert("b", 4.3); - o.finish().unwrap(); + o.finish(); let (meta1, value1) = b.finish(); @@ -939,13 +939,13 @@ mod tests { assert!(v1.metadata().unwrap().is_sorted()); // create a second object with different insertion order - let mut b = VariantBuilder::new().with_field_names(["d", "c", "b", "a"].into_iter()); + let mut b = VariantBuilder::new().with_field_names(["d", "c", "b", "a"]); let mut o = b.new_object(); o.insert("b", 4.3); o.insert("a", ()); - o.finish().unwrap(); + o.finish(); let (meta2, value2) = b.finish(); @@ -969,7 +969,7 @@ mod tests { o.insert("a", false); o.insert("b", false); - o.finish().unwrap(); + o.finish(); let (m, v) = b.finish(); diff --git a/parquet-variant/tests/variant_interop.rs b/parquet-variant/tests/variant_interop.rs index 07ff6d01b410..00c326c06406 100644 --- a/parquet-variant/tests/variant_interop.rs +++ b/parquet-variant/tests/variant_interop.rs @@ -272,7 +272,7 @@ fn variant_object_builder() { obj.insert("null_field", ()); obj.insert("timestamp_field", "2025-04-16T12:34:56.78"); - obj.finish().unwrap(); + obj.finish(); let (built_metadata, built_value) = builder.finish(); let actual = Variant::try_new(&built_metadata, &built_value).unwrap(); @@ -384,7 +384,7 @@ fn generate_random_value(rng: &mut StdRng, builder: &mut VariantBuilder, max_dep let key = format!("field_{i}"); object_builder.insert(&key, rng.random::()); } - object_builder.finish().unwrap(); + object_builder.finish(); } 15 => { // Time diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index f601ac7cefdc..bae90a51f0a8 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -65,7 +65,7 @@ serde_json = { version = "1.0", default-features = false, features = ["std"], op seq-macro = { version = "0.3", default-features = false } futures = { version = "0.3", default-features = false, features = ["std"], optional = true } tokio = { version = "1.0", optional = true, default-features = false, features = ["macros", "rt", "io-util"] } -hashbrown = { version = "0.15", default-features = false } +hashbrown = { version = "0.16", default-features = false } twox-hash = { version = "2.0", default-features = false, features = ["xxhash64"] } paste = { version = "1.0" } half = { version = "2.1", default-features = false, features = ["num-traits"] } diff --git a/parquet/benches/metadata.rs b/parquet/benches/metadata.rs index 949e0d98ea39..8c886e4d5eea 100644 --- a/parquet/benches/metadata.rs +++ b/parquet/benches/metadata.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use parquet::file::metadata::ParquetMetaDataReader; use rand::Rng; use thrift::protocol::TCompactOutputProtocol; @@ -25,7 +26,7 @@ use parquet::file::reader::SerializedFileReader; use parquet::file::serialized_reader::ReadOptionsBuilder; use parquet::format::{ ColumnChunk, ColumnMetaData, CompressionCodec, Encoding, FieldRepetitionType, FileMetaData, - RowGroup, SchemaElement, Type, + PageEncodingStats, PageType, RowGroup, SchemaElement, Type, }; use parquet::thrift::TSerializable; @@ -93,7 +94,18 @@ fn encoded_meta() -> Vec { index_page_offset: Some(rng.random()), dictionary_page_offset: Some(rng.random()), statistics: Some(stats.clone()), - encoding_stats: None, + encoding_stats: Some(vec![ + PageEncodingStats { + page_type: PageType::DICTIONARY_PAGE, + encoding: Encoding::PLAIN, + count: 1, + }, + PageEncodingStats { + page_type: PageType::DATA_PAGE, + encoding: Encoding::RLE_DICTIONARY, + count: 10, + }, + ]), bloom_filter_offset: None, bloom_filter_length: None, size_statistics: None, @@ -151,6 +163,36 @@ fn get_footer_bytes(data: Bytes) -> Bytes { data.slice(meta_start..meta_end) } +#[cfg(feature = "arrow")] +fn rewrite_file(bytes: Bytes) -> (Bytes, FileMetaData) { + use arrow::array::RecordBatchReader; + use parquet::arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter}; + use parquet::file::properties::{EnabledStatistics, WriterProperties}; + + let parquet_reader = ParquetRecordBatchReaderBuilder::try_new(bytes) + .expect("parquet open") + .build() + .expect("parquet open"); + let writer_properties = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Page) + .set_write_page_header_statistics(true) + .build(); + let mut output = Vec::new(); + let mut parquet_writer = ArrowWriter::try_new( + &mut output, + parquet_reader.schema(), + Some(writer_properties), + ) + .expect("create arrow writer"); + + for maybe_batch in parquet_reader { + let batch = maybe_batch.expect("reading batch"); + parquet_writer.write(&batch).expect("writing data"); + } + let file_meta = parquet_writer.close().expect("finalizing file"); + (output.into(), file_meta) +} + fn criterion_benchmark(c: &mut Criterion) { // Read file into memory to isolate filesystem performance let file = "../parquet-testing/data/alltypes_tiny_pages.parquet"; @@ -168,19 +210,54 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - let meta_data = get_footer_bytes(data); - c.bench_function("decode file metadata", |b| { + let meta_data = get_footer_bytes(data.clone()); + c.bench_function("decode parquet metadata", |b| { + b.iter(|| { + ParquetMetaDataReader::decode_metadata(&meta_data).unwrap(); + }) + }); + + c.bench_function("decode thrift file metadata", |b| { b.iter(|| { parquet::thrift::bench_file_metadata(&meta_data); }) }); - let buf = black_box(encoded_meta()).into(); - c.bench_function("decode file metadata (wide)", |b| { + let buf: Bytes = black_box(encoded_meta()).into(); + c.bench_function("decode parquet metadata (wide)", |b| { + b.iter(|| { + ParquetMetaDataReader::decode_metadata(&buf).unwrap(); + }) + }); + + c.bench_function("decode thrift file metadata (wide)", |b| { b.iter(|| { parquet::thrift::bench_file_metadata(&buf); }) }); + + // rewrite file with page statistics. then read page headers. + #[cfg(feature = "arrow")] + let (file_bytes, metadata) = rewrite_file(data.clone()); + #[cfg(feature = "arrow")] + c.bench_function("page headers", |b| { + b.iter(|| { + metadata.row_groups.iter().for_each(|rg| { + rg.columns.iter().for_each(|col| { + if let Some(col_meta) = &col.meta_data { + if let Some(dict_offset) = col_meta.dictionary_page_offset { + parquet::thrift::bench_page_header( + &file_bytes.slice(dict_offset as usize..), + ); + } + parquet::thrift::bench_page_header( + &file_bytes.slice(col_meta.data_page_offset as usize..), + ); + } + }); + }); + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/parquet/src/arrow/async_writer/mod.rs b/parquet/src/arrow/async_writer/mod.rs index 3a74aa7c9c20..4547f71274b7 100644 --- a/parquet/src/arrow/async_writer/mod.rs +++ b/parquet/src/arrow/async_writer/mod.rs @@ -61,7 +61,7 @@ mod store; pub use store::*; use crate::{ - arrow::arrow_writer::ArrowWriterOptions, + arrow::arrow_writer::{ArrowColumnChunk, ArrowColumnWriter, ArrowWriterOptions}, arrow::ArrowWriter, errors::{ParquetError, Result}, file::{metadata::RowGroupMetaData, properties::WriterProperties}, @@ -288,6 +288,22 @@ impl AsyncArrowWriter { Ok(()) } + + /// Create a new row group writer and return its column writers. + pub async fn get_column_writers(&mut self) -> Result> { + let before = self.sync_writer.flushed_row_groups().len(); + let writers = self.sync_writer.get_column_writers()?; + if before != self.sync_writer.flushed_row_groups().len() { + self.do_write().await?; + } + Ok(writers) + } + + /// Append the given column chunks to the file as a new row group. + pub async fn append_row_group(&mut self, chunks: Vec) -> Result<()> { + self.sync_writer.append_row_group(chunks)?; + self.do_write().await + } } #[cfg(test)] @@ -298,6 +314,7 @@ mod tests { use std::sync::Arc; use crate::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder}; + use crate::arrow::arrow_writer::compute_leaves; use super::*; @@ -332,6 +349,51 @@ mod tests { assert_eq!(to_write, read); } + #[tokio::test] + async fn test_async_arrow_group_writer() { + let col = Arc::new(Int64Array::from_iter_values([4, 5, 6])) as ArrayRef; + let to_write_record = RecordBatch::try_from_iter([("col", col)]).unwrap(); + + let mut buffer = Vec::new(); + let mut writer = + AsyncArrowWriter::try_new(&mut buffer, to_write_record.schema(), None).unwrap(); + + // Use classic API + writer.write(&to_write_record).await.unwrap(); + + let mut writers = writer.get_column_writers().await.unwrap(); + let col = Arc::new(Int64Array::from_iter_values([1, 2, 3])) as ArrayRef; + let to_write_arrow_group = RecordBatch::try_from_iter([("col", col)]).unwrap(); + + for (field, column) in to_write_arrow_group + .schema() + .fields() + .iter() + .zip(to_write_arrow_group.columns()) + { + for leaf in compute_leaves(field.as_ref(), column).unwrap() { + writers[0].write(&leaf).unwrap(); + } + } + + let columns: Vec<_> = writers.into_iter().map(|w| w.close().unwrap()).collect(); + // Append the arrow group as a new row group. Flush in progress + writer.append_row_group(columns).await.unwrap(); + writer.close().await.unwrap(); + + let buffer = Bytes::from(buffer); + let mut reader = ParquetRecordBatchReaderBuilder::try_new(buffer) + .unwrap() + .build() + .unwrap(); + + let col = Arc::new(Int64Array::from_iter_values([4, 5, 6, 1, 2, 3])) as ArrayRef; + let expected = RecordBatch::try_from_iter([("col", col)]).unwrap(); + + let read = reader.next().unwrap().unwrap(); + assert_eq!(expected, read); + } + // Read the data from the test file and write it by the async writer and sync writer. // And then compares the results of the two writers. #[tokio::test] diff --git a/parquet/src/bloom_filter/mod.rs b/parquet/src/bloom_filter/mod.rs index 384a4a10486e..09302bab8fec 100644 --- a/parquet/src/bloom_filter/mod.rs +++ b/parquet/src/bloom_filter/mod.rs @@ -119,6 +119,13 @@ impl Block { Self(result) } + #[inline] + #[cfg(not(target_endian = "little"))] + fn to_ne_bytes(self) -> [u8; 32] { + // SAFETY: [u32; 8] and [u8; 32] have the same size and neither has invalid bit patterns. + unsafe { std::mem::transmute(self.0) } + } + #[inline] #[cfg(not(target_endian = "little"))] fn to_le_bytes(self) -> [u8; 32] { diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 9374e226b87f..82b8ba166f14 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -1104,12 +1104,23 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { rep_levels_byte_len + def_levels_byte_len + values_data.buf.len(); // Data Page v2 compresses values only. - match self.compressor { + let is_compressed = match self.compressor { Some(ref mut cmpr) => { + let buffer_len = buffer.len(); cmpr.compress(&values_data.buf, &mut buffer)?; + if uncompressed_size <= buffer.len() - buffer_len { + buffer.truncate(buffer_len); + buffer.extend_from_slice(&values_data.buf); + false + } else { + true + } } - None => buffer.extend_from_slice(&values_data.buf), - } + None => { + buffer.extend_from_slice(&values_data.buf); + false + } + }; let data_page = Page::DataPageV2 { buf: buffer.into(), @@ -1119,7 +1130,7 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { num_rows: self.page_metrics.num_buffered_rows, def_levels_byte_len: def_levels_byte_len as u32, rep_levels_byte_len: rep_levels_byte_len as u32, - is_compressed: self.compressor.is_some(), + is_compressed, statistics: page_statistics, }; @@ -4236,4 +4247,33 @@ mod tests { .unwrap(); ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path) } + + #[test] + fn test_page_v2_snappy_compression_fallback() { + // Test that PageV2 sets is_compressed to false when Snappy compression increases data size + let page_writer = TestPageWriter {}; + + // Create WriterProperties with PageV2 and Snappy compression + let props = WriterProperties::builder() + .set_writer_version(WriterVersion::PARQUET_2_0) + // Disable dictionary to ensure data is written directly + .set_dictionary_enabled(false) + .set_compression(Compression::SNAPPY) + .build(); + + let mut column_writer = + get_test_column_writer::(Box::new(page_writer), 0, 0, Arc::new(props)); + + // Create small, simple data that Snappy compression will likely increase in size + // due to compression overhead for very small data + let values = vec![ByteArray::from("a")]; + + column_writer.write_batch(&values, None, None).unwrap(); + + let result = column_writer.close().unwrap(); + assert_eq!( + result.metadata.uncompressed_size(), + result.metadata.compressed_size() + ); + } } diff --git a/parquet/src/file/properties.rs b/parquet/src/file/properties.rs index 96e3706e27d7..603db6660f45 100644 --- a/parquet/src/file/properties.rs +++ b/parquet/src/file/properties.rs @@ -193,6 +193,12 @@ impl WriterProperties { WriterPropertiesBuilder::default() } + /// Converts this [`WriterProperties`] into a [`WriterPropertiesBuilder`] + /// Used for mutating existing property settings + pub fn into_builder(self) -> WriterPropertiesBuilder { + self.into() + } + /// Returns data page size limit. /// /// Note: this is a best effort limit based on the write batch size @@ -435,6 +441,7 @@ impl WriterProperties { /// Builder for [`WriterProperties`] Parquet writer configuration. /// /// See example on [`WriterProperties`] +#[derive(Debug, Clone)] pub struct WriterPropertiesBuilder { data_page_size_limit: usize, data_page_row_count_limit: usize, @@ -934,6 +941,30 @@ impl WriterPropertiesBuilder { } } +impl From for WriterPropertiesBuilder { + fn from(props: WriterProperties) -> Self { + WriterPropertiesBuilder { + data_page_size_limit: props.data_page_size_limit, + data_page_row_count_limit: props.data_page_row_count_limit, + write_batch_size: props.write_batch_size, + max_row_group_size: props.max_row_group_size, + bloom_filter_position: props.bloom_filter_position, + writer_version: props.writer_version, + created_by: props.created_by, + offset_index_disabled: props.offset_index_disabled, + key_value_metadata: props.key_value_metadata, + default_column_properties: props.default_column_properties, + column_properties: props.column_properties, + sorting_columns: props.sorting_columns, + column_index_truncate_length: props.column_index_truncate_length, + statistics_truncate_length: props.statistics_truncate_length, + coerce_types: props.coerce_types, + #[cfg(feature = "encryption")] + file_encryption_properties: props.file_encryption_properties, + } + } +} + /// Controls the level of statistics to be computed by the writer and stored in /// the parquet file. /// @@ -1377,50 +1408,59 @@ mod tests { .set_column_bloom_filter_fpp(ColumnPath::from("col"), 0.1) .build(); - assert_eq!(props.writer_version(), WriterVersion::PARQUET_2_0); - assert_eq!(props.data_page_size_limit(), 10); - assert_eq!(props.dictionary_page_size_limit(), 20); - assert_eq!(props.write_batch_size(), 30); - assert_eq!(props.max_row_group_size(), 40); - assert_eq!(props.created_by(), "default"); - assert_eq!( - props.key_value_metadata(), - Some(&vec![ - KeyValue::new("key".to_string(), "value".to_string(),) - ]) - ); + fn test_props(props: &WriterProperties) { + assert_eq!(props.writer_version(), WriterVersion::PARQUET_2_0); + assert_eq!(props.data_page_size_limit(), 10); + assert_eq!(props.dictionary_page_size_limit(), 20); + assert_eq!(props.write_batch_size(), 30); + assert_eq!(props.max_row_group_size(), 40); + assert_eq!(props.created_by(), "default"); + assert_eq!( + props.key_value_metadata(), + Some(&vec![ + KeyValue::new("key".to_string(), "value".to_string(),) + ]) + ); - assert_eq!( - props.encoding(&ColumnPath::from("a")), - Some(Encoding::DELTA_BINARY_PACKED) - ); - assert_eq!( - props.compression(&ColumnPath::from("a")), - Compression::GZIP(Default::default()) - ); - assert!(!props.dictionary_enabled(&ColumnPath::from("a"))); - assert_eq!( - props.statistics_enabled(&ColumnPath::from("a")), - EnabledStatistics::None - ); + assert_eq!( + props.encoding(&ColumnPath::from("a")), + Some(Encoding::DELTA_BINARY_PACKED) + ); + assert_eq!( + props.compression(&ColumnPath::from("a")), + Compression::GZIP(Default::default()) + ); + assert!(!props.dictionary_enabled(&ColumnPath::from("a"))); + assert_eq!( + props.statistics_enabled(&ColumnPath::from("a")), + EnabledStatistics::None + ); - assert_eq!( - props.encoding(&ColumnPath::from("col")), - Some(Encoding::RLE) - ); - assert_eq!( - props.compression(&ColumnPath::from("col")), - Compression::SNAPPY - ); - assert!(props.dictionary_enabled(&ColumnPath::from("col"))); - assert_eq!( - props.statistics_enabled(&ColumnPath::from("col")), - EnabledStatistics::Chunk - ); - assert_eq!( - props.bloom_filter_properties(&ColumnPath::from("col")), - Some(&BloomFilterProperties { fpp: 0.1, ndv: 100 }) - ); + assert_eq!( + props.encoding(&ColumnPath::from("col")), + Some(Encoding::RLE) + ); + assert_eq!( + props.compression(&ColumnPath::from("col")), + Compression::SNAPPY + ); + assert!(props.dictionary_enabled(&ColumnPath::from("col"))); + assert_eq!( + props.statistics_enabled(&ColumnPath::from("col")), + EnabledStatistics::Chunk + ); + assert_eq!( + props.bloom_filter_properties(&ColumnPath::from("col")), + Some(&BloomFilterProperties { fpp: 0.1, ndv: 100 }) + ); + } + + // Test direct build of properties + test_props(&props); + + // Test that into_builder() gives the same result + let props_into_builder_and_back = props.into_builder().build(); + test_props(&props_into_builder_and_back); } #[test] diff --git a/parquet/src/thrift.rs b/parquet/src/thrift.rs index fc391abe87d7..e16e394be2bb 100644 --- a/parquet/src/thrift.rs +++ b/parquet/src/thrift.rs @@ -33,12 +33,20 @@ pub trait TSerializable: Sized { fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()>; } -/// Public function to aid benchmarking. +// Public function to aid benchmarking. Reads Parquet `FileMetaData` encoded in `bytes`. +#[doc(hidden)] pub fn bench_file_metadata(bytes: &bytes::Bytes) { let mut input = TCompactSliceInputProtocol::new(bytes); crate::format::FileMetaData::read_from_in_protocol(&mut input).unwrap(); } +// Public function to aid benchmarking. Reads Parquet `PageHeader` encoded in `bytes`. +#[doc(hidden)] +pub fn bench_page_header(bytes: &bytes::Bytes) { + let mut prot = TCompactSliceInputProtocol::new(bytes); + crate::format::PageHeader::read_from_in_protocol(&mut prot).unwrap(); +} + /// A more performant implementation of [`TCompactInputProtocol`] that reads a slice /// /// [`TCompactInputProtocol`]: thrift::protocol::TCompactInputProtocol