From a795030a777ffdb037250a4e34a3ccc3c26dc6cf Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Tue, 24 Jun 2025 05:51:35 -0400 Subject: [PATCH 1/6] [Variant] Use `BTreeMap` for `VariantBuilder.dict` and `ObjectBuilder.fields` to maintain invariants upon entry writes (#7720) # Which issue does this PR close? - It doesn't directly close the issue, but it's related to https://github.com/apache/arrow-rs/issues/7698 # Rationale for this change This commit changes the `dict` field in `VariantBuilder` + the `fields` field in `ObjectBuilder` to be `BTreeMap`s, and checks for existing field names in a object before appending a new field. These collections are often used in places where having an already sorted structure would be more performant. Inside of `ObjectBuilder::finish()`, we sort the fields by `field_name` and we can use the fact that `VariantBuilder`'s `dict` maintains a sorted mapping to `field_id` by `field_name`. To check whether an existing field name exists in a object, it is simply two lookups: 1) to find the `field_name: &str`'s unique `field_name_id`, and 2) check if the `ObjectBuilder` `fields` already has a key with that `field_name_id`. We make `ObjectBuilder` `fields` a `BTreeMap` sorted by `field_id`. Since `field_id`s correlate to insertion order, we now have some notion of which fields were inserted first. This also improves the time to look up the max field id, as it changes the linear scan over the entire `fields` collection to a logarithmic call using `fields.keys().last()`. --- parquet-variant/src/builder.rs | 114 +++++++++++++++++++++++++++------ 1 file changed, 94 insertions(+), 20 deletions(-) diff --git a/parquet-variant/src/builder.rs b/parquet-variant/src/builder.rs index c595d72e0afc..a5fb66a84ff4 100644 --- a/parquet-variant/src/builder.rs +++ b/parquet-variant/src/builder.rs @@ -16,7 +16,7 @@ // under the License. use crate::decoder::{VariantBasicType, VariantPrimitiveType}; use crate::{ShortString, Variant}; -use std::collections::HashMap; +use std::collections::BTreeMap; const BASIC_TYPE_BITS: u8 = 2; const UNIX_EPOCH_DATE: chrono::NaiveDate = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); @@ -166,7 +166,7 @@ fn make_room_for_header(buffer: &mut Vec, start_pos: usize, header_size: usi /// pub struct VariantBuilder { buffer: Vec, - dict: HashMap, + dict: BTreeMap, dict_keys: Vec, } @@ -174,7 +174,7 @@ impl VariantBuilder { pub fn new() -> Self { Self { buffer: Vec::new(), - dict: HashMap::new(), + dict: BTreeMap::new(), dict_keys: Vec::new(), } } @@ -296,7 +296,7 @@ impl VariantBuilder { /// Add key to dictionary, return its ID fn add_key(&mut self, key: &str) -> u32 { - use std::collections::hash_map::Entry; + use std::collections::btree_map::Entry; match self.dict.entry(key.to_string()) { Entry::Occupied(entry) => *entry.get(), Entry::Vacant(entry) => { @@ -482,7 +482,7 @@ impl<'a> ListBuilder<'a> { pub struct ObjectBuilder<'a> { parent: &'a mut VariantBuilder, start_pos: usize, - fields: Vec<(u32, usize)>, // (field_id, offset) + fields: BTreeMap, // (field_id, offset) } impl<'a> ObjectBuilder<'a> { @@ -491,7 +491,7 @@ impl<'a> ObjectBuilder<'a> { Self { parent, start_pos, - fields: Vec::new(), + fields: BTreeMap::new(), } } @@ -500,25 +500,27 @@ impl<'a> ObjectBuilder<'a> { let id = self.parent.add_key(key); let field_start = self.parent.offset() - self.start_pos; self.parent.append_value(value); - self.fields.push((id, field_start)); + let res = self.fields.insert(id, field_start); + debug_assert!(res.is_none()); } /// Finalize object with sorted fields - pub fn finish(mut self) { - // Sort fields by key name - self.fields.sort_by(|a, b| { - let key_a = &self.parent.dict_keys[a.0 as usize]; - let key_b = &self.parent.dict_keys[b.0 as usize]; - key_a.cmp(key_b) - }); - + pub fn finish(self) { let data_size = self.parent.offset() - self.start_pos; let num_fields = self.fields.len(); let is_large = num_fields > u8::MAX as usize; let size_bytes = if is_large { 4 } else { 1 }; - let max_id = self.fields.iter().map(|&(id, _)| id).max().unwrap_or(0); - let id_size = int_size(max_id as usize); + let field_ids_by_sorted_field_name = self + .parent + .dict + .iter() + .filter_map(|(_, id)| self.fields.contains_key(id).then_some(*id)) + .collect::>(); + + let max_id = self.fields.keys().last().copied().unwrap_or(0) as usize; + + let id_size = int_size(max_id); let offset_size = int_size(data_size); let header_size = 1 @@ -542,17 +544,18 @@ impl<'a> ObjectBuilder<'a> { } // Write field IDs (sorted order) - for &(id, _) in &self.fields { + for id in &field_ids_by_sorted_field_name { write_offset( &mut self.parent.buffer[pos..pos + id_size as usize], - id as usize, + *id as usize, id_size, ); pos += id_size as usize; } // Write field offsets - for &(_, offset) in &self.fields { + for id in &field_ids_by_sorted_field_name { + let &offset = self.fields.get(id).unwrap(); write_offset( &mut self.parent.buffer[pos..pos + offset_size as usize], offset, @@ -749,6 +752,77 @@ mod tests { assert_eq!(field_ids, vec![1, 2, 0]); } + #[test] + fn test_object_and_metadata_ordering() { + let mut builder = VariantBuilder::new(); + + let mut obj = builder.new_object(); + + obj.append_value("zebra", "stripes"); // ID = 0 + obj.append_value("apple", "red"); // ID = 1 + + { + // fields_map is ordered by insertion order (field id) + let fields_map = obj.fields.keys().copied().collect::>(); + assert_eq!(fields_map, vec![0, 1]); + + // dict is ordered by field names + // NOTE: when we support nested objects, we'll want to perform a filter by fields_map field ids + let dict_metadata = obj + .parent + .dict + .iter() + .map(|(f, i)| (f.as_str(), *i)) + .collect::>(); + + assert_eq!(dict_metadata, vec![("apple", 1), ("zebra", 0)]); + + // dict_keys is ordered by insertion order (field id) + let dict_keys = obj + .parent + .dict_keys + .iter() + .map(|k| k.as_str()) + .collect::>(); + assert_eq!(dict_keys, vec!["zebra", "apple"]); + } + + obj.append_value("banana", "yellow"); // ID = 2 + + { + // fields_map is ordered by insertion order (field id) + let fields_map = obj.fields.keys().copied().collect::>(); + assert_eq!(fields_map, vec![0, 1, 2]); + + // dict is ordered by field names + // NOTE: when we support nested objects, we'll want to perform a filter by fields_map field ids + let dict_metadata = obj + .parent + .dict + .iter() + .map(|(f, i)| (f.as_str(), *i)) + .collect::>(); + + assert_eq!( + dict_metadata, + vec![("apple", 1), ("banana", 2), ("zebra", 0)] + ); + + // dict_keys is ordered by insertion order (field id) + let dict_keys = obj + .parent + .dict_keys + .iter() + .map(|k| k.as_str()) + .collect::>(); + assert_eq!(dict_keys, vec!["zebra", "apple", "banana"]); + } + + obj.finish(); + + builder.finish(); + } + #[test] fn test_append_object() { let (object_metadata, object_value) = { From 2b40d1dfc35862ff350a40dfbc66f8a14f4eea31 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 24 Jun 2025 07:26:36 -0400 Subject: [PATCH 2/6] [Variant] Add Variant::as_object and Variant::as_list (#7755) # Which issue does this PR close? - part of https://github.com/apache/arrow-rs/issues/6736 # Rationale for this change - While reviewing @friendlymatthew 's PR https://github.com/apache/arrow-rs/pull/7740 I found that the code to get the Variant object was awkward I think that an accessor is similar to the existing `as_null`, `as_i32,` etc APIs. # What changes are included in this PR? 1. Add Variant::as_object and Variant::as_list # Are there any user-facing changes? New API (and docs with tests) --- parquet-variant/src/variant.rs | 64 ++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 2e042b6074cb..51327b4d2528 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -809,6 +809,70 @@ impl<'m, 'v> Variant<'m, 'v> { } } + /// Converts this variant to an `Object` if it is an [`VariantObject`]. + /// + /// Returns `Some(&VariantObject)` for object variants, + /// `None` for non-object variants. + /// + /// # Examples + /// ``` + /// # use parquet_variant::{Variant, VariantBuilder, VariantObject}; + /// # let (metadata, value) = { + /// # let mut builder = VariantBuilder::new(); + /// # let mut obj = builder.new_object(); + /// # obj.append_value("name", "John"); + /// # obj.finish(); + /// # builder.finish() + /// # }; + /// // object that is {"name": "John"} + /// let variant = Variant::try_new(&metadata, &value).unwrap(); + /// // use the `as_object` method to access the object + /// let obj = variant.as_object().expect("variant should be an object"); + /// assert_eq!(obj.field_by_name("name").unwrap(), Some(Variant::from("John"))); + /// ``` + pub fn as_object(&'m self) -> Option<&'m VariantObject<'m, 'v>> { + if let Variant::Object(obj) = self { + Some(obj) + } else { + None + } + } + + /// Converts this variant to a `List` if it is a [`VariantList`]. + /// + /// Returns `Some(&VariantList)` for list variants, + /// `None` for non-list variants. + /// + /// # Examples + /// ``` + /// # use parquet_variant::{Variant, VariantBuilder, VariantList}; + /// # let (metadata, value) = { + /// # let mut builder = VariantBuilder::new(); + /// # let mut list = builder.new_list(); + /// # list.append_value("John"); + /// # list.append_value("Doe"); + /// # list.finish(); + /// # builder.finish() + /// # }; + /// // list that is ["John", "Doe"] + /// let variant = Variant::try_new(&metadata, &value).unwrap(); + /// // use the `as_list` method to access the list + /// let list = variant.as_list().expect("variant should be a list"); + /// assert_eq!(list.len(), 2); + /// assert_eq!(list.get(0).unwrap(), Variant::from("John")); + /// assert_eq!(list.get(1).unwrap(), Variant::from("Doe")); + /// ``` + pub fn as_list(&'m self) -> Option<&'m VariantList<'m, 'v>> { + if let Variant::List(list) = self { + Some(list) + } else { + None + } + } + + /// Return the metadata associated with this variant, if any. + /// + /// Returns `Some(&VariantMetadata)` for object and list variants, pub fn metadata(&self) -> Option<&'m VariantMetadata> { match self { Variant::Object(VariantObject { metadata, .. }) From a49ce3e22f192cefeba8058230dd7588a4c47e31 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 24 Jun 2025 12:28:10 -0400 Subject: [PATCH 3/6] Add testing section to pull request template (#7749) # Which issue does this PR close? N/A # Rationale for this change It is critical and generally required to add tests for any changes to arrow-rs and it something we look for in our PR reviews. It would be nice to remind people of this explicitly in the PR # What changes are included in this PR? 1. Update the PR template to include a section on testing 2. Add a list marker (`-`) to the closes section which causes github to render the name of the PR in markdown not just the number (see rationale on https://github.com/apache/datafusion/pull/14507) I copied the wording from DataFusion: https://github.com/apache/datafusion/blob/b6c8cc57760686fffe4878e69c1be27e4d9f5e68/.github/pull_request_template.md?plain=1#L22 # Are there any user-facing changes? A new section on PR descriptions --- .github/pull_request_template.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index e999f505bca1..49b34c6137f7 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -2,7 +2,7 @@ We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. -Closes #NNN. +- Closes #NNN. # Rationale for this change @@ -13,6 +13,14 @@ Explaining clearly why changes are proposed helps reviewers understand your chan There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. +# Are these changes tested? + +We typically require tests for all PRs in order to: +1. Prevent the code from being accidentally broken by subsequent changes +2. Serve as another way to document the expected behavior of the code + +If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? + # Are there any user-facing changes? If there are user-facing changes then we may require documentation to be updated before approving the PR. From 121371ca59af249e4eae404abe4d2281276daa2a Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 24 Jun 2025 20:08:30 +0200 Subject: [PATCH 4/6] feat: [Variant] Add Validation for Variant Deciaml (#7738) # Which issue does this PR close? - Closes #7697 # Rationale for this change # What changes are included in this PR? - Introduced new types: VariantDecimal4, VariantDecimal8, and VariantDecimal16 - These types encapsulate decimal values and ensure proper validation and wrapping # Are there any user-facing changes? --- parquet-variant/src/builder.rs | 14 +- parquet-variant/src/variant.rs | 237 ++++++++++++++++++----- parquet-variant/tests/variant_interop.rs | 21 +- 3 files changed, 209 insertions(+), 63 deletions(-) diff --git a/parquet-variant/src/builder.rs b/parquet-variant/src/builder.rs index a5fb66a84ff4..1c6ebe23d24f 100644 --- a/parquet-variant/src/builder.rs +++ b/parquet-variant/src/builder.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. use crate::decoder::{VariantBasicType, VariantPrimitiveType}; -use crate::{ShortString, Variant}; +use crate::{ShortString, Variant, VariantDecimal16, VariantDecimal4, VariantDecimal8}; use std::collections::BTreeMap; const BASIC_TYPE_BITS: u8 = 2; @@ -384,9 +384,15 @@ impl VariantBuilder { Variant::Date(v) => self.append_date(v), Variant::TimestampMicros(v) => self.append_timestamp_micros(v), Variant::TimestampNtzMicros(v) => self.append_timestamp_ntz_micros(v), - Variant::Decimal4 { integer, scale } => self.append_decimal4(integer, scale), - Variant::Decimal8 { integer, scale } => self.append_decimal8(integer, scale), - Variant::Decimal16 { integer, scale } => self.append_decimal16(integer, scale), + Variant::Decimal4(VariantDecimal4 { integer, scale }) => { + self.append_decimal4(integer, scale) + } + Variant::Decimal8(VariantDecimal8 { integer, scale }) => { + self.append_decimal8(integer, scale) + } + Variant::Decimal16(VariantDecimal16 { integer, scale }) => { + self.append_decimal16(integer, scale) + } Variant::Float(v) => self.append_float(v), Variant::Double(v) => self.append_double(v), Variant::Binary(v) => self.append_binary(v), diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 51327b4d2528..b343a538d54c 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -40,8 +40,100 @@ const MAX_SHORT_STRING_BYTES: usize = 0x3F; #[derive(Debug, Clone, Copy, PartialEq)] pub struct ShortString<'a>(pub(crate) &'a str); +/// Represents a 4-byte decimal value in the Variant format. +/// +/// This struct stores a decimal number using a 32-bit signed integer for the coefficient +/// and an 8-bit unsigned integer for the scale (number of decimal places). Its precision is limited to 9 digits. +/// +/// For valid precision and scale values, see the Variant specification: +/// +/// +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct VariantDecimal4 { + pub(crate) integer: i32, + pub(crate) scale: u8, +} + +impl VariantDecimal4 { + pub fn try_new(integer: i32, scale: u8) -> Result { + const PRECISION_MAX: u32 = 9; + + // Validate that scale doesn't exceed precision + if scale as u32 > PRECISION_MAX { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {} cannot be greater than precision 9 for 4-byte decimal", + scale + ))); + } + + Ok(VariantDecimal4 { integer, scale }) + } +} + +/// Represents an 8-byte decimal value in the Variant format. +/// +/// This struct stores a decimal number using a 64-bit signed integer for the coefficient +/// and an 8-bit unsigned integer for the scale (number of decimal places). Its precision is between 10 and 18 digits. +/// +/// For valid precision and scale values, see the Variant specification: +/// +/// +/// +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct VariantDecimal8 { + pub(crate) integer: i64, + pub(crate) scale: u8, +} + +impl VariantDecimal8 { + pub fn try_new(integer: i64, scale: u8) -> Result { + const PRECISION_MAX: u32 = 18; + + // Validate that scale doesn't exceed precision + if scale as u32 > PRECISION_MAX { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {} cannot be greater than precision 18 for 8-byte decimal", + scale + ))); + } + + Ok(VariantDecimal8 { integer, scale }) + } +} + +/// Represents an 16-byte decimal value in the Variant format. +/// +/// This struct stores a decimal number using a 128-bit signed integer for the coefficient +/// and an 8-bit unsigned integer for the scale (number of decimal places). Its precision is between 19 and 38 digits. +/// +/// For valid precision and scale values, see the Variant specification: +/// +/// +/// +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct VariantDecimal16 { + pub(crate) integer: i128, + pub(crate) scale: u8, +} + +impl VariantDecimal16 { + pub fn try_new(integer: i128, scale: u8) -> Result { + const PRECISION_MAX: u32 = 38; + + // Validate that scale doesn't exceed precision + if scale as u32 > PRECISION_MAX { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {} cannot be greater than precision 38 for 16-byte decimal", + scale + ))); + } + + Ok(VariantDecimal16 { integer, scale }) + } +} + impl<'a> ShortString<'a> { - /// Attempts to interpret `value` as a variant short string value. + /// Attempts to interpret `value` as a variant short string value. /// /// # Validation /// @@ -194,11 +286,11 @@ pub enum Variant<'m, 'v> { /// Primitive (type_id=1): TIMESTAMP(isAdjustedToUTC=false, MICROS) TimestampNtzMicros(NaiveDateTime), /// Primitive (type_id=1): DECIMAL(precision, scale) 32-bits - Decimal4 { integer: i32, scale: u8 }, + Decimal4(VariantDecimal4), /// Primitive (type_id=1): DECIMAL(precision, scale) 64-bits - Decimal8 { integer: i64, scale: u8 }, + Decimal8(VariantDecimal8), /// Primitive (type_id=1): DECIMAL(precision, scale) 128-bits - Decimal16 { integer: i128, scale: u8 }, + Decimal16(VariantDecimal16), /// Primitive (type_id=1): FLOAT Float(f32), /// Primitive (type_id=1): DOUBLE @@ -269,15 +361,15 @@ impl<'m, 'v> Variant<'m, 'v> { VariantPrimitiveType::Int64 => Variant::Int64(decoder::decode_int64(value_data)?), VariantPrimitiveType::Decimal4 => { let (integer, scale) = decoder::decode_decimal4(value_data)?; - Variant::Decimal4 { integer, scale } + Variant::Decimal4(VariantDecimal4 { integer, scale }) } VariantPrimitiveType::Decimal8 => { let (integer, scale) = decoder::decode_decimal8(value_data)?; - Variant::Decimal8 { integer, scale } + Variant::Decimal8(VariantDecimal8 { integer, scale }) } VariantPrimitiveType::Decimal16 => { let (integer, scale) = decoder::decode_decimal16(value_data)?; - Variant::Decimal16 { integer, scale } + Variant::Decimal16(VariantDecimal16 { integer, scale }) } VariantPrimitiveType::Float => Variant::Float(decoder::decode_float(value_data)?), VariantPrimitiveType::Double => { @@ -640,18 +732,18 @@ impl<'m, 'v> Variant<'m, 'v> { /// # Examples /// /// ``` - /// use parquet_variant::Variant; + /// use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8}; /// /// // you can extract decimal parts from smaller or equally-sized decimal variants - /// let v1 = Variant::from((1234_i32, 2)); + /// let v1 = Variant::from(VariantDecimal4::try_new(1234_i32, 2).unwrap()); /// assert_eq!(v1.as_decimal_int32(), Some((1234_i32, 2))); /// /// // and from larger decimal variants if they fit - /// let v2 = Variant::from((1234_i64, 2)); + /// let v2 = Variant::from(VariantDecimal8::try_new(1234_i64, 2).unwrap()); /// assert_eq!(v2.as_decimal_int32(), Some((1234_i32, 2))); /// /// // but not if the value would overflow i32 - /// let v3 = Variant::from((12345678901i64, 2)); + /// let v3 = Variant::from(VariantDecimal8::try_new(12345678901i64, 2).unwrap()); /// assert_eq!(v3.as_decimal_int32(), None); /// /// // or if the variant is not a decimal @@ -660,17 +752,17 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_decimal_int32(&self) -> Option<(i32, u8)> { match *self { - Variant::Decimal4 { integer, scale } => Some((integer, scale)), - Variant::Decimal8 { integer, scale } => { - if let Ok(converted_integer) = integer.try_into() { - Some((converted_integer, scale)) + Variant::Decimal4(decimal4) => Some((decimal4.integer, decimal4.scale)), + Variant::Decimal8(decimal8) => { + if let Ok(converted_integer) = decimal8.integer.try_into() { + Some((converted_integer, decimal8.scale)) } else { None } } - Variant::Decimal16 { integer, scale } => { - if let Ok(converted_integer) = integer.try_into() { - Some((converted_integer, scale)) + Variant::Decimal16(decimal16) => { + if let Ok(converted_integer) = decimal16.integer.try_into() { + Some((converted_integer, decimal16.scale)) } else { None } @@ -688,18 +780,18 @@ impl<'m, 'v> Variant<'m, 'v> { /// # Examples /// /// ``` - /// use parquet_variant::Variant; + /// use parquet_variant::{Variant, VariantDecimal8, VariantDecimal16}; /// /// // you can extract decimal parts from smaller or equally-sized decimal variants - /// let v1 = Variant::from((1234_i64, 2)); + /// let v1 = Variant::from(VariantDecimal8::try_new(1234_i64, 2).unwrap()); /// assert_eq!(v1.as_decimal_int64(), Some((1234_i64, 2))); /// /// // and from larger decimal variants if they fit - /// let v2 = Variant::from((1234_i128, 2)); + /// let v2 = Variant::from(VariantDecimal16::try_new(1234_i128, 2).unwrap()); /// assert_eq!(v2.as_decimal_int64(), Some((1234_i64, 2))); /// /// // but not if the value would overflow i64 - /// let v3 = Variant::from((2e19 as i128, 2)); + /// let v3 = Variant::from(VariantDecimal16::try_new(2e19 as i128, 2).unwrap()); /// assert_eq!(v3.as_decimal_int64(), None); /// /// // or if the variant is not a decimal @@ -708,11 +800,11 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_decimal_int64(&self) -> Option<(i64, u8)> { match *self { - Variant::Decimal4 { integer, scale } => Some((integer.into(), scale)), - Variant::Decimal8 { integer, scale } => Some((integer, scale)), - Variant::Decimal16 { integer, scale } => { - if let Ok(converted_integer) = integer.try_into() { - Some((converted_integer, scale)) + Variant::Decimal4(decimal) => Some((decimal.integer.into(), decimal.scale)), + Variant::Decimal8(decimal) => Some((decimal.integer, decimal.scale)), + Variant::Decimal16(decimal) => { + if let Ok(converted_integer) = decimal.integer.try_into() { + Some((converted_integer, decimal.scale)) } else { None } @@ -730,10 +822,10 @@ impl<'m, 'v> Variant<'m, 'v> { /// # Examples /// /// ``` - /// use parquet_variant::Variant; + /// use parquet_variant::{Variant, VariantDecimal16}; /// /// // you can extract decimal parts from smaller or equally-sized decimal variants - /// let v1 = Variant::from((1234_i128, 2)); + /// let v1 = Variant::from(VariantDecimal16::try_new(1234_i128, 2).unwrap()); /// assert_eq!(v1.as_decimal_int128(), Some((1234_i128, 2))); /// /// // but not if the variant is not a decimal @@ -742,9 +834,9 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_decimal_int128(&self) -> Option<(i128, u8)> { match *self { - Variant::Decimal4 { integer, scale } => Some((integer.into(), scale)), - Variant::Decimal8 { integer, scale } => Some((integer.into(), scale)), - Variant::Decimal16 { integer, scale } => Some((integer, scale)), + Variant::Decimal4(decimal) => Some((decimal.integer.into(), decimal.scale)), + Variant::Decimal8(decimal) => Some((decimal.integer.into(), decimal.scale)), + Variant::Decimal16(decimal) => Some((decimal.integer, decimal.scale)), _ => None, } } @@ -912,30 +1004,21 @@ impl From for Variant<'_, '_> { } } -impl From<(i32, u8)> for Variant<'_, '_> { - fn from(value: (i32, u8)) -> Self { - Variant::Decimal4 { - integer: value.0, - scale: value.1, - } +impl From for Variant<'_, '_> { + fn from(value: VariantDecimal4) -> Self { + Variant::Decimal4(value) } } -impl From<(i64, u8)> for Variant<'_, '_> { - fn from(value: (i64, u8)) -> Self { - Variant::Decimal8 { - integer: value.0, - scale: value.1, - } +impl From for Variant<'_, '_> { + fn from(value: VariantDecimal8) -> Self { + Variant::Decimal8(value) } } -impl From<(i128, u8)> for Variant<'_, '_> { - fn from(value: (i128, u8)) -> Self { - Variant::Decimal16 { - integer: value.0, - scale: value.1, - } +impl From for Variant<'_, '_> { + fn from(value: VariantDecimal16) -> Self { + Variant::Decimal16(value) } } @@ -994,6 +1077,36 @@ impl<'v> From<&'v str> for Variant<'_, 'v> { } } +impl TryFrom<(i32, u8)> for Variant<'_, '_> { + type Error = ArrowError; + + fn try_from(value: (i32, u8)) -> Result { + Ok(Variant::Decimal4(VariantDecimal4::try_new( + value.0, value.1, + )?)) + } +} + +impl TryFrom<(i64, u8)> for Variant<'_, '_> { + type Error = ArrowError; + + fn try_from(value: (i64, u8)) -> Result { + Ok(Variant::Decimal8(VariantDecimal8::try_new( + value.0, value.1, + )?)) + } +} + +impl TryFrom<(i128, u8)> for Variant<'_, '_> { + type Error = ArrowError; + + fn try_from(value: (i128, u8)) -> Result { + Ok(Variant::Decimal16(VariantDecimal16::try_new( + value.0, value.1, + )?)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1007,4 +1120,28 @@ mod tests { let res = ShortString::try_new(&long_string); assert!(res.is_err()); } + + #[test] + fn test_variant_decimal_conversion() { + let decimal4 = VariantDecimal4::try_new(1234_i32, 2).unwrap(); + let variant = Variant::from(decimal4); + assert_eq!(variant.as_decimal_int32(), Some((1234_i32, 2))); + + let decimal8 = VariantDecimal8::try_new(12345678901_i64, 2).unwrap(); + let variant = Variant::from(decimal8); + assert_eq!(variant.as_decimal_int64(), Some((12345678901_i64, 2))); + + let decimal16 = VariantDecimal16::try_new(123456789012345678901234567890_i128, 2).unwrap(); + let variant = Variant::from(decimal16); + assert_eq!( + variant.as_decimal_int128(), + Some((123456789012345678901234567890_i128, 2)) + ); + } + + #[test] + fn test_invalid_variant_decimal_conversion() { + let decimal4 = VariantDecimal4::try_new(123456789_i32, 20); + assert!(decimal4.is_err(), "i32 overflow should fail"); + } } diff --git a/parquet-variant/tests/variant_interop.rs b/parquet-variant/tests/variant_interop.rs index bfa2ab267c27..be63357422e4 100644 --- a/parquet-variant/tests/variant_interop.rs +++ b/parquet-variant/tests/variant_interop.rs @@ -24,7 +24,9 @@ use std::fs; use std::path::{Path, PathBuf}; use chrono::NaiveDate; -use parquet_variant::{ShortString, Variant, VariantBuilder}; +use parquet_variant::{ + ShortString, Variant, VariantBuilder, VariantDecimal16, VariantDecimal4, VariantDecimal8, +}; fn cases_dir() -> PathBuf { Path::new(env!("CARGO_MANIFEST_DIR")) @@ -63,9 +65,10 @@ fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> { ("primitive_boolean_false", Variant::BooleanFalse), ("primitive_boolean_true", Variant::BooleanTrue), ("primitive_date", Variant::Date(NaiveDate::from_ymd_opt(2025, 4 , 16).unwrap())), - ("primitive_decimal4", Variant::Decimal4{integer: 1234, scale: 2}), - ("primitive_decimal8", Variant::Decimal8{integer: 1234567890, scale: 2}), - ("primitive_decimal16", Variant::Decimal16{integer: 1234567891234567890, scale: 2}), + ("primitive_decimal4", Variant::from(VariantDecimal4::try_new(1234i32, 2u8).unwrap())), + // ("primitive_decimal8", Variant::Decimal8{integer: 1234567890, scale: 2}), + ("primitive_decimal8", Variant::Decimal8(VariantDecimal8::try_new(1234567890,2).unwrap())), + ("primitive_decimal16", Variant::Decimal16(VariantDecimal16::try_new(1234567891234567890, 2).unwrap())), ("primitive_float", Variant::Float(1234567890.1234)), ("primitive_double", Variant::Double(1234567890.1234)), ("primitive_int8", Variant::Int8(42)), @@ -123,10 +126,7 @@ fn variant_object_primitive() { // spark wrote this as a decimal4 (not a double) ( "double_field", - Variant::Decimal4 { - integer: 123456789, - scale: 8, - }, + Variant::Decimal4(VariantDecimal4::try_new(123456789, 8).unwrap()), ), ("int_field", Variant::Int8(1)), ("null_field", Variant::Null), @@ -210,7 +210,10 @@ fn variant_object_builder() { // The double field is actually encoded as decimal4 with scale 8 // Value: 123456789, Scale: 8 -> 1.23456789 - obj.append_value("double_field", (123456789i32, 8u8)); + obj.append_value( + "double_field", + VariantDecimal4::try_new(123456789i32, 8u8).unwrap(), + ); obj.append_value("boolean_true_field", true); obj.append_value("boolean_false_field", false); obj.append_value("string_field", "Apache Parquet"); From fe94db2fcf5e9f8ef0fbc00d84e383368af7e081 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 24 Jun 2025 14:41:43 -0400 Subject: [PATCH 5/6] Fix logical conflict with #7738 --- parquet-variant/src/to_json.rs | 69 ++++++++++++---------------------- 1 file changed, 23 insertions(+), 46 deletions(-) diff --git a/parquet-variant/src/to_json.rs b/parquet-variant/src/to_json.rs index 0cdcb8b49e63..82a677206a13 100644 --- a/parquet-variant/src/to_json.rs +++ b/parquet-variant/src/to_json.rs @@ -23,6 +23,7 @@ use serde_json::Value; use std::io::Write; use crate::variant::{Variant, VariantList, VariantObject}; +use crate::{VariantDecimal16, VariantDecimal4, VariantDecimal8}; // Format string constants to avoid duplication and reduce errors const DATE_FORMAT: &str = "%Y-%m-%d"; @@ -106,7 +107,7 @@ pub fn variant_to_json(json_buffer: &mut impl Write, variant: &Variant) -> Resul Variant::Double(f) => { write!(json_buffer, "{}", f)?; } - Variant::Decimal4 { integer, scale } => { + Variant::Decimal4(VariantDecimal4 { integer, scale }) => { // Convert decimal to string representation using integer arithmetic if *scale == 0 { write!(json_buffer, "{}", integer)?; @@ -123,7 +124,7 @@ pub fn variant_to_json(json_buffer: &mut impl Write, variant: &Variant) -> Resul } } } - Variant::Decimal8 { integer, scale } => { + Variant::Decimal8(VariantDecimal8 { integer, scale }) => { // Convert decimal to string representation using integer arithmetic if *scale == 0 { write!(json_buffer, "{}", integer)?; @@ -140,7 +141,7 @@ pub fn variant_to_json(json_buffer: &mut impl Write, variant: &Variant) -> Resul } } } - Variant::Decimal16 { integer, scale } => { + Variant::Decimal16(VariantDecimal16 { integer, scale }) => { // Convert decimal to string representation using integer arithmetic if *scale == 0 { write!(json_buffer, "{}", integer)?; @@ -364,7 +365,7 @@ pub fn variant_to_json_value(variant: &Variant) -> Result { Variant::Double(f) => serde_json::Number::from_f64(*f) .map(Value::Number) .ok_or_else(|| ArrowError::InvalidArgumentError("Invalid double value".to_string())), - Variant::Decimal4 { integer, scale } => { + Variant::Decimal4(VariantDecimal4 { integer, scale }) => { // Use integer arithmetic to avoid f64 precision loss if *scale == 0 { Ok(Value::Number((*integer).into())) @@ -390,7 +391,7 @@ pub fn variant_to_json_value(variant: &Variant) -> Result { }) } } - Variant::Decimal8 { integer, scale } => { + Variant::Decimal8(VariantDecimal8 { integer, scale }) => { // Use integer arithmetic to avoid f64 precision loss if *scale == 0 { Ok(Value::Number((*integer).into())) @@ -416,7 +417,7 @@ pub fn variant_to_json_value(variant: &Variant) -> Result { }) } } - Variant::Decimal16 { integer, scale } => { + Variant::Decimal16(VariantDecimal16 { integer, scale }) => { // Use integer arithmetic to avoid f64 precision loss if *scale == 0 { Ok(Value::Number((*integer as i64).into())) // Convert to i64 for JSON compatibility @@ -482,18 +483,12 @@ mod tests { #[test] fn test_decimal_edge_cases() -> Result<(), ArrowError> { // Test negative decimal - let negative_variant = Variant::Decimal4 { - integer: -12345, - scale: 3, - }; + let negative_variant = Variant::from(VariantDecimal4::try_new(-12345, 3)?); let negative_json = variant_to_json_string(&negative_variant)?; assert_eq!(negative_json, "-12.345"); // Test large scale decimal - let large_scale_variant = Variant::Decimal8 { - integer: 123456789, - scale: 6, - }; + let large_scale_variant = Variant::from(VariantDecimal8::try_new(123456789, 6)?); let large_scale_json = variant_to_json_string(&large_scale_variant)?; assert_eq!(large_scale_json, "123.456789"); @@ -502,10 +497,7 @@ mod tests { #[test] fn test_decimal16_to_json() -> Result<(), ArrowError> { - let variant = Variant::Decimal16 { - integer: 123456789012345, - scale: 4, - }; + let variant = Variant::from(VariantDecimal16::try_new(123456789012345, 4)?); let json = variant_to_json_string(&variant)?; assert_eq!(json, "12345678901.2345"); @@ -513,10 +505,7 @@ mod tests { assert!(matches!(json_value, Value::Number(_))); // Test very large number - let large_variant = Variant::Decimal16 { - integer: 999999999999999999, - scale: 2, - }; + let large_variant = Variant::from(VariantDecimal16::try_new(999999999999999999, 2)?); let large_json = variant_to_json_string(&large_variant)?; // Due to f64 precision limits, very large numbers may lose precision assert!( @@ -839,10 +828,7 @@ mod tests { // Decimals JsonTest { - variant: Variant::Decimal4 { - integer: 12345, - scale: 2, - }, + variant: Variant::from(VariantDecimal4::try_new(12345, 2).unwrap()), expected_json: "123.45", expected_value: serde_json::Number::from_f64(123.45) .map(Value::Number) @@ -851,10 +837,7 @@ mod tests { .run(); JsonTest { - variant: Variant::Decimal4 { - integer: 42, - scale: 0, - }, + variant: Variant::from(VariantDecimal4::try_new(42, 0).unwrap()), expected_json: "42", expected_value: serde_json::Number::from_f64(42.0) .map(Value::Number) @@ -863,10 +846,7 @@ mod tests { .run(); JsonTest { - variant: Variant::Decimal8 { - integer: 1234567890, - scale: 3, - }, + variant: Variant::from(VariantDecimal8::try_new(1234567890, 3).unwrap()), expected_json: "1234567.89", expected_value: serde_json::Number::from_f64(1234567.89) .map(Value::Number) @@ -875,10 +855,7 @@ mod tests { .run(); JsonTest { - variant: Variant::Decimal16 { - integer: 123456789012345, - scale: 4, - }, + variant: Variant::from(VariantDecimal16::try_new(123456789012345, 4).unwrap()), expected_json: "12345678901.2345", expected_value: serde_json::Number::from_f64(12345678901.2345) .map(Value::Number) @@ -1277,10 +1254,10 @@ mod tests { fn test_high_precision_decimal_no_loss() -> Result<(), ArrowError> { // Test case that would lose precision with f64 conversion // This is a 63-bit precision decimal8 value that f64 cannot represent exactly - let high_precision_decimal8 = Variant::Decimal8 { - integer: 9007199254740993, // 2^53 + 1, exceeds f64 precision - scale: 6, - }; + let high_precision_decimal8 = Variant::from(VariantDecimal8::try_new( + 9007199254740993, // 2^53 + 1, exceeds f64 precision + 6, + )?); let json_string = variant_to_json_string(&high_precision_decimal8)?; let json_value = variant_to_json_value(&high_precision_decimal8)?; @@ -1294,10 +1271,10 @@ mod tests { assert_eq!(parsed, json_value); // Test another case with trailing zeros that should be trimmed - let decimal_with_zeros = Variant::Decimal8 { - integer: 1234567890000, // Should result in 1234567.89 (trailing zeros trimmed) - scale: 6, - }; + let decimal_with_zeros = Variant::from(VariantDecimal8::try_new( + 1234567890000, // Should result in 1234567.89 (trailing zeros trimmed) + 6, + )?); let json_string_zeros = variant_to_json_string(&decimal_with_zeros)?; assert_eq!(json_string_zeros, "1234567.89"); From 4b18f7f75a8142f7c20d2b9719d68059efbe4661 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 24 Jun 2025 14:45:28 -0400 Subject: [PATCH 6/6] Less explicit panics --- parquet-variant/src/to_json.rs | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/parquet-variant/src/to_json.rs b/parquet-variant/src/to_json.rs index 82a677206a13..80759b80a5c8 100644 --- a/parquet-variant/src/to_json.rs +++ b/parquet-variant/src/to_json.rs @@ -980,9 +980,7 @@ mod tests { // Parse the JSON to verify structure - handle JSON parsing errors manually let parsed: Value = serde_json::from_str(&json) .map_err(|e| ArrowError::ParseError(format!("JSON parse error: {}", e)))?; - let Value::Object(obj) = parsed else { - panic!("Expected JSON object"); - }; + let obj = parsed.as_object().expect("expected JSON object"); assert_eq!(obj.get("name"), Some(&Value::String("Alice".to_string()))); assert_eq!(obj.get("age"), Some(&Value::Number(30.into()))); assert_eq!(obj.get("active"), Some(&Value::Bool(true))); @@ -1071,9 +1069,7 @@ mod tests { assert_eq!(json, "[1,2,3,4,5]"); let json_value = variant_to_json_value(&variant)?; - let Value::Array(arr) = json_value else { - panic!("Expected JSON array"); - }; + let arr = json_value.as_array().expect("expected JSON array"); assert_eq!(arr.len(), 5); assert_eq!(arr[0], Value::Number(1.into())); assert_eq!(arr[4], Value::Number(5.into())); @@ -1125,9 +1121,7 @@ mod tests { let parsed: Value = serde_json::from_str(&json) .map_err(|e| ArrowError::ParseError(format!("JSON parse error: {}", e)))?; - let Value::Array(arr) = parsed else { - panic!("Expected JSON array"); - }; + let arr = parsed.as_array().expect("expected JSON array"); assert_eq!(arr.len(), 5); assert_eq!(arr[0], Value::String("hello".to_string())); assert_eq!(arr[1], Value::Number(42.into())); @@ -1160,9 +1154,7 @@ mod tests { // Parse and verify all fields are present let parsed: Value = serde_json::from_str(&json) .map_err(|e| ArrowError::ParseError(format!("JSON parse error: {}", e)))?; - let Value::Object(obj) = parsed else { - panic!("Expected JSON object"); - }; + let obj = parsed.as_object().expect("expected JSON object"); assert_eq!(obj.len(), 3); assert_eq!(obj.get("alpha"), Some(&Value::String("first".to_string()))); assert_eq!(obj.get("beta"), Some(&Value::String("second".to_string()))); @@ -1195,9 +1187,7 @@ mod tests { let parsed: Value = serde_json::from_str(&json) .map_err(|e| ArrowError::ParseError(format!("JSON parse error: {}", e)))?; - let Value::Array(arr) = parsed else { - panic!("Expected JSON array"); - }; + let arr = parsed.as_array().expect("expected JSON array"); assert_eq!(arr.len(), 7); assert_eq!(arr[0], Value::String("string_value".to_string())); assert_eq!(arr[1], Value::Number(42.into())); @@ -1233,9 +1223,7 @@ mod tests { let parsed: Value = serde_json::from_str(&json) .map_err(|e| ArrowError::ParseError(format!("JSON parse error: {}", e)))?; - let Value::Object(obj) = parsed else { - panic!("Expected JSON object"); - }; + let obj = parsed.as_object().expect("expected JSON object"); assert_eq!(obj.len(), 6); assert_eq!( obj.get("string_field"),