Skip to content

Commit 8282a78

Browse files
committed
[mlir][python] add type wrappers
1 parent 0e42df4 commit 8282a78

File tree

4 files changed

+310
-16
lines changed

4 files changed

+310
-16
lines changed

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
463463

464464
static void bindDerived(ClassTy &c) {
465465
c.def_static("get", &PyVectorType::get, py::arg("shape"),
466-
py::arg("elementType"), py::kw_only(),
466+
py::arg("element_type"), py::kw_only(),
467467
py::arg("scalable") = py::none(),
468468
py::arg("scalable_dims") = py::none(),
469469
py::arg("loc") = py::none(), "Create a vector type")
@@ -689,13 +689,9 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
689689
static void bindDerived(ClassTy &c) {
690690
c.def_static(
691691
"get_tuple",
692-
[](py::list elementList, DefaultingPyMlirContext context) {
693-
intptr_t num = py::len(elementList);
694-
// Mapping py::list to SmallVector.
695-
SmallVector<MlirType, 4> elements;
696-
for (auto element : elementList)
697-
elements.push_back(element.cast<PyType>());
698-
MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
692+
[](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
693+
MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
694+
elements.data());
699695
return PyTupleType(context->getRef(), t);
700696
},
701697
py::arg("elements"), py::arg("context") = py::none(),
@@ -727,13 +723,11 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
727723
static void bindDerived(ClassTy &c) {
728724
c.def_static(
729725
"get",
730-
[](std::vector<PyType> inputs, std::vector<PyType> results,
726+
[](std::vector<MlirType> inputs, std::vector<MlirType> results,
731727
DefaultingPyMlirContext context) {
732-
SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
733-
SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
734-
MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
735-
inputsRaw.data(), resultsRaw.size(),
736-
resultsRaw.data());
728+
MlirType t =
729+
mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
730+
results.size(), results.data());
737731
return PyFunctionType(context->getRef(), t);
738732
},
739733
py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
@@ -742,7 +736,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
742736
"inputs",
743737
[](PyFunctionType &self) {
744738
MlirType t = self;
745-
auto contextRef = self.getContext();
746739
py::list types;
747740
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
748741
++i) {
@@ -754,7 +747,6 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
754747
c.def_property_readonly(
755748
"results",
756749
[](PyFunctionType &self) {
757-
auto contextRef = self.getContext();
758750
py::list types;
759751
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
760752
++i) {

mlir/python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
2121
_mlir_libs/__init__.py
2222
ir.py
2323
passmanager.py
24+
types.py
2425
dialects/_ods_common.py
2526

2627
# The main _mlir module has submodules: include stubs from each.

mlir/python/mlir/types.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from functools import partial
6+
from typing import Optional, List
7+
8+
from .ir import (
9+
Attribute,
10+
BF16Type,
11+
ComplexType,
12+
Context,
13+
F16Type,
14+
F32Type,
15+
F64Type,
16+
Float8E4M3B11FNUZType,
17+
Float8E4M3FNType,
18+
Float8E5M2Type,
19+
FunctionType,
20+
IndexType,
21+
IntegerType,
22+
MemRefType,
23+
NoneType,
24+
OpaqueType,
25+
RankedTensorType,
26+
StridedLayoutAttr,
27+
StringAttr,
28+
TupleType,
29+
Type,
30+
UnrankedMemRefType,
31+
UnrankedTensorType,
32+
VectorType,
33+
)
34+
35+
__all__ = []
36+
37+
_index = lambda: IndexType.get()
38+
_bool = lambda: IntegerType.get_signless(1)
39+
40+
_i8 = lambda: IntegerType.get_signless(8)
41+
_i16 = lambda: IntegerType.get_signless(16)
42+
_i32 = lambda: IntegerType.get_signless(32)
43+
_i64 = lambda: IntegerType.get_signless(64)
44+
45+
_si8 = lambda: IntegerType.get_signed(8)
46+
_si16 = lambda: IntegerType.get_signed(16)
47+
_si32 = lambda: IntegerType.get_signed(32)
48+
_si64 = lambda: IntegerType.get_signed(64)
49+
50+
_ui8 = lambda: IntegerType.get_unsigned(8)
51+
_ui16 = lambda: IntegerType.get_unsigned(16)
52+
_ui32 = lambda: IntegerType.get_unsigned(32)
53+
_ui64 = lambda: IntegerType.get_unsigned(64)
54+
55+
_f16 = lambda: F16Type.get()
56+
_f32 = lambda: F32Type.get()
57+
_f64 = lambda: F64Type.get()
58+
_bf16 = lambda: BF16Type.get()
59+
60+
_f8e5m2 = lambda: Float8E5M2Type.get()
61+
_f8e4m3 = lambda: Float8E4M3FNType.get()
62+
_f8e4m3b11fnuz = lambda: Float8E4M3B11FNUZType.get()
63+
64+
_none = lambda: NoneType.get()
65+
66+
67+
def _i(width):
68+
return IntegerType.get_signless(width)
69+
70+
71+
def _si(width):
72+
return IntegerType.get_signed(width)
73+
74+
75+
def _ui(width):
76+
return IntegerType.get_unsigned(width)
77+
78+
79+
def _complex(type):
80+
return ComplexType.get(type)
81+
82+
83+
def _opaque(dialect_namespace, type_data):
84+
return OpaqueType.get(dialect_namespace, type_data)
85+
86+
87+
def _shaped(*shape, element_type: Type = None, type_constructor=None):
88+
if type_constructor is None:
89+
raise ValueError("shaped is an abstract base class - cannot be constructed.")
90+
if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
91+
shape and isinstance(shape[-1], Type) and element_type is not None
92+
):
93+
raise ValueError(
94+
f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
95+
)
96+
if element_type is not None:
97+
type = element_type
98+
sizes = shape
99+
else:
100+
type = shape[-1]
101+
sizes = shape[:-1]
102+
if sizes:
103+
return type_constructor(sizes, type)
104+
else:
105+
return type_constructor(type)
106+
107+
108+
def _vector(
109+
*shape,
110+
element_type: Type = None,
111+
scalable: Optional[List[bool]] = None,
112+
scalable_dims: Optional[List[int]] = None,
113+
):
114+
return _shaped(
115+
*shape,
116+
element_type=element_type,
117+
type_constructor=partial(
118+
VectorType.get, scalable=scalable, scalable_dims=scalable_dims
119+
),
120+
)
121+
122+
123+
def _tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
124+
if encoding is not None:
125+
encoding = StringAttr.get(encoding)
126+
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
127+
if encoding is not None:
128+
raise ValueError("UnrankedTensorType does not support encoding.")
129+
return _shaped(
130+
*shape, element_type=element_type, type_constructor=UnrankedTensorType.get
131+
)
132+
return _shaped(
133+
*shape,
134+
element_type=element_type,
135+
type_constructor=partial(RankedTensorType.get, encoding=encoding),
136+
)
137+
138+
139+
def _memref(
140+
*shape,
141+
element_type: Type = None,
142+
memory_space: Optional[int] = None,
143+
layout: Optional[StridedLayoutAttr] = None,
144+
):
145+
if memory_space is not None:
146+
memory_space = Attribute.parse(str(memory_space))
147+
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
148+
return _shaped(
149+
*shape,
150+
element_type=element_type,
151+
type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
152+
)
153+
return _shaped(
154+
*shape,
155+
element_type=element_type,
156+
type_constructor=partial(
157+
MemRefType.get, memory_space=memory_space, layout=layout
158+
),
159+
)
160+
161+
162+
def _tuple(*elements):
163+
return TupleType.get_tuple(elements)
164+
165+
166+
def _function(*, inputs, results):
167+
return FunctionType.get(inputs, results)
168+
169+
170+
def __getattr__(name):
171+
if name == "__path__":
172+
# https://docs.python.org/3/reference/import.html#path__
173+
# If a module is a package (either regular or namespace), the module object’s __path__ attribute must be set.
174+
# This module is NOT a package and so this must be None (rather than throw the RuntimeError below).
175+
return None
176+
try:
177+
Context.current
178+
except ValueError:
179+
raise RuntimeError("Types can only be instantiated under an active context.")
180+
181+
if f"_{name}" in globals():
182+
builder = globals()[f"_{name}"]
183+
if (
184+
isinstance(builder, type(lambda: None))
185+
and builder.__name__ == (lambda: None).__name__
186+
):
187+
return builder()
188+
return builder
189+
raise RuntimeError(f"{name} is not a legal type.")

mlir/test/python/ir/builtin_types.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gc
44
from mlir.ir import *
55
from mlir.dialects import arith, tensor, func, memref
6+
import mlir.types as T
67

78

89
def run(f):
@@ -772,3 +773,114 @@ def testCustomTypeTypeCaster():
772773
print(t)
773774
# CHECK: OperationType(!transform.op<"foo.bar">)
774775
print(repr(t))
776+
777+
778+
# CHECK-LABEL: TEST: testTypeWrappers
779+
@run
780+
def testTypeWrappers():
781+
try:
782+
from mlir.types import i32
783+
except RuntimeError as e:
784+
assert e.args[0] == "Types can only be instantiated under an active context."
785+
786+
try:
787+
from mlir.types import tensor
788+
except RuntimeError as e:
789+
assert e.args[0] == "Types can only be instantiated under an active context."
790+
791+
def stride(strides, offset=0):
792+
return StridedLayoutAttr.get(offset, strides)
793+
794+
with Context(), Location.unknown():
795+
try:
796+
from mlir.types import non_existent_type
797+
except RuntimeError as e:
798+
assert e.args[0] == "non_existent_type is not a legal type."
799+
800+
ia = T.i(5)
801+
sia = T.si(6)
802+
uia = T.ui(7)
803+
assert repr(ia) == "IntegerType(i5)"
804+
assert repr(sia) == "IntegerType(si6)"
805+
assert repr(uia) == "IntegerType(ui7)"
806+
807+
assert T.i(16) == T.i16
808+
assert T.si(16) == T.si16
809+
assert T.ui(16) == T.ui16
810+
811+
c1 = T.complex(T.f16)
812+
c2 = T.complex(T.i32)
813+
assert repr(c1) == "ComplexType(complex<f16>)"
814+
assert repr(c2) == "ComplexType(complex<i32>)"
815+
816+
vec_1 = T.vector(2, 3, T.f32)
817+
vec_2 = T.vector(2, 3, 4, T.f32)
818+
assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
819+
assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"
820+
821+
m1 = T.memref(2, 3, 4, T.f64)
822+
assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"
823+
824+
m2 = T.memref(2, 3, 4, T.f64, memory_space=1)
825+
assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"
826+
827+
m3 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=stride([5, 7, 13]))
828+
assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"
829+
830+
m4 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=stride([5, 7, 13], 42))
831+
assert (
832+
repr(m4)
833+
== "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
834+
)
835+
836+
S = ShapedType.get_dynamic_size()
837+
838+
t1 = T.tensor(S, 3, S, T.f64)
839+
assert repr(t1) == "RankedTensorType(tensor<?x3x?xf64>)"
840+
ut1 = T.tensor(T.f64)
841+
assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)"
842+
t2 = T.tensor(S, 3, S, element_type=T.f64)
843+
assert repr(t2) == "RankedTensorType(tensor<?x3x?xf64>)"
844+
ut2 = T.tensor(element_type=T.f64)
845+
assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)"
846+
847+
t3 = T.tensor(S, 3, S, T.f64, encoding="encoding")
848+
assert repr(t3) == 'RankedTensorType(tensor<?x3x?xf64, "encoding">)'
849+
850+
v = T.vector(3, 3, 3, T.f64)
851+
assert repr(v) == "VectorType(vector<3x3x3xf64>)"
852+
853+
m5 = T.memref(S, 3, S, T.f64)
854+
assert repr(m5) == "MemRefType(memref<?x3x?xf64>)"
855+
um1 = T.memref(T.f64)
856+
assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)"
857+
m6 = T.memref(S, 3, S, element_type=T.f64)
858+
assert repr(m6) == "MemRefType(memref<?x3x?xf64>)"
859+
um2 = T.memref(element_type=T.f64)
860+
assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)"
861+
862+
m7 = T.memref(S, 3, S, T.f64)
863+
assert repr(m7) == "MemRefType(memref<?x3x?xf64>)"
864+
um3 = T.memref(T.f64)
865+
assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)"
866+
867+
scalable_1 = T.vector(2, 3, T.f32, scalable=[False, True])
868+
scalable_2 = T.vector(2, 3, 4, T.f32, scalable=[True, False, True])
869+
assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
870+
assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"
871+
872+
scalable_3 = T.vector(2, 3, T.f32, scalable_dims=[1])
873+
scalable_4 = T.vector(2, 3, 4, T.f32, scalable_dims=[0, 2])
874+
assert scalable_3 == scalable_1
875+
assert scalable_4 == scalable_2
876+
877+
opaq = T.opaque("scf", "placeholder")
878+
assert repr(opaq) == "OpaqueType(!scf.placeholder)"
879+
880+
tup1 = T.tuple(T.i16, T.i32, T.i64)
881+
tup2 = T.tuple(T.f16, T.f32, T.f64)
882+
assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
883+
assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"
884+
885+
func = T.function(inputs=(T.i16, T.i32, T.i64), results=(T.f16, T.f32, T.f64))
886+
assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"

0 commit comments

Comments
 (0)