Skip to content

[MLIR][Transform] apply_registered_op fixes: arg order & python options auto-conversion #143779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,10 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
of targeted ops.
}];

let arguments = (ins StrAttr:$pass_name,
let arguments = (ins TransformHandleTypeInterface:$target,
StrAttr:$pass_name,
DefaultValuedAttr<DictionaryAttr, "{}">:$options,
Variadic<TransformParamTypeInterface>:$dynamic_options,
TransformHandleTypeInterface:$target);
Variadic<TransformParamTypeInterface>:$dynamic_options);
let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = [{
$pass_name (`with` `options` `=`
Expand Down
18 changes: 11 additions & 7 deletions mlir/python/mlir/dialects/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
def __init__(
self,
result: Type,
pass_name: Union[str, StringAttr],
target: Union[Operation, Value, OpView],
pass_name: Union[str, StringAttr],
*,
options: Optional[
Dict[
Union[str, StringAttr],
Union[Attribute, Value, Operation, OpView],
Union[Attribute, Value, Operation, OpView, str, int, bool],
]
] = None,
loc=None,
Expand All @@ -253,17 +253,21 @@ def __init__(
cur_param_operand_idx += 1
elif isinstance(value, Attribute):
options_dict[key] = value
# The following cases auto-convert Python values to attributes.
elif isinstance(value, bool):
options_dict[key] = BoolAttr.get(value)
elif isinstance(value, int):
default_int_type = IntegerType.get_signless(64, context)
options_dict[key] = IntegerAttr.get(default_int_type, value)
elif isinstance(value, str):
options_dict[key] = StringAttr.get(value)
else:
raise TypeError(f"Unsupported option type: {type(value)}")
if len(options_dict) > 0:
print(options_dict, cur_param_operand_idx)
super().__init__(
result,
_get_op_result_or_value(target),
pass_name,
dynamic_options,
target=_get_op_result_or_value(target),
options=DictAttr.get(options_dict),
loc=loc,
ip=ip,
Expand All @@ -272,13 +276,13 @@ def __init__(

def apply_registered_pass(
result: Type,
pass_name: Union[str, StringAttr],
target: Union[Operation, Value, OpView],
pass_name: Union[str, StringAttr],
*,
options: Optional[
Dict[
Union[str, StringAttr],
Union[Attribute, Value, Operation, OpView],
Union[Attribute, Value, Operation, OpView, str, int, bool],
]
] = None,
loc=None,
Expand Down
19 changes: 9 additions & 10 deletions mlir/test/Dialect/Transform/test-pass-application.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} {
"test-convergence" = true,
"max-num-rewrites" = %max_rewrites }
to %1
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
: (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
transform.yield
}
}
Expand All @@ -171,7 +171,6 @@ func.func @invalid_options_as_str() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
// expected-error @+2 {{expected '{' in options dictionary}}
%2 = transform.apply_registered_pass "canonicalize"
with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
Expand Down Expand Up @@ -256,7 +255,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @+2 {{expected '{' in options dictionary}}
transform.apply_registered_pass "canonicalize"
with options = %pass_options to %1
: (!transform.any_param, !transform.any_op) -> !transform.any_op
: (!transform.any_op, !transform.any_param) -> !transform.any_op
transform.yield
}
}
Expand All @@ -276,7 +275,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
transform.apply_registered_pass "canonicalize"
with options = { "top-down" = %topdown_options } to %1
: (!transform.any_param, !transform.any_op) -> !transform.any_op
: (!transform.any_op, !transform.any_param) -> !transform.any_op
transform.yield
}
}
Expand Down Expand Up @@ -316,12 +315,12 @@ module attributes {transform.with_named_sequence} {
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
// expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
%2 = "transform.apply_registered_pass"(%1, %0) <{
%2 = "transform.apply_registered_pass"(%0, %1) <{
options = {"max-iterations" = #transform.param_operand<index=1 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
: (!transform.any_param, !transform.any_op) -> !transform.any_op
: (!transform.any_op, !transform.any_param) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
Expand All @@ -340,13 +339,13 @@ module attributes {transform.with_named_sequence} {
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
// expected-error @below {{dynamic option index 0 is already used in options}}
%3 = "transform.apply_registered_pass"(%1, %2, %0) <{
%3 = "transform.apply_registered_pass"(%0, %1, %2) <{
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
"max-num-rewrites" = #transform.param_operand<index=0 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
: (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
Expand All @@ -364,12 +363,12 @@ module attributes {transform.with_named_sequence} {
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
// expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}}
%3 = "transform.apply_registered_pass"(%1, %2, %0) <{
%3 = "transform.apply_registered_pass"(%0, %1, %2) <{
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
: (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
10 changes: 5 additions & 5 deletions mlir/test/python/dialects/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,12 @@ def testApplyRegisteredPassOp(module: Module):
)
with InsertionPoint(sequence.body):
mod = transform.ApplyRegisteredPassOp(
transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
)
mod = transform.ApplyRegisteredPassOp(
transform.AnyOpType.get(),
"canonicalize",
mod.result,
"canonicalize",
options={"top-down": BoolAttr.get(False)},
)
max_iter = transform.param_constant(
Expand All @@ -281,12 +281,12 @@ def testApplyRegisteredPassOp(module: Module):
)
transform.apply_registered_pass(
transform.AnyOpType.get(),
"canonicalize",
mod,
"canonicalize",
options={
"top-down": BoolAttr.get(False),
"max-iterations": max_iter,
"test-convergence": BoolAttr.get(True),
"test-convergence": True,
"max-rewrites": max_rewrites,
},
)
Expand All @@ -305,4 +305,4 @@ def testApplyRegisteredPassOp(module: Module):
# CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
# CHECK-SAME: "test-convergence" = true,
# CHECK-SAME: "top-down" = false}
# CHECK-SAME: to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
# CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
Loading