Skip to content

Commit 22f4b7d

Browse files
authored
feat: Extend allowed types of the categories arg of dy.Enum (#138)
1 parent f303b3b commit 22f4b7d

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

dataframely/columns/enum.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
from __future__ import annotations
55

6-
from collections.abc import Sequence
6+
import enum
7+
from collections.abc import Iterable
8+
from inspect import isclass
79
from typing import Any
810

911
import polars as pl
@@ -22,7 +24,7 @@ class Enum(Column):
2224

2325
def __init__(
2426
self,
25-
categories: Sequence[str],
27+
categories: pl.Series | Iterable[str] | type[enum.Enum],
2628
*,
2729
nullable: bool | None = None,
2830
primary_key: bool = False,
@@ -32,7 +34,8 @@ def __init__(
3234
):
3335
"""
3436
Args:
35-
categories: The list of valid categories for the enum.
37+
categories: The set of valid categories for the enum, or an existing Python
38+
string-valued enum.
3639
nullable: Whether this column may contain null values.
3740
Explicitly set `nullable=True` if you want your column to be nullable.
3841
In a future release, `nullable=False` will be the default if `nullable`
@@ -63,7 +66,13 @@ def __init__(
6366
alias=alias,
6467
metadata=metadata,
6568
)
66-
self.categories = list(categories)
69+
if isclass(categories) and issubclass(categories, enum.Enum):
70+
categories = pl.Series(
71+
values=[getattr(v, "value", v) for v in categories.__members__.values()]
72+
)
73+
elif not isinstance(categories, pl.Series):
74+
categories = pl.Series(values=categories)
75+
self.categories = categories
6776

6877
@property
6978
def dtype(self) -> pl.DataType:
@@ -72,7 +81,7 @@ def dtype(self) -> pl.DataType:
7281
def validate_dtype(self, dtype: PolarsDataType) -> bool:
7382
if not isinstance(dtype, pl.Enum):
7483
return False
75-
return self.categories == dtype.categories.to_list()
84+
return self.categories.equals(dtype.categories)
7685

7786
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
7887
category_lengths = [len(c) for c in self.categories]
@@ -92,5 +101,7 @@ def pyarrow_dtype(self) -> pa.DataType:
92101

93102
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
94103
return generator.sample_choice(
95-
n, choices=self.categories, null_probability=self._null_probability
104+
n,
105+
choices=self.categories.to_list(),
106+
null_probability=self._null_probability,
96107
).cast(self.dtype)

tests/column_types/test_enum.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) QuantCo 2025-2025
22
# SPDX-License-Identifier: BSD-3-Clause
3-
3+
import enum
4+
from collections.abc import Iterable
5+
from enum import Enum
46
from typing import Any
57

68
import polars as pl
@@ -61,3 +63,48 @@ def test_different_sequences(type1: type, type2: type) -> None:
6163
S = create_schema("test", {"x": dy.Enum(type1(allowed))})
6264
df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(type2(allowed)))})
6365
S.validate(df)
66+
67+
68+
def test_enum_of_enum_136() -> None:
69+
class Categories(str, Enum):
70+
a = "a"
71+
b = "b"
72+
73+
assert pl.Enum(Categories) == dy.Enum(Categories).dtype
74+
75+
76+
def test_enum_of_series() -> None:
77+
categories = pl.Series(["a", "b"])
78+
assert pl.Enum(categories) == dy.Enum(categories).dtype
79+
80+
81+
def test_enum_of_iterable() -> None:
82+
categories = (x for x in ["a", "b"])
83+
assert pl.Enum(["a", "b"]) == dy.Enum(categories).dtype
84+
85+
86+
@pytest.mark.parametrize(
87+
"categories1",
88+
[
89+
["a", "b"],
90+
("a", "b"),
91+
pl.Series(["a", "b"]),
92+
Enum("Categories", {"a": "a", "b": "b"}),
93+
],
94+
)
95+
@pytest.mark.parametrize(
96+
"categories2",
97+
[
98+
["a", "b"],
99+
("a", "b"),
100+
pl.Series(["a", "b"]),
101+
Enum("Categories", {"a": "a", "b": "b"}),
102+
],
103+
)
104+
def test_sequences_and_enums(
105+
categories1: pl.Series | Iterable[str] | type[enum.Enum],
106+
categories2: pl.Series | Iterable[str] | type[enum.Enum],
107+
) -> None:
108+
S = create_schema("test", {"x": dy.Enum(categories1)})
109+
df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(categories2))})
110+
S.validate(df)

0 commit comments

Comments
 (0)