Skip to content

Commit d1a9c4c

Browse files
committed
incorporate comments
1 parent 8282a78 commit d1a9c4c

File tree

6 files changed

+73
-116
lines changed

6 files changed

+73
-116
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,8 +2538,8 @@ void mlir::python::populateIRCore(py::module &m) {
25382538
[](py::object & /*class*/) {
25392539
auto *context = PyThreadContextEntry::getDefaultContext();
25402540
if (!context)
2541-
throw py::value_error("No current Context");
2542-
return context;
2541+
return py::none().cast<py::object>();
2542+
return py::cast(context);
25432543
},
25442544
"Gets the Context bound to the current thread or raises ValueError")
25452545
.def_property_readonly(

mlir/python/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
2121
_mlir_libs/__init__.py
2222
ir.py
2323
passmanager.py
24-
types.py
24+
extras/types.py
2525
dialects/_ods_common.py
2626

2727
# The main _mlir module has submodules: include stubs from each.

mlir/python/mlir/extras/__init__.py

Whitespace-only changes.

mlir/python/mlir/types.py renamed to mlir/python/mlir/extras/types.py

Lines changed: 36 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from functools import partial
66
from typing import Optional, List
77

8-
from .ir import (
8+
from ..ir import (
99
Attribute,
1010
BF16Type,
1111
ComplexType,
12-
Context,
1312
F16Type,
1413
F32Type,
1514
F64Type,
@@ -32,55 +31,54 @@
3231
VectorType,
3332
)
3433

35-
__all__ = []
34+
index = lambda: IndexType.get()
3635

37-
_index = lambda: IndexType.get()
38-
_bool = lambda: IntegerType.get_signless(1)
3936

40-
_i8 = lambda: IntegerType.get_signless(8)
41-
_i16 = lambda: IntegerType.get_signless(16)
42-
_i32 = lambda: IntegerType.get_signless(32)
43-
_i64 = lambda: IntegerType.get_signless(64)
37+
def i(width):
38+
return IntegerType.get_signless(width)
4439

45-
_si8 = lambda: IntegerType.get_signed(8)
46-
_si16 = lambda: IntegerType.get_signed(16)
47-
_si32 = lambda: IntegerType.get_signed(32)
48-
_si64 = lambda: IntegerType.get_signed(64)
4940

50-
_ui8 = lambda: IntegerType.get_unsigned(8)
51-
_ui16 = lambda: IntegerType.get_unsigned(16)
52-
_ui32 = lambda: IntegerType.get_unsigned(32)
53-
_ui64 = lambda: IntegerType.get_unsigned(64)
41+
def si(width):
42+
return IntegerType.get_signed(width)
5443

55-
_f16 = lambda: F16Type.get()
56-
_f32 = lambda: F32Type.get()
57-
_f64 = lambda: F64Type.get()
58-
_bf16 = lambda: BF16Type.get()
5944

60-
_f8e5m2 = lambda: Float8E5M2Type.get()
61-
_f8e4m3 = lambda: Float8E4M3FNType.get()
62-
_f8e4m3b11fnuz = lambda: Float8E4M3B11FNUZType.get()
45+
def ui(width):
46+
return IntegerType.get_unsigned(width)
6347

64-
_none = lambda: NoneType.get()
6548

49+
bool = lambda: i(1)
50+
i8 = lambda: i(8)
51+
i16 = lambda: i(16)
52+
i32 = lambda: i(32)
53+
i64 = lambda: i(64)
6654

67-
def _i(width):
68-
return IntegerType.get_signless(width)
55+
si8 = lambda: si(8)
56+
si16 = lambda: si(16)
57+
si32 = lambda: si(32)
58+
si64 = lambda: si(64)
6959

60+
ui8 = lambda: ui(8)
61+
ui16 = lambda: ui(16)
62+
ui32 = lambda: ui(32)
63+
ui64 = lambda: ui(64)
7064

71-
def _si(width):
72-
return IntegerType.get_signed(width)
65+
f16 = lambda: F16Type.get()
66+
f32 = lambda: F32Type.get()
67+
f64 = lambda: F64Type.get()
68+
bf16 = lambda: BF16Type.get()
7369

70+
f8E5M2 = lambda: Float8E5M2Type.get()
71+
f8E4M3 = lambda: Float8E4M3FNType.get()
72+
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
7473

75-
def _ui(width):
76-
return IntegerType.get_unsigned(width)
74+
none = lambda: NoneType.get()
7775

7876

79-
def _complex(type):
77+
def complex(type):
8078
return ComplexType.get(type)
8179

8280

83-
def _opaque(dialect_namespace, type_data):
81+
def opaque(dialect_namespace, type_data):
8482
return OpaqueType.get(dialect_namespace, type_data)
8583

8684

@@ -105,7 +103,7 @@ def _shaped(*shape, element_type: Type = None, type_constructor=None):
105103
return type_constructor(type)
106104

107105

108-
def _vector(
106+
def vector(
109107
*shape,
110108
element_type: Type = None,
111109
scalable: Optional[List[bool]] = None,
@@ -120,7 +118,7 @@ def _vector(
120118
)
121119

122120

123-
def _tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
121+
def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
124122
if encoding is not None:
125123
encoding = StringAttr.get(encoding)
126124
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
@@ -136,7 +134,7 @@ def _tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
136134
)
137135

138136

139-
def _memref(
137+
def memref(
140138
*shape,
141139
element_type: Type = None,
142140
memory_space: Optional[int] = None,
@@ -159,31 +157,9 @@ def _memref(
159157
)
160158

161159

162-
def _tuple(*elements):
160+
def tuple(*elements):
163161
return TupleType.get_tuple(elements)
164162

165163

166-
def _function(*, inputs, results):
164+
def function(*, inputs, results):
167165
return FunctionType.get(inputs, results)
168-
169-
170-
def __getattr__(name):
171-
if name == "__path__":
172-
# https://docs.python.org/3/reference/import.html#path__
173-
# If a module is a package (either regular or namespace), the module object’s __path__ attribute must be set.
174-
# This module is NOT a package and so this must be None (rather than throw the RuntimeError below).
175-
return None
176-
try:
177-
Context.current
178-
except ValueError:
179-
raise RuntimeError("Types can only be instantiated under an active context.")
180-
181-
if f"_{name}" in globals():
182-
builder = globals()[f"_{name}"]
183-
if (
184-
isinstance(builder, type(lambda: None))
185-
and builder.__name__ == (lambda: None).__name__
186-
):
187-
return builder()
188-
return builder
189-
raise RuntimeError(f"{name} is not a legal type.")

mlir/test/python/ir/builtin_types.py

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import gc
44
from mlir.ir import *
55
from mlir.dialects import arith, tensor, func, memref
6-
import mlir.types as T
6+
import mlir.extras.types as T
77

88

99
def run(f):
@@ -778,109 +778,96 @@ def testCustomTypeTypeCaster():
778778
# CHECK-LABEL: TEST: testTypeWrappers
779779
@run
780780
def testTypeWrappers():
781-
try:
782-
from mlir.types import i32
783-
except RuntimeError as e:
784-
assert e.args[0] == "Types can only be instantiated under an active context."
785-
786-
try:
787-
from mlir.types import tensor
788-
except RuntimeError as e:
789-
assert e.args[0] == "Types can only be instantiated under an active context."
790-
791781
def stride(strides, offset=0):
792782
return StridedLayoutAttr.get(offset, strides)
793783

794784
with Context(), Location.unknown():
795-
try:
796-
from mlir.types import non_existent_type
797-
except RuntimeError as e:
798-
assert e.args[0] == "non_existent_type is not a legal type."
799-
800785
ia = T.i(5)
801786
sia = T.si(6)
802787
uia = T.ui(7)
803788
assert repr(ia) == "IntegerType(i5)"
804789
assert repr(sia) == "IntegerType(si6)"
805790
assert repr(uia) == "IntegerType(ui7)"
806791

807-
assert T.i(16) == T.i16
808-
assert T.si(16) == T.si16
809-
assert T.ui(16) == T.ui16
792+
assert T.i(16) == T.i16()
793+
assert T.si(16) == T.si16()
794+
assert T.ui(16) == T.ui16()
810795

811-
c1 = T.complex(T.f16)
812-
c2 = T.complex(T.i32)
796+
c1 = T.complex(T.f16())
797+
c2 = T.complex(T.i32())
813798
assert repr(c1) == "ComplexType(complex<f16>)"
814799
assert repr(c2) == "ComplexType(complex<i32>)"
815800

816-
vec_1 = T.vector(2, 3, T.f32)
817-
vec_2 = T.vector(2, 3, 4, T.f32)
801+
vec_1 = T.vector(2, 3, T.f32())
802+
vec_2 = T.vector(2, 3, 4, T.f32())
818803
assert repr(vec_1) == "VectorType(vector<2x3xf32>)"
819804
assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)"
820805

821-
m1 = T.memref(2, 3, 4, T.f64)
806+
m1 = T.memref(2, 3, 4, T.f64())
822807
assert repr(m1) == "MemRefType(memref<2x3x4xf64>)"
823808

824-
m2 = T.memref(2, 3, 4, T.f64, memory_space=1)
809+
m2 = T.memref(2, 3, 4, T.f64(), memory_space=1)
825810
assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)"
826811

827-
m3 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=stride([5, 7, 13]))
812+
m3 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13]))
828813
assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)"
829814

830-
m4 = T.memref(2, 3, 4, T.f64, memory_space=1, layout=stride([5, 7, 13], 42))
815+
m4 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13], 42))
831816
assert (
832817
repr(m4)
833818
== "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)"
834819
)
835820

836821
S = ShapedType.get_dynamic_size()
837822

838-
t1 = T.tensor(S, 3, S, T.f64)
823+
t1 = T.tensor(S, 3, S, T.f64())
839824
assert repr(t1) == "RankedTensorType(tensor<?x3x?xf64>)"
840-
ut1 = T.tensor(T.f64)
825+
ut1 = T.tensor(T.f64())
841826
assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)"
842-
t2 = T.tensor(S, 3, S, element_type=T.f64)
827+
t2 = T.tensor(S, 3, S, element_type=T.f64())
843828
assert repr(t2) == "RankedTensorType(tensor<?x3x?xf64>)"
844-
ut2 = T.tensor(element_type=T.f64)
829+
ut2 = T.tensor(element_type=T.f64())
845830
assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)"
846831

847-
t3 = T.tensor(S, 3, S, T.f64, encoding="encoding")
832+
t3 = T.tensor(S, 3, S, T.f64(), encoding="encoding")
848833
assert repr(t3) == 'RankedTensorType(tensor<?x3x?xf64, "encoding">)'
849834

850-
v = T.vector(3, 3, 3, T.f64)
835+
v = T.vector(3, 3, 3, T.f64())
851836
assert repr(v) == "VectorType(vector<3x3x3xf64>)"
852837

853-
m5 = T.memref(S, 3, S, T.f64)
838+
m5 = T.memref(S, 3, S, T.f64())
854839
assert repr(m5) == "MemRefType(memref<?x3x?xf64>)"
855-
um1 = T.memref(T.f64)
840+
um1 = T.memref(T.f64())
856841
assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)"
857-
m6 = T.memref(S, 3, S, element_type=T.f64)
842+
m6 = T.memref(S, 3, S, element_type=T.f64())
858843
assert repr(m6) == "MemRefType(memref<?x3x?xf64>)"
859-
um2 = T.memref(element_type=T.f64)
844+
um2 = T.memref(element_type=T.f64())
860845
assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)"
861846

862-
m7 = T.memref(S, 3, S, T.f64)
847+
m7 = T.memref(S, 3, S, T.f64())
863848
assert repr(m7) == "MemRefType(memref<?x3x?xf64>)"
864-
um3 = T.memref(T.f64)
849+
um3 = T.memref(T.f64())
865850
assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)"
866851

867-
scalable_1 = T.vector(2, 3, T.f32, scalable=[False, True])
868-
scalable_2 = T.vector(2, 3, 4, T.f32, scalable=[True, False, True])
852+
scalable_1 = T.vector(2, 3, T.f32(), scalable=[False, True])
853+
scalable_2 = T.vector(2, 3, 4, T.f32(), scalable=[True, False, True])
869854
assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)"
870855
assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)"
871856

872-
scalable_3 = T.vector(2, 3, T.f32, scalable_dims=[1])
873-
scalable_4 = T.vector(2, 3, 4, T.f32, scalable_dims=[0, 2])
857+
scalable_3 = T.vector(2, 3, T.f32(), scalable_dims=[1])
858+
scalable_4 = T.vector(2, 3, 4, T.f32(), scalable_dims=[0, 2])
874859
assert scalable_3 == scalable_1
875860
assert scalable_4 == scalable_2
876861

877862
opaq = T.opaque("scf", "placeholder")
878863
assert repr(opaq) == "OpaqueType(!scf.placeholder)"
879864

880-
tup1 = T.tuple(T.i16, T.i32, T.i64)
881-
tup2 = T.tuple(T.f16, T.f32, T.f64)
865+
tup1 = T.tuple(T.i16(), T.i32(), T.i64())
866+
tup2 = T.tuple(T.f16(), T.f32(), T.f64())
882867
assert repr(tup1) == "TupleType(tuple<i16, i32, i64>)"
883868
assert repr(tup2) == "TupleType(tuple<f16, f32, f64>)"
884869

885-
func = T.function(inputs=(T.i16, T.i32, T.i64), results=(T.f16, T.f32, T.f64))
870+
func = T.function(
871+
inputs=(T.i16(), T.i32(), T.i64()), results=(T.f16(), T.f32(), T.f64())
872+
)
886873
assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"

mlir/test/python/ir/context_managers.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@ def run(f):
1515
def testContextEnterExit():
1616
with Context() as ctx:
1717
assert Context.current is ctx
18-
try:
19-
_ = Context.current
20-
except ValueError as e:
21-
# CHECK: No current Context
22-
print(e)
23-
else:
24-
assert False, "Expected exception"
18+
assert Context.current is None
2519

2620

2721
run(testContextEnterExit)

0 commit comments

Comments
 (0)