Skip to content

Commit 4d6acde

Browse files
rolfmoreltomtor
authored andcommitted
[MLIR][Transform] apply_registered_op fixes: arg order & python options auto-conversion (llvm#143779)
1 parent 9fdbfa5 commit 4d6acde

File tree

4 files changed

+28
-25
lines changed

4 files changed

+28
-25
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,10 +434,10 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
434434
of targeted ops.
435435
}];
436436

437-
let arguments = (ins StrAttr:$pass_name,
437+
let arguments = (ins TransformHandleTypeInterface:$target,
438+
StrAttr:$pass_name,
438439
DefaultValuedAttr<DictionaryAttr, "{}">:$options,
439-
Variadic<TransformParamTypeInterface>:$dynamic_options,
440-
TransformHandleTypeInterface:$target);
440+
Variadic<TransformParamTypeInterface>:$dynamic_options);
441441
let results = (outs TransformHandleTypeInterface:$result);
442442
let assemblyFormat = [{
443443
$pass_name (`with` `options` `=`

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,13 @@ class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
224224
def __init__(
225225
self,
226226
result: Type,
227-
pass_name: Union[str, StringAttr],
228227
target: Union[Operation, Value, OpView],
228+
pass_name: Union[str, StringAttr],
229229
*,
230230
options: Optional[
231231
Dict[
232232
Union[str, StringAttr],
233-
Union[Attribute, Value, Operation, OpView],
233+
Union[Attribute, Value, Operation, OpView, str, int, bool],
234234
]
235235
] = None,
236236
loc=None,
@@ -253,17 +253,21 @@ def __init__(
253253
cur_param_operand_idx += 1
254254
elif isinstance(value, Attribute):
255255
options_dict[key] = value
256+
# The following cases auto-convert Python values to attributes.
257+
elif isinstance(value, bool):
258+
options_dict[key] = BoolAttr.get(value)
259+
elif isinstance(value, int):
260+
default_int_type = IntegerType.get_signless(64, context)
261+
options_dict[key] = IntegerAttr.get(default_int_type, value)
256262
elif isinstance(value, str):
257263
options_dict[key] = StringAttr.get(value)
258264
else:
259265
raise TypeError(f"Unsupported option type: {type(value)}")
260-
if len(options_dict) > 0:
261-
print(options_dict, cur_param_operand_idx)
262266
super().__init__(
263267
result,
268+
_get_op_result_or_value(target),
264269
pass_name,
265270
dynamic_options,
266-
target=_get_op_result_or_value(target),
267271
options=DictAttr.get(options_dict),
268272
loc=loc,
269273
ip=ip,
@@ -272,13 +276,13 @@ def __init__(
272276

273277
def apply_registered_pass(
274278
result: Type,
275-
pass_name: Union[str, StringAttr],
276279
target: Union[Operation, Value, OpView],
280+
pass_name: Union[str, StringAttr],
277281
*,
278282
options: Optional[
279283
Dict[
280284
Union[str, StringAttr],
281-
Union[Attribute, Value, Operation, OpView],
285+
Union[Attribute, Value, Operation, OpView, str, int, bool],
282286
]
283287
] = None,
284288
loc=None,

mlir/test/Dialect/Transform/test-pass-application.mlir

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} {
157157
"test-convergence" = true,
158158
"max-num-rewrites" = %max_rewrites }
159159
to %1
160-
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
160+
: (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
161161
transform.yield
162162
}
163163
}
@@ -171,7 +171,6 @@ func.func @invalid_options_as_str() {
171171
module attributes {transform.with_named_sequence} {
172172
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
173173
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
174-
%max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
175174
// expected-error @+2 {{expected '{' in options dictionary}}
176175
%2 = transform.apply_registered_pass "canonicalize"
177176
with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
@@ -256,7 +255,7 @@ module attributes {transform.with_named_sequence} {
256255
// expected-error @+2 {{expected '{' in options dictionary}}
257256
transform.apply_registered_pass "canonicalize"
258257
with options = %pass_options to %1
259-
: (!transform.any_param, !transform.any_op) -> !transform.any_op
258+
: (!transform.any_op, !transform.any_param) -> !transform.any_op
260259
transform.yield
261260
}
262261
}
@@ -276,7 +275,7 @@ module attributes {transform.with_named_sequence} {
276275
// expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
277276
transform.apply_registered_pass "canonicalize"
278277
with options = { "top-down" = %topdown_options } to %1
279-
: (!transform.any_param, !transform.any_op) -> !transform.any_op
278+
: (!transform.any_op, !transform.any_param) -> !transform.any_op
280279
transform.yield
281280
}
282281
}
@@ -316,12 +315,12 @@ module attributes {transform.with_named_sequence} {
316315
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
317316
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
318317
// expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
319-
%2 = "transform.apply_registered_pass"(%1, %0) <{
318+
%2 = "transform.apply_registered_pass"(%0, %1) <{
320319
options = {"max-iterations" = #transform.param_operand<index=1 : i64>,
321320
"test-convergence" = true,
322321
"top-down" = false},
323322
pass_name = "canonicalize"}>
324-
: (!transform.any_param, !transform.any_op) -> !transform.any_op
323+
: (!transform.any_op, !transform.any_param) -> !transform.any_op
325324
"transform.yield"() : () -> ()
326325
}) : () -> ()
327326
}) {transform.with_named_sequence} : () -> ()
@@ -340,13 +339,13 @@ module attributes {transform.with_named_sequence} {
340339
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
341340
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
342341
// expected-error @below {{dynamic option index 0 is already used in options}}
343-
%3 = "transform.apply_registered_pass"(%1, %2, %0) <{
342+
%3 = "transform.apply_registered_pass"(%0, %1, %2) <{
344343
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
345344
"max-num-rewrites" = #transform.param_operand<index=0 : i64>,
346345
"test-convergence" = true,
347346
"top-down" = false},
348347
pass_name = "canonicalize"}>
349-
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
348+
: (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
350349
"transform.yield"() : () -> ()
351350
}) : () -> ()
352351
}) {transform.with_named_sequence} : () -> ()
@@ -364,12 +363,12 @@ module attributes {transform.with_named_sequence} {
364363
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
365364
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
366365
// expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}}
367-
%3 = "transform.apply_registered_pass"(%1, %2, %0) <{
366+
%3 = "transform.apply_registered_pass"(%0, %1, %2) <{
368367
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
369368
"test-convergence" = true,
370369
"top-down" = false},
371370
pass_name = "canonicalize"}>
372-
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
371+
: (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
373372
"transform.yield"() : () -> ()
374373
}) : () -> ()
375374
}) {transform.with_named_sequence} : () -> ()

mlir/test/python/dialects/transform.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,12 +263,12 @@ def testApplyRegisteredPassOp(module: Module):
263263
)
264264
with InsertionPoint(sequence.body):
265265
mod = transform.ApplyRegisteredPassOp(
266-
transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
266+
transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
267267
)
268268
mod = transform.ApplyRegisteredPassOp(
269269
transform.AnyOpType.get(),
270-
"canonicalize",
271270
mod.result,
271+
"canonicalize",
272272
options={"top-down": BoolAttr.get(False)},
273273
)
274274
max_iter = transform.param_constant(
@@ -281,12 +281,12 @@ def testApplyRegisteredPassOp(module: Module):
281281
)
282282
transform.apply_registered_pass(
283283
transform.AnyOpType.get(),
284-
"canonicalize",
285284
mod,
285+
"canonicalize",
286286
options={
287287
"top-down": BoolAttr.get(False),
288288
"max-iterations": max_iter,
289-
"test-convergence": BoolAttr.get(True),
289+
"test-convergence": True,
290290
"max-rewrites": max_rewrites,
291291
},
292292
)
@@ -305,4 +305,4 @@ def testApplyRegisteredPassOp(module: Module):
305305
# CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
306306
# CHECK-SAME: "test-convergence" = true,
307307
# CHECK-SAME: "top-down" = false}
308-
# CHECK-SAME: to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
308+
# CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op

0 commit comments

Comments
 (0)