Skip to content

Commit 8d4b41b

Browse files
authored
refactor: Remove generic parameter from collection (#50)
1 parent f891a59 commit 8d4b41b

File tree

2 files changed

+9
-16
lines changed

2 files changed

+9
-16
lines changed

dataframely/collection.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# Copyright (c) QuantCo 2025-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
import sys
54
import warnings
65
from abc import ABC
76
from collections.abc import Mapping, Sequence
87
from pathlib import Path
9-
from typing import Any, Generic, Self, TypeVar, cast
8+
from typing import Any, Self, cast
109

1110
import polars as pl
1211
import polars.exceptions as plexc
@@ -18,15 +17,8 @@
1817
from .failure import FailureInfo
1918
from .random import Generator
2019

21-
if sys.version_info >= (3, 13):
22-
SamplingType = TypeVar(
23-
"SamplingType", bound=Mapping[str, Any], default=Mapping[str, Any]
24-
)
25-
else: # pragma: no cover
26-
SamplingType = TypeVar("SamplingType", bound=Mapping[str, Any])
2720

28-
29-
class Collection(BaseCollection, ABC, Generic[SamplingType]):
21+
class Collection(BaseCollection, ABC):
3022
"""Base class for all collections of data frames with a predefined schema.
3123
3224
A collection is comprised of a set of *members* which are collectively "consistent",
@@ -86,7 +78,7 @@ def sample(
8678
cls,
8779
num_rows: int | None = None,
8880
*,
89-
overrides: Sequence[SamplingType] | None = None,
81+
overrides: Sequence[Mapping[str, Any]] | None = None,
9082
generator: Generator | None = None,
9183
) -> Self:
9284
"""Create a random sample from the members of this collection.
@@ -162,10 +154,11 @@ def sample(
162154
samples = (
163155
overrides
164156
if overrides is not None
165-
else [cast(SamplingType, {}) for _ in range(cast(int, num_rows))]
157+
else [{} for _ in range(cast(int, num_rows))]
166158
)
167159
processed_samples = [
168-
cls._preprocess_sample(sample, i, g) for i, sample in enumerate(samples)
160+
cls._preprocess_sample(dict(sample.items()), i, g)
161+
for i, sample in enumerate(samples)
169162
]
170163

171164
# 2) Ensure that all samples have primary keys assigned to ensure that we
@@ -234,8 +227,8 @@ def sample(
234227

235228
@classmethod
236229
def _preprocess_sample(
237-
cls, sample: SamplingType, index: int, generator: Generator
238-
) -> SamplingType:
230+
cls, sample: dict[str, Any], index: int, generator: Generator
231+
) -> dict[str, Any]:
239232
"""Overridable method to preprocess a sample passed to :meth:`sample`.
240233
241234
The purpose of this method is to (1) set the primary key columns to enable

tests/test_typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class SamplingType(TypedDict):
9494
second: NotRequired[SamplingTypeSecond]
9595

9696

97-
class MyCollection(dy.Collection[SamplingType]):
97+
class MyCollection(dy.Collection):
9898
first: dy.LazyFrame[MyFirstSchema]
9999
second: dy.LazyFrame[MySecondSchema]
100100

0 commit comments

Comments
 (0)