Skip to content

Commit 22e50f9

Browse files
committed
add shaped type wrappers
1 parent 80af1ea commit 22e50f9

File tree

2 files changed

+87
-8
lines changed

2 files changed

+87
-8
lines changed

mlir_utils/types.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1+
from functools import partial
12
from typing import Union
23

34
import numpy as np
45
from mlir.ir import (
5-
IntegerType,
6-
F64Type,
7-
RankedTensorType,
8-
IndexType,
6+
Attribute,
97
F16Type,
108
F32Type,
9+
F64Type,
10+
IndexType,
11+
IntegerType,
12+
MemRefType,
13+
RankedTensorType,
1114
Type,
15+
UnrankedMemRefType,
16+
UnrankedTensorType,
17+
VectorType,
1218
)
1319

1420
index_t = IndexType.get()
@@ -66,15 +72,55 @@ def infer_mlir_type(
6672
)
6773

6874

69-
def tensor_t(*args, element_type: Type = None):
70-
if (element_type is None and not isinstance(args[-1], Type)) or (
71-
isinstance(args[-1], Type) and element_type is not None
75+
def shaped_t(*args, element_type: Type = None, type_constructor=None):
76+
if type_constructor is None:
77+
raise ValueError("shaped_t is an abstract base class - cannot be constructed")
78+
if (element_type is None and args and not isinstance(args[-1], Type)) or (
79+
args and isinstance(args[-1], Type) and element_type is not None
7280
):
7381
raise ValueError(
7482
f"either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type"
7583
)
7684
if element_type is not None:
7785
type = element_type
86+
sizes = args
7887
else:
7988
type = args[-1]
80-
return RankedTensorType.get(args[:-1], type)
89+
sizes = args[:-1]
90+
if sizes:
91+
return type_constructor(sizes, type)
92+
else:
93+
return type_constructor(type)
94+
95+
96+
def vector_t(*args, element_type: Type = None):
97+
return shaped_t(*args, element_type=element_type, type_constructor=VectorType.get)
98+
99+
100+
def tensor_t(*args, element_type: Type = None):
101+
if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
102+
return shaped_t(
103+
*args, element_type=element_type, type_constructor=UnrankedTensorType.get
104+
)
105+
else:
106+
return shaped_t(
107+
*args, element_type=element_type, type_constructor=RankedTensorType.get
108+
)
109+
110+
111+
def memref_t(*args, element_type: Type = None, memory_space: int = None):
112+
if memory_space is None:
113+
memory_space = 0
114+
memory_space = Attribute.parse(str(memory_space))
115+
if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
116+
return shaped_t(
117+
*args,
118+
element_type=element_type,
119+
type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
120+
)
121+
else:
122+
return shaped_t(
123+
*args,
124+
element_type=element_type,
125+
type_constructor=partial(MemRefType.get, memory_space=memory_space),
126+
)

tests/test_types.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
3+
from mlir_utils.dialects.ext.tensor import S
4+
5+
# noinspection PyUnresolvedReferences
6+
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
7+
from mlir_utils.types import f64_t, tensor_t, memref_t, vector_t
8+
9+
# needed since the fix isn't defined here nor conftest.py
10+
pytest.mark.usefixtures("ctx")
11+
12+
13+
def test_shaped_types(ctx: MLIRContext):
14+
t = tensor_t(S, 3, S, f64_t)
15+
assert repr(t) == "RankedTensorType(tensor<?x3x?xf64>)"
16+
ut = tensor_t(f64_t)
17+
assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)"
18+
t = tensor_t(S, 3, S, element_type=f64_t)
19+
assert repr(t) == "RankedTensorType(tensor<?x3x?xf64>)"
20+
ut = tensor_t(element_type=f64_t)
21+
assert repr(ut) == "UnrankedTensorType(tensor<*xf64>)"
22+
23+
m = memref_t(S, 3, S, f64_t)
24+
assert repr(m) == "MemRefType(memref<?x3x?xf64>)"
25+
um = memref_t(f64_t)
26+
assert repr(um) == "UnrankedMemRefType(memref<*xf64>)"
27+
m = memref_t(S, 3, S, element_type=f64_t)
28+
assert repr(m) == "MemRefType(memref<?x3x?xf64>)"
29+
um = memref_t(element_type=f64_t)
30+
assert repr(um) == "UnrankedMemRefType(memref<*xf64>)"
31+
32+
v = vector_t(3, 3, 3, f64_t)
33+
assert repr(v) == "VectorType(vector<3x3x3xf64>)"

0 commit comments

Comments
 (0)