Skip to content

fix fused native constraints with results #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ struct PatternLowering {

/// A mapping between constraint questions that refer to values created by
/// constraints and the temporary placeholder values created for them.
DenseMap<std::pair<ConstraintQuestion *, unsigned>, Value> substitutions;
std::multimap<std::pair<ConstraintQuestion *, unsigned>, Value> substitutions;
};
} // namespace

Expand Down Expand Up @@ -377,8 +377,9 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
auto *constrResPos = cast<ConstraintPosition>(pos);
Value placeholderValue = builder.create<pdl_interp::CreateAttributeOp>(
loc, StringAttr::get(builder.getContext(), "placeholder"));
substitutions[{constrResPos->getQuestion(), constrResPos->getIndex()}] =
placeholderValue;
substitutions.insert(
{{constrResPos->getQuestion(), constrResPos->getIndex()},
placeholderValue});
value = placeholderValue;
break;
}
Expand Down Expand Up @@ -474,11 +475,15 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
std::pair<ConstraintQuestion *, unsigned> substitutionKey = {
cstQuestion, result.index()};
// Check if there are substitutions to perform. If the result is never
// used no substitutions will have been generated.
if (substitutions.count(substitutionKey)) {
substitutions[substitutionKey].replaceAllUsesWith(result.value());
substitutions[substitutionKey].getDefiningOp()->erase();
}
// used or multiple calls to the same constraint have been merged,
// no substitutions will have been generated for this specific op.
auto range = substitutions.equal_range(substitutionKey);
std::for_each(range.first, range.second, [&](const auto &elem) {
Value placeholder = elem.second;
placeholder.replaceAllUsesWith(result.value());
placeholder.getDefiningOp()->erase();
});
substitutions.erase(substitutionKey);
}
break;
}
Expand Down
27 changes: 21 additions & 6 deletions mlir/lib/Rewrite/ByteCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
LLVM_DEBUG({
llvm::dbgs() << " * Arguments: ";
llvm::interleaveComma(args, llvm::dbgs());
llvm::dbgs() << "\n";
});

ByteCodeField numResults = read();
Expand All @@ -1450,12 +1451,26 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx];
ByteCodeRewriteResultList results(numResults);
LogicalResult rewriteResult = constraintFn(rewriter, results, args);
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");

for (PDLValue &result : results.getResults()) {
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
memory[read()] = result.getAsOpaquePointer();
ArrayRef<PDLValue> constraintResults = results.getResults();
LLVM_DEBUG({
if (succeeded(rewriteResult)) {
llvm::dbgs() << " * Constraint succeeded\n";
llvm::dbgs() << " * Results: ";
llvm::interleaveComma(constraintResults, llvm::dbgs());
llvm::dbgs() << "\n";
} else {
llvm::dbgs() << " * Constraint failed\n";
}
});
assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
"native PDL rewrite function returned "
"unexpected number of results");
// Populate memory either with the results or with 0s to preserve memory
// structure as expected
for (int i = 0; i < numResults; i++) {
memory[read()] = succeeded(rewriteResult)
? constraintResults[i].getAsOpaquePointer()
: 0;
}
// Depending on the constraint jump to the proper destination.
selectJump(succeeded(rewriteResult));
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ module @constraint_with_unused_result {

// -----

// CHECK-LABEL: module @constraint_with_result_multiple
module @constraint_with_result_multiple {
// check that native constraints work as expected even when multiple identical constraints are fused

// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
// CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr"
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
pdl.pattern : benefit(1) {
%root = operation
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
rewrite %root with "rewriter"(%attr : !pdl.attribute)
}
pdl.pattern : benefit(1) {
%root = operation
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
rewrite %root with "rewriter"(%attr : !pdl.attribute)
}
}

// -----

// CHECK-LABEL: module @inputs
module @inputs {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
Expand Down