3
3
4
4
from __future__ import annotations
5
5
6
- from collections .abc import Sequence
6
+ import enum
7
+ from collections .abc import Iterable
8
+ from inspect import isclass
7
9
from typing import Any
8
10
9
11
import polars as pl
@@ -22,7 +24,7 @@ class Enum(Column):
22
24
23
25
def __init__ (
24
26
self ,
25
- categories : Sequence [str ],
27
+ categories : pl . Series | Iterable [str ] | type [ enum . Enum ],
26
28
* ,
27
29
nullable : bool | None = None ,
28
30
primary_key : bool = False ,
@@ -32,7 +34,8 @@ def __init__(
32
34
):
33
35
"""
34
36
Args:
35
- categories: The list of valid categories for the enum.
37
+ categories: The set of valid categories for the enum, or an existing Python
38
+ string-valued enum.
36
39
nullable: Whether this column may contain null values.
37
40
Explicitly set `nullable=True` if you want your column to be nullable.
38
41
In a future release, `nullable=False` will be the default if `nullable`
@@ -63,7 +66,13 @@ def __init__(
63
66
alias = alias ,
64
67
metadata = metadata ,
65
68
)
66
- self .categories = list (categories )
69
+ if isclass (categories ) and issubclass (categories , enum .Enum ):
70
+ categories = pl .Series (
71
+ values = [getattr (v , "value" , v ) for v in categories .__members__ .values ()]
72
+ )
73
+ elif not isinstance (categories , pl .Series ):
74
+ categories = pl .Series (values = categories )
75
+ self .categories = categories
67
76
68
77
@property
69
78
def dtype (self ) -> pl .DataType :
@@ -72,7 +81,7 @@ def dtype(self) -> pl.DataType:
72
81
def validate_dtype (self , dtype : PolarsDataType ) -> bool :
73
82
if not isinstance (dtype , pl .Enum ):
74
83
return False
75
- return self .categories == dtype .categories . to_list ( )
84
+ return self .categories . equals ( dtype .categories )
76
85
77
86
def sqlalchemy_dtype (self , dialect : sa .Dialect ) -> sa_TypeEngine :
78
87
category_lengths = [len (c ) for c in self .categories ]
@@ -92,5 +101,7 @@ def pyarrow_dtype(self) -> pa.DataType:
92
101
93
102
def _sample_unchecked (self , generator : Generator , n : int ) -> pl .Series :
94
103
return generator .sample_choice (
95
- n , choices = self .categories , null_probability = self ._null_probability
104
+ n ,
105
+ choices = self .categories .to_list (),
106
+ null_probability = self ._null_probability ,
96
107
).cast (self .dtype )
0 commit comments