Skip to content

Commit 781ab12

Browse files
committed
[mlir] Move supplemental patterns before op replacement
This moves the C++ code generated from supplemental patterns before op replacement. It is necessary for supplemental patterns to be able to access the source op.
1 parent 1d82c76 commit 781ab12

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,9 +1172,22 @@ void PatternEmitter::emitRewriteLogic() {
11721172
os << val << ";\n";
11731173
}
11741174

1175+
auto processSupplementalPatterns = [&]() {
1176+
int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1177+
for (int i = 0, offset = -numSupplementalPatterns;
1178+
i < numSupplementalPatterns; ++i) {
1179+
DagNode resultTree = pattern.getSupplementalPattern(i);
1180+
auto val = handleResultPattern(resultTree, offset++, 0);
1181+
if (resultTree.isNativeCodeCall() &&
1182+
resultTree.getNumReturnsOfNativeCode() == 0)
1183+
os << val << ";\n";
1184+
}
1185+
};
1186+
11751187
if (numExpectedResults == 0) {
11761188
assert(replStartIndex >= numResultPatterns &&
11771189
"invalid auxiliary vs. replacement pattern division!");
1190+
processSupplementalPatterns();
11781191
// No result to replace. Just erase the op.
11791192
os << "rewriter.eraseOp(op0);\n";
11801193
} else {
@@ -1196,20 +1209,10 @@ void PatternEmitter::emitRewriteLogic() {
11961209
" tblgen_repl_values.push_back(v);\n}\n",
11971210
"\n");
11981211
}
1212+
processSupplementalPatterns();
11991213
os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
12001214
}
12011215

1202-
// Process supplemtal patterns.
1203-
int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1204-
for (int i = 0, offset = -numSupplementalPatterns;
1205-
i < numSupplementalPatterns; ++i) {
1206-
DagNode resultTree = pattern.getSupplementalPattern(i);
1207-
auto val = handleResultPattern(resultTree, offset++, 0);
1208-
if (resultTree.isNativeCodeCall() &&
1209-
resultTree.getNumReturnsOfNativeCode() == 0)
1210-
os << val << ";\n";
1211-
}
1212-
12131216
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
12141217
}
12151218

0 commit comments

Comments
 (0)