|
9 | 9 |
|
10 | 10 |
|
11 | 11 | # mixin that requires `is_constant`
|
12 |
| -class ShapedValue: |
| 12 | +def ShapedValue(cls): |
13 | 13 | @cached_property
|
14 | 14 | def literal_value(self) -> np.ndarray:
|
15 | 15 | if not self.is_constant:
|
@@ -42,3 +42,22 @@ def n_elements(self) -> int:
|
42 | 42 | @cached_property
|
43 | 43 | def dtype(self) -> Type:
|
44 | 44 | return self._shaped_type.element_type
|
| 45 | + |
| 46 | + setattr(cls, "literal_value", literal_value) |
| 47 | + cls.literal_value.__set_name__(None, "literal_value") |
| 48 | + setattr(cls, "_shaped_type", _shaped_type) |
| 49 | + cls._shaped_type.__set_name__(None, "_shaped_type") |
| 50 | + |
| 51 | + setattr(cls, "has_static_shape", has_static_shape) |
| 52 | + setattr(cls, "has_rank", has_rank) |
| 53 | + |
| 54 | + setattr(cls, "rank", rank) |
| 55 | + cls.rank.__set_name__(None, "rank") |
| 56 | + setattr(cls, "shape", shape) |
| 57 | + cls.shape.__set_name__(None, "shape") |
| 58 | + setattr(cls, "n_elements", n_elements) |
| 59 | + cls.n_elements.__set_name__(None, "n_elements") |
| 60 | + setattr(cls, "dtype", dtype) |
| 61 | + cls.dtype.__set_name__(None, "dtype") |
| 62 | + |
| 63 | + return cls |
0 commit comments