Skip to content

Commit fe7bf4b

Browse files
authored
[MLIR][Transform] apply_registered_pass op's options as a dict (#143159)
Improve ApplyRegisteredPassOp's support for taking options by taking them as a dict (vs a list of string-valued key-value pairs). Values of options are provided as either static attributes or as params (which pass in attributes at interpreter runtime). In either case, the keys and value attributes are converted to strings and a single options-string, in the format used on the commandline, is constructed to pass to the `addToPipeline`-pass API.
1 parent ec8d68b commit fe7bf4b

File tree

10 files changed

+469
-116
lines changed

10 files changed

+469
-116
lines changed

mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ mlir_tablegen(TransformDialectEnums.h.inc -gen-enum-decls)
2020
mlir_tablegen(TransformDialectEnums.cpp.inc -gen-enum-defs)
2121
add_public_tablegen_target(MLIRTransformDialectEnumIncGen)
2222
add_dependencies(mlir-headers MLIRTransformDialectEnumIncGen)
23+
mlir_tablegen(TransformAttrs.h.inc -gen-attrdef-decls)
24+
mlir_tablegen(TransformAttrs.cpp.inc -gen-attrdef-defs)
25+
add_public_tablegen_target(MLIRTransformDialectAttributesIncGen)
26+
add_dependencies(mlir-headers MLIRTransformDialectAttributesIncGen)
2327

2428
add_mlir_dialect(TransformOps transform)
2529
add_mlir_doc(TransformOps TransformOps Dialects/ -gen-op-doc -dialect=transform)

mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,7 @@
1717

1818
#include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc"
1919

20+
#define GET_ATTRDEF_CLASSES
21+
#include "mlir/Dialect/Transform/IR/TransformAttrs.h.inc"
22+
2023
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H

mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
1111

1212
include "mlir/IR/EnumAttr.td"
13+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
14+
15+
class Transform_Attr<string name, string attrMnemonic,
16+
list<Trait> traits = [],
17+
string baseCppClass = "::mlir::Attribute">
18+
: AttrDef<Transform_Dialect, name, traits, baseCppClass> {
19+
let mnemonic = attrMnemonic;
20+
}
1321

1422
def PropagateFailuresCase : I32EnumAttrCase<"Propagate", 1, "propagate">;
1523
def SuppressFailuresCase : I32EnumAttrCase<"Suppress", 2, "suppress">;
@@ -33,4 +41,15 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
3341
let cppNamespace = "::mlir::transform";
3442
}
3543

44+
def ParamOperandAttr : Transform_Attr<"ParamOperand", "param_operand"> {
45+
let description = [{
46+
Used to refer to a specific param-operand (via its index) from within an
47+
attribute on a transform operation.
48+
}];
49+
let parameters = (ins
50+
"IntegerAttr":$index
51+
);
52+
let assemblyFormat = "`<` `index` `=` $index `>`";
53+
}
54+
3655
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def Transform_Dialect : Dialect {
1919
let cppNamespace = "::mlir::transform";
2020

2121
let hasOperationAttrVerify = 1;
22+
let useDefaultAttributePrinterParser = 1;
2223
let extraClassDeclaration = [{
2324
/// Symbol name for the default entry point "named sequence".
2425
constexpr const static ::llvm::StringLiteral

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,10 +405,23 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
405405
let description = [{
406406
This transform applies the specified pass or pass pipeline to the targeted
407407
ops. The name of the pass/pipeline is specified as a string attribute, as
408-
set during pass/pipeline registration. Optionally, pass options may be
409-
specified as (space-separated) string attributes with the option to pass
410-
these attributes via params. The pass options syntax is identical to the one
411-
used with "mlir-opt".
408+
set during pass/pipeline registration.
409+
410+
Optionally, pass options may be specified via a DictionaryAttr. This
411+
dictionary is converted to a string -- formatted `key=value ...` -- which
412+
is expected to be in the exact format used by the pass on the commandline.
413+
Values are either attributes or (SSA-values of) Transform Dialect params.
414+
For example:
415+
416+
```mlir
417+
transform.apply_registered_pass "canonicalize"
418+
with options = { "top-down" = false,
419+
"max-iterations" = %max_iter,
420+
"test-convergence" = true,
421+
"max-num-rewrites" = %max_rewrites }
422+
to %module
423+
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
424+
```
412425

413426
This op first looks for a pass pipeline with the specified name. If no such
414427
pipeline exists, it looks for a pass with the specified name. If no such
@@ -422,7 +435,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
422435
}];
423436

424437
let arguments = (ins StrAttr:$pass_name,
425-
DefaultValuedAttr<ArrayAttr, "{}">:$options,
438+
DefaultValuedAttr<DictionaryAttr, "{}">:$options,
426439
Variadic<TransformParamTypeInterface>:$dynamic_options,
427440
TransformHandleTypeInterface:$target);
428441
let results = (outs TransformHandleTypeInterface:$result);

mlir/lib/Dialect/Transform/IR/TransformDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,22 @@
88

99
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1010
#include "mlir/Analysis/CallGraph.h"
11+
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
1112
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1213
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
1314
#include "mlir/Dialect/Transform/IR/Utils.h"
1415
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1516
#include "mlir/IR/DialectImplementation.h"
1617
#include "llvm/ADT/SCCIterator.h"
18+
#include "llvm/ADT/TypeSwitch.h"
1719

1820
using namespace mlir;
1921

2022
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
2123

24+
#define GET_ATTRDEF_CLASSES
25+
#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
26+
2227
#ifndef NDEBUG
2328
void transform::detail::checkImplementsTransformOpInterface(
2429
StringRef name, MLIRContext *context) {
@@ -66,6 +71,10 @@ void transform::TransformDialect::initialize() {
6671
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
6772
>();
6873
initializeTypes();
74+
addAttributes<
75+
#define GET_ATTRDEF_LIST
76+
#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
77+
>();
6978
initializeLibraryModule();
7079
}
7180

0 commit comments

Comments
 (0)