Skip to content

Commit b49bb7b

Browse files
author
Mogball
committed
[MLIR][PDL] Add support for representing and lowering negated constraints
This commit enables modelling negation of native constraints. This is accomplished through an attribute `isNegated` on the operations `pdl.apply_native_constraint` and `pdl_interp.apply_constraint` and according adjustments to the conversion in the ConvertPDLToPDLInterpPass. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D153871
1 parent bfbea45 commit b49bb7b

File tree

6 files changed

+42
-11
lines changed

6 files changed

+42
-11
lines changed

mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def PDL_ApplyNativeConstraintOp
4545
```
4646
}];
4747

48-
let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
48+
let arguments = (ins StrAttr:$name,
49+
Variadic<PDL_AnyType>:$args,
50+
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
4951
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
5052
let hasVerifier = 1;
5153
}

mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
9090
`pdl_interp.apply_constraint` operations apply a generic constraint, that
9191
has been registered with the interpreter, with a given set of positional
9292
values. On success, this operation branches to the true destination,
93-
otherwise the false destination is taken.
93+
otherwise the false destination is taken. This behavior can be reversed
94+
by setting the attribute `isNegated` to true.
9495

9596
Example:
9697

@@ -101,7 +102,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
101102
```
102103
}];
103104

104-
let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
105+
let arguments = (ins StrAttr:$name,
106+
Variadic<PDL_AnyType>:$args,
107+
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
105108
let assemblyFormat = [{
106109
$name `(` $args `:` type($args) `)` attr-dict `->` successors
107110
}];

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,9 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
447447
}
448448
case Predicates::ConstraintQuestion: {
449449
auto *cstQuestion = cast<ConstraintQuestion>(question);
450-
builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
451-
args, success, failure);
450+
builder.create<pdl_interp::ApplyConstraintOp>(
451+
loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
452+
failure);
452453
break;
453454
}
454455
default:

mlir/lib/Conversion/PDLToPDLInterp/Predicate.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ struct AttributeQuestion
450450
/// Apply a parameterized constraint to multiple position values.
451451
struct ConstraintQuestion
452452
: public PredicateBase<ConstraintQuestion, Qualifier,
453-
std::tuple<StringRef, ArrayRef<Position *>>,
453+
std::tuple<StringRef, ArrayRef<Position *>, bool>,
454454
Predicates::ConstraintQuestion> {
455455
using Base::Base;
456456

@@ -460,11 +460,20 @@ struct ConstraintQuestion
460460
/// Return the arguments of the constraint.
461461
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
462462

463+
/// Return the negation status of the constraint.
464+
bool getIsNegated() const { return std::get<2>(key); }
465+
463466
/// Construct an instance with the given storage allocator.
464467
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
465468
KeyTy key) {
466469
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
467-
alloc.copyInto(std::get<1>(key))});
470+
alloc.copyInto(std::get<1>(key)),
471+
std::get<2>(key)});
472+
}
473+
474+
/// Returns a hash suitable for the given keytype.
475+
static llvm::hash_code hashKey(const KeyTy &key) {
476+
return llvm::hash_value(key);
468477
}
469478
};
470479

@@ -664,9 +673,11 @@ class PredicateBuilder {
664673
}
665674

666675
/// Create a predicate that applies a generic constraint.
667-
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos) {
668-
return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)),
669-
TrueAnswer::get(uniquer)};
676+
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
677+
bool isNegated) {
678+
return {
679+
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
680+
TrueAnswer::get(uniquer)};
670681
}
671682

672683
/// Create a predicate comparing a value with null.

mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
273273
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
274274
comparePosDepth);
275275
PredicateBuilder::Predicate pred =
276-
builder.getConstraint(op.getName(), allPositions);
276+
builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
277277
predList.emplace_back(pos, pred);
278278
}
279279

mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,20 @@ module @constraints {
7979

8080
// -----
8181

82+
// CHECK-LABEL: module @negated_constraint
83+
module @negated_constraint {
84+
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
85+
// CHECK: pdl_interp.apply_constraint "constraint"(%[[ROOT]] : !pdl.operation) {isNegated = true}
86+
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
87+
pdl.pattern : benefit(1) {
88+
%root = operation
89+
pdl.apply_native_constraint "constraint"(%root : !pdl.operation) {isNegated = true}
90+
rewrite %root with "rewriter"
91+
}
92+
}
93+
94+
// -----
95+
8296
// CHECK-LABEL: module @inputs
8397
module @inputs {
8498
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)

0 commit comments

Comments
 (0)