Skip to content

Commit 579ced4

Browse files
authored
[MLIR][Python] Add structured.fuseop to python interpreter (llvm#120601)
Implements a python interface for structured.fuseOp allowing more freedom with inputs.
1 parent 8584991 commit 579ced4

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

mlir/python/mlir/dialects/transform/structured.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,77 @@ def __init__(
140140
)
141141

142142

143+
@_ods_cext.register_operation(_Dialect, replace=True)
144+
class FuseOp(FuseOp):
145+
"""Specialization for FuseOp class."""
146+
147+
@overload
148+
def __init__(
149+
self,
150+
loop_types: Union[Type, Sequence[Type]],
151+
target: Union[Operation, Value, OpView],
152+
*,
153+
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
154+
tile_interchange: OptionalIntList = None,
155+
apply_cleanup: Optional[bool] = False,
156+
loc=None,
157+
ip=None,
158+
):
159+
...
160+
161+
@overload
162+
def __init__(
163+
self,
164+
target: Union[Operation, Value, OpView],
165+
*,
166+
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
167+
tile_interchange: OptionalIntList = None,
168+
apply_cleanup: Optional[bool] = False,
169+
loc=None,
170+
ip=None,
171+
):
172+
...
173+
174+
def __init__(
175+
self,
176+
loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],
177+
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
178+
*,
179+
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
180+
tile_interchange: OptionalIntList = None,
181+
apply_cleanup: Optional[bool] = False,
182+
loc=None,
183+
ip=None,
184+
):
185+
tile_sizes = tile_sizes if tile_sizes else []
186+
tile_interchange = tile_interchange if tile_interchange else []
187+
_, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
188+
_, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
189+
num_loops = sum(0 if v == 0 else 1 for v in tile_sizes)
190+
191+
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
192+
loop_types = [transform.AnyOpType.get()] * num_loops
193+
target = loop_types_or_target
194+
assert target_or_none is None, "Cannot construct FuseOp with two targets."
195+
else:
196+
loop_types = (
197+
([loop_types_or_target] * num_loops)
198+
if isinstance(loop_types_or_target, Type)
199+
else loop_types_or_target
200+
)
201+
target = target_or_none
202+
super().__init__(
203+
target.type,
204+
loop_types,
205+
target,
206+
tile_sizes=tile_sizes,
207+
tile_interchange=tile_interchange,
208+
apply_cleanup=apply_cleanup,
209+
loc=loc,
210+
ip=ip,
211+
)
212+
213+
143214
@_ods_cext.register_operation(_Dialect, replace=True)
144215
class GeneralizeOp(GeneralizeOp):
145216
"""Specialization for GeneralizeOp class."""

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,42 @@ def testFuseIntoContainingOpCompact(target):
101101
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
102102

103103

104+
@run
105+
@create_sequence
106+
def testFuseOpCompact(target):
107+
structured.FuseOp(
108+
target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True
109+
)
110+
# CHECK-LABEL: TEST: testFuseOpCompact
111+
# CHECK: transform.sequence
112+
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
113+
# CHECK-SAME: interchange [0, 1] apply_cleanup = true
114+
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
115+
116+
117+
@run
118+
@create_sequence
119+
def testFuseOpNoArg(target):
120+
structured.FuseOp(target)
121+
# CHECK-LABEL: TEST: testFuseOpNoArg
122+
# CHECK: transform.sequence
123+
# CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
124+
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
125+
126+
127+
@run
128+
@create_sequence
129+
def testFuseOpAttributes(target):
130+
attr = DenseI64ArrayAttr.get([4, 8])
131+
ichange = DenseI64ArrayAttr.get([0, 1])
132+
structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
133+
# CHECK-LABEL: TEST: testFuseOpAttributes
134+
# CHECK: transform.sequence
135+
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
136+
# CHECK-SAME: interchange [0, 1]
137+
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
138+
139+
104140
@run
105141
@create_sequence
106142
def testGeneralize(target):

0 commit comments

Comments
 (0)