Skip to content

Commit abf4234

Browse files
committed
PDLToPDLInterp: Ensure dependencies between native constraints and their arguments
1 parent c4065a7 commit abf4234

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,19 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
884884
*root = std::make_unique<ExitNode>();
885885
}
886886

887+
/// Sorts the range begin/end with the partial order given by cmp.
888+
/// cmp must be a partial ordering.
889+
template <typename Iterator, typename Compare>
890+
void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
891+
while (begin != end) {
892+
auto const next = std::stable_partition(begin, end, [&](auto const &a) {
893+
return std::none_of(begin, end, [&](auto const &b) { return cmp(b, a); });
894+
});
895+
assert(next != begin && "not a partial ordering");
896+
begin = next;
897+
}
898+
}
899+
887900
/// Given a module containing PDL pattern operations, generate a matcher tree
888901
/// using the patterns within the given module and return the root matcher node.
889902
std::unique_ptr<MatcherNode>
@@ -964,6 +977,24 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
964977
return *lhs < *rhs;
965978
});
966979

980+
// Mostly keep the now established order, but also ensure that
981+
// ConstraintQuestions come after the results they use.
982+
stableTopologicalSort(ordered.begin(), ordered.end(),
983+
[](OrderedPredicate *a, OrderedPredicate *b) {
984+
auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
985+
auto *cqb = dyn_cast<ConstraintQuestion>(b->question);
986+
if (cqa && cqb) {
987+
// Does any argument of b use a? Then b must be
988+
// sorted after a.
989+
return llvm::any_of(
990+
cqb->getArgs(), [&](Position *p) {
991+
auto *cp = dyn_cast<ConstraintPosition>(p);
992+
return cp && cp->getQuestion() == cqa;
993+
});
994+
}
995+
return false;
996+
});
997+
967998
// Build the matchers for each of the pattern predicate lists.
968999
std::unique_ptr<MatcherNode> root;
9691000
for (OrderedPredicateList &list : lists)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
2+
3+
// Ensuse that the dependency between add & less
4+
// causes them to be in the correct order.
5+
// CHECK: apply_constraint "__builtin_add"
6+
// CHECK: apply_constraint "__builtin_less"
7+
8+
module {
9+
pdl.pattern @test : benefit(1) {
10+
%0 = attribute
11+
%1 = types
12+
%2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range<type>)
13+
%3 = attribute = 0 : i32
14+
%4 = attribute = 1 : i32
15+
%5 = apply_native_constraint "__builtin_add"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute
16+
apply_native_constraint "__builtin_less"(%0, %5 : !pdl.attribute, !pdl.attribute)
17+
rewrite %2 {
18+
replace %2 with %2
19+
}
20+
}
21+
}

0 commit comments

Comments
 (0)