Skip to content

Commit 461d7cf

Browse files
committed
[MLIR][Transform] friendlier Python-bindings apply_registered_pass op
In particular, use similar syntax for providing options as in the (pretty-)printed IR.
1 parent b6364ab commit 461d7cf

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,41 @@ def __init__(
214214
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
215215

216216

217+
@_ods_cext.register_operation(_Dialect, replace=True)
218+
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
219+
def __init__(
220+
self,
221+
result: Type,
222+
pass_name: Union[str, StringAttr],
223+
target: Value,
224+
*,
225+
options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
226+
loc=None,
227+
ip=None,
228+
):
229+
static_options = []
230+
dynamic_options = []
231+
for opt in options:
232+
if isinstance(opt, str):
233+
static_options.append(StringAttr.get(opt))
234+
elif isinstance(opt, StringAttr):
235+
static_options.append(opt)
236+
elif isinstance(opt, Value):
237+
static_options.append(UnitAttr.get())
238+
dynamic_options.append(_get_op_result_or_value(opt))
239+
else:
240+
raise TypeError(f"Unsupported option type: {type(opt)}")
241+
super().__init__(
242+
result,
243+
pass_name,
244+
dynamic_options,
245+
target=_get_op_result_or_value(target),
246+
options=static_options,
247+
loc=loc,
248+
ip=ip,
249+
)
250+
251+
217252
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
218253

219254

mlir/test/python/dialects/transform.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,39 @@ def testReplicateOp(module: Module):
254254
# CHECK: %[[FIRST:.+]] = pdl_match
255255
# CHECK: %[[SECOND:.+]] = pdl_match
256256
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
257+
258+
259+
@run
260+
def testApplyRegisteredPassOp(module: Module):
261+
sequence = transform.SequenceOp(
262+
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
263+
)
264+
with InsertionPoint(sequence.body):
265+
mod = transform.ApplyRegisteredPassOp(
266+
transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
267+
)
268+
mod = transform.ApplyRegisteredPassOp(
269+
transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
270+
)
271+
max_iter = transform.param_constant(
272+
transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
273+
)
274+
max_rewrites = transform.param_constant(
275+
transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
276+
)
277+
transform.ApplyRegisteredPassOp(
278+
transform.AnyOpType.get(),
279+
"canonicalize",
280+
mod,
281+
options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
282+
)
283+
transform.YieldOp()
284+
# CHECK-LABEL: TEST: testApplyRegisteredPassOp
285+
# CHECK: transform.sequence
286+
# CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
287+
# CHECK: %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
288+
# CHECK: %[[MAX_ITER:.+]] = transform.param.constant
289+
# CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
290+
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
291+
# CHECK-SAME: with options = "top-down=false" %[[MAX_ITER]]
292+
# CHECK-SAME: "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op

0 commit comments

Comments
 (0)