|
1 | 1 | # Copyright (c) QuantCo 2025-2025
|
2 | 2 | # SPDX-License-Identifier: BSD-3-Clause
|
3 | 3 |
|
4 |
| -import sys |
5 | 4 | import warnings
|
6 | 5 | from abc import ABC
|
7 | 6 | from collections.abc import Mapping, Sequence
|
8 | 7 | from pathlib import Path
|
9 |
| -from typing import Any, Generic, Self, TypeVar, cast |
| 8 | +from typing import Any, Self, cast |
10 | 9 |
|
11 | 10 | import polars as pl
|
12 | 11 | import polars.exceptions as plexc
|
|
18 | 17 | from .failure import FailureInfo
|
19 | 18 | from .random import Generator
|
20 | 19 |
|
21 |
| -if sys.version_info >= (3, 13): |
22 |
| - SamplingType = TypeVar( |
23 |
| - "SamplingType", bound=Mapping[str, Any], default=Mapping[str, Any] |
24 |
| - ) |
25 |
| -else: # pragma: no cover |
26 |
| - SamplingType = TypeVar("SamplingType", bound=Mapping[str, Any]) |
27 | 20 |
|
28 |
| - |
29 |
| -class Collection(BaseCollection, ABC, Generic[SamplingType]): |
| 21 | +class Collection(BaseCollection, ABC): |
30 | 22 | """Base class for all collections of data frames with a predefined schema.
|
31 | 23 |
|
32 | 24 | A collection is comprised of a set of *members* which are collectively "consistent",
|
@@ -86,7 +78,7 @@ def sample(
|
86 | 78 | cls,
|
87 | 79 | num_rows: int | None = None,
|
88 | 80 | *,
|
89 |
| - overrides: Sequence[SamplingType] | None = None, |
| 81 | + overrides: Sequence[Mapping[str, Any]] | None = None, |
90 | 82 | generator: Generator | None = None,
|
91 | 83 | ) -> Self:
|
92 | 84 | """Create a random sample from the members of this collection.
|
@@ -162,10 +154,11 @@ def sample(
|
162 | 154 | samples = (
|
163 | 155 | overrides
|
164 | 156 | if overrides is not None
|
165 |
| - else [cast(SamplingType, {}) for _ in range(cast(int, num_rows))] |
| 157 | + else [{} for _ in range(cast(int, num_rows))] |
166 | 158 | )
|
167 | 159 | processed_samples = [
|
168 |
| - cls._preprocess_sample(sample, i, g) for i, sample in enumerate(samples) |
| 160 | + cls._preprocess_sample(dict(sample.items()), i, g) |
| 161 | + for i, sample in enumerate(samples) |
169 | 162 | ]
|
170 | 163 |
|
171 | 164 | # 2) Ensure that all samples have primary keys assigned to ensure that we
|
@@ -234,8 +227,8 @@ def sample(
|
234 | 227 |
|
235 | 228 | @classmethod
|
236 | 229 | def _preprocess_sample(
|
237 |
| - cls, sample: SamplingType, index: int, generator: Generator |
238 |
| - ) -> SamplingType: |
| 230 | + cls, sample: dict[str, Any], index: int, generator: Generator |
| 231 | + ) -> dict[str, Any]: |
239 | 232 | """Overridable method to preprocess a sample passed to :meth:`sample`.
|
240 | 233 |
|
241 | 234 | The purpose of this method is to (1) set the primary key columns to enable
|
|
0 commit comments