|
1 | 1 | import ctypes
|
2 |
| -from functools import wraps |
3 | 2 | import inspect
|
| 3 | +from collections import defaultdict |
| 4 | +from functools import wraps |
| 5 | +from typing import Callable |
4 | 6 |
|
5 | 7 | from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
|
6 |
| -from mlir.ir import InsertionPoint, Value, Type |
| 8 | +from mlir.ir import InsertionPoint, Value, Type, TypeID |
7 | 9 |
|
8 | 10 |
|
9 | 11 | def get_result_or_results(op):
|
@@ -31,20 +33,52 @@ def maybe_no_args(*args, **kwargs):
|
31 | 33 | return maybe_no_args
|
32 | 34 |
|
33 | 35 |
|
| 36 | +__VALUE_CASTERS: defaultdict[ |
| 37 | + TypeID, list[Callable[[Value], Value | None]] |
| 38 | +] = defaultdict(list) |
| 39 | + |
| 40 | + |
| 41 | +def register_value_caster( |
| 42 | + typeid: TypeID, caster: Callable[[Value], Value], priority: int = None |
| 43 | +): |
| 44 | + if not isinstance(typeid, TypeID): |
| 45 | + raise ValueError(f"{typeid=} is not a TypeID") |
| 46 | + if priority is None: |
| 47 | + __VALUE_CASTERS[typeid].append(caster) |
| 48 | + else: |
| 49 | + __VALUE_CASTERS[typeid].insert(priority, caster) |
| 50 | + |
| 51 | + |
| 52 | +def has_value_caster(typeid: TypeID): |
| 53 | + if not isinstance(typeid, TypeID): |
| 54 | + raise ValueError(f"{typeid=} is not a TypeID") |
| 55 | + if not typeid in __VALUE_CASTERS: |
| 56 | + return False |
| 57 | + return True |
| 58 | + |
| 59 | + |
| 60 | +def get_value_caster(typeid: TypeID): |
| 61 | + if not has_value_caster(typeid): |
| 62 | + raise ValueError(f"no registered caster for {typeid=}") |
| 63 | + return __VALUE_CASTERS[typeid] |
| 64 | + |
| 65 | + |
34 | 66 | def maybe_cast(val: Value):
|
35 | 67 | """Maybe cast an ir.Value to one of Tensor, Scalar.
|
36 | 68 |
|
37 | 69 | Args:
|
38 | 70 | val: The ir.Value to maybe cast.
|
39 | 71 | """
|
40 |
| - from mlir_utils.dialects.ext.tensor import Tensor |
41 | 72 | from mlir_utils.dialects.ext.arith import Scalar
|
42 | 73 |
|
43 | 74 | if not isinstance(val, Value):
|
44 | 75 | return val
|
45 | 76 |
|
46 |
| - if Tensor.isinstance(val): |
47 |
| - return Tensor(val) |
| 77 | + if has_value_caster(val.type.typeid): |
| 78 | + for caster in get_value_caster(val.type.typeid): |
| 79 | + if casted := caster(val): |
| 80 | + return casted |
| 81 | + raise ValueError(f"no successful casts for {val=}") |
48 | 82 | if Scalar.isinstance(val):
|
49 | 83 | return Scalar(val)
|
50 | 84 | return val
|
|
0 commit comments