Skip to content

Commit c26847d

Browse files
authored
[mlir][drr] Allow variadic in rewrite side (#93340)
Enables writing patterns where one has op creation with variadic in result pattern more easily. Signed-off-by: Jacques Pienaar <[email protected]>
1 parent f239490 commit c26847d

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,12 @@ def : Pat<
16961696
(MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
16971697
ConstantStrAttr<StrAttr, "MatchVariadic">)>;
16981698

1699+
def : Pat<
1700+
(MixedVOperandOp5 $input1a, $input1b, $input2, $attr1,
1701+
ConstantStrAttr<StrAttr, "MatchInverseVariadic">),
1702+
(MixedVOperandOp3 $input2, (variadic $input1b), (variadic $input1a),
1703+
ConstantAttr<I32Attr, "1">:$attr1)>;
1704+
16991705
def : Pat<
17001706
(MixedVOperandOp4 (variadic (MixedVOperandInOutI32Op $input1a),
17011707
(MixedVOperandInOutI32Op $input1b)),

mlir/test/mlir-tblgen/pattern.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,14 @@ func.func @testMatchVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
527527
return
528528
}
529529

530+
// CHECK-LABEL: @testReplaceVariadic
531+
func.func @testReplaceVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
532+
// CHECK: "test.mixed_variadic_in3"(%arg2, %arg1, %arg0) <{count = 1 : i32}>
533+
"test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 0 : i32, pattern_name = "MatchInverseVariadic"}> : (i32, i32, i32) -> ()
534+
535+
return
536+
}
537+
530538
// CHECK-LABEL: @testMatchVariadicSubDag
531539
func.func @testMatchVariadicSubDag(%arg0: i32, %arg1: i32, %arg2: i32) -> () {
532540
// CHECK: %[[IN0:.*]] = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ class PatternEmitter {
159159
// Returns the symbol of the old value serving as the replacement.
160160
StringRef handleReplaceWithValue(DagNode tree);
161161

162+
// Emits the C++ statement to replace the matched DAG with an array of
163+
// matched values.
164+
std::string handleVariadic(DagNode tree, int depth);
165+
162166
// Trailing directives are used at the end of DAG node argument lists to
163167
// specify additional behaviour for op matchers and creators, etc.
164168
struct TrailingDirectives {
@@ -1241,6 +1245,9 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
12411245
if (resultTree.isReplaceWithValue())
12421246
return handleReplaceWithValue(resultTree).str();
12431247

1248+
if (resultTree.isVariadic())
1249+
return handleVariadic(resultTree, depth);
1250+
12441251
// Normal op creation.
12451252
auto symbol = handleOpCreation(resultTree, resultIndex, depth);
12461253
if (resultTree.getSymbol().empty()) {
@@ -1251,6 +1258,26 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree,
12511258
return symbol;
12521259
}
12531260

1261+
std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
1262+
assert(tree.isVariadic());
1263+
1264+
auto name = std::string(formatv("tblgen_variadic_values_{0}", nextValueId++));
1265+
symbolInfoMap.bindValue(name);
1266+
os << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n";
1267+
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1268+
if (auto child = tree.getArgAsNestedDag(i)) {
1269+
os << name << ".push_back(" << handleResultPattern(child, i, depth + 1)
1270+
<< ");\n";
1271+
} else {
1272+
os << name << ".push_back("
1273+
<< handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))
1274+
<< ");\n";
1275+
}
1276+
}
1277+
1278+
return name;
1279+
}
1280+
12541281
StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
12551282
assert(tree.isReplaceWithValue());
12561283

0 commit comments

Comments
 (0)