Skip to content

Commit c80e6ed

Browse files
committed
Revert "[mlir][PDL] Add support for native constraints with results (#82760)"
Due to buildbot failure https://lab.llvm.org/buildbot/#/builders/88/builds/72130 This reverts commit dca32a3.
1 parent dca32a3 commit c80e6ed

File tree

18 files changed

+98
-555
lines changed

18 files changed

+98
-555
lines changed

mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,20 @@ def PDL_ApplyNativeConstraintOp
3535
let description = [{
3636
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
3737
has been registered externally with the consumer of PDL, to a given set of
38-
entities and optionally return a number of values.
38+
entities.
3939

4040
Example:
4141

4242
```mlir
4343
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
4444
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
45-
// Apply constraint `with_result` to `root`. This constraint returns an attribute.
46-
%attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
4745
```
4846
}];
4947

5048
let arguments = (ins StrAttr:$name,
5149
Variadic<PDL_AnyType>:$args,
5250
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
53-
let results = (outs Variadic<PDL_AnyType>:$results);
54-
let assemblyFormat = [{
55-
$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
56-
}];
51+
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
5752
let hasVerifier = 1;
5853
}
5954

mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
8888
let description = [{
8989
`pdl_interp.apply_constraint` operations apply a generic constraint, that
9090
has been registered with the interpreter, with a given set of positional
91-
values.
92-
The constraint function may return any number of results.
93-
On success, this operation branches to the true destination,
91+
values. On success, this operation branches to the true destination,
9492
otherwise the false destination is taken. This behavior can be reversed
9593
by setting the attribute `isNegated` to true.
9694

@@ -106,10 +104,8 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
106104
let arguments = (ins StrAttr:$name,
107105
Variadic<PDL_AnyType>:$args,
108106
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
109-
let results = (outs Variadic<PDL_AnyType>:$results);
110107
let assemblyFormat = [{
111-
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
112-
`->` successors
108+
$name `(` $args `:` type($args) `)` attr-dict `->` successors
113109
}];
114110
}
115111

mlir/include/mlir/IR/PDLPatternMatch.h.inc

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,8 @@ protected:
318318
/// A generic PDL pattern constraint function. This function applies a
319319
/// constraint to a given set of opaque PDLValue entities. Returns success if
320320
/// the constraint successfully held, failure otherwise.
321-
using PDLConstraintFunction = std::function<LogicalResult(
322-
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
323-
321+
using PDLConstraintFunction =
322+
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
324323
/// A native PDL rewrite function. This function performs a rewrite on the
325324
/// given set of values. Any results from this rewrite that should be passed
326325
/// back to PDL should be added to the provided result list. This method is only
@@ -727,7 +726,7 @@ std::enable_if_t<
727726
PDLConstraintFunction>
728727
buildConstraintFn(ConstraintFnT &&constraintFn) {
729728
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
730-
PatternRewriter &rewriter, PDLResultList &,
729+
PatternRewriter &rewriter,
731730
ArrayRef<PDLValue> values) -> LogicalResult {
732731
auto argIndices = std::make_index_sequence<
733732
llvm::function_traits<ConstraintFnT>::num_args - 1>();
@@ -843,13 +842,10 @@ public:
843842
/// Register a constraint function with PDL. A constraint function may be
844843
/// specified in one of two ways:
845844
///
846-
/// * `LogicalResult (PatternRewriter &,
847-
/// PDLResultList &,
848-
/// ArrayRef<PDLValue>)`
845+
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
849846
///
850847
/// In this overload the arguments of the constraint function are passed via
851-
/// the low-level PDLValue form, and the results are manually appended to
852-
/// the given result list.
848+
/// the low-level PDLValue form.
853849
///
854850
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
855851
///
@@ -964,8 +960,8 @@ public:
964960
}
965961
};
966962
class PDLResultList {};
967-
using PDLConstraintFunction = std::function<LogicalResult(
968-
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
963+
using PDLConstraintFunction =
964+
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
969965
using PDLRewriteFunction = std::function<LogicalResult(
970966
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
971967

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ struct PatternLowering {
5050

5151
/// Generate interpreter operations for the tree rooted at the given matcher
5252
/// node, in the specified region.
53-
Block *generateMatcher(MatcherNode &node, Region &region,
54-
Block *block = nullptr);
53+
Block *generateMatcher(MatcherNode &node, Region &region);
5554

5655
/// Get or create an access to the provided positional value in the current
5756
/// block. This operation may mutate the provided block pointer if nested
@@ -149,10 +148,6 @@ struct PatternLowering {
149148
/// A mapping between pattern operations and the corresponding configuration
150149
/// set.
151150
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
152-
153-
/// A mapping from a constraint question to the ApplyConstraintOp
154-
/// that implements it.
155-
DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
156151
};
157152
} // namespace
158153

@@ -187,11 +182,9 @@ void PatternLowering::lower(ModuleOp module) {
187182
firstMatcherBlock->erase();
188183
}
189184

190-
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
191-
Block *block) {
185+
Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
192186
// Push a new scope for the values used by this matcher.
193-
if (!block)
194-
block = &region.emplaceBlock();
187+
Block *block = &region.emplaceBlock();
195188
ValueMapScope scope(values);
196189

197190
// If this is the return node, simply insert the corresponding interpreter
@@ -371,15 +364,6 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
371364
loc, cast<ArrayAttr>(rawTypeAttr));
372365
break;
373366
}
374-
case Predicates::ConstraintResultPos: {
375-
// Due to the order of traversal, the ApplyConstraintOp has already been
376-
// created and we can find it in constraintOpMap.
377-
auto *constrResPos = cast<ConstraintPosition>(pos);
378-
auto i = constraintOpMap.find(constrResPos->getQuestion());
379-
assert(i != constraintOpMap.end());
380-
value = i->second->getResult(constrResPos->getIndex());
381-
break;
382-
}
383367
default:
384368
llvm_unreachable("Generating unknown Position getter");
385369
break;
@@ -406,11 +390,12 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
406390
args.push_back(getValueAt(currentBlock, position));
407391
}
408392

409-
// Generate a new block as success successor and get the failure successor.
410-
Block *success = &region->emplaceBlock();
393+
// Generate the matcher in the current (potentially nested) region
394+
// and get the failure successor.
395+
Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
411396
Block *failure = failureBlockStack.back();
412397

413-
// Create the predicate.
398+
// Finally, create the predicate.
414399
builder.setInsertionPointToEnd(currentBlock);
415400
Predicates::Kind kind = question->getKind();
416401
switch (kind) {
@@ -462,20 +447,14 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
462447
}
463448
case Predicates::ConstraintQuestion: {
464449
auto *cstQuestion = cast<ConstraintQuestion>(question);
465-
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
466-
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
467-
cstQuestion->getIsNegated(), success, failure);
468-
469-
constraintOpMap.insert({cstQuestion, applyConstraintOp});
450+
builder.create<pdl_interp::ApplyConstraintOp>(
451+
loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
452+
failure);
470453
break;
471454
}
472455
default:
473456
llvm_unreachable("Generating unknown Predicate operation");
474457
}
475-
476-
// Generate the matcher in the current (potentially nested) region.
477-
// This might use the results of the current predicate.
478-
generateMatcher(*boolNode->getSuccessNode(), *region, success);
479458
}
480459

481460
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>

mlir/lib/Conversion/PDLToPDLInterp/Predicate.h

Lines changed: 11 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ enum Kind : unsigned {
4747
OperandPos,
4848
OperandGroupPos,
4949
AttributePos,
50-
ConstraintResultPos,
5150
ResultPos,
5251
ResultGroupPos,
5352
TypePos,
@@ -280,28 +279,6 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
280279
bool isOperandDefiningOp() const;
281280
};
282281

283-
//===----------------------------------------------------------------------===//
284-
// ConstraintPosition
285-
286-
struct ConstraintQuestion;
287-
288-
/// A position describing the result of a native constraint. It saves the
289-
/// corresponding ConstraintQuestion and result index to enable referring
290-
/// back to them
291-
struct ConstraintPosition
292-
: public PredicateBase<ConstraintPosition, Position,
293-
std::pair<ConstraintQuestion *, unsigned>,
294-
Predicates::ConstraintResultPos> {
295-
using PredicateBase::PredicateBase;
296-
297-
/// Returns the ConstraintQuestion to enable keeping track of the native
298-
/// constraint this position stems from.
299-
ConstraintQuestion *getQuestion() const { return key.first; }
300-
301-
// Returns the result index of this position
302-
unsigned getIndex() const { return key.second; }
303-
};
304-
305282
//===----------------------------------------------------------------------===//
306283
// ResultPosition
307284

@@ -470,13 +447,11 @@ struct AttributeQuestion
470447
: public PredicateBase<AttributeQuestion, Qualifier, void,
471448
Predicates::AttributeQuestion> {};
472449

473-
/// Apply a parameterized constraint to multiple position values and possibly
474-
/// produce results.
450+
/// Apply a parameterized constraint to multiple position values.
475451
struct ConstraintQuestion
476-
: public PredicateBase<
477-
ConstraintQuestion, Qualifier,
478-
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
479-
Predicates::ConstraintQuestion> {
452+
: public PredicateBase<ConstraintQuestion, Qualifier,
453+
std::tuple<StringRef, ArrayRef<Position *>, bool>,
454+
Predicates::ConstraintQuestion> {
480455
using Base::Base;
481456

482457
/// Return the name of the constraint.
@@ -485,19 +460,15 @@ struct ConstraintQuestion
485460
/// Return the arguments of the constraint.
486461
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
487462

488-
/// Return the result types of the constraint.
489-
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
490-
491463
/// Return the negation status of the constraint.
492-
bool getIsNegated() const { return std::get<3>(key); }
464+
bool getIsNegated() const { return std::get<2>(key); }
493465

494466
/// Construct an instance with the given storage allocator.
495467
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
496468
KeyTy key) {
497469
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
498470
alloc.copyInto(std::get<1>(key)),
499-
alloc.copyInto(std::get<2>(key)),
500-
std::get<3>(key)});
471+
std::get<2>(key)});
501472
}
502473

503474
/// Returns a hash suitable for the given keytype.
@@ -555,7 +526,6 @@ class PredicateUniquer : public StorageUniquer {
555526
// Register the types of Positions with the uniquer.
556527
registerParametricStorageType<AttributePosition>();
557528
registerParametricStorageType<AttributeLiteralPosition>();
558-
registerParametricStorageType<ConstraintPosition>();
559529
registerParametricStorageType<ForEachPosition>();
560530
registerParametricStorageType<OperandPosition>();
561531
registerParametricStorageType<OperandGroupPosition>();
@@ -618,12 +588,6 @@ class PredicateBuilder {
618588
return OperationPosition::get(uniquer, p);
619589
}
620590

621-
// Returns a position for a new value created by a constraint.
622-
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
623-
unsigned index) {
624-
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
625-
}
626-
627591
/// Returns an attribute position for an attribute of the given operation.
628592
Position *getAttribute(OperationPosition *p, StringRef name) {
629593
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -709,11 +673,11 @@ class PredicateBuilder {
709673
}
710674

711675
/// Create a predicate that applies a generic constraint.
712-
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
713-
ArrayRef<Type> resultTypes, bool isNegated) {
714-
return {ConstraintQuestion::get(
715-
uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
716-
TrueAnswer::get(uniquer)};
676+
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
677+
bool isNegated) {
678+
return {
679+
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
680+
TrueAnswer::get(uniquer)};
717681
}
718682

719683
/// Create a predicate comparing a value with null.

0 commit comments

Comments
 (0)