Skip to content

Commit 08972c6

Browse files
heshamdarborchero
andauthored
feat: Allow for multiple column level checks with custom naming (#19)
Co-authored-by: Oliver Borchert <[email protected]>
1 parent 4f0367a commit 08972c6

File tree

13 files changed

+355
-46
lines changed

13 files changed

+355
-46
lines changed

dataframely/columns/_base.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
from abc import ABC, abstractmethod
7+
from collections import Counter
78
from collections.abc import Callable
89
from typing import Any
910

@@ -31,7 +32,12 @@ def __init__(
3132
*,
3233
nullable: bool | None = None,
3334
primary_key: bool = False,
34-
check: Callable[[pl.Expr], pl.Expr] | None = None,
35+
check: (
36+
Callable[[pl.Expr], pl.Expr]
37+
| list[Callable[[pl.Expr], pl.Expr]]
38+
| dict[str, Callable[[pl.Expr], pl.Expr]]
39+
| None
40+
) = None,
3541
alias: str | None = None,
3642
metadata: dict[str, Any] | None = None,
3743
):
@@ -43,8 +49,17 @@ def __init__(
4349
is not specified.
4450
primary_key: Whether this column is part of the primary key of the schema.
4551
If ``True``, ``nullable`` is automatically set to ``False``.
46-
check: A custom check to run for this column. Must return a non-aggregated
47-
boolean expression.
52+
check: A custom rule or multiple rules to run for this column. This can be:
53+
- A single callable that returns a non-aggregated boolean expression.
54+
The name of the rule is derived from the callable name, or defaults to
55+
"check" for lambdas.
56+
- A list of callables, where each callable returns a non-aggregated
57+
boolean expression. The name of the rule is derived from the callable
58+
name, or defaults to "check" for lambdas. Where multiple rules result
59+
in the same name, the suffix __i is appended to the name.
60+
- A dictionary mapping rule names to callables, where each callable
61+
returns a non-aggregated boolean expression.
62+
All rule names provided here are given the prefix "check_".
4863
alias: An overwrite for this column's name which allows for using a column
4964
name that is not a valid Python identifier. Especially note that setting
5065
this option does _not_ allow to refer to the column with two different
@@ -104,10 +119,59 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
104119
result = {}
105120
if not self.nullable:
106121
result["nullability"] = expr.is_not_null()
122+
107123
if self.check is not None:
108-
result["check"] = self.check(expr)
124+
if isinstance(self.check, dict):
125+
for rule_name, rule_callable in self.check.items():
126+
result[f"check__{rule_name}"] = rule_callable(expr)
127+
else:
128+
list_of_rules = (
129+
self.check if isinstance(self.check, list) else [self.check]
130+
)
131+
# Get unique names for rules from callables
132+
rule_names = self._derive_check_rule_names(list_of_rules)
133+
for rule_name, rule_callable in zip(rule_names, list_of_rules):
134+
result[rule_name] = rule_callable(expr)
135+
109136
return result
110137

138+
def _derive_check_rule_names(
139+
self, rules: list[Callable[[pl.Expr], pl.Expr]]
140+
) -> list[str]:
141+
"""Generate unique names for rule callables.
142+
143+
For callables with the same name, appends a suffix __i where i is the index
144+
of occurrence (starting from 0), but only if there are duplicates.
145+
146+
Args:
147+
rules: List of rule callables.
148+
149+
Returns:
150+
List of unique names corresponding to the rule callables.
151+
"""
152+
base_names = [
153+
f"check__{rule.__name__}" if rule.__name__ != "<lambda>" else "check"
154+
for rule in rules
155+
]
156+
157+
# Count occurrences using Counter
158+
name_counts = Counter(base_names)
159+
160+
# Append suffixes to names that are duplicated
161+
final_names = []
162+
duplicate_counter: dict[str, int] = {
163+
name: 0 for name in name_counts if name_counts[name] > 1
164+
}
165+
for name in base_names:
166+
if name_counts[name] > 1:
167+
postfix = duplicate_counter[name]
168+
final_names.append(f"{name}__{postfix}")
169+
duplicate_counter[name] += 1
170+
else:
171+
final_names.append(name)
172+
173+
return final_names
174+
111175
# -------------------------------------- SQL ------------------------------------- #
112176

113177
def sqlalchemy_column(self, name: str, dialect: sa.Dialect) -> sa.Column:

dataframely/columns/any.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,28 @@ class Any(Column):
2525
def __init__(
2626
self,
2727
*,
28-
check: Callable[[pl.Expr], pl.Expr] | None = None,
28+
check: (
29+
Callable[[pl.Expr], pl.Expr]
30+
| list[Callable[[pl.Expr], pl.Expr]]
31+
| dict[str, Callable[[pl.Expr], pl.Expr]]
32+
| None
33+
) = None,
2934
alias: str | None = None,
3035
metadata: dict[str, Any] | None = None,
3136
):
3237
"""
3338
Args:
34-
check: A custom check to run for this column. Must return a non-aggregated
35-
boolean expression.
39+
check: A custom rule or multiple rules to run for this column. This can be:
40+
- A single callable that returns a non-aggregated boolean expression.
41+
The name of the rule is derived from the callable name, or defaults to
42+
"check" for lambdas.
43+
- A list of callables, where each callable returns a non-aggregated
44+
boolean expression. The name of the rule is derived from the callable
45+
name, or defaults to "check" for lambdas. Where multiple rules result
46+
in the same name, the suffix __i is appended to the name.
47+
- A dictionary mapping rule names to callables, where each callable
48+
returns a non-aggregated boolean expression.
49+
All rule names provided here are given the prefix "check_".
3650
alias: An overwrite for this column's name which allows for using a column
3751
name that is not a valid Python identifier. Especially note that setting
3852
this option does _not_ allow to refer to the column with two different

dataframely/columns/array.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ def __init__(
2828
# polars doesn't yet support grouping by arrays,
2929
# see https://github.com/pola-rs/polars/issues/22574
3030
primary_key: Literal[False] = False,
31-
check: Callable[[pl.Expr], pl.Expr] | None = None,
31+
check: (
32+
Callable[[pl.Expr], pl.Expr]
33+
| list[Callable[[pl.Expr], pl.Expr]]
34+
| dict[str, Callable[[pl.Expr], pl.Expr]]
35+
| None
36+
) = None,
3237
alias: str | None = None,
3338
metadata: dict[str, Any] | None = None,
3439
):
@@ -39,8 +44,17 @@ def __init__(
3944
nullable: Whether this column may contain null values.
4045
primary_key: Whether this column is part of the primary key of the schema.
4146
Not yet supported for the Array type.
42-
check: A custom check to run for this column. Must return a non-aggregated
43-
boolean expression.
47+
check: A custom rule or multiple rules to run for this column. This can be:
48+
- A single callable that returns a non-aggregated boolean expression.
49+
The name of the rule is derived from the callable name, or defaults to
50+
"check" for lambdas.
51+
- A list of callables, where each callable returns a non-aggregated
52+
boolean expression. The name of the rule is derived from the callable
53+
name, or defaults to "check" for lambdas. Where multiple rules result
54+
in the same name, the suffix __i is appended to the name.
55+
- A dictionary mapping rule names to callables, where each callable
56+
returns a non-aggregated boolean expression.
57+
All rule names provided here are given the prefix "check_".
4458
alias: An overwrite for this column's name which allows for using a column
4559
name that is not a valid Python identifier. Especially note that setting
4660
this option does _not_ allow to refer to the column with two different

dataframely/columns/datetime.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ def __init__(
3939
max: dt.date | None = None,
4040
max_exclusive: dt.date | None = None,
4141
resolution: str | None = None,
42-
check: Callable[[pl.Expr], pl.Expr] | None = None,
42+
check: (
43+
Callable[[pl.Expr], pl.Expr]
44+
| list[Callable[[pl.Expr], pl.Expr]]
45+
| dict[str, Callable[[pl.Expr], pl.Expr]]
46+
| None
47+
) = None,
4348
alias: str | None = None,
4449
metadata: dict[str, Any] | None = None,
4550
):
@@ -61,8 +66,17 @@ def __init__(
6166
formatting language used by :mod:`polars` datetime ``round`` method.
6267
For example, a value ``1mo`` expects all dates to be on the first of the
6368
month. Note that this setting does *not* affect the storage resolution.
64-
check: A custom check to run for this column. Must return a non-aggregated
65-
boolean expression.
69+
check: A custom rule or multiple rules to run for this column. This can be:
70+
- A single callable that returns a non-aggregated boolean expression.
71+
The name of the rule is derived from the callable name, or defaults to
72+
"check" for lambdas.
73+
- A list of callables, where each callable returns a non-aggregated
74+
boolean expression. The name of the rule is derived from the callable
75+
name, or defaults to "check" for lambdas. Where multiple rules result
76+
in the same name, the suffix __i is appended to the name.
77+
- A dictionary mapping rule names to callables, where each callable
78+
returns a non-aggregated boolean expression.
79+
All rule names provided here are given the prefix "check_".
6680
alias: An overwrite for this column's name which allows for using a column
6781
name that is not a valid Python identifier. Especially note that setting
6882
this option does _not_ allow to refer to the column with two different
@@ -152,7 +166,12 @@ def __init__(
152166
max: dt.time | None = None,
153167
max_exclusive: dt.time | None = None,
154168
resolution: str | None = None,
155-
check: Callable[[pl.Expr], pl.Expr] | None = None,
169+
check: (
170+
Callable[[pl.Expr], pl.Expr]
171+
| list[Callable[[pl.Expr], pl.Expr]]
172+
| dict[str, Callable[[pl.Expr], pl.Expr]]
173+
| None
174+
) = None,
156175
alias: str | None = None,
157176
metadata: dict[str, Any] | None = None,
158177
):
@@ -174,8 +193,17 @@ def __init__(
174193
formatting language used by :mod:`polars` datetime ``round`` method.
175194
For example, a value ``1h`` expects all times to be full hours. Note
176195
that this setting does *not* affect the storage resolution.
177-
check: A custom check to run for this column. Must return a non-aggregated
178-
boolean expression.
196+
check: A custom rule or multiple rules to run for this column. This can be:
197+
- A single callable that returns a non-aggregated boolean expression.
198+
The name of the rule is derived from the callable name, or defaults to
199+
"check" for lambdas.
200+
- A list of callables, where each callable returns a non-aggregated
201+
boolean expression. The name of the rule is derived from the callable
202+
name, or defaults to "check" for lambdas. Where multiple rules result
203+
in the same name, the suffix __i is appended to the name.
204+
- A dictionary mapping rule names to callables, where each callable
205+
returns a non-aggregated boolean expression.
206+
All rule names provided here are given the prefix "check_".
179207
alias: An overwrite for this column's name which allows for using a column
180208
name that is not a valid Python identifier. Especially note that setting
181209
this option does _not_ allow to refer to the column with two different
@@ -271,7 +299,12 @@ def __init__(
271299
max: dt.datetime | None = None,
272300
max_exclusive: dt.datetime | None = None,
273301
resolution: str | None = None,
274-
check: Callable[[pl.Expr], pl.Expr] | None = None,
302+
check: (
303+
Callable[[pl.Expr], pl.Expr]
304+
| list[Callable[[pl.Expr], pl.Expr]]
305+
| dict[str, Callable[[pl.Expr], pl.Expr]]
306+
| None
307+
) = None,
275308
alias: str | None = None,
276309
metadata: dict[str, Any] | None = None,
277310
):
@@ -293,8 +326,17 @@ def __init__(
293326
the formatting language used by :mod:`polars` datetime ``round`` method.
294327
For example, a value ``1h`` expects all datetimes to be full hours. Note
295328
that this setting does *not* affect the storage resolution.
296-
check: A custom check to run for this column. Must return a non-aggregated
297-
boolean expression.
329+
check: A custom rule or multiple rules to run for this column. This can be:
330+
- A single callable that returns a non-aggregated boolean expression.
331+
The name of the rule is derived from the callable name, or defaults to
332+
"check" for lambdas.
333+
- A list of callables, where each callable returns a non-aggregated
334+
boolean expression. The name of the rule is derived from the callable
335+
name, or defaults to "check" for lambdas. Where multiple rules result
336+
in the same name, the suffix __i is appended to the name.
337+
- A dictionary mapping rule names to callables, where each callable
338+
returns a non-aggregated boolean expression.
339+
All rule names provided here are given the prefix "check_".
298340
alias: An overwrite for this column's name which allows for using a column
299341
name that is not a valid Python identifier. Especially note that setting
300342
this option does _not_ allow to refer to the column with two different
@@ -380,7 +422,12 @@ def __init__(
380422
max: dt.timedelta | None = None,
381423
max_exclusive: dt.timedelta | None = None,
382424
resolution: str | None = None,
383-
check: Callable[[pl.Expr], pl.Expr] | None = None,
425+
check: (
426+
Callable[[pl.Expr], pl.Expr]
427+
| list[Callable[[pl.Expr], pl.Expr]]
428+
| dict[str, Callable[[pl.Expr], pl.Expr]]
429+
| None
430+
) = None,
384431
alias: str | None = None,
385432
metadata: dict[str, Any] | None = None,
386433
):
@@ -402,8 +449,17 @@ def __init__(
402449
the formatting language used by :mod:`polars` datetime ``round`` method.
403450
For example, a value ``1h`` expects all durations to be full hours. Note
404451
that this setting does *not* affect the storage resolution.
405-
check: A custom check to run for this column. Must return a non-aggregated
406-
boolean expression.
452+
check: A custom rule or multiple rules to run for this column. This can be:
453+
- A single callable that returns a non-aggregated boolean expression.
454+
The name of the rule is derived from the callable name, or defaults to
455+
"check" for lambdas.
456+
- A list of callables, where each callable returns a non-aggregated
457+
boolean expression. The name of the rule is derived from the callable
458+
name, or defaults to "check" for lambdas. Where multiple rules result
459+
in the same name, the suffix __i is appended to the name.
460+
- A dictionary mapping rule names to callables, where each callable
461+
returns a non-aggregated boolean expression.
462+
All rule names provided here are given the prefix "check_".
407463
alias: An overwrite for this column's name which allows for using a column
408464
name that is not a valid Python identifier. Especially note that setting
409465
this option does _not_ allow to refer to the column with two different

dataframely/columns/decimal.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ def __init__(
3333
min_exclusive: decimal.Decimal | None = None,
3434
max: decimal.Decimal | None = None,
3535
max_exclusive: decimal.Decimal | None = None,
36-
check: Callable[[pl.Expr], pl.Expr] | None = None,
36+
check: (
37+
Callable[[pl.Expr], pl.Expr]
38+
| list[Callable[[pl.Expr], pl.Expr]]
39+
| dict[str, Callable[[pl.Expr], pl.Expr]]
40+
| None
41+
) = None,
3742
alias: str | None = None,
3843
metadata: dict[str, Any] | None = None,
3944
):
@@ -53,8 +58,17 @@ def __init__(
5358
max: The maximum value for decimals in this column (inclusive).
5459
max_exclusive: Like ``max`` but exclusive. May not be specified if ``max``
5560
is specified and vice versa.
56-
check: A custom check to run for this column. Must return a non-aggregated
57-
boolean expression.
61+
check: A custom rule or multiple rules to run for this column. This can be:
62+
- A single callable that returns a non-aggregated boolean expression.
63+
The name of the rule is derived from the callable name, or defaults to
64+
"check" for lambdas.
65+
- A list of callables, where each callable returns a non-aggregated
66+
boolean expression. The name of the rule is derived from the callable
67+
name, or defaults to "check" for lambdas. Where multiple rules result
68+
in the same name, the suffix __i is appended to the name.
69+
- A dictionary mapping rule names to callables, where each callable
70+
returns a non-aggregated boolean expression.
71+
All rule names provided here are given the prefix "check_".
5872
alias: An overwrite for this column's name which allows for using a column
5973
name that is not a valid Python identifier. Especially note that setting
6074
this option does _not_ allow to refer to the column with two different

0 commit comments

Comments
 (0)