diff --git a/kernel/src/engine/arrow_expression/tests.rs b/kernel/src/engine/arrow_expression/tests.rs index 0c97723d2..5391dc4c0 100644 --- a/kernel/src/engine/arrow_expression/tests.rs +++ b/kernel/src/engine/arrow_expression/tests.rs @@ -828,7 +828,7 @@ fn test_create_one_mismatching_scalar_types() { let handler = ArrowEvaluationHandler; assert_result_error_with_message( handler.create_one(schema, values), - "Schema error: Mismatched scalar type while creating Expression: expected Integer, got Long", + "Schema error: Mismatched scalar type while creating Expression: expected Primitive(Integer), got Primitive(Long)", ); } diff --git a/kernel/src/expressions/literal_expression_transform.rs b/kernel/src/expressions/literal_expression_transform.rs index 1c8a483e4..cb3792074 100644 --- a/kernel/src/expressions/literal_expression_transform.rs +++ b/kernel/src/expressions/literal_expression_transform.rs @@ -1,28 +1,9 @@ //! The [`LiteralExpressionTransform`] is a [`SchemaTransform`] that transforms a [`Schema`] and an //! ordered list of leaf values (scalars) into an [`Expression`] with a literal value for each leaf. -use std::borrow::Cow; -use std::ops::Deref as _; - -use tracing::debug; - use crate::expressions::{Expression, Scalar}; -use crate::schema::{ - ArrayType, DataType, MapType, PrimitiveType, SchemaTransform, StructField, StructType, -}; - -/// [`SchemaTransform`] that will transform a [`Schema`] and an ordered list of leaf values -/// (Scalars) into an Expression with a [`Literal`] expr for each leaf. -#[derive(Debug)] -pub(crate) struct LiteralExpressionTransform<'a, T: Iterator> { - /// Leaf values to insert in schema order. - scalars: T, - /// A stack of built Expressions. After visiting children, we pop them off to - /// build the parent container, then push the parent back on. - stack: Vec, - /// Since schema transforms are infallible we keep track of errors here - error: Result<(), Error>, -} +use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, StructField, StructType}; +use crate::DeltaResult; /// Any error for [`LiteralExpressionTransform`] #[derive(thiserror::Error, Debug)] @@ -48,97 +29,71 @@ pub enum Error { Unsupported(String), } +#[derive(Debug, Default)] +pub(crate) struct LiteralExpressionTransform<'a, T: Iterator> { + /// Leaf values to insert in schema order. + scalars: T, +} + impl<'a, I: Iterator> LiteralExpressionTransform<'a, I> { pub(crate) fn new(scalars: impl IntoIterator) -> Self { Self { scalars: scalars.into_iter(), - stack: Vec::new(), - error: Ok(()), } } - /// return the Expression we just built (or propagate Error). the top of `stack` should be our - /// final Expression - pub(crate) fn try_into_expr(mut self) -> Result { - self.error?; + /// Bind the visitor to a StructType and produce an Expression + pub(crate) fn bind(&mut self, struct_type: &StructType) -> DeltaResult { + use crate::schema::visitor::visit_struct; + let result = visit_struct(struct_type, self)?; - if let Some(s) = self.scalars.next() { - return Err(Error::ExcessScalars(s.clone())); + // Check for excess scalars after visiting + if let Some(scalar) = self.scalars.next() { + return Err(Error::ExcessScalars(scalar.clone()).into()); } - self.stack.pop().ok_or(Error::EmptyStack) - } - - fn set_error(&mut self, error: Error) { - // Only set when the error not yet set - if let Err(ref existing_error) = self.error { - debug!("Trying to overwrite an existing error: {existing_error:?} with {error:?}"); - } else { - self.error = Err(error); - } + Ok(result) } -} -// All leaf types (primitive, array, map) share the same "shape" of transformation logic -macro_rules! transform_leaf { - ($self:ident, $type_variant:path, $type:ident) => {{ - // first always check error to terminate early if possible - $self.error.as_ref().ok()?; - - let Some(scalar) = $self.scalars.next() else { - $self.set_error(Error::InsufficientScalars); - return None; + fn visit_leaf(&mut self, schema_type: &DataType) -> DeltaResult { + let Some(scalar) = self.scalars.next() else { + return Err(Error::InsufficientScalars.into()); }; - // NOTE: Grab a reference here so code below can leverage the blanket impl Deref for &T - let $type_variant(ref scalar_type) = scalar.data_type() else { - $self.set_error(Error::Schema(format!( - "Mismatched scalar type while creating Expression: expected {}({:?}), got {:?}", - stringify!($type_variant), - $type, + if schema_type.clone() != scalar.data_type() { + return Err(Error::Schema(format!( + "Mismatched scalar type while creating Expression: expected {:?}, got {:?}", + schema_type, scalar.data_type() - ))); - return None; + )) + .into()); }; - // NOTE: &T and &Box both deref to &T - if scalar_type.deref() != $type { - $self.set_error(Error::Schema(format!( - "Mismatched scalar type while creating Expression: expected {:?}, got {:?}", - $type, scalar_type - ))); - return None; - } - - $self.stack.push(Expression::Literal(scalar.clone())); - None - }}; + Ok(Expression::Literal(scalar.clone())) + } } -impl<'a, T: Iterator> SchemaTransform<'a> for LiteralExpressionTransform<'a, T> { - fn transform_primitive( - &mut self, - prim_type: &'a PrimitiveType, - ) -> Option> { - transform_leaf!(self, DataType::Primitive, prim_type) +impl<'a, I: Iterator> delta_kernel::schema::visitor::SchemaVisitor + for LiteralExpressionTransform<'a, I> +{ + type T = Expression; + + fn field(&mut self, field: &StructField, value: Self::T) -> DeltaResult { + match &field.data_type { + DataType::Struct(_) => Ok(value), + DataType::Primitive(_) => self.visit_leaf(&field.data_type), + DataType::Array(_) => self.visit_leaf(&field.data_type), + DataType::Map(_) => self.visit_leaf(&field.data_type), + DataType::Variant(_) => self.visit_leaf(&field.data_type), + } } - fn transform_struct(&mut self, struct_type: &'a StructType) -> Option> { - // first always check error to terminate early if possible - self.error.as_ref().ok()?; - - // Only consume newly-added entries (if any). There could be fewer than expected if - // the recursion encountered an error. - let mark = self.stack.len(); - self.recurse_into_struct(struct_type)?; - let field_exprs = self.stack.split_off(mark); - + fn r#struct( + &mut self, + struct_type: &StructType, + field_exprs: Vec, + ) -> DeltaResult { let fields = struct_type.fields(); - if field_exprs.len() != fields.len() { - self.set_error(Error::InsufficientScalars); - return None; - } - let mut found_non_nullable_null = false; let mut all_null = true; for (field, expr) in fields.zip(&field_exprs) { @@ -154,36 +109,93 @@ impl<'a, T: Iterator> SchemaTransform<'a> for LiteralExpressi let struct_expr = if found_non_nullable_null { if !all_null { // we found a non_nullable NULL, but other siblings are non-null: error - self.set_error(Error::Schema( + return Err(Error::Schema( "NULL value for non-nullable struct field with non-NULL siblings".to_string(), - )); - return None; + ) + .into()); } Expression::null_literal(struct_type.clone().into()) } else { Expression::struct_from(field_exprs) }; - self.stack.push(struct_expr); - None + Ok(struct_expr) + } + + fn list(&mut self, _list: &ArrayType, _value: Self::T) -> DeltaResult { + // Everything is handled on the field level + Ok(Expression::Unknown("Should not happen".to_string())) } - fn transform_struct_field(&mut self, field: &'a StructField) -> Option> { - // first always check error to terminate early if possible - self.error.as_ref().ok()?; + fn map( + &mut self, + _map: &MapType, + _key_value: Self::T, + _value: Self::T, + ) -> DeltaResult { + // Everything is handled on the field level + Ok(Expression::Unknown("Should not happen".to_string())) + } - self.recurse_into_struct_field(field); - Some(Cow::Borrowed(field)) + fn primitive(&mut self, _p: &PrimitiveType) -> DeltaResult { + // Everything is handled on the field level + Ok(Expression::Unknown("Should not happen".to_string())) } - // arrays treated as leaves - fn transform_array(&mut self, array_type: &'a ArrayType) -> Option> { - transform_leaf!(self, DataType::Array, array_type) + fn variant(&mut self, _struct: &StructType) -> DeltaResult { + // Everything is handled on the field level + Ok(Expression::Unknown("Should not happen".to_string())) } +} - // maps treated as leaves - fn transform_map(&mut self, map_type: &'a MapType) -> Option> { - transform_leaf!(self, DataType::Map, map_type) +struct StructFieldIterator { + // Stack of (struct_type, current_position, deferred_struct_field) + // deferred_struct_field is the struct field that contains this struct_type + // and should be returned after all children are processed + stack: Vec<(StructType, usize, Option)>, +} + +impl StructFieldIterator { + fn new(root: StructType) -> Self { + StructFieldIterator { + stack: vec![(root, 0, None)], + } + } +} + +impl Iterator for StructFieldIterator { + type Item = StructField; + + fn next(&mut self) -> Option { + while !self.stack.is_empty() { + let (struct_type, current_pos, deferred_field) = self.stack.last().unwrap().clone(); + + if current_pos < struct_type.fields_len() { + // Get the current field and increment position + let field = struct_type.by_index(current_pos).clone(); + self.stack.last_mut().unwrap().1 += 1; + + // Check if this field is a nested struct + if let DataType::Struct(nested_struct) = &field.data_type { + // Push the nested struct onto the stack with the current field as deferred + self.stack.push((nested_struct.as_ref().clone(), 0, Some(field))); + // Continue to process the nested struct first + } else { + // Non-struct field: return it immediately + return Some(field); + } + } else { + // Current level exhausted - pop it and return any deferred field + self.stack.pop(); + if let Some(deferred) = deferred_field { + // This was a struct field that we deferred - return it now + return Some(deferred); + } + // No deferred field, continue with parent level + } + } + + None } } @@ -208,18 +220,15 @@ mod tests { schema: SchemaRef, expected: Result, ) { - let mut schema_transform = LiteralExpressionTransform::new(values); - let datatype = schema.into(); - let _transformed = schema_transform.transform(&datatype); + let actual = LiteralExpressionTransform::new(values).bind(&schema); match expected { Ok(expected_expr) => { - let actual_expr = schema_transform.try_into_expr().unwrap(); // TODO: we can't compare NULLs so we convert with .to_string to workaround - // see: https://github.com/delta-io/delta-kernel-rs/pull/677 - assert_eq!(expected_expr.to_string(), actual_expr.to_string()); + // see: https://github.com/delta-io/delta-kernel-rs/pull/1267 + assert_eq!(expected_expr.to_string(), actual.unwrap().to_string()); } Err(()) => { - assert!(schema_transform.try_into_expr().is_err()); + assert!(actual.is_err()); } } } @@ -545,4 +554,253 @@ mod tests { (N, N) -> Null, } } + + #[cfg(test)] + mod struct_field_iterator_tests { + use super::*; + use crate::DataType as DeltaDataTypes; + + #[test] + fn test_simple_flat_struct() { + let schema = StructType::new([ + StructField::nullable("field1", DeltaDataTypes::INTEGER), + StructField::not_null("field2", DeltaDataTypes::STRING), + StructField::nullable("field3", DeltaDataTypes::BOOLEAN), + ]); + + let iterator = StructFieldIterator::new(schema); + let fields: Vec = iterator.collect(); + + assert_eq!(fields.len(), 3); + assert_eq!(fields[0].name(), "field1"); + assert_eq!(fields[0].data_type(), &DeltaDataTypes::INTEGER); + assert_eq!(fields[1].name(), "field2"); + assert_eq!(fields[1].data_type(), &DeltaDataTypes::STRING); + assert_eq!(fields[2].name(), "field3"); + assert_eq!(fields[2].data_type(), &DeltaDataTypes::BOOLEAN); + } + + #[test] + fn test_nested_struct() { + let inner_struct = StructType::new([ + StructField::nullable("inner1", DeltaDataTypes::INTEGER), + StructField::nullable("inner2", DeltaDataTypes::STRING), + ]); + + let schema = StructType::new([ + StructField::nullable("outer1", DeltaDataTypes::INTEGER), + StructField::nullable("nested", inner_struct.clone()), + StructField::nullable("outer2", DeltaDataTypes::STRING), + ]); + + let iterator = StructFieldIterator::new(schema); + let fields: Vec = iterator.collect(); + + // Post-order traversal: outer1, inner1, inner2, nested, outer2 + assert_eq!(fields.len(), 5); + assert_eq!(fields[0].name(), "outer1"); + assert_eq!(fields[0].data_type(), &DeltaDataTypes::INTEGER); + assert_eq!(fields[1].name(), "inner1"); + assert_eq!(fields[1].data_type(), &DeltaDataTypes::INTEGER); + assert_eq!(fields[2].name(), "inner2"); + assert_eq!(fields[2].data_type(), &DeltaDataTypes::STRING); + assert_eq!(fields[3].name(), "nested"); + assert!(matches!(fields[3].data_type(), DataType::Struct(_))); + assert_eq!(fields[4].name(), "outer2"); + assert_eq!(fields[4].data_type(), &DeltaDataTypes::STRING); + } + + #[test] + fn test_deeply_nested_structs() { + let level2 = StructType::new([StructField::nullable( + "level2_field", + DeltaDataTypes::INTEGER, + )]); + + let level1 = StructType::new([ + StructField::nullable("level1_field", DeltaDataTypes::STRING), + StructField::nullable("level2", level2), + ]); + + let schema = StructType::new([ + StructField::nullable("root_field", DeltaDataTypes::BOOLEAN), + StructField::nullable("level1", level1), + ]); + + let iterator = StructFieldIterator::new(schema); + let fields: Vec = iterator.collect(); + + // Post-order traversal: root_field, level1_field, level2_field, level2, level1 + assert_eq!(fields.len(), 5); + assert_eq!(fields[0].name(), "root_field"); + assert_eq!(fields[1].name(), "level1_field"); + assert_eq!(fields[2].name(), "level2_field"); + assert_eq!(fields[3].name(), "level2"); + assert_eq!(fields[4].name(), "level1"); + } + + #[test] + fn test_struct_with_array_types() { + let array_type = ArrayType::new(DeltaDataTypes::INTEGER, true); + let schema = StructType::new([ + StructField::nullable("before_array", DeltaDataTypes::STRING), + StructField::nullable( + "array_field", + DeltaDataTypes::Array(Box::new(array_type.clone())), + ), + StructField::nullable("after_array", DeltaDataTypes::BOOLEAN), + ]); + + let iterator = StructFieldIterator::new(schema); + let fields: Vec = iterator.collect(); + + assert_eq!(fields.len(), 3); + assert_eq!(fields[0].name(), "before_array"); + assert_eq!(fields[0].data_type(), &DeltaDataTypes::STRING); + assert_eq!(fields[1].name(), "array_field"); + assert!(matches!(fields[1].data_type(), DataType::Array(_))); + assert_eq!(fields[2].name(), "after_array"); + assert_eq!(fields[2].data_type(), &DeltaDataTypes::BOOLEAN); + } + + #[test] + fn test_struct_with_map_types() { + let map_type = MapType::new(DeltaDataTypes::STRING, DeltaDataTypes::INTEGER, false); + let schema = StructType::new([ + StructField::nullable("before_map", DeltaDataTypes::STRING), + StructField::nullable("map_field", DeltaDataTypes::Map(Box::new(map_type.clone()))), + StructField::nullable("after_map", DeltaDataTypes::BOOLEAN), + ]); + + let iterator = StructFieldIterator::new(schema); + let fields: Vec = iterator.collect(); + + assert_eq!(fields.len(), 3); + assert_eq!(fields[0].name(), "before_map"); + assert_eq!(fields[0].data_type(), &DeltaDataTypes::STRING); + assert_eq!(fields[1].name(), "map_field"); + assert!(matches!(fields[1].data_type(), DataType::Map(_))); + assert_eq!(fields[2].name(), "after_map"); + assert_eq!(fields[2].data_type(), &DeltaDataTypes::BOOLEAN); + } + + #[test] + fn test_complex_mixed_structure() { + // Create a complex structure with nested structs, arrays, and maps + let array_type = ArrayType::new(DeltaDataTypes::STRING, true); + let map_type = MapType::new(DeltaDataTypes::STRING, DeltaDataTypes::INTEGER, false); + + let inner_struct = StructType::new([ + StructField::nullable("inner_primitive", DeltaDataTypes::DOUBLE), + StructField::nullable( + "inner_array", + DeltaDataTypes::Array(Box::new(array_type.clone())), + ), + ]); + + let schema = StructType::new([ + StructField::nullable("root_string", DeltaDataTypes::STRING), + StructField::nullable("root_map", DeltaDataTypes::Map(Box::new(map_type.clone()))), + StructField::nullable("nested_struct", inner_struct), + StructField::nullable("root_array", DeltaDataTypes::Array(Box::new(array_type))), + StructField::nullable("root_int", DeltaDataTypes::INTEGER), + ]); + + let iterator = StructFieldIterator::new(schema); + let fields: Vec = iterator.collect(); + + // Post-order traversal: root_string, root_map, inner_primitive, inner_array, nested_struct, root_array, root_int + assert_eq!(fields.len(), 7); + + let field_names: Vec<&str> = fields.iter().map(|f| f.name().as_str()).collect(); + assert_eq!( + field_names, + vec![ + "root_string", + "root_map", + "inner_primitive", + "inner_array", + "nested_struct", + "root_array", + "root_int" + ] + ); + + // Verify data types + assert_eq!(fields[0].data_type(), &DeltaDataTypes::STRING); + assert!(matches!(fields[1].data_type(), DataType::Map(_))); + assert_eq!(fields[2].data_type(), &DeltaDataTypes::DOUBLE); + assert!(matches!(fields[3].data_type(), DataType::Array(_))); + assert!(matches!(fields[4].data_type(), DataType::Struct(_))); + assert!(matches!(fields[5].data_type(), DataType::Array(_))); + assert_eq!(fields[6].data_type(), &DeltaDataTypes::INTEGER); + } + + #[test] + fn test_empty_struct() { + let schema = StructType::new([]); + let iterator = StructFieldIterator::new(schema); + let fields: Vec = iterator.collect(); + assert_eq!(fields.len(), 0); + } + + #[test] + fn test_struct_with_nested_empty_struct() { + let empty_struct = StructType::new([]); + let schema = StructType::new([ + StructField::nullable("before_empty", DeltaDataTypes::STRING), + StructField::nullable("empty_nested", empty_struct), + StructField::nullable("after_empty", DeltaDataTypes::INTEGER), + ]); + + let iterator = StructFieldIterator::new(schema); + let fields: Vec = iterator.collect(); + + // Should get: before_empty, empty_nested, after_empty + assert_eq!(fields.len(), 3); + assert_eq!(fields[0].name(), "before_empty"); + assert_eq!(fields[1].name(), "empty_nested"); + assert_eq!(fields[2].name(), "after_empty"); + } + + #[test] + fn test_multiple_sibling_nested_structs() { + let struct1 = StructType::new([StructField::nullable( + "struct1_field", + DeltaDataTypes::STRING, + )]); + + let struct2 = StructType::new([ + StructField::nullable("struct2_field1", DeltaDataTypes::INTEGER), + StructField::nullable("struct2_field2", DeltaDataTypes::BOOLEAN), + ]); + + let schema = StructType::new([ + StructField::nullable("root_field", DeltaDataTypes::DOUBLE), + StructField::nullable("first_nested", struct1), + StructField::nullable("second_nested", struct2), + StructField::nullable("final_field", DeltaDataTypes::STRING), + ]); + + let iterator = StructFieldIterator::new(schema); + let fields: Vec = iterator.collect(); + + // Post-order traversal: root_field, struct1_field, first_nested, struct2_field1, struct2_field2, second_nested, final_field + assert_eq!(fields.len(), 7); + + let field_names: Vec<&str> = fields.iter().map(|f| f.name().as_str()).collect(); + assert_eq!( + field_names, + vec![ + "root_field", + "struct1_field", + "first_nested", + "struct2_field1", + "struct2_field2", + "second_nested", + "final_field" + ] + ); + } + } } diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index cfba0c764..8a226906e 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -147,7 +147,7 @@ pub use snapshot::Snapshot; use expressions::literal_expression_transform::LiteralExpressionTransform; use expressions::Scalar; -use schema::{SchemaTransform, StructField, StructType}; +use schema::{StructField, StructType}; #[cfg(any( feature = "default-engine-native-tls", @@ -458,12 +458,12 @@ trait EvaluationHandlerExtension: EvaluationHandler { let null_row = self.null_row(null_row_schema.clone())?; // Convert schema and leaf values to an expression - let mut schema_transform = LiteralExpressionTransform::new(values); - schema_transform.transform_struct(schema.as_ref()); - let row_expr = schema_transform.try_into_expr()?; - - let eval = self.new_expression_evaluator(null_row_schema, row_expr.into(), schema.into()); - eval.evaluate(null_row.as_ref()) + LiteralExpressionTransform::new(values) + .bind(schema.as_ref()) + .and_then(|row_expr| { + self.new_expression_evaluator(null_row_schema, row_expr.into(), schema.into()) + .evaluate(null_row.as_ref()) + }) } } diff --git a/kernel/src/schema/mod.rs b/kernel/src/schema/mod.rs index 9f42d6eee..de7ab5718 100644 --- a/kernel/src/schema/mod.rs +++ b/kernel/src/schema/mod.rs @@ -3,6 +3,7 @@ use std::borrow::Cow; use std::collections::HashMap; use std::fmt::{Display, Formatter}; +use std::ops::Index; use std::sync::Arc; use indexmap::IndexMap; @@ -24,6 +25,7 @@ pub mod derive_macro_utils; #[cfg(not(feature = "internal-api"))] pub(crate) mod derive_macro_utils; pub(crate) mod variant_utils; +pub(crate) mod visitor; pub type Schema = StructType; pub type SchemaRef = Arc; @@ -412,6 +414,10 @@ impl StructType { self.fields.get_index_of(name.as_ref()) } + pub fn by_index(&self, pos: usize) -> &StructField { + self.fields.index(pos) + } + pub fn fields(&self) -> impl ExactSizeIterator { self.fields.values() } diff --git a/kernel/src/schema/visitor.rs b/kernel/src/schema/visitor.rs new file mode 100644 index 000000000..dd9cb7f33 --- /dev/null +++ b/kernel/src/schema/visitor.rs @@ -0,0 +1,59 @@ +use crate::schema::ArrayType; +use delta_kernel::schema::{DataType, MapType, PrimitiveType, StructField, StructType}; +use delta_kernel::DeltaResult; + +/// A post order schema visitor. +/// +/// For order of methods called, please refer to [`visit_schema`]. +pub(crate) trait SchemaVisitor { + /// Return type of this visitor. + type T; + + /// Called after struct's field type visited. + fn field(&mut self, field: &StructField, value: Self::T) -> DeltaResult; + /// Called after struct's fields visited. + fn r#struct(&mut self, r#struct: &StructType, results: Vec) -> DeltaResult; + /// Called after list fields visited. + fn list(&mut self, list: &ArrayType, value: Self::T) -> DeltaResult; + /// Called after map's key and value fields visited. + fn map(&mut self, map: &MapType, key_value: Self::T, value: Self::T) -> DeltaResult; + /// Called when see a primitive type. + fn primitive(&mut self, p: &PrimitiveType) -> DeltaResult; + /// Called when see a primitive type. + fn variant(&mut self, r#struct: &StructType) -> DeltaResult; +} + +/// Visiting a type in post order. +#[allow(dead_code)] // Reserved for future use +pub(crate) fn visit_type( + r#type: &DataType, + visitor: &mut V, +) -> DeltaResult { + match r#type { + DataType::Primitive(p) => visitor.primitive(p), + DataType::Array(list) => { + let value = visit_type(&list.element_type, visitor)?; + visitor.list(list, value) + } + DataType::Map(map) => { + let key_result = visit_type(&map.key_type, visitor)?; + let value_result = visit_type(&map.value_type, visitor)?; + + visitor.map(map, key_result, value_result) + } + DataType::Struct(s) => visit_struct(s, visitor), + DataType::Variant(v) => visitor.variant(v), + } +} + +/// Visit struct type in post order. +pub(crate) fn visit_struct(s: &StructType, visitor: &mut V) -> DeltaResult { + let mut results = Vec::with_capacity(s.fields().len()); + for field in s.fields() { + let result = visit_type(&field.data_type, visitor)?; + let result = visitor.field(field, result)?; + results.push(result); + } + + visitor.r#struct(s, results) +}