Skip to content

Commit c294d67

Browse files
mgehre-amdmartin-luecke
authored andcommitted
[mlir][PDL] Add support for native constraints with results (llvm#82760)
From https://reviews.llvm.org/D153245 This adds support for native PDL (and PDLL) C++ constraints to return results. This is useful for situations where a pattern checks for certain constraints of multiple interdependent attributes and computes a new attribute value based on them. Currently, for such an example it is required to escape to C++ during matching to perform the check and after a successful match again escape to native C++ to perform the computation during the rewriting part of the pattern. With this work we can do the computation in C++ during matching and use the result in the rewriting part of the pattern. Effectively this enables a choice in the trade-off of memory consumption during matching vs recomputation of values. This is an example of a situation where this is useful: We have two operations with certain attributes that have interdependent constraints. For instance `attr_foo: one_of [0, 2, 4, 8], attr_bar: one_of [0, 2, 4, 8]` and `attr_foo == attr_bar`. The pattern should only match if all conditions are true. The new operation should be created with a new attribute which is computed from the two matched attributes e.g. `attr_baz = attr_foo * attr_bar`. For the check we already escape to native C++ and have all values at hand so it makes sense to directly compute the new attribute value as well: ``` Constraint checkAndCompute(attr0: Attr, attr1: Attr) -> Attr; Pattern example with benefit(1) { let foo = op<test.foo>() {attr = attr_foo : Attr}; let bar = op<test.bar>(foo) {attr = attr_bar : Attr}; let attr_baz = checkAndCompute(attr_foo, attr_bar); rewrite bar with { let baz = op<test.baz> {attr=attr_baz}; replace bar with baz; }; } ``` To achieve this the following notable changes were necessary: PDLL: - Remove check in PDLL parser that prevented native constraints from returning results PDL: - Change PDL definition of pdl.apply_native_constraint to allow variadic results PDL_interp: - Change PDL_interp definition of pdl_interp.apply_constraint to allow variadic results PDLToPDLInterp Pass: The input to the pass is an arbitrary number of PDL patterns. The pass collects the predicates that are required to match all of the pdl patterns and establishes an ordering that allows creation of a single efficient matcher function to match all of them. Values that are matched and possibly used in the rewriting part of a pattern are represented as positions. This allows fusion and thus reusing a single position for multiple matching patterns. Accordingly, we introduce ConstraintPosition, which records the type and index of the result of the constraint. The problem is for the corresponding value to be used in the rewriting part of a pattern it has to be an input to the pdl_interp.record_match operation, which is generated early during the pass such that its surrounding block can be referred to by branching operations. In consequence the value has to be materialized after the original pdl.apply_native_constraint has been deleted but before we get the chance to generate the corresponding pdl_interp.apply_constraint operation. We solve this by emitting a placeholder value when a ConstraintPosition is evaluated. These placeholder values (due to fusion there may be multiple for one constraint result) are replaced later when the actual pdl_interp.apply_constraint operation is created. Changes since the phabricator review: - Addressed all comments - In particular, removed registerConstraintFunctionWithResults and instead changed registerConstraintFunction so that contraint functions always have results (empty by default) - Thus we don't need to reuse `rewriteFunctions` to store constraint functions with results anymore, and can instead use `constraintFunctions` - Perform a stable sort of ConstraintQuestion, so that ConstraintQuestion appear before other ConstraintQuestion that use their results. - Don't create placeholders for pdl_interp::ApplyConstraintOp. Instead generate the `pdl_interp::ApplyConstraintOp` before generating the successor block. - Fixed a test failure in the pdl python bindings Original code by @martin-luecke Co-authored-by: martin-luecke <[email protected]>
1 parent 9db9dd7 commit c294d67

File tree

13 files changed

+397
-288
lines changed

13 files changed

+397
-288
lines changed

mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@ def PDL_ApplyNativeConstraintOp
3535
let description = [{
3636
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
3737
has been registered externally with the consumer of PDL, to a given set of
38-
entities and optionally return a number of values..
38+
entities and optionally return a number of values.
3939

4040
Example:
4141

4242
```mlir
4343
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
4444
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
45+
// Apply constraint `with_result` to `root`. This constraint returns an attribute.
46+
%attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
4547
```
4648
}];
4749

@@ -50,7 +52,7 @@ def PDL_ApplyNativeConstraintOp
5052
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
5153
let results = (outs Variadic<PDL_AnyType>:$results);
5254
let assemblyFormat = [{
53-
$name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict
55+
$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
5456
}];
5557
let hasVerifier = 1;
5658
}

mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
8888
let description = [{
8989
`pdl_interp.apply_constraint` operations apply a generic constraint, that
9090
has been registered with the interpreter, with a given set of positional
91-
values. On success, this operation branches to the true destination,
91+
values.
92+
The constraint function may return any number of results.
93+
On success, this operation branches to the true destination,
9294
otherwise the false destination is taken. This behavior can be reversed
9395
by setting the attribute `isNegated` to true.
9496

@@ -106,7 +108,8 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
106108
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
107109
let results = (outs Variadic<PDL_AnyType>:$results);
108110
let assemblyFormat = [{
109-
$name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict `->` successors
111+
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
112+
`->` successors
110113
}];
111114
}
112115

mlir/include/mlir/IR/PDLPatternMatch.h.inc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,9 @@ protected:
318318
/// A generic PDL pattern constraint function. This function applies a
319319
/// constraint to a given set of opaque PDLValue entities. Returns success if
320320
/// the constraint successfully held, failure otherwise.
321-
using PDLConstraintFunction =
322-
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
321+
using PDLConstraintFunction = std::function<LogicalResult(
322+
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
323+
323324
/// A native PDL rewrite function. This function performs a rewrite on the
324325
/// given set of values. Any results from this rewrite that should be passed
325326
/// back to PDL should be added to the provided result list. This method is only
@@ -726,7 +727,7 @@ std::enable_if_t<
726727
PDLConstraintFunction>
727728
buildConstraintFn(ConstraintFnT &&constraintFn) {
728729
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
729-
PatternRewriter &rewriter,
730+
PatternRewriter &rewriter, PDLResultList &,
730731
ArrayRef<PDLValue> values) -> LogicalResult {
731732
auto argIndices = std::make_index_sequence<
732733
llvm::function_traits<ConstraintFnT>::num_args - 1>();
@@ -842,10 +843,13 @@ public:
842843
/// Register a constraint function with PDL. A constraint function may be
843844
/// specified in one of two ways:
844845
///
845-
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
846+
/// * `LogicalResult (PatternRewriter &,
847+
/// PDLResultList &,
848+
/// ArrayRef<PDLValue>)`
846849
///
847850
/// In this overload the arguments of the constraint function are passed via
848-
/// the low-level PDLValue form.
851+
/// the low-level PDLValue form, and the results are manually appended to
852+
/// the given result list.
849853
///
850854
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
851855
///
@@ -963,8 +967,8 @@ public:
963967
}
964968
};
965969
class PDLResultList {};
966-
using PDLConstraintFunction =
967-
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
970+
using PDLConstraintFunction = std::function<LogicalResult(
971+
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
968972
using PDLRewriteFunction = std::function<LogicalResult(
969973
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
970974

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ struct PatternLowering {
5050

5151
/// Generate interpreter operations for the tree rooted at the given matcher
5252
/// node, in the specified region.
53-
Block *generateMatcher(MatcherNode &node, Region &region);
53+
Block *generateMatcher(MatcherNode &node, Region &region,
54+
Block *block = nullptr);
5455

5556
/// Get or create an access to the provided positional value in the current
5657
/// block. This operation may mutate the provided block pointer if nested
@@ -149,9 +150,9 @@ struct PatternLowering {
149150
/// set.
150151
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
151152

152-
/// A mapping between constraint questions that refer to values created by
153-
/// constraints and the temporary placeholder values created for them.
154-
std::multimap<std::pair<ConstraintQuestion *, unsigned>, Value> substitutions;
153+
/// A mapping from a constraint question to the ApplyConstraintOp
154+
/// that implements it.
155+
DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
155156
};
156157
} // namespace
157158

@@ -186,9 +187,11 @@ void PatternLowering::lower(ModuleOp module) {
186187
firstMatcherBlock->erase();
187188
}
188189

189-
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
190+
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
191+
Block *block) {
190192
// Push a new scope for the values used by this matcher.
191-
Block *block = &region.emplaceBlock();
193+
if (!block)
194+
block = &region.emplaceBlock();
192195
ValueMapScope scope(values);
193196

194197
// If this is the return node, simply insert the corresponding interpreter
@@ -369,18 +372,12 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
369372
break;
370373
}
371374
case Predicates::ConstraintResultPos: {
372-
// At this point in time the corresponding pdl.ApplyNativeConstraint op has
373-
// been deleted and the new pdl_interp.ApplyConstraint has not been created
374-
// yet. To enable use of results created by these operations we build a
375-
// placeholder value that will be replaced when the actual
376-
// pdl_interp.ApplyConstraint operation is created.
375+
// Due to the order of traversal, the ApplyConstraintOp has already been
376+
// created and we can find it in constraintOpMap.
377377
auto *constrResPos = cast<ConstraintPosition>(pos);
378-
Value placeholderValue = builder.create<pdl_interp::CreateAttributeOp>(
379-
loc, StringAttr::get(builder.getContext(), "placeholder"));
380-
substitutions.insert(
381-
{{constrResPos->getQuestion(), constrResPos->getIndex()},
382-
placeholderValue});
383-
value = placeholderValue;
378+
auto i = constraintOpMap.find(constrResPos->getQuestion());
379+
assert(i != constraintOpMap.end());
380+
value = i->second->getResult(constrResPos->getIndex());
384381
break;
385382
}
386383
default:
@@ -409,12 +406,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
409406
args.push_back(getValueAt(currentBlock, position));
410407
}
411408

412-
// Generate the matcher in the current (potentially nested) region
413-
// and get the failure successor.
414-
Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
409+
// Generate a new block as success successor and get the failure successor.
410+
Block *success = &region->emplaceBlock();
415411
Block *failure = failureBlockStack.back();
416412

417-
// Finally, create the predicate.
413+
// Create the predicate.
418414
builder.setInsertionPointToEnd(currentBlock);
419415
Predicates::Kind kind = question->getKind();
420416
switch (kind) {
@@ -469,27 +465,17 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
469465
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
470466
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
471467
cstQuestion->getIsNegated(), success, failure);
472-
// Replace the generated placeholders with the results of the constraint and
473-
// erase them
474-
for (auto result : llvm::enumerate(applyConstraintOp.getResults())) {
475-
std::pair<ConstraintQuestion *, unsigned> substitutionKey = {
476-
cstQuestion, result.index()};
477-
// Check if there are substitutions to perform. If the result is never
478-
// used or multiple calls to the same constraint have been merged,
479-
// no substitutions will have been generated for this specific op.
480-
auto range = substitutions.equal_range(substitutionKey);
481-
std::for_each(range.first, range.second, [&](const auto &elem) {
482-
Value placeholder = elem.second;
483-
placeholder.replaceAllUsesWith(result.value());
484-
placeholder.getDefiningOp()->erase();
485-
});
486-
substitutions.erase(substitutionKey);
487-
}
468+
469+
constraintOpMap.insert({cstQuestion, applyConstraintOp});
488470
break;
489471
}
490472
default:
491473
llvm_unreachable("Generating unknown Predicate operation");
492474
}
475+
476+
// Generate the matcher in the current (potentially nested) region.
477+
// This might use the results of the current predicate.
478+
generateMatcher(*boolNode->getSuccessNode(), *region, success);
493479
}
494480

495481
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>

mlir/lib/Conversion/PDLToPDLInterp/Predicate.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
299299
bool isOperandDefiningOp() const;
300300
};
301301

302+
//===----------------------------------------------------------------------===//
303+
// ConstraintPosition
304+
305+
struct ConstraintQuestion;
306+
307+
/// A position describing the result of a native constraint. It saves the
308+
/// corresponding ConstraintQuestion and result index to enable referring
309+
/// back to them
310+
struct ConstraintPosition
311+
: public PredicateBase<ConstraintPosition, Position,
312+
std::pair<ConstraintQuestion *, unsigned>,
313+
Predicates::ConstraintResultPos> {
314+
using PredicateBase::PredicateBase;
315+
316+
/// Returns the ConstraintQuestion to enable keeping track of the native
317+
/// constraint this position stems from.
318+
ConstraintQuestion *getQuestion() const { return key.first; }
319+
320+
// Returns the result index of this position
321+
unsigned getIndex() const { return key.second; }
322+
};
323+
302324
//===----------------------------------------------------------------------===//
303325
// ResultPosition
304326

mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
#include "mlir/IR/BuiltinOps.h"
1616
#include "mlir/Interfaces/InferTypeOpInterface.h"
1717
#include "llvm/ADT/MapVector.h"
18+
#include "llvm/ADT/SmallPtrSet.h"
1819
#include "llvm/ADT/TypeSwitch.h"
1920
#include "llvm/Support/Debug.h"
2021
#include <queue>
21-
#include "llvm/ADT/SmallPtrSet.h"
2222

2323
#define DEBUG_TYPE "pdl-predicate-tree"
2424

@@ -50,14 +50,15 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
5050
DenseMap<Value, Position *> &inputs,
5151
AttributePosition *pos) {
5252
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
53-
pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
5453
predList.emplace_back(pos, builder.getIsNotNull());
5554

56-
// If the attribute has a type or value, add a constraint.
57-
if (Value type = attr.getValueType())
58-
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
59-
else if (Attribute value = attr.getValueAttr())
60-
predList.emplace_back(pos, builder.getAttributeConstraint(value));
55+
if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
56+
// If the attribute has a type or value, add a constraint.
57+
if (Value type = attr.getValueType())
58+
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
59+
else if (Attribute value = attr.getValueAttr())
60+
predList.emplace_back(pos, builder.getAttributeConstraint(value));
61+
}
6162
}
6263

6364
/// Collect all of the predicates for the given operand position.
@@ -265,34 +266,32 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
265266
DenseMap<Value, Position *> &inputs) {
266267
OperandRange arguments = op.getArgs();
267268

268-
Position *pos = nullptr;
269269
std::vector<Position *> allPositions;
270+
allPositions.reserve(arguments.size());
271+
for (Value arg : arguments)
272+
allPositions.push_back(inputs.lookup(arg));
270273

271-
// If this constraint has no arguments, this means it has no dependencies, and
272-
// the same applies to all results
273-
if (arguments.empty()) {
274-
pos = builder.getRoot();
275-
} else {
276-
allPositions.reserve(arguments.size());
277-
for (Value arg : arguments)
278-
allPositions.push_back(inputs.lookup(arg));
279-
280-
// Push the constraint to the furthest position.
281-
pos = *std::max_element(allPositions.begin(), allPositions.end(),
282-
comparePosDepth);
283-
}
284-
assert(pos && "Must have a non-null value");
285-
274+
// Push the constraint to the furthest position.
275+
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
276+
comparePosDepth);
286277
ResultRange results = op.getResults();
287278
PredicateBuilder::Predicate pred = builder.getConstraint(
288279
op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
289280
op.getIsNegated());
290281

291-
// for each result register a position so it can be used later
292-
for (auto result : llvm::enumerate(results)) {
282+
// For each result register a position so it can be used later
283+
for (auto [i, result] : llvm::enumerate(results)) {
293284
ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
294-
ConstraintPosition *pos = builder.getConstraintPosition(q, result.index());
295-
inputs[result.value()] = pos;
285+
ConstraintPosition *pos = builder.getConstraintPosition(q, i);
286+
auto [it, inserted] = inputs.insert({result, pos});
287+
// If this is an input value that has been visited in the tree, add a
288+
// constraint to ensure that both instances refer to the same value.
289+
if (!inserted) {
290+
auto minMaxPositions =
291+
std::minmax<Position *>(pos, it->second, comparePosDepth);
292+
predList.emplace_back(minMaxPositions.second,
293+
builder.getEqualTo(minMaxPositions.first));
294+
}
296295
}
297296
predList.emplace_back(pos, pred);
298297
}
@@ -896,15 +895,14 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
896895
}
897896

898897
/// Sorts the range begin/end with the partial order given by cmp.
899-
/// cmp must be a partial ordering.
900898
template <typename Iterator, typename Compare>
901-
void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
899+
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
902900
while (begin != end) {
903901
// Cannot compute sortBeforeOthers in the predicate of stable_partition
904902
// because stable_partition will not keep the [begin, end) range intact
905903
// while it runs.
906904
llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
907-
for(auto i = begin; i != end; ++i) {
905+
for (auto i = begin; i != end; ++i) {
908906
if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
909907
sortBeforeOthers.insert(*i);
910908
}
@@ -917,6 +915,28 @@ void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
917915
}
918916
}
919917

918+
/// Returns true if 'b' depends on a result of 'a'.
919+
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
920+
auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
921+
if (!cqa)
922+
return false;
923+
924+
auto positionDependsOnA = [&](Position *p) {
925+
auto *cp = dyn_cast<ConstraintPosition>(p);
926+
return cp && cp->getQuestion() == cqa;
927+
};
928+
929+
if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
930+
// Does any argument of b use a?
931+
return llvm::any_of(cqb->getArgs(), positionDependsOnA);
932+
}
933+
if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
934+
return positionDependsOnA(b->position) ||
935+
positionDependsOnA(equalTo->getValue());
936+
}
937+
return positionDependsOnA(b->position);
938+
}
939+
920940
/// Given a module containing PDL pattern operations, generate a matcher tree
921941
/// using the patterns within the given module and return the root matcher node.
922942
std::unique_ptr<MatcherNode>
@@ -999,21 +1019,7 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
9991019

10001020
// Mostly keep the now established order, but also ensure that
10011021
// ConstraintQuestions come after the results they use.
1002-
stableTopologicalSort(ordered.begin(), ordered.end(),
1003-
[](OrderedPredicate *a, OrderedPredicate *b) {
1004-
auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
1005-
auto *cqb = dyn_cast<ConstraintQuestion>(b->question);
1006-
if (cqa && cqb) {
1007-
// Does any argument of b use a? Then b must be
1008-
// sorted after a.
1009-
return llvm::any_of(
1010-
cqb->getArgs(), [&](Position *p) {
1011-
auto *cp = dyn_cast<ConstraintPosition>(p);
1012-
return cp && cp->getQuestion() == cqa;
1013-
});
1014-
}
1015-
return false;
1016-
});
1022+
stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
10171023

10181024
// Build the matchers for each of the pattern predicate lists.
10191025
std::unique_ptr<MatcherNode> root;

0 commit comments

Comments
 (0)