Skip to content

Commit eb61574

Browse files
authored
refactor: Refactor (group-)rule to be lazily evaluated (#64)
1 parent d82dcf0 commit eb61574

File tree

5 files changed

+42
-80
lines changed

5 files changed

+42
-80
lines changed

dataframely/_base_schema.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
import polars as pl
1313

14-
from ._rule import GroupRule, Rule, with_evaluation_rules
14+
from ._rule import GroupRule, Rule
1515
from .columns import Column
16-
from .exc import ImplementationError, RuleImplementationError
16+
from .exc import ImplementationError
1717

1818
_COLUMN_ATTR = "__dataframely_columns__"
1919
_RULE_ATTR = "__dataframely_rules__"
@@ -112,23 +112,6 @@ def __new__(
112112
f"which are not in the schema: {missing_list}."
113113
)
114114

115-
# 3) Assuming that non-custom rules are implemented correctly, we check that all
116-
# custom rules are _also_ implemented correctly by evaluating rules on an
117-
# empty data frame and checking for the evaluated dtypes.
118-
if len(result.rules) > 0:
119-
lf_empty = pl.LazyFrame(
120-
schema={col_name: col.dtype for col_name, col in result.columns.items()}
121-
)
122-
# NOTE: For some reason, `polars` does not yield correct dtypes when calling
123-
# `collect_schema()`
124-
schema = with_evaluation_rules(lf_empty, result.rules).collect().schema
125-
for rule_name, rule in result.rules.items():
126-
dtype = schema[rule_name]
127-
if not isinstance(dtype, pl.Boolean):
128-
raise RuleImplementationError(
129-
rule_name, dtype, isinstance(rule, GroupRule)
130-
)
131-
132115
return super().__new__(mcs, name, bases, namespace, *args, **kwargs)
133116

134117
def __getattribute__(cls, name: str) -> Any:

dataframely/_rule.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@
1515
class Rule:
1616
"""Internal class representing validation rules."""
1717

18-
def __init__(self, expr: pl.Expr) -> None:
19-
self.expr = expr
18+
def __init__(self, expr: pl.Expr | ValidationFunction) -> None:
19+
self._expr = expr
20+
21+
@property
22+
def expr(self) -> pl.Expr:
23+
"""Get the expression of the rule."""
24+
if callable(self._expr):
25+
return self._expr()
26+
return self._expr
2027

2128
def matches(self, other: Rule) -> bool:
2229
"""Check whether this rule semantically matches another rule.
@@ -49,7 +56,9 @@ def __repr__(self) -> str:
4956
class GroupRule(Rule):
5057
"""Rule that is evaluated on a group of columns."""
5158

52-
def __init__(self, expr: pl.Expr, group_columns: list[str]) -> None:
59+
def __init__(
60+
self, expr: pl.Expr | ValidationFunction, group_columns: list[str]
61+
) -> None:
5362
super().__init__(expr)
5463
self.group_columns = group_columns
5564

@@ -101,8 +110,8 @@ def rule(*, group_by: list[str] | None = None) -> Callable[[ValidationFunction],
101110

102111
def decorator(validation_fn: ValidationFunction) -> Rule:
103112
if group_by is not None:
104-
return GroupRule(expr=validation_fn(), group_columns=group_by)
105-
return Rule(expr=validation_fn())
113+
return GroupRule(expr=validation_fn, group_columns=group_by)
114+
return Rule(expr=validation_fn)
106115

107116
return decorator
108117

dataframely/exc.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
from collections import defaultdict
55

6-
import polars as pl
7-
86
from ._polars import PolarsDataType
97

108

@@ -108,27 +106,3 @@ def __init__(self, attr: str, kls: type) -> None:
108106
"`from __future__ import annotations` in the file that defines the collection."
109107
)
110108
super().__init__(message)
111-
112-
113-
class RuleImplementationError(ImplementationError):
114-
"""Error raised when a rule is implemented incorrectly."""
115-
116-
def __init__(
117-
self, name: str, return_dtype: pl.DataType, is_group_rule: bool
118-
) -> None:
119-
if is_group_rule:
120-
details = (
121-
" When implementing a group rule (i.e. when using the `group_by` "
122-
"parameter), make sure to use an aggregation function such as `.any()`, "
123-
"`.all()`, and others to reduce an expression evaluated on multiple "
124-
"rows in the same group to a single boolean value for the group."
125-
)
126-
else:
127-
details = ""
128-
129-
message = (
130-
f"Validation rule '{name}' has not been implemented correctly. It "
131-
f"returns dtype '{return_dtype}' but it must return a boolean value."
132-
+ details
133-
)
134-
super().__init__(message)

tests/schema/test_rule_implementation.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import dataframely as dy
88
from dataframely._rule import GroupRule, Rule
9-
from dataframely.exc import ImplementationError, RuleImplementationError
9+
from dataframely.exc import ImplementationError
1010
from dataframely.testing import create_schema
1111

1212

@@ -29,32 +29,6 @@ def test_group_rule_group_by_error() -> None:
2929
)
3030

3131

32-
def test_rule_implementation_error() -> None:
33-
with pytest.raises(
34-
RuleImplementationError, match=r"rule 'integer_rule'.*returns dtype 'Int64'"
35-
):
36-
create_schema(
37-
"test",
38-
columns={"a": dy.Integer()},
39-
rules={"integer_rule": Rule(pl.col("a") + 1)},
40-
)
41-
42-
43-
def test_group_rule_implementation_error() -> None:
44-
with pytest.raises(
45-
RuleImplementationError,
46-
match=(
47-
r"rule 'b_greater_zero'.*returns dtype 'List\(Boolean\)'.*"
48-
r"make sure to use an aggregation function"
49-
),
50-
):
51-
create_schema(
52-
"test",
53-
columns={"a": dy.Integer(), "b": dy.Integer()},
54-
rules={"b_greater_zero": GroupRule(pl.col("b") > 0, group_columns=["a"])},
55-
)
56-
57-
5832
def test_rule_column_overlap_error() -> None:
5933
with pytest.raises(
6034
ImplementationError,

tests/schema/test_validate.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ def b_unique_within_a() -> pl.Expr:
2828
return pl.col("b").n_unique() == 1
2929

3030

31+
class MyComplexSchemaWithLazyRules(dy.Schema):
32+
a = dy.Int64()
33+
b = dy.Int64()
34+
35+
@dy.rule()
36+
def b_greater_a() -> pl.Expr:
37+
return MyComplexSchemaWithLazyRules.b.col > MyComplexSchemaWithLazyRules.a.col
38+
39+
@dy.rule(group_by=["a"])
40+
def b_unique_within_a() -> pl.Expr:
41+
return (
42+
MyComplexSchemaWithLazyRules.b.col.n_unique() == SOME_CONSTANT_DEFINED_LATER
43+
)
44+
45+
46+
SOME_CONSTANT_DEFINED_LATER = 1
47+
48+
3149
# -------------------------------------- COLUMNS ------------------------------------- #
3250

3351

@@ -119,9 +137,13 @@ def test_success_multi_row_strip_cast(
119137

120138

121139
@pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
122-
def test_group_rule_on_nulls(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None:
140+
@pytest.mark.parametrize("schema", [MyComplexSchema, MyComplexSchemaWithLazyRules])
141+
def test_group_rule_on_nulls(
142+
df_type: type[pl.DataFrame] | type[pl.LazyFrame],
143+
schema: type[MyComplexSchema] | type[MyComplexSchemaWithLazyRules],
144+
) -> None:
123145
# The schema is violated because we have multiple "b" values for the same "a" value
124146
df = df_type({"a": [None, None], "b": [1, 2]})
125147
with pytest.raises(RuleValidationError):
126-
MyComplexSchema.validate(df, cast=True)
127-
assert not MyComplexSchema.is_valid(df, cast=True)
148+
schema.validate(df, cast=True)
149+
assert not schema.is_valid(df, cast=True)

0 commit comments

Comments
 (0)