@@ -67,12 +67,8 @@ def __init__(
67
67
metadata = metadata ,
68
68
)
69
69
if isclass (categories ) and issubclass (categories , enum .Enum ):
70
- categories = pl .Series (
71
- values = [getattr (v , "value" , v ) for v in categories .__members__ .values ()]
72
- )
73
- elif not isinstance (categories , pl .Series ):
74
- categories = pl .Series (values = categories )
75
- self .categories = categories
70
+ categories = (item .value for item in categories )
71
+ self .categories = list (categories )
76
72
77
73
@property
78
74
def dtype (self ) -> pl .DataType :
@@ -81,7 +77,7 @@ def dtype(self) -> pl.DataType:
81
77
def validate_dtype (self , dtype : PolarsDataType ) -> bool :
82
78
if not isinstance (dtype , pl .Enum ):
83
79
return False
84
- return self .categories . equals ( dtype .categories )
80
+ return self .categories == dtype .categories . to_list ( )
85
81
86
82
def sqlalchemy_dtype (self , dialect : sa .Dialect ) -> sa_TypeEngine :
87
83
category_lengths = [len (c ) for c in self .categories ]
@@ -102,6 +98,6 @@ def pyarrow_dtype(self) -> pa.DataType:
102
98
def _sample_unchecked (self , generator : Generator , n : int ) -> pl .Series :
103
99
return generator .sample_choice (
104
100
n ,
105
- choices = self .categories . to_list () ,
101
+ choices = self .categories ,
106
102
null_probability = self ._null_probability ,
107
103
).cast (self .dtype )
0 commit comments