Skip to content

Commit dd473f1

Browse files
authored
[mlir][python] simplify extensions (#69642)
#68853 enabled a lot of nice cleanup. Note, I made sure each of the touched extensions had tests.
1 parent dda3ed9 commit dd473f1

File tree

8 files changed

+13
-217
lines changed

8 files changed

+13
-217
lines changed

mlir/python/mlir/dialects/affine.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,48 +3,3 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._affine_ops_gen import *
6-
from ._affine_ops_gen import _Dialect
7-
8-
try:
9-
from ..ir import *
10-
from ._ods_common import (
11-
get_op_result_or_value as _get_op_result_or_value,
12-
get_op_results_or_values as _get_op_results_or_values,
13-
_cext as _ods_cext,
14-
)
15-
except ImportError as e:
16-
raise RuntimeError("Error loading imports from extension module") from e
17-
18-
from typing import Optional, Sequence, Union
19-
20-
21-
@_ods_cext.register_operation(_Dialect, replace=True)
22-
class AffineStoreOp(AffineStoreOp):
23-
"""Specialization for the Affine store operation."""
24-
25-
def __init__(
26-
self,
27-
value: Union[Operation, OpView, Value],
28-
memref: Union[Operation, OpView, Value],
29-
map: AffineMap = None,
30-
*,
31-
map_operands=None,
32-
loc=None,
33-
ip=None,
34-
):
35-
"""Creates an affine store operation.
36-
37-
- `value`: the value to store into the memref.
38-
- `memref`: the buffer to store into.
39-
- `map`: the affine map that maps the map_operands to the index of the
40-
memref.
41-
- `map_operands`: the list of arguments to substitute the dimensions,
42-
then symbols in the affine map, in increasing order.
43-
"""
44-
map = map if map is not None else []
45-
map_operands = map_operands if map_operands is not None else []
46-
indicies = [_get_op_result_or_value(op) for op in map_operands]
47-
_ods_successors = None
48-
super().__init__(
49-
value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
50-
)

mlir/python/mlir/dialects/bufferization.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,4 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._bufferization_ops_gen import *
6-
from ._bufferization_ops_gen import _Dialect
76
from ._bufferization_enum_gen import *
8-
9-
try:
10-
from typing import Sequence, Union
11-
from ..ir import *
12-
from ._ods_common import get_default_loc_context, _cext as _ods_cext
13-
14-
from typing import Any, List, Union
15-
except ImportError as e:
16-
raise RuntimeError("Error loading imports from extension module") from e
17-
18-
19-
@_ods_cext.register_operation(_Dialect, replace=True)
20-
class AllocTensorOp(AllocTensorOp):
21-
"""Extends the bufferization.alloc_tensor op."""
22-
23-
def __init__(
24-
self,
25-
tensor_type: Type,
26-
dynamic_sizes: Sequence[Value],
27-
copy: Value,
28-
size_hint: Value,
29-
escape: BoolAttr,
30-
*,
31-
loc=None,
32-
ip=None,
33-
):
34-
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
35-
super().__init__(
36-
tensor_type,
37-
dynamic_sizes,
38-
copy=copy,
39-
size_hint=size_hint,
40-
loc=loc,
41-
ip=ip,
42-
)

mlir/python/mlir/dialects/func.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626
class ConstantOp(ConstantOp):
2727
"""Specialization for the constant op class."""
2828

29-
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
30-
super().__init__(result, value, loc=loc, ip=ip)
31-
3229
@property
3330
def type(self):
3431
return self.results[0].type

mlir/python/mlir/dialects/memref.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,3 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._memref_ops_gen import *
6-
from ._memref_ops_gen import _Dialect
7-
8-
try:
9-
from ..ir import *
10-
from ._ods_common import (
11-
get_op_result_or_value as _get_op_result_or_value,
12-
get_op_results_or_values as _get_op_results_or_values,
13-
_cext as _ods_cext,
14-
)
15-
except ImportError as e:
16-
raise RuntimeError("Error loading imports from extension module") from e
17-
18-
from typing import Optional, Sequence, Union
19-
20-
21-
@_ods_cext.register_operation(_Dialect, replace=True)
22-
class LoadOp(LoadOp):
23-
"""Specialization for the MemRef load operation."""
24-
25-
def __init__(
26-
self,
27-
memref: Union[Operation, OpView, Value],
28-
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
29-
*,
30-
loc=None,
31-
ip=None,
32-
):
33-
"""Creates a memref load operation.
34-
35-
Args:
36-
memref: the buffer to load from.
37-
indices: the list of subscripts, may be empty for zero-dimensional
38-
buffers.
39-
loc: user-visible location of the operation.
40-
ip: insertion point.
41-
"""
42-
indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
43-
super().__init__(memref, indices_resolved, loc=loc, ip=ip)

mlir/python/mlir/dialects/pdl.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,6 @@
2121
)
2222

2323

24-
@_ods_cext.register_operation(_Dialect, replace=True)
25-
class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
26-
"""Specialization for PDL apply native constraint op class."""
27-
28-
def __init__(
29-
self,
30-
name: Union[str, StringAttr],
31-
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
32-
*,
33-
loc=None,
34-
ip=None,
35-
):
36-
if args is None:
37-
args = []
38-
args = _get_values(args)
39-
super().__init__(name, args, loc=loc, ip=ip)
40-
41-
42-
@_ods_cext.register_operation(_Dialect, replace=True)
43-
class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
44-
"""Specialization for PDL apply native rewrite op class."""
45-
46-
def __init__(
47-
self,
48-
results: Sequence[Type],
49-
name: Union[str, StringAttr],
50-
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
51-
*,
52-
loc=None,
53-
ip=None,
54-
):
55-
if args is None:
56-
args = []
57-
args = _get_values(args)
58-
super().__init__(results, name, args, loc=loc, ip=ip)
59-
60-
6124
@_ods_cext.register_operation(_Dialect, replace=True)
6225
class AttributeOp(AttributeOp):
6326
"""Specialization for PDL attribute op class."""
@@ -75,21 +38,6 @@ def __init__(
7538
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
7639

7740

78-
@_ods_cext.register_operation(_Dialect, replace=True)
79-
class EraseOp(EraseOp):
80-
"""Specialization for PDL erase op class."""
81-
82-
def __init__(
83-
self,
84-
operation: Optional[Union[OpView, Operation, Value]] = None,
85-
*,
86-
loc=None,
87-
ip=None,
88-
):
89-
operation = _get_value(operation)
90-
super().__init__(operation, loc=loc, ip=ip)
91-
92-
9341
@_ods_cext.register_operation(_Dialect, replace=True)
9442
class OperandOp(OperandOp):
9543
"""Specialization for PDL operand op class."""
@@ -216,23 +164,6 @@ def __init__(
216164
super().__init__(result, parent, index, loc=loc, ip=ip)
217165

218166

219-
@_ods_cext.register_operation(_Dialect, replace=True)
220-
class ResultsOp(ResultsOp):
221-
"""Specialization for PDL results op class."""
222-
223-
def __init__(
224-
self,
225-
result: Type,
226-
parent: Union[OpView, Operation, Value],
227-
index: Optional[Union[IntegerAttr, int]] = None,
228-
*,
229-
loc=None,
230-
ip=None,
231-
):
232-
parent = _get_value(parent)
233-
super().__init__(result, parent, index=index, loc=loc, ip=ip)
234-
235-
236167
@_ods_cext.register_operation(_Dialect, replace=True)
237168
class RewriteOp(RewriteOp):
238169
"""Specialization for PDL rewrite op class."""

mlir/python/mlir/dialects/scf.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020
from typing import Optional, Sequence, Union
2121

2222

23-
_ForOp = ForOp
24-
25-
2623
@_ods_cext.register_operation(_Dialect, replace=True)
27-
class ForOp(_ForOp):
24+
class ForOp(ForOp):
2825
"""Specialization for the SCF for op class."""
2926

3027
def __init__(
@@ -50,17 +47,8 @@ def __init__(
5047
iter_args = _get_op_results_or_values(iter_args)
5148

5249
results = [arg.type for arg in iter_args]
53-
super(_ForOp, self).__init__(
54-
self.build_generic(
55-
regions=1,
56-
results=results,
57-
operands=[
58-
_get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
59-
]
60-
+ list(iter_args),
61-
loc=loc,
62-
ip=ip,
63-
)
50+
super().__init__(
51+
results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
6452
)
6553
self.regions[0].blocks.append(self.operands[0].type, *results)
6654

@@ -83,28 +71,23 @@ def inner_iter_args(self):
8371
return self.body.arguments[1:]
8472

8573

86-
_IfOp = IfOp
87-
88-
8974
@_ods_cext.register_operation(_Dialect, replace=True)
90-
class IfOp(_IfOp):
75+
class IfOp(IfOp):
9176
"""Specialization for the SCF if op class."""
9277

93-
def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
78+
def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
9479
"""Creates an SCF `if` operation.
9580
9681
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
9782
- `hasElse` determines whether the if operation has the else branch.
9883
"""
84+
if results_ is None:
85+
results_ = []
9986
operands = []
10087
operands.append(cond)
10188
results = []
10289
results.extend(results_)
103-
super(_IfOp, self).__init__(
104-
self.build_generic(
105-
regions=2, results=results, operands=operands, loc=loc, ip=ip
106-
)
107-
)
90+
super().__init__(results, cond)
10891
self.regions[0].blocks.append(*[])
10992
if hasElse:
11093
self.regions[1].blocks.append(*[])

mlir/test/python/dialects/affine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def affine_store_test(arg0):
3737
a1 = arith.ConstantOp(f32, 2.1)
3838

3939
# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
40-
affine.AffineStoreOp(a1, mem, map, map_operands=[arg0, arg0])
40+
affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
4141

4242
return mem
4343

mlir/test/python/dialects/func.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def testFunctionCalls():
8484
qux = func.FuncOp("qux", ([], [F32Type.get()]))
8585
qux.sym_visibility = StringAttr.get("private")
8686

87+
con = func.ConstantOp(qux.type, qux.sym_name.value)
88+
assert con.type == qux.type
89+
8790
with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
8891
func.CallOp(foo, [])
8992
func.CallOp([IndexType.get()], "bar", [])
@@ -94,6 +97,7 @@ def testFunctionCalls():
9497
# CHECK: func private @foo()
9598
# CHECK: func private @bar() -> index
9699
# CHECK: func private @qux() -> f32
100+
# CHECK: %f = func.constant @qux : () -> f32
97101
# CHECK: func @caller() {
98102
# CHECK: call @foo() : () -> ()
99103
# CHECK: %0 = call @bar() : () -> index

0 commit comments

Comments
 (0)