Skip to content

Commit d8916e8

Browse files
fix: Correct schema deserialization on failure (#128)
1 parent 327d649 commit d8916e8

File tree

3 files changed

+56
-13
lines changed

3 files changed

+56
-13
lines changed

dataframely/failure.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .schema import Schema
2020

2121
RULE_METADATA_KEY = "dataframely_rule_columns"
22+
UNKNOWN_SCHEMA_NAME = "__DATAFRAMELY_UNKNOWN__"
2223

2324
S = TypeVar("S", bound=BaseSchema)
2425

@@ -176,7 +177,7 @@ def scan_parquet(
176177
def _from_parquet(
177178
cls, source: str | Path | IO[bytes], scan: bool, **kwargs: Any
178179
) -> FailureInfo[Schema]:
179-
from .schema import deserialize_schema
180+
from .schema import Schema, deserialize_schema
180181

181182
metadata = pl.read_parquet_metadata(source)
182183
schema_metadata = metadata.get(SCHEMA_METADATA_KEY)
@@ -189,10 +190,13 @@ def _from_parquet(
189190
if scan
190191
else pl.read_parquet(source, **kwargs).lazy()
191192
)
193+
failure_schema = deserialize_schema(schema_metadata, strict=False) or type(
194+
UNKNOWN_SCHEMA_NAME, (Schema,), {}
195+
)
192196
return FailureInfo(
193197
lf,
194198
json.loads(rule_metadata),
195-
schema=deserialize_schema(schema_metadata),
199+
schema=failure_schema,
196200
)
197201

198202

dataframely/schema.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -959,27 +959,33 @@ def read_parquet_metadata_schema(
959959
"""
960960
metadata = pl.read_parquet_metadata(source)
961961
if (schema_metadata := metadata.get(SCHEMA_METADATA_KEY)) is not None:
962-
try:
963-
return deserialize_schema(schema_metadata)
964-
except (JSONDecodeError, plexc.ComputeError):
965-
return None
962+
return deserialize_schema(schema_metadata, strict=False)
966963
return None
967964

968965

969-
def deserialize_schema(data: str) -> type[Schema]:
966+
@overload
967+
def deserialize_schema(data: str, strict: Literal[True] = True) -> type[Schema]: ...
968+
969+
970+
@overload
971+
def deserialize_schema(data: str, strict: Literal[False]) -> type[Schema] | None: ...
972+
973+
974+
def deserialize_schema(data: str, strict: bool = True) -> type[Schema] | None:
970975
"""Deserialize a schema from a JSON string.
971976
972977
This method allows to dynamically load a schema from its serialization, without
973978
having to know the schema to load in advance.
974979
975980
Args:
976981
data: The JSON string created via :meth:`Schema.serialize`.
982+
strict: Whether to raise an exception if the schema cannot be deserialized.
977983
978984
Returns:
979985
The schema loaded from the JSON data.
980986
981987
Raises:
982-
ValueError: If the schema format version is not supported.
988+
ValueError: If the schema format version is not supported and ``strict=True``.
983989
984990
Attention:
985991
This functionality is considered unstable. It may be changed at any time
@@ -988,10 +994,15 @@ def deserialize_schema(data: str) -> type[Schema]:
988994
See also:
989995
:meth:`Schema.serialize` for additional information on serialization.
990996
"""
991-
decoded = json.loads(data, cls=SchemaJSONDecoder)
992-
if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION:
993-
raise ValueError(f"Unsupported schema format version: {format}")
994-
return _schema_from_dict(decoded)
997+
try:
998+
decoded = json.loads(data, cls=SchemaJSONDecoder)
999+
if (format := decoded["versions"]["format"]) != SERIALIZATION_FORMAT_VERSION:
1000+
raise ValueError(f"Unsupported schema format version: {format}")
1001+
return _schema_from_dict(decoded)
1002+
except (ValueError, JSONDecodeError, plexc.ComputeError) as e:
1003+
if strict:
1004+
raise e from e
1005+
return None
9951006

9961007

9971008
def _schema_from_dict(data: dict[str, Any]) -> type[Schema]:

tests/test_failure_info.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from polars.testing import assert_frame_equal
1010

1111
import dataframely as dy
12+
from dataframely._serialization import SCHEMA_METADATA_KEY
13+
from dataframely.failure import RULE_METADATA_KEY, UNKNOWN_SCHEMA_NAME, FailureInfo
1214

1315

1416
class MySchema(dy.Schema):
@@ -68,7 +70,9 @@ def test_write_parquet_custom_metadata(tmp_path: Path) -> None:
6870
"read_fn",
6971
[dy.FailureInfo.read_parquet, dy.FailureInfo.scan_parquet],
7072
)
71-
def test_missing_metadata(tmp_path: Path, read_fn: Callable[[Path], None]) -> None:
73+
def test_missing_metadata(
74+
tmp_path: Path, read_fn: Callable[[Path], FailureInfo]
75+
) -> None:
7276
df = pl.DataFrame(
7377
{
7478
"a": [4, 5, 6, 6, 7, 8],
@@ -79,3 +83,27 @@ def test_missing_metadata(tmp_path: Path, read_fn: Callable[[Path], None]) -> No
7983

8084
with pytest.raises(ValueError, match=r"does not contain the required metadata"):
8185
read_fn(tmp_path / "failure.parquet")
86+
87+
88+
@pytest.mark.parametrize(
89+
"read_fn",
90+
[dy.FailureInfo.read_parquet, dy.FailureInfo.scan_parquet],
91+
)
92+
def test_invalid_schema_deserialization(
93+
tmp_path: Path, read_fn: Callable[[Path], FailureInfo]
94+
) -> None:
95+
df = pl.DataFrame(
96+
{
97+
"a": [1, 2, 3],
98+
"b": [False, True, False],
99+
}
100+
)
101+
df.write_parquet(
102+
tmp_path / "failure.parquet",
103+
metadata={
104+
SCHEMA_METADATA_KEY: "{WRONG",
105+
RULE_METADATA_KEY: '["b"]',
106+
},
107+
)
108+
failure = read_fn(tmp_path / "failure.parquet")
109+
assert failure.schema.__name__ == UNKNOWN_SCHEMA_NAME

0 commit comments

Comments
 (0)