Skip to content

Commit 1fb0a59

Browse files
committed
chore: Move literal expression transform to a visitor
After #1207 I thought it would be a nice touch to move the traversal of the schema to a visitor, which creates a very nice separation between flow and the actual logic.
1 parent 023930c commit 1fb0a59

File tree

5 files changed

+139
-121
lines changed

5 files changed

+139
-121
lines changed

kernel/src/engine/arrow_expression/tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ fn test_create_one_mismatching_scalar_types() {
828828
let handler = ArrowEvaluationHandler;
829829
assert_result_error_with_message(
830830
handler.create_one(schema, values),
831-
"Schema error: Mismatched scalar type while creating Expression: expected Integer, got Long",
831+
"Schema error: Mismatched scalar type while creating Expression: expected Primitive(Integer), got Primitive(Long)",
832832
);
833833
}
834834

kernel/src/expressions/literal_expression_transform.rs

Lines changed: 71 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,9 @@
11
//! The [`LiteralExpressionTransform`] is a [`SchemaTransform`] that transforms a [`Schema`] and an
22
//! ordered list of leaf values (scalars) into an [`Expression`] with a literal value for each leaf.
33
4-
use std::borrow::Cow;
5-
use std::ops::Deref as _;
6-
7-
use tracing::debug;
8-
94
use crate::expressions::{Expression, Scalar};
10-
use crate::schema::{
11-
ArrayType, DataType, MapType, PrimitiveType, SchemaTransform, StructField, StructType,
12-
};
13-
14-
/// [`SchemaTransform`] that will transform a [`Schema`] and an ordered list of leaf values
15-
/// (Scalars) into an Expression with a [`Literal`] expr for each leaf.
16-
#[derive(Debug)]
17-
pub(crate) struct LiteralExpressionTransform<'a, T: Iterator<Item = &'a Scalar>> {
18-
/// Leaf values to insert in schema order.
19-
scalars: T,
20-
/// A stack of built Expressions. After visiting children, we pop them off to
21-
/// build the parent container, then push the parent back on.
22-
stack: Vec<Expression>,
23-
/// Since schema transforms are infallible we keep track of errors here
24-
error: Result<(), Error>,
25-
}
5+
use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, StructField, StructType};
6+
use crate::DeltaResult;
267

278
/// Any error for [`LiteralExpressionTransform`]
289
#[derive(thiserror::Error, Debug)]
@@ -48,97 +29,71 @@ pub enum Error {
4829
Unsupported(String),
4930
}
5031

32+
#[derive(Debug, Default)]
33+
pub(crate) struct LiteralExpressionTransform<'a, T: Iterator<Item = &'a Scalar>> {
34+
/// Leaf values to insert in schema order.
35+
scalars: T,
36+
}
37+
5138
impl<'a, I: Iterator<Item = &'a Scalar>> LiteralExpressionTransform<'a, I> {
5239
pub(crate) fn new(scalars: impl IntoIterator<IntoIter = I>) -> Self {
5340
Self {
5441
scalars: scalars.into_iter(),
55-
stack: Vec::new(),
56-
error: Ok(()),
5742
}
5843
}
5944

60-
/// return the Expression we just built (or propagate Error). the top of `stack` should be our
61-
/// final Expression
62-
pub(crate) fn try_into_expr(mut self) -> Result<Expression, Error> {
63-
self.error?;
45+
/// Bind the visitor to a StructType and produce an Expression
46+
pub(crate) fn bind(&mut self, struct_type: &StructType) -> DeltaResult<Expression> {
47+
use crate::schema::visitor::visit_struct;
48+
let result = visit_struct(struct_type, self)?;
6449

65-
if let Some(s) = self.scalars.next() {
66-
return Err(Error::ExcessScalars(s.clone()));
50+
// Check for excess scalars after visiting
51+
if let Some(scalar) = self.scalars.next() {
52+
return Err(Error::ExcessScalars(scalar.clone()).into());
6753
}
6854

69-
self.stack.pop().ok_or(Error::EmptyStack)
70-
}
71-
72-
fn set_error(&mut self, error: Error) {
73-
// Only set when the error not yet set
74-
if let Err(ref existing_error) = self.error {
75-
debug!("Trying to overwrite an existing error: {existing_error:?} with {error:?}");
76-
} else {
77-
self.error = Err(error);
78-
}
55+
Ok(result)
7956
}
80-
}
81-
82-
// All leaf types (primitive, array, map) share the same "shape" of transformation logic
83-
macro_rules! transform_leaf {
84-
($self:ident, $type_variant:path, $type:ident) => {{
85-
// first always check error to terminate early if possible
86-
$self.error.as_ref().ok()?;
8757

88-
let Some(scalar) = $self.scalars.next() else {
89-
$self.set_error(Error::InsufficientScalars);
90-
return None;
58+
fn visit_leaf(&mut self, schema_type: &DataType) -> DeltaResult<Expression> {
59+
let Some(scalar) = self.scalars.next() else {
60+
return Err(Error::InsufficientScalars.into());
9161
};
9262

93-
// NOTE: Grab a reference here so code below can leverage the blanket impl<T> Deref for &T
94-
let $type_variant(ref scalar_type) = scalar.data_type() else {
95-
$self.set_error(Error::Schema(format!(
96-
"Mismatched scalar type while creating Expression: expected {}({:?}), got {:?}",
97-
stringify!($type_variant),
98-
$type,
63+
if schema_type.clone() != scalar.data_type() {
64+
return Err(Error::Schema(format!(
65+
"Mismatched scalar type while creating Expression: expected {:?}, got {:?}",
66+
schema_type,
9967
scalar.data_type()
100-
)));
101-
return None;
68+
))
69+
.into());
10270
};
10371

104-
// NOTE: &T and &Box<T> both deref to &T
105-
if scalar_type.deref() != $type {
106-
$self.set_error(Error::Schema(format!(
107-
"Mismatched scalar type while creating Expression: expected {:?}, got {:?}",
108-
$type, scalar_type
109-
)));
110-
return None;
111-
}
112-
113-
$self.stack.push(Expression::Literal(scalar.clone()));
114-
None
115-
}};
72+
Ok(Expression::Literal(scalar.clone()))
73+
}
11674
}
11775

118-
impl<'a, T: Iterator<Item = &'a Scalar>> SchemaTransform<'a> for LiteralExpressionTransform<'a, T> {
119-
fn transform_primitive(
120-
&mut self,
121-
prim_type: &'a PrimitiveType,
122-
) -> Option<Cow<'a, PrimitiveType>> {
123-
transform_leaf!(self, DataType::Primitive, prim_type)
76+
impl<'a, I: Iterator<Item = &'a Scalar>> delta_kernel::schema::visitor::SchemaVisitor
77+
for LiteralExpressionTransform<'a, I>
78+
{
79+
type T = Expression;
80+
81+
fn field(&mut self, field: &StructField, value: Self::T) -> DeltaResult<Self::T> {
82+
match &field.data_type {
83+
DataType::Struct(_) => Ok(value),
84+
DataType::Primitive(_) => self.visit_leaf(&field.data_type),
85+
DataType::Array(_) => self.visit_leaf(&field.data_type),
86+
DataType::Map(_) => self.visit_leaf(&field.data_type),
87+
DataType::Variant(_) => self.visit_leaf(&field.data_type),
88+
}
12489
}
12590

126-
fn transform_struct(&mut self, struct_type: &'a StructType) -> Option<Cow<'a, StructType>> {
127-
// first always check error to terminate early if possible
128-
self.error.as_ref().ok()?;
129-
130-
// Only consume newly-added entries (if any). There could be fewer than expected if
131-
// the recursion encountered an error.
132-
let mark = self.stack.len();
133-
self.recurse_into_struct(struct_type)?;
134-
let field_exprs = self.stack.split_off(mark);
135-
91+
fn r#struct(
92+
&mut self,
93+
struct_type: &StructType,
94+
field_exprs: Vec<Self::T>,
95+
) -> DeltaResult<Self::T> {
13696
let fields = struct_type.fields();
137-
if field_exprs.len() != fields.len() {
138-
self.set_error(Error::InsufficientScalars);
139-
return None;
140-
}
141-
14297
let mut found_non_nullable_null = false;
14398
let mut all_null = true;
14499
for (field, expr) in fields.zip(&field_exprs) {
@@ -154,36 +109,42 @@ impl<'a, T: Iterator<Item = &'a Scalar>> SchemaTransform<'a> for LiteralExpressi
154109
let struct_expr = if found_non_nullable_null {
155110
if !all_null {
156111
// we found a non_nullable NULL, but other siblings are non-null: error
157-
self.set_error(Error::Schema(
112+
return Err(Error::Schema(
158113
"NULL value for non-nullable struct field with non-NULL siblings".to_string(),
159-
));
160-
return None;
114+
)
115+
.into());
161116
}
162117
Expression::null_literal(struct_type.clone().into())
163118
} else {
164119
Expression::struct_from(field_exprs)
165120
};
166121

167-
self.stack.push(struct_expr);
168-
None
122+
Ok(struct_expr)
169123
}
170124

171-
fn transform_struct_field(&mut self, field: &'a StructField) -> Option<Cow<'a, StructField>> {
172-
// first always check error to terminate early if possible
173-
self.error.as_ref().ok()?;
125+
fn list(&mut self, _list: &ArrayType, _value: Self::T) -> DeltaResult<Self::T> {
126+
// Everything is handled on the field level
127+
Ok(Expression::Unknown("Should not happen".to_string()))
128+
}
174129

175-
self.recurse_into_struct_field(field);
176-
Some(Cow::Borrowed(field))
130+
fn map(
131+
&mut self,
132+
_map: &MapType,
133+
_key_value: Self::T,
134+
_value: Self::T,
135+
) -> DeltaResult<Self::T> {
136+
// Everything is handled on the field level
137+
Ok(Expression::Unknown("Should not happen".to_string()))
177138
}
178139

179-
// arrays treated as leaves
180-
fn transform_array(&mut self, array_type: &'a ArrayType) -> Option<Cow<'a, ArrayType>> {
181-
transform_leaf!(self, DataType::Array, array_type)
140+
fn primitive(&mut self, _p: &PrimitiveType) -> DeltaResult<Self::T> {
141+
// Everything is handled on the field level
142+
Ok(Expression::Unknown("Should not happen".to_string()))
182143
}
183144

184-
// maps treated as leaves
185-
fn transform_map(&mut self, map_type: &'a MapType) -> Option<Cow<'a, MapType>> {
186-
transform_leaf!(self, DataType::Map, map_type)
145+
fn variant(&mut self, _struct: &StructType) -> DeltaResult<Self::T> {
146+
// Everything is handled on the field level
147+
Ok(Expression::Unknown("Should not happen".to_string()))
187148
}
188149
}
189150

@@ -208,18 +169,15 @@ mod tests {
208169
schema: SchemaRef,
209170
expected: Result<Expr, ()>,
210171
) {
211-
let mut schema_transform = LiteralExpressionTransform::new(values);
212-
let datatype = schema.into();
213-
let _transformed = schema_transform.transform(&datatype);
172+
let actual = LiteralExpressionTransform::new(values).bind(&schema);
214173
match expected {
215174
Ok(expected_expr) => {
216-
let actual_expr = schema_transform.try_into_expr().unwrap();
217175
// TODO: we can't compare NULLs so we convert with .to_string to workaround
218-
// see: https://github.com/delta-io/delta-kernel-rs/pull/677
219-
assert_eq!(expected_expr.to_string(), actual_expr.to_string());
176+
// see: https://github.com/delta-io/delta-kernel-rs/pull/1267
177+
assert_eq!(expected_expr.to_string(), actual.unwrap().to_string());
220178
}
221179
Err(()) => {
222-
assert!(schema_transform.try_into_expr().is_err());
180+
assert!(actual.is_err());
223181
}
224182
}
225183
}

kernel/src/lib.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ pub use snapshot::Snapshot;
147147

148148
use expressions::literal_expression_transform::LiteralExpressionTransform;
149149
use expressions::Scalar;
150-
use schema::{SchemaTransform, StructField, StructType};
150+
use schema::{StructField, StructType};
151151

152152
#[cfg(any(
153153
feature = "default-engine-native-tls",
@@ -458,12 +458,12 @@ trait EvaluationHandlerExtension: EvaluationHandler {
458458
let null_row = self.null_row(null_row_schema.clone())?;
459459

460460
// Convert schema and leaf values to an expression
461-
let mut schema_transform = LiteralExpressionTransform::new(values);
462-
schema_transform.transform_struct(schema.as_ref());
463-
let row_expr = schema_transform.try_into_expr()?;
464-
465-
let eval = self.new_expression_evaluator(null_row_schema, row_expr.into(), schema.into());
466-
eval.evaluate(null_row.as_ref())
461+
LiteralExpressionTransform::new(values)
462+
.bind(schema.as_ref())
463+
.and_then(|row_expr| {
464+
self.new_expression_evaluator(null_row_schema, row_expr.into(), schema.into())
465+
.evaluate(null_row.as_ref())
466+
})
467467
}
468468
}
469469

kernel/src/schema/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub mod derive_macro_utils;
2424
#[cfg(not(feature = "internal-api"))]
2525
pub(crate) mod derive_macro_utils;
2626
pub(crate) mod variant_utils;
27+
pub(crate) mod visitor;
2728

2829
pub type Schema = StructType;
2930
pub type SchemaRef = Arc<StructType>;

kernel/src/schema/visitor.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use crate::schema::ArrayType;
2+
use delta_kernel::schema::{DataType, MapType, PrimitiveType, StructField, StructType};
3+
use delta_kernel::DeltaResult;
4+
5+
/// A post order schema visitor.
6+
///
7+
/// For order of methods called, please refer to [`visit_schema`].
8+
pub(crate) trait SchemaVisitor {
9+
/// Return type of this visitor.
10+
type T;
11+
12+
/// Called after struct's field type visited.
13+
fn field(&mut self, field: &StructField, value: Self::T) -> DeltaResult<Self::T>;
14+
/// Called after struct's fields visited.
15+
fn r#struct(&mut self, r#struct: &StructType, results: Vec<Self::T>) -> DeltaResult<Self::T>;
16+
/// Called after list fields visited.
17+
fn list(&mut self, list: &ArrayType, value: Self::T) -> DeltaResult<Self::T>;
18+
/// Called after map's key and value fields visited.
19+
fn map(&mut self, map: &MapType, key_value: Self::T, value: Self::T) -> DeltaResult<Self::T>;
20+
/// Called when see a primitive type.
21+
fn primitive(&mut self, p: &PrimitiveType) -> DeltaResult<Self::T>;
22+
/// Called when see a primitive type.
23+
fn variant(&mut self, r#struct: &StructType) -> DeltaResult<Self::T>;
24+
}
25+
26+
/// Visiting a type in post order.
27+
#[allow(dead_code)] // Reserved for future use
28+
pub(crate) fn visit_type<V: SchemaVisitor>(
29+
r#type: &DataType,
30+
visitor: &mut V,
31+
) -> DeltaResult<V::T> {
32+
match r#type {
33+
DataType::Primitive(p) => visitor.primitive(p),
34+
DataType::Array(list) => {
35+
let value = visit_type(&list.element_type, visitor)?;
36+
visitor.list(list, value)
37+
}
38+
DataType::Map(map) => {
39+
let key_result = visit_type(&map.key_type, visitor)?;
40+
let value_result = visit_type(&map.value_type, visitor)?;
41+
42+
visitor.map(map, key_result, value_result)
43+
}
44+
DataType::Struct(s) => visit_struct(s, visitor),
45+
DataType::Variant(v) => visitor.variant(v),
46+
}
47+
}
48+
49+
/// Visit struct type in post order.
50+
pub(crate) fn visit_struct<V: SchemaVisitor>(s: &StructType, visitor: &mut V) -> DeltaResult<V::T> {
51+
let mut results = Vec::with_capacity(s.fields().len());
52+
for field in s.fields() {
53+
let result = visit_type(&field.data_type, visitor)?;
54+
let result = visitor.field(field, result)?;
55+
results.push(result);
56+
}
57+
58+
visitor.r#struct(s, results)
59+
}

0 commit comments

Comments
 (0)