Skip to content

Commit a8d0b33

Browse files
committed
Address @adam-smnk's review
1 parent 079b3db commit a8d0b33

File tree

4 files changed

+34
-31
lines changed

4 files changed

+34
-31
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,15 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
4141
let cppNamespace = "::mlir::transform";
4242
}
4343

44-
def ParamOperandIndexAttr : Transform_Attr<"ParamOperandIndex",
45-
"param_operand_index" > {
46-
let mnemonic = "param_operand_index";
44+
def ParamOperandAttr : Transform_Attr<"ParamOperand", "param_operand"> {
4745
let description = [{
4846
Used to refer to a specific param-operand (via its index) from within an
4947
attribute on a transform operation.
5048
}];
5149
let parameters = (ins
5250
"IntegerAttr":$index
5351
);
54-
let assemblyFormat = "`<` $index `>`";
52+
let assemblyFormat = "`<` `index` `=` $index `>`";
5553
}
5654

5755
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -800,8 +800,8 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
800800
optionsStream << "="; // And the key-value separator.
801801

802802
Attribute valueAttrToAppend;
803-
if (auto paramOperandIndex = dyn_cast<transform::ParamOperandIndexAttr>(
804-
namedAttribute.getValue())) {
803+
if (auto paramOperandIndex =
804+
dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
805805
// The corresponding value attribute is passed in via a param.
806806
// Obtain the param-operand via its specified index.
807807
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
@@ -906,16 +906,20 @@ static ParseResult parseApplyRegisteredPassOptions(
906906
return parser.emitError(parser.getCurrentLocation())
907907
<< "expected a valid attribute or operand as value associated "
908908
<< "to key '" << key << "'";
909+
// To make use of the operand, we need to store it in the options dict.
910+
// As SSA-values cannot occur in attributes, what we do instead is store
911+
// an attribute in its place that contains the index of the param-operand,
912+
// so that an attr-value associated to the param can be resolved later on.
909913
dynamicOptions.push_back(operand);
910914
auto wrappedIndex = IntegerAttr::get(
911915
IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
912-
valueAttr = transform::ParamOperandIndexAttr::get(parser.getContext(),
913-
wrappedIndex);
916+
valueAttr =
917+
transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
914918
} else if (failed(parsedValueAttr.value())) {
915919
return failure(); // NB: Attempted parse should have output error message.
916-
} else if (isa<transform::ParamOperandIndexAttr>(valueAttr)) {
920+
} else if (isa<transform::ParamOperandAttr>(valueAttr)) {
917921
return parser.emitError(parser.getCurrentLocation())
918-
<< "the param_operand_index attribute is a marker reserved for "
922+
<< "the param_operand attribute is a marker reserved for "
919923
<< "indicating a value will be passed via params and is only used "
920924
<< "in the generic print format";
921925
}
@@ -951,7 +955,8 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
951955
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
952956
printer << namedAttribute.getName() << " = ";
953957
Attribute value = namedAttribute.getValue();
954-
if (auto indexAttr = dyn_cast<transform::ParamOperandIndexAttr>(value)) {
958+
if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
959+
// Resolve index of param-operand to its actual SSA-value and print that.
955960
printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
956961
} else {
957962
printer.printAttribute(value);
@@ -966,9 +971,9 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
966971

967972
auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
968973
for (NamedAttribute namedAttr : getOptions())
969-
if (auto paramOperandIndex =
970-
dyn_cast<transform::ParamOperandIndexAttr>(namedAttr.getValue())) {
971-
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
974+
if (auto paramOperand =
975+
dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
976+
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
972977
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
973978
return emitOpError()
974979
<< "dynamic option index " << dynamicOptionIdx
@@ -983,7 +988,7 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
983988
for (Value dynamicOption : dynamicOptions)
984989
if (dynamicOption)
985990
return emitOpError() << "a param operand does not have a corresponding "
986-
<< "param_operand_index attr in the options dict";
991+
<< "param_operand attr in the options dict";
987992

988993
return success();
989994
}

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from typing import Dict, Optional, Sequence, Union, NewType
2222

2323

24-
@register_attribute_builder("ParamOperandIndexAttr")
25-
def _paramOperandIndexAttr(x: int, context) -> Attribute:
26-
return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
24+
@register_attribute_builder("ParamOperandAttr")
25+
def _paramOperandAttr(x: int, context) -> Attribute:
26+
return Attribute.parse(f"#transform.param_operand<index={x}>", context=context)
2727

2828

2929
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -239,7 +239,7 @@ def __init__(
239239
options_dict = {}
240240
dynamic_options = []
241241

242-
ParamOperandIndexAttr = AttrBuilder.get("ParamOperandIndexAttr")
242+
ParamOperandAttr = AttrBuilder.get("ParamOperandAttr")
243243
context = (loc and loc.context) or Context.current
244244

245245
cur_param_operand_idx = 0
@@ -249,7 +249,7 @@ def __init__(
249249

250250
if isinstance(value, (Value, Operation, OpView)):
251251
dynamic_options.append(_get_op_result_or_value(value))
252-
options_dict[key] = ParamOperandIndexAttr(
252+
options_dict[key] = ParamOperandAttr(
253253
cur_param_operand_idx, context
254254
)
255255
cur_param_operand_idx += 1

mlir/test/Dialect/Transform/test-pass-application.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ func.func @invalid_options_due_to_reserved_attr() {
204204
module attributes {transform.with_named_sequence} {
205205
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
206206
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
207-
// expected-error @+2 {{the param_operand_index attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
207+
// expected-error @+2 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
208208
%2 = transform.apply_registered_pass "canonicalize"
209-
with options = { "top-down" = #transform.param_operand_index<0> } to %1 : (!transform.any_op) -> !transform.any_op
209+
with options = { "top-down" = #transform.param_operand<index=0> } to %1 : (!transform.any_op) -> !transform.any_op
210210
transform.yield
211211
}
212212
}
@@ -306,7 +306,7 @@ module attributes {transform.with_named_sequence} {
306306
// Check that the following cases are caugh in the generic format. //
307307
/////////////////////////////////////////////////////////////////////
308308

309-
// Invalid due to param_operand_index occurences in options dict not being
309+
// Invalid due to param_operand occurences in options dict not being
310310
// one-to-one with the dynamic options provided as params:
311311
// param_operand_index out of bounds w.r.t. the number of options provided via params.
312312

@@ -317,7 +317,7 @@ module attributes {transform.with_named_sequence} {
317317
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
318318
// expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
319319
%2 = "transform.apply_registered_pass"(%1, %0) <{
320-
options = {"max-iterations" = #transform.param_operand_index<1 : i64>,
320+
options = {"max-iterations" = #transform.param_operand<index=1 : i64>,
321321
"test-convergence" = true,
322322
"top-down" = false},
323323
pass_name = "canonicalize"}>
@@ -328,10 +328,10 @@ module attributes {transform.with_named_sequence} {
328328

329329
// -----
330330

331-
// Invalid due to param_operand_index occurences in options dict not being
331+
// Invalid due to param_operand occurences in options dict not being
332332
// one-to-one with the dynamic options provided as params:
333333
// the first option-param is referred to twice and the second one not at all.
334-
// (The pretty-printed format supports this by passing in the same param twice.)
334+
// (In the pretty-printed format, if you want to refer to a param SSA-value twice, it counts as two param arguments.)
335335

336336
"builtin.module"() ({
337337
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
@@ -341,8 +341,8 @@ module attributes {transform.with_named_sequence} {
341341
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
342342
// expected-error @below {{dynamic option index 0 is already used in options}}
343343
%3 = "transform.apply_registered_pass"(%1, %2, %0) <{
344-
options = {"max-iterations" = #transform.param_operand_index<0 : i64>,
345-
"max-num-rewrites" = #transform.param_operand_index<0 : i64>,
344+
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
345+
"max-num-rewrites" = #transform.param_operand<index=0 : i64>,
346346
"test-convergence" = true,
347347
"top-down" = false},
348348
pass_name = "canonicalize"}>
@@ -353,7 +353,7 @@ module attributes {transform.with_named_sequence} {
353353

354354
// -----
355355

356-
// Invalid due to param_operand_index occurences in options dict not being
356+
// Invalid due to param_operand occurences in options dict not being
357357
// one-to-one with the dynamic options provided as params:
358358
// two option-params are provide though only the first one is referred to from the options-dict.
359359

@@ -363,9 +363,9 @@ module attributes {transform.with_named_sequence} {
363363
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
364364
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
365365
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
366-
// expected-error @below {{a param operand does not have a corresponding param_operand_index attr in the options dict}}
366+
// expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}}
367367
%3 = "transform.apply_registered_pass"(%1, %2, %0) <{
368-
options = {"max-iterations" = #transform.param_operand_index<0 : i64>,
368+
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
369369
"test-convergence" = true,
370370
"top-down" = false},
371371
pass_name = "canonicalize"}>

0 commit comments

Comments
 (0)