|
| 1 | +from functools import partial |
1 | 2 | from typing import Union
|
2 | 3 |
|
3 | 4 | import numpy as np
|
4 | 5 | from mlir.ir import (
|
5 |
| - IntegerType, |
6 |
| - F64Type, |
7 |
| - RankedTensorType, |
8 |
| - IndexType, |
| 6 | + Attribute, |
9 | 7 | F16Type,
|
10 | 8 | F32Type,
|
| 9 | + F64Type, |
| 10 | + IndexType, |
| 11 | + IntegerType, |
| 12 | + MemRefType, |
| 13 | + RankedTensorType, |
11 | 14 | Type,
|
| 15 | + UnrankedMemRefType, |
| 16 | + UnrankedTensorType, |
| 17 | + VectorType, |
12 | 18 | )
|
13 | 19 |
|
14 | 20 | index_t = IndexType.get()
|
@@ -66,15 +72,55 @@ def infer_mlir_type(
|
66 | 72 | )
|
67 | 73 |
|
68 | 74 |
|
69 |
| -def tensor_t(*args, element_type: Type = None): |
70 |
| - if (element_type is None and not isinstance(args[-1], Type)) or ( |
71 |
| - isinstance(args[-1], Type) and element_type is not None |
| 75 | +def shaped_t(*args, element_type: Type = None, type_constructor=None): |
| 76 | + if type_constructor is None: |
| 77 | + raise ValueError("shaped_t is an abstract base class - cannot be constructed") |
| 78 | + if (element_type is None and args and not isinstance(args[-1], Type)) or ( |
| 79 | + args and isinstance(args[-1], Type) and element_type is not None |
72 | 80 | ):
|
73 | 81 | raise ValueError(
|
74 | 82 | f"either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type"
|
75 | 83 | )
|
76 | 84 | if element_type is not None:
|
77 | 85 | type = element_type
|
| 86 | + sizes = args |
78 | 87 | else:
|
79 | 88 | type = args[-1]
|
80 |
| - return RankedTensorType.get(args[:-1], type) |
| 89 | + sizes = args[:-1] |
| 90 | + if sizes: |
| 91 | + return type_constructor(sizes, type) |
| 92 | + else: |
| 93 | + return type_constructor(type) |
| 94 | + |
| 95 | + |
| 96 | +def vector_t(*args, element_type: Type = None): |
| 97 | + return shaped_t(*args, element_type=element_type, type_constructor=VectorType.get) |
| 98 | + |
| 99 | + |
| 100 | +def tensor_t(*args, element_type: Type = None): |
| 101 | + if not len(args) or len(args) == 1 and isinstance(args[-1], Type): |
| 102 | + return shaped_t( |
| 103 | + *args, element_type=element_type, type_constructor=UnrankedTensorType.get |
| 104 | + ) |
| 105 | + else: |
| 106 | + return shaped_t( |
| 107 | + *args, element_type=element_type, type_constructor=RankedTensorType.get |
| 108 | + ) |
| 109 | + |
| 110 | + |
| 111 | +def memref_t(*args, element_type: Type = None, memory_space: int = None): |
| 112 | + if memory_space is None: |
| 113 | + memory_space = 0 |
| 114 | + memory_space = Attribute.parse(str(memory_space)) |
| 115 | + if not len(args) or len(args) == 1 and isinstance(args[-1], Type): |
| 116 | + return shaped_t( |
| 117 | + *args, |
| 118 | + element_type=element_type, |
| 119 | + type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space), |
| 120 | + ) |
| 121 | + else: |
| 122 | + return shaped_t( |
| 123 | + *args, |
| 124 | + element_type=element_type, |
| 125 | + type_constructor=partial(MemRefType.get, memory_space=memory_space), |
| 126 | + ) |
0 commit comments