Skip to content

Commit f891e20

Browse files
fix fused native constraints with results (#36)
1 parent e62b72a commit f891e20

File tree

3 files changed

+57
-14
lines changed

3 files changed

+57
-14
lines changed

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ struct PatternLowering {
151151

152152
/// A mapping between constraint questions that refer to values created by
153153
/// constraints and the temporary placeholder values created for them.
154-
DenseMap<std::pair<ConstraintQuestion *, unsigned>, Value> substitutions;
154+
std::multimap<std::pair<ConstraintQuestion *, unsigned>, Value> substitutions;
155155
};
156156
} // namespace
157157

@@ -377,8 +377,9 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
377377
auto *constrResPos = cast<ConstraintPosition>(pos);
378378
Value placeholderValue = builder.create<pdl_interp::CreateAttributeOp>(
379379
loc, StringAttr::get(builder.getContext(), "placeholder"));
380-
substitutions[{constrResPos->getQuestion(), constrResPos->getIndex()}] =
381-
placeholderValue;
380+
substitutions.insert(
381+
{{constrResPos->getQuestion(), constrResPos->getIndex()},
382+
placeholderValue});
382383
value = placeholderValue;
383384
break;
384385
}
@@ -474,11 +475,15 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
474475
std::pair<ConstraintQuestion *, unsigned> substitutionKey = {
475476
cstQuestion, result.index()};
476477
// Check if there are substitutions to perform. If the result is never
477-
// used no substitutions will have been generated.
478-
if (substitutions.count(substitutionKey)) {
479-
substitutions[substitutionKey].replaceAllUsesWith(result.value());
480-
substitutions[substitutionKey].getDefiningOp()->erase();
481-
}
478+
// used or multiple calls to the same constraint have been merged,
479+
// no substitutions will have been generated for this specific op.
480+
auto range = substitutions.equal_range(substitutionKey);
481+
std::for_each(range.first, range.second, [&](const auto &elem) {
482+
Value placeholder = elem.second;
483+
placeholder.replaceAllUsesWith(result.value());
484+
placeholder.getDefiningOp()->erase();
485+
});
486+
substitutions.erase(substitutionKey);
482487
}
483488
break;
484489
}

mlir/lib/Rewrite/ByteCode.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
14381438
LLVM_DEBUG({
14391439
llvm::dbgs() << " * Arguments: ";
14401440
llvm::interleaveComma(args, llvm::dbgs());
1441+
llvm::dbgs() << "\n";
14411442
});
14421443

14431444
ByteCodeField numResults = read();
@@ -1450,12 +1451,26 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
14501451
const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx];
14511452
ByteCodeRewriteResultList results(numResults);
14521453
LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1453-
assert(results.getResults().size() == numResults &&
1454-
"native PDL rewrite function returned unexpected number of results");
1455-
1456-
for (PDLValue &result : results.getResults()) {
1457-
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1458-
memory[read()] = result.getAsOpaquePointer();
1454+
ArrayRef<PDLValue> constraintResults = results.getResults();
1455+
LLVM_DEBUG({
1456+
if (succeeded(rewriteResult)) {
1457+
llvm::dbgs() << " * Constraint succeeded\n";
1458+
llvm::dbgs() << " * Results: ";
1459+
llvm::interleaveComma(constraintResults, llvm::dbgs());
1460+
llvm::dbgs() << "\n";
1461+
} else {
1462+
llvm::dbgs() << " * Constraint failed\n";
1463+
}
1464+
});
1465+
assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1466+
"native PDL rewrite function returned "
1467+
"unexpected number of results");
1468+
// Populate memory either with the results or with 0s to preserve memory
1469+
// structure as expected
1470+
for (int i = 0; i < numResults; i++) {
1471+
memory[read()] = succeeded(rewriteResult)
1472+
? constraintResults[i].getAsOpaquePointer()
1473+
: 0;
14591474
}
14601475
// Depending on the constraint jump to the proper destination.
14611476
selectJump(succeeded(rewriteResult));

mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,29 @@ module @constraint_with_unused_result {
107107

108108
// -----
109109

110+
// CHECK-LABEL: module @constraint_with_result_multiple
111+
module @constraint_with_result_multiple {
112+
// check that native constraints work as expected even when multiple identical constraints are fused
113+
114+
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
115+
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
116+
// CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr"
117+
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
118+
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
119+
pdl.pattern : benefit(1) {
120+
%root = operation
121+
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
122+
rewrite %root with "rewriter"(%attr : !pdl.attribute)
123+
}
124+
pdl.pattern : benefit(1) {
125+
%root = operation
126+
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
127+
rewrite %root with "rewriter"(%attr : !pdl.attribute)
128+
}
129+
}
130+
131+
// -----
132+
110133
// CHECK-LABEL: module @inputs
111134
module @inputs {
112135
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)

0 commit comments

Comments
 (0)