Skip to content

Commit 2d5332f

Browse files
refactor: Store Enum categories as list instead of pl.Series (#150)
1 parent 56b2f86 commit 2d5332f

File tree

2 files changed

+4
-21
lines changed

2 files changed

+4
-21
lines changed

dataframely/columns/_base.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,6 @@ def _attributes_match(
381381
if name == "check":
382382
return _compare_checks(lhs, rhs, column_expr)
383383

384-
lhs_is_series = isinstance(lhs, pl.Series)
385-
rhs_is_series = isinstance(rhs, pl.Series)
386-
387-
if lhs_is_series != rhs_is_series:
388-
return False
389-
390-
if lhs_is_series and rhs_is_series:
391-
return _compare_series(lhs, rhs)
392-
393384
return lhs == rhs
394385

395386
# -------------------------------- DUNDER METHODS -------------------------------- #
@@ -413,10 +404,6 @@ def __str__(self) -> str:
413404
return self.__class__.__name__.lower()
414405

415406

416-
def _compare_series(lhs: pl.Series, rhs: pl.Series) -> bool:
417-
return (len(lhs) == len(rhs)) and lhs.equals(rhs)
418-
419-
420407
def _compare_checks(lhs: Check | None, rhs: Check | None, expr: pl.Expr) -> bool:
421408
match (lhs, rhs):
422409
case (None, None):

dataframely/columns/enum.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,8 @@ def __init__(
6767
metadata=metadata,
6868
)
6969
if isclass(categories) and issubclass(categories, enum.Enum):
70-
categories = pl.Series(
71-
values=[getattr(v, "value", v) for v in categories.__members__.values()]
72-
)
73-
elif not isinstance(categories, pl.Series):
74-
categories = pl.Series(values=categories)
75-
self.categories = categories
70+
categories = (item.value for item in categories)
71+
self.categories = list(categories)
7672

7773
@property
7874
def dtype(self) -> pl.DataType:
@@ -81,7 +77,7 @@ def dtype(self) -> pl.DataType:
8177
def validate_dtype(self, dtype: PolarsDataType) -> bool:
8278
if not isinstance(dtype, pl.Enum):
8379
return False
84-
return self.categories.equals(dtype.categories)
80+
return self.categories == dtype.categories.to_list()
8581

8682
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
8783
category_lengths = [len(c) for c in self.categories]
@@ -102,6 +98,6 @@ def pyarrow_dtype(self) -> pa.DataType:
10298
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
10399
return generator.sample_choice(
104100
n,
105-
choices=self.categories.to_list(),
101+
choices=self.categories,
106102
null_probability=self._null_probability,
107103
).cast(self.dtype)

0 commit comments

Comments
 (0)