Skip to content

Commit 777b6de

Browse files
feat: Allow semantic comparison of collections (#62)
Co-authored-by: Oliver Borchert <[email protected]>
1 parent eb61574 commit 777b6de

File tree

5 files changed

+193
-2
lines changed

5 files changed

+193
-2
lines changed

dataframely/_base_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ def _get_metadata(source: dict[str, Any]) -> Metadata:
149149
def __repr__(cls) -> str:
150150
parts = [f'[Schema "{cls.__name__}"]']
151151
parts.append(textwrap.indent("Columns:", prefix=" " * 2))
152-
for name, col in cls.columns().items(): # type: ignore
152+
for name, col in cls.columns().items():
153153
parts.append(textwrap.indent(f'- "{name}": {col!r}', prefix=" " * 4))
154-
if validation_rules := cls._schema_validation_rules(): # type: ignore
154+
if validation_rules := cls._schema_validation_rules():
155155
parts.append(textwrap.indent("Rules:", prefix=" " * 2))
156156
for name, rule in validation_rules.items():
157157
parts.append(textwrap.indent(f'- "{name}": {rule!r}', prefix=" " * 4))

dataframely/_filter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def filter() -> Callable[[Callable[[C], pl.LazyFrame]], Filter[C]]:
3434
Attention:
3535
Make sure to provide unique combinations of the primary keys or the filters
3636
might introduce duplicate rows.
37+
38+
Attention:
39+
The filter logic should return a lazy frame with a static computational graph.
40+
Other implementations using arbitrary python logic works for filtering and
41+
validation, but may lead to wrong results in Collection comparisons
42+
and (de-)serialization.
3743
"""
3844

3945
def decorator(validation_fn: Callable[[C], pl.LazyFrame]) -> Filter[C]:

dataframely/_rule.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ def rule(*, group_by: list[str] | None = None) -> Callable[[ValidationFunction],
106106
rules. By default, any rule that evaluates to ``null`` because one of the
107107
columns used in the rule is ``null`` is interpreted as ``true``, i.e. the row
108108
is assumed to be valid.
109+
110+
Attention:
111+
The rule logic should return a static result.
112+
Other implementations using arbitrary python logic works for filtering and
113+
validation, but may lead to wrong results in Schema comparisons
114+
and (de-)serialization.
109115
"""
110116

111117
def decorator(validation_fn: ValidationFunction) -> Rule:

dataframely/collection.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import warnings
55
from abc import ABC
66
from collections.abc import Mapping, Sequence
7+
from dataclasses import asdict
78
from pathlib import Path
89
from typing import Any, Self, cast
910

@@ -225,6 +226,65 @@ def sample(
225226
# 3) Eventually, we initialize the final collection and return
226227
return cls.validate(members)
227228

229+
@classmethod
230+
def matches(cls, other: type["Collection"]) -> bool:
231+
"""Check whether this collection semantically matches another.
232+
233+
Args:
234+
other: The collection to compare with.
235+
236+
Returns:
237+
Whether the two collections are semantically equal.
238+
239+
Attention:
240+
For custom filters, reliable comparison results are only guaranteed
241+
if the filter always returns a static polars expression.
242+
Otherwise, this function may falsely indicate a match.
243+
"""
244+
245+
def _members_match() -> bool:
246+
members_lhs = cls.members()
247+
members_rhs = other.members()
248+
249+
# Member names must match
250+
if members_lhs.keys() != members_rhs.keys():
251+
return False
252+
253+
# Member attributes must match
254+
for name in members_lhs:
255+
lhs = asdict(members_lhs[name])
256+
rhs = asdict(members_rhs[name])
257+
for attr in lhs.keys() | rhs.keys():
258+
if attr == "schema":
259+
if not lhs[attr].matches(rhs[attr]):
260+
return False
261+
else:
262+
if lhs[attr] != rhs[attr]:
263+
return False
264+
return True
265+
266+
def _filters_match() -> bool:
267+
filters_lhs = cls._filters()
268+
filters_rhs = other._filters()
269+
270+
# Filter names must match
271+
if filters_lhs.keys() != filters_rhs.keys():
272+
return False
273+
274+
# Computational graph of filter logic must match
275+
# Evaluate on empty dataframes
276+
empty_left = cls.create_empty()
277+
empty_right = other.create_empty()
278+
279+
for name in filters_lhs:
280+
lhs = filters_lhs[name].logic(empty_left)
281+
rhs = filters_rhs[name].logic(empty_right)
282+
if lhs.serialize(format="json") != rhs.serialize(format="json"):
283+
return False
284+
return True
285+
286+
return _members_match() and _filters_match()
287+
228288
@classmethod
229289
def _preprocess_sample(
230290
cls, sample: dict[str, Any], index: int, generator: Generator

tests/collection/test_matches.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) QuantCo 2025-2025
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import polars as pl
5+
6+
import dataframely as dy
7+
8+
9+
def test_collection_matches_itself() -> None:
10+
"""Collections should match themselves."""
11+
12+
class MySchema(dy.Schema):
13+
foo = dy.Integer()
14+
15+
# First collection has one member
16+
class MyCollection1(dy.Collection):
17+
x: dy.LazyFrame[MySchema]
18+
19+
assert MyCollection1.matches(MyCollection1)
20+
21+
22+
def test_collection_matches_different_members() -> None:
23+
"""Collections should count as different if they have members with different
24+
names."""
25+
26+
class MySchema(dy.Schema):
27+
foo = dy.Integer()
28+
29+
class MyCollection1(dy.Collection):
30+
x: dy.LazyFrame[MySchema]
31+
32+
class MyCollection2(dy.Collection):
33+
y: dy.LazyFrame[MySchema]
34+
35+
# Should not match
36+
assert not MyCollection1.matches(MyCollection2)
37+
38+
39+
def test_collection_matches_different_schemas() -> None:
40+
"""Collections should count as different if their members have different schemas."""
41+
42+
class MyIntSchema(dy.Schema):
43+
foo = dy.Integer()
44+
45+
class MyStringSchema(dy.Schema):
46+
foo = dy.String()
47+
48+
assert not MyIntSchema.matches(MyStringSchema), (
49+
"Test schemas must not match for test setup to make sense"
50+
)
51+
52+
# Collections have the same member name
53+
# but mismatching schemas
54+
class MyCollection1(dy.Collection):
55+
x: dy.LazyFrame[MyIntSchema]
56+
57+
class MyCollection2(dy.Collection):
58+
x: dy.LazyFrame[MyStringSchema]
59+
60+
# Should not match
61+
assert not MyCollection1.matches(MyCollection2)
62+
63+
64+
def test_collection_matches_different_filter_names() -> None:
65+
"""Collections should count as different if they have the same members but different
66+
names."""
67+
68+
class MyIntSchema(dy.Schema):
69+
foo = dy.Integer(primary_key=True)
70+
71+
class MyCollection1(dy.Collection):
72+
x: dy.LazyFrame[MyIntSchema]
73+
74+
class MyCollection2(MyCollection1):
75+
@dy.filter()
76+
def test_filter(self) -> pl.LazyFrame:
77+
return dy.filter_relationship_one_to_one(self.x, self.x, ["foo"])
78+
79+
# Should not match
80+
assert not MyCollection1.matches(MyCollection2)
81+
82+
83+
def test_collection_matches_different_filter_logc() -> None:
84+
"""Collections should count as different if they have the same members but different
85+
filter logic."""
86+
87+
class MyIntSchema(dy.Schema):
88+
foo = dy.Integer(primary_key=True)
89+
90+
class BaseCollection(dy.Collection):
91+
x: dy.LazyFrame[MyIntSchema]
92+
93+
class MyCollection1(BaseCollection):
94+
@dy.filter()
95+
def test_filter(self) -> pl.LazyFrame:
96+
return dy.filter_relationship_one_to_one(self.x, self.x, ["foo"])
97+
98+
class MyCollection2(BaseCollection):
99+
@dy.filter()
100+
def test_filter(self) -> pl.LazyFrame:
101+
return dy.filter_relationship_one_to_at_least_one(self.x, self.x, ["foo"])
102+
103+
assert not MyCollection1.matches(MyCollection2)
104+
105+
106+
def test_collection_matches_different_optional() -> None:
107+
"""Collections should count as different if their members differ in whether they are
108+
optional or not."""
109+
110+
class FooSchema(dy.Schema):
111+
x = dy.Integer()
112+
113+
class MandatoryFooCollection(dy.Collection):
114+
foo: dy.LazyFrame[FooSchema]
115+
116+
class OptionalFooCollection(dy.Collection):
117+
foo: dy.LazyFrame[FooSchema] | None
118+
119+
assert not MandatoryFooCollection.matches(OptionalFooCollection)

0 commit comments

Comments
 (0)