-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Transform] apply_registered_op fixes: arg order & python options auto-conversion #143779
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ns auto-conversion
@llvm/pr-subscribers-mlir Author: Rolf Morel (rolfmorel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/143779.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index f75ba27e58e76..0aa750e625436 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -434,10 +434,10 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
of targeted ops.
}];
- let arguments = (ins StrAttr:$pass_name,
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ StrAttr:$pass_name,
DefaultValuedAttr<DictionaryAttr, "{}">:$options,
- Variadic<TransformParamTypeInterface>:$dynamic_options,
- TransformHandleTypeInterface:$target);
+ Variadic<TransformParamTypeInterface>:$dynamic_options);
let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = [{
$pass_name (`with` `options` `=`
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 10a04b0cc14e0..bfe96b1b3e5d4 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -224,13 +224,13 @@ class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
def __init__(
self,
result: Type,
- pass_name: Union[str, StringAttr],
target: Union[Operation, Value, OpView],
+ pass_name: Union[str, StringAttr],
*,
options: Optional[
Dict[
Union[str, StringAttr],
- Union[Attribute, Value, Operation, OpView],
+ Union[Attribute, Value, Operation, OpView, str, int, bool],
]
] = None,
loc=None,
@@ -253,17 +253,21 @@ def __init__(
cur_param_operand_idx += 1
elif isinstance(value, Attribute):
options_dict[key] = value
+ # The following cases auto-convert Python values to attributes.
+ elif isinstance(value, bool):
+ options_dict[key] = BoolAttr.get(value)
+ elif isinstance(value, int):
+ default_int_type = IntegerType.get_signless(64, context)
+ options_dict[key] = IntegerAttr.get(default_int_type, value)
elif isinstance(value, str):
options_dict[key] = StringAttr.get(value)
else:
raise TypeError(f"Unsupported option type: {type(value)}")
- if len(options_dict) > 0:
- print(options_dict, cur_param_operand_idx)
super().__init__(
result,
+ _get_op_result_or_value(target),
pass_name,
dynamic_options,
- target=_get_op_result_or_value(target),
options=DictAttr.get(options_dict),
loc=loc,
ip=ip,
@@ -272,13 +276,13 @@ def __init__(
def apply_registered_pass(
result: Type,
- pass_name: Union[str, StringAttr],
target: Union[Operation, Value, OpView],
+ pass_name: Union[str, StringAttr],
*,
options: Optional[
Dict[
Union[str, StringAttr],
- Union[Attribute, Value, Operation, OpView],
+ Union[Attribute, Value, Operation, OpView, str, int, bool],
]
] = None,
loc=None,
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 6e6d4eb7e249f..1d1be9eda3496 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} {
"test-convergence" = true,
"max-num-rewrites" = %max_rewrites }
to %1
- : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
transform.yield
}
}
@@ -171,7 +171,6 @@ func.func @invalid_options_as_str() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
// expected-error @+2 {{expected '{' in options dictionary}}
%2 = transform.apply_registered_pass "canonicalize"
with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
@@ -256,7 +255,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @+2 {{expected '{' in options dictionary}}
transform.apply_registered_pass "canonicalize"
with options = %pass_options to %1
- : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param) -> !transform.any_op
transform.yield
}
}
@@ -276,7 +275,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
transform.apply_registered_pass "canonicalize"
with options = { "top-down" = %topdown_options } to %1
- : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param) -> !transform.any_op
transform.yield
}
}
@@ -316,12 +315,12 @@ module attributes {transform.with_named_sequence} {
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
// expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
- %2 = "transform.apply_registered_pass"(%1, %0) <{
+ %2 = "transform.apply_registered_pass"(%0, %1) <{
options = {"max-iterations" = #transform.param_operand<index=1 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
- : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
@@ -340,13 +339,13 @@ module attributes {transform.with_named_sequence} {
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
// expected-error @below {{dynamic option index 0 is already used in options}}
- %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+ %3 = "transform.apply_registered_pass"(%0, %1, %2) <{
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
"max-num-rewrites" = #transform.param_operand<index=0 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
- : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
@@ -364,12 +363,12 @@ module attributes {transform.with_named_sequence} {
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
// expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}}
- %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+ %3 = "transform.apply_registered_pass"(%0, %1, %2) <{
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
- : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 48bc9bad37a1e..eeb95605d7a9a 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -263,12 +263,12 @@ def testApplyRegisteredPassOp(module: Module):
)
with InsertionPoint(sequence.body):
mod = transform.ApplyRegisteredPassOp(
- transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+ transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
)
mod = transform.ApplyRegisteredPassOp(
transform.AnyOpType.get(),
- "canonicalize",
mod.result,
+ "canonicalize",
options={"top-down": BoolAttr.get(False)},
)
max_iter = transform.param_constant(
@@ -281,12 +281,12 @@ def testApplyRegisteredPassOp(module: Module):
)
transform.apply_registered_pass(
transform.AnyOpType.get(),
- "canonicalize",
mod,
+ "canonicalize",
options={
"top-down": BoolAttr.get(False),
"max-iterations": max_iter,
- "test-convergence": BoolAttr.get(True),
+ "test-convergence": True,
"max-rewrites": max_rewrites,
},
)
@@ -305,4 +305,4 @@ def testApplyRegisteredPassOp(module: Module):
# CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
# CHECK-SAME: "test-convergence" = true,
# CHECK-SAME: "top-down" = false}
- # CHECK-SAME: to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ # CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
|
Merging without review so that downstream users don't need to deal with op's arg order having been different from what it was before: #142683 |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/30544 Here is the relevant piece of the build log for the reference
|
Please keep the habit of providing a comprehensive description of the PR. The title barely gives an idea of the "what" but we're missing the context on this commit. Also, skipping review when something is broken (or you're addressing trivial post-commit comments, or ...) is always OK, but your mention of a downstream convenience does not seem like an urgency to land to me: downstream need to be able to handle their integration with local patches or delays in fixes if an issue does not reproduce or show upstream (but here without the description I can't really say about the seriousness of the issue). |
* llvm/llvm-project#139340 ``` sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` * llvm/llvm-project#141466 & llvm/llvm-project#141019 * Add `BufferizationState &state` to `bufferize` and `getBuffer` * llvm/llvm-project#143159 & llvm/llvm-project#142683 & llvm/llvm-project#143779 * Updates to `transform.apply_registered_pass` and its Python-bindings * llvm/llvm-project#143217 * `tilingResult->mergeResult.replacements` -> `tilingResult->replacements` * llvm/llvm-project#140559 & llvm/llvm-project#143871 * Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s & fix which enables conversion again.
…ns auto-conversion (llvm#143779)
No description provided.