Skip to content

Commit fbf08c8

Browse files
authored
fix: Be more lenient when reading metadata (#114)
1 parent 9157ed3 commit fbf08c8

File tree

6 files changed

+79
-10
lines changed

6 files changed

+79
-10
lines changed

dataframely/_serialization.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def default(self, obj: Any) -> Any:
4646
case pl.Expr():
4747
return {
4848
"__type__": "expression",
49-
"value": obj.meta.serialize(format="json"),
49+
"value": base64.b64encode(obj.meta.serialize()).decode("utf-8"),
5050
}
5151
case pl.LazyFrame():
5252
return {
@@ -87,8 +87,14 @@ def object_hook(self, dct: dict[str, Any]) -> Any:
8787
case "tuple":
8888
return tuple(dct["value"])
8989
case "expression":
90-
data = BytesIO(cast(str, dct["value"]).encode("utf-8"))
91-
return pl.Expr.deserialize(data, format="json")
90+
value_str = cast(str, dct["value"]).encode("utf-8")
91+
if value_str.startswith(b"{"):
92+
# NOTE: This branch is for backwards-compatibility only
93+
data = BytesIO(value_str)
94+
return pl.Expr.deserialize(data, format="json")
95+
else:
96+
data = BytesIO(base64.b64decode(value_str))
97+
return pl.Expr.deserialize(data)
9298
case "lazyframe":
9399
data = BytesIO(
94100
base64.b64decode(cast(str, dct["value"]).encode("utf-8"))

dataframely/collection.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def _from_parquet(
870870
if (collection_type is None) and (schema_file := path / "schema.json").exists():
871871
try:
872872
collection_type = deserialize_collection(schema_file.read_text())
873-
except JSONDecodeError:
873+
except (JSONDecodeError, plexc.ComputeError):
874874
pass
875875

876876
return data, collection_type
@@ -947,12 +947,15 @@ def read_parquet_metadata_collection(
947947
source: Path to a parquet file or a file-like object that contains the metadata.
948948
949949
Returns:
950-
The collection that was serialized to the metadata or ``None`` if no collection metadata
951-
is found.
950+
The collection that was serialized to the metadata. ``None`` if no collection
951+
metadata is found or the deserialization fails.
952952
"""
953953
metadata = pl.read_parquet_metadata(source)
954954
if (schema_metadata := metadata.get(COLLECTION_METADATA_KEY)) is not None:
955-
return deserialize_collection(schema_metadata)
955+
try:
956+
return deserialize_collection(schema_metadata)
957+
except (JSONDecodeError, plexc.ComputeError):
958+
return None
956959
return None
957960

958961

dataframely/schema.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import warnings
99
from abc import ABC
1010
from collections.abc import Iterable, Mapping, Sequence
11+
from json import JSONDecodeError
1112
from pathlib import Path
1213
from typing import IO, Any, Literal, overload
1314

@@ -970,12 +971,15 @@ def read_parquet_metadata_schema(
970971
source: Path to a parquet file or a file-like object that contains the metadata.
971972
972973
Returns:
973-
The schema that was serialized to the metadata or ``None`` if no schema metadata
974-
is found.
974+
The schema that was serialized to the metadata. ``None`` if no schema metadata
975+
is found or the deserialization fails.
975976
"""
976977
metadata = pl.read_parquet_metadata(source)
977978
if (schema_metadata := metadata.get(SCHEMA_METADATA_KEY)) is not None:
978-
return deserialize_schema(schema_metadata)
979+
try:
980+
return deserialize_schema(schema_metadata)
981+
except (JSONDecodeError, plexc.ComputeError):
982+
return None
979983
return None
980984

981985

tests/collection/test_read_write_parquet.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from polars.testing import assert_frame_equal
1313

1414
import dataframely as dy
15+
from dataframely._serialization import COLLECTION_METADATA_KEY
1516
from dataframely.collection import _reconcile_collection_types
1617
from dataframely.exc import ValidationRequiredError
1718
from dataframely.testing import create_collection, create_schema
@@ -378,3 +379,20 @@ def test_reconcile_collection_types(
378379
inputs: list[type[dy.Collection] | None], output: type[dy.Collection] | None
379380
) -> None:
380381
assert output == _reconcile_collection_types(inputs)
382+
383+
384+
# ---------------------------------- MANUAL METADATA --------------------------------- #
385+
386+
387+
def test_read_invalid_parquet_metadata_collection(tmp_path: Path) -> None:
388+
# Arrange
389+
df = pl.DataFrame({"a": [1, 2, 3]})
390+
df.write_parquet(
391+
tmp_path / "df.parquet", metadata={COLLECTION_METADATA_KEY: "invalid"}
392+
)
393+
394+
# Act
395+
collection = dy.read_parquet_metadata_collection(tmp_path / "df.parquet")
396+
397+
# Assert
398+
assert collection is None

tests/schema/test_read_write_parquet.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from polars.testing import assert_frame_equal
1111

1212
import dataframely as dy
13+
from dataframely._serialization import SCHEMA_METADATA_KEY
1314
from dataframely.exc import ValidationRequiredError
1415
from dataframely.testing import create_schema
1516

@@ -216,3 +217,18 @@ def test_read_write_parquet_validation_skip_invalid_schema(
216217

217218
# Assert
218219
spy.assert_not_called()
220+
221+
222+
# ---------------------------------- MANUAL METADATA --------------------------------- #
223+
224+
225+
def test_read_invalid_parquet_metadata_schema(tmp_path: Path) -> None:
226+
# Arrange
227+
df = pl.DataFrame({"a": [1, 2, 3]})
228+
df.write_parquet(tmp_path / "df.parquet", metadata={SCHEMA_METADATA_KEY: "invalid"})
229+
230+
# Act
231+
schema = dy.read_parquet_metadata_schema(tmp_path / "df.parquet")
232+
233+
# Assert
234+
assert schema is None

tests/test_serialization.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) QuantCo 2025-2025
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import json
5+
6+
import polars as pl
7+
8+
from dataframely._serialization import SchemaJSONDecoder
9+
10+
11+
def test_decode_json_expression() -> None:
12+
# Arrange
13+
expr = pl.col("a") + 1
14+
encoded = json.dumps(
15+
{"__type__": "expression", "value": expr.meta.serialize(format="json")}
16+
)
17+
18+
# Act
19+
decoded = json.loads(encoded, cls=SchemaJSONDecoder)
20+
21+
# Assert
22+
assert expr.meta.eq(decoded)

0 commit comments

Comments
 (0)