Skip to content

Commit 537b2aa

Browse files
authored
[mlir][python] meta region_op (#75673)
1 parent 11140cc commit 537b2aa

File tree

14 files changed

+429
-13
lines changed

14 files changed

+429
-13
lines changed

mlir/python/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
2121
_mlir_libs/__init__.py
2222
ir.py
2323
passmanager.py
24-
extras/types.py
2524
dialects/_ods_common.py
2625

2726
# The main _mlir module has submodules: include stubs from each.
@@ -30,6 +29,14 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
3029
_mlir_libs/_mlir/passmanager.pyi
3130
)
3231

32+
declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
33+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
34+
ADD_TO_PARENT MLIRPythonSources.Core.Python
35+
SOURCES
36+
extras/types.py
37+
extras/meta.py
38+
)
39+
3340
declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
3441
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
3542
ADD_TO_PARENT MLIRPythonSources

mlir/python/mlir/dialects/arith.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from ._ods_common import (
1212
get_default_loc_context as _get_default_loc_context,
1313
_cext as _ods_cext,
14+
get_op_result_or_op_results as _get_op_result_or_op_results,
15+
SubClassValueT as _SubClassValueT,
1416
)
1517

1618
from typing import Any, List, Union
@@ -75,3 +77,9 @@ def literal_value(self) -> Union[int, float]:
7577
return FloatAttr(self.value).value
7678
else:
7779
raise ValueError("only integer and float constants have literal values")
80+
81+
82+
def constant(
83+
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
84+
) -> _SubClassValueT:
85+
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))

mlir/python/mlir/dialects/builtin.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from typing import Dict, Optional
6+
57
from ._builtin_ops_gen import *
68
from ._builtin_ops_gen import _Dialect
9+
from ..extras.meta import region_op
710

811
try:
912
from ..ir import *
@@ -23,3 +26,23 @@ def __init__(self, *, loc=None, ip=None):
2326
@property
2427
def body(self):
2528
return self.regions[0].blocks[0]
29+
30+
31+
@region_op
32+
def module(
33+
*,
34+
sym_name=None,
35+
sym_visibility=None,
36+
attrs: Optional[Dict[str, Attribute]] = None,
37+
loc=None,
38+
ip=None,
39+
):
40+
mod = ModuleOp.__base__(
41+
sym_name=sym_name, sym_visibility=sym_visibility, loc=loc, ip=ip
42+
)
43+
if attrs is None:
44+
attrs = {}
45+
for attr_name, attr in attrs.items():
46+
mod.operation.attributes[attr_name] = attr
47+
48+
return mod

mlir/python/mlir/dialects/func.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ def emit_call_op(*call_args):
243243
return decorator
244244

245245

246+
func = FuncOp.from_py_func
247+
248+
246249
@_ods_cext.register_operation(_Dialect, replace=True)
247250
class CallOp(CallOp):
248251
"""Specialization for the call op class."""

mlir/python/mlir/dialects/pdl.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ._pdl_ops_gen import *
66
from ._pdl_ops_gen import _Dialect
77
from .._mlir_libs._mlirDialectsPDL import *
8+
from .._mlir_libs._mlirDialectsPDL import OperationType
89

910

1011
try:
@@ -13,7 +14,7 @@
1314
except ImportError as e:
1415
raise RuntimeError("Error loading imports from extension module") from e
1516

16-
from typing import Union, Optional, Sequence, Mapping
17+
from typing import Union, Optional, Sequence, Mapping, NewType
1718
from ._ods_common import (
1819
get_op_result_or_value as _get_value,
1920
get_op_results_or_values as _get_values,
@@ -220,3 +221,10 @@ def __init__(
220221
constantTypes = []
221222
result = pdl.RangeType.get(pdl.TypeType.get())
222223
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
224+
225+
226+
OperationTypeT = NewType("OperationType", OperationType)
227+
228+
229+
def op_t() -> OperationTypeT:
230+
return OperationTypeT(OperationType.get())

mlir/python/mlir/dialects/scf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def for_(
120120
params = [start, stop, step]
121121
for i, p in enumerate(params):
122122
if isinstance(p, int):
123-
p = constant(IntegerAttr.get(IndexType.get(), p))
123+
p = constant(IndexType.get(), p)
124124
elif isinstance(p, float):
125125
raise ValueError(f"{p=} must be int.")
126126
params[i] = p

mlir/python/mlir/dialects/tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from ._tensor_ops_gen import *
66
from ._tensor_ops_gen import _Dialect
7+
from ..extras.meta import region_op
78

89
try:
910
from ..ir import *
@@ -40,3 +41,9 @@ def __init__(
4041
dynamic_sizes.append(s)
4142
result_type = RankedTensorType.get(static_sizes, element_type)
4243
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
44+
45+
46+
generate = region_op(
47+
lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
48+
terminator=lambda args: YieldOp(args[0]),
49+
)

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
except ImportError as e:
1919
raise RuntimeError("Error loading imports from extension module") from e
2020

21-
from typing import Optional, Sequence, Union
21+
from typing import Optional, Sequence, Union, NewType
2222

2323

2424
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -175,15 +175,15 @@ def __init__(
175175
result_types: Sequence[Type],
176176
sym_visibility=None,
177177
arg_attrs=None,
178-
res_attrs=None
178+
res_attrs=None,
179179
):
180180
function_type = FunctionType.get(input_types, result_types)
181181
super().__init__(
182182
sym_name=sym_name,
183183
function_type=TypeAttr.get(function_type),
184184
sym_visibility=sym_visibility,
185185
arg_attrs=arg_attrs,
186-
res_attrs=res_attrs
186+
res_attrs=res_attrs,
187187
)
188188
self.regions[0].blocks.append(*input_types)
189189

@@ -212,3 +212,10 @@ def __init__(
212212
if operands is None:
213213
operands = []
214214
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
215+
216+
217+
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
218+
219+
220+
def any_op_t() -> AnyOpTypeT:
221+
return AnyOpTypeT(AnyOpType.get())

mlir/python/mlir/dialects/transform/extras/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,16 @@
44

55
from typing import Callable, Optional, Sequence, Union
66

7+
from ....extras.meta import region_op
78
from .... import ir
8-
from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
9+
from .. import (
10+
AnyOpType,
11+
OperationType,
12+
NamedSequenceOp,
13+
YieldOp,
14+
SequenceOp,
15+
ApplyPatternsOp,
16+
)
917
from .. import structured
1018

1119

@@ -147,3 +155,8 @@ def test_match_ops_single(module: OpHandle):
147155

148156
if dump_script:
149157
print(named_sequence_op)
158+
159+
160+
sequence = region_op(SequenceOp.__base__, terminator=YieldOp)
161+
named_sequence = region_op(NamedSequenceOp, terminator=YieldOp)
162+
apply_patterns = region_op(ApplyPatternsOp)

mlir/python/mlir/extras/meta.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
import inspect
6+
from functools import wraps
7+
8+
from ..dialects._ods_common import get_op_result_or_op_results
9+
from ..ir import Type, InsertionPoint
10+
11+
12+
def op_region_builder(op, op_region, terminator=None):
13+
def builder_wrapper(body_builder):
14+
# Add a block with block args having types determined by type hints on the wrapped function.
15+
if len(op_region.blocks) == 0:
16+
sig = inspect.signature(body_builder)
17+
types = [p.annotation for p in sig.parameters.values()]
18+
if not (
19+
len(types) == len(sig.parameters)
20+
and all(isinstance(t, Type) for t in types)
21+
):
22+
raise ValueError(
23+
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
24+
)
25+
26+
op_region.blocks.append(*types)
27+
28+
with InsertionPoint(op_region.blocks[0]):
29+
results = body_builder(*list(op_region.blocks[0].arguments))
30+
31+
with InsertionPoint(list(op_region.blocks)[-1]):
32+
if terminator is not None:
33+
res = []
34+
if isinstance(results, (tuple, list)):
35+
res.extend(results)
36+
elif results is not None:
37+
res.append(results)
38+
terminator(res)
39+
40+
return get_op_result_or_op_results(op)
41+
42+
return builder_wrapper
43+
44+
45+
def region_op(op_constructor, terminator=None):
46+
"""Decorator to define an MLIR Op specified as a python function.
47+
48+
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
49+
active for the current thread (i.e. established in a `with` block).
50+
51+
Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor.
52+
53+
When applied as a decorator to a Python function, an entry block will
54+
be constructed for the Op with types as specified **as type hints on the args of the function**.
55+
The block arguments will be passed positionally to the Python function.
56+
57+
If a terminator is specified then the return from the decorated function will be passed
58+
to the terminator as the last statement in the entry block. Note, the API for the terminator
59+
is a (possibly empty) list; terminator accepting single values should be wrapped in a
60+
`lambda args: term(args[0])`
61+
62+
The identifier (name) of the function will become:
63+
1. A single value result if the Op returns a single value;
64+
2. An OpResultList (as a list) if the Op returns multiple values;
65+
3. The Operation if the Op returns no results.
66+
67+
See examples in tensor.py and transform.extras.
68+
"""
69+
70+
def op_decorator(*args, **kwargs):
71+
op = op_constructor(*args, **kwargs)
72+
op_region = op.regions[0]
73+
74+
return op_region_builder(op, op_region, terminator)
75+
76+
@wraps(op_decorator)
77+
def maybe_no_args(*args, **kwargs):
78+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
79+
return op_decorator()(args[0])
80+
else:
81+
return op_decorator(*args, **kwargs)
82+
83+
return maybe_no_args

mlir/test/python/dialects/arith_dialect.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,20 @@ def __str__(self):
7575
f64_t = F64Type.get()
7676

7777
with InsertionPoint(module.body):
78-
a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
78+
a = arith.constant(f16_t, 42.42)
7979
# CHECK: ArithValue(%cst = arith.constant 4.240
8080
print(a)
8181

8282
b = a + a
8383
# CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
8484
print(b)
8585

86-
a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
86+
a = arith.constant(f32_t, 42.42)
8787
b = a - a
8888
# CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
8989
print(b)
9090

91-
a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
91+
a = arith.constant(f64_t, 42.42)
9292
b = a * a
9393
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
9494
print(b)

mlir/test/python/dialects/tensor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import mlir.dialects.arith as arith
55
import mlir.dialects.func as func
66
import mlir.dialects.tensor as tensor
7+
from mlir.extras import types as T
78

89

910
def run(f):
@@ -139,3 +140,37 @@ def default_builder():
139140
t = tensor.FromElementsOp(RankedTensorType.get((1, 2), f32), [c0, c1])
140141
# CHECK: %{{.*}} = "tensor.from_elements"(%[[C0]], %[[C1]]) : (f32, f32) -> tensor<1x2xf32>
141142
print(t)
143+
144+
145+
# CHECK-LABEL: TEST: testGenerateRegionOp
146+
@run
147+
def testGenerateRegionOp():
148+
S = ShapedType.get_dynamic_size()
149+
with Context(), Location.unknown():
150+
module = Module.create()
151+
with InsertionPoint(module.body):
152+
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
153+
# CHECK: %[[VAL_1:.*]] = arith.constant 2 : index
154+
one = arith.constant(T.index(), 1)
155+
two = arith.constant(T.index(), 2)
156+
157+
@tensor.generate(T.tensor(S, 3, S, T.index()), dynamic_extents=[one, two])
158+
def generate_one(i: T.index(), j: T.index(), k: T.index()):
159+
ij = arith.addi(i, j)
160+
ijk = arith.addi(ij, k)
161+
return ijk
162+
163+
assert (
164+
isinstance(generate_one, Value)
165+
and generate_one.owner.name == "tensor.generate"
166+
)
167+
168+
# CHECK: %[[GENERATED:.*]] = tensor.generate
169+
# CHECK-SAME: %[[VAL_0]],
170+
# CHECK-SAME: %[[VAL_1]] {
171+
# CHECK: ^bb0(%[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index):
172+
# CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
173+
# CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_3]] : index
174+
# CHECK: tensor.yield %[[VAL_5]] : index
175+
# CHECK: } : tensor<?x3x?xindex>
176+
print(module)

0 commit comments

Comments
 (0)