Skip to content

FXML-2302: Support for negating PDLL constraints #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def PDL_ApplyNativeConstraintOp
```
}];

let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args, DefaultValuedAttr<BoolAttr,"false">:$isNegated);
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict";
let hasVerifier = 1;
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
pdl_interp.apply_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest
```
}];

let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args, DefaultValuedAttr<BoolAttr,"false">:$isNegated);
let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors
Expand Down
13 changes: 10 additions & 3 deletions mlir/include/mlir/Tools/PDLL/AST/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
private llvm::TrailingObjects<CallExpr, Expr *> {
public:
static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
ArrayRef<Expr *> arguments, Type resultType);
ArrayRef<Expr *> arguments, Type resultType,
bool isNegated = false);

/// Return the callable of this call.
Expr *getCallableExpr() const { return callable; }
Expand All @@ -403,9 +404,13 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
return const_cast<CallExpr *>(this)->getArguments();
}

bool getIsNegated() const { return isNegated; }

private:
CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs)
: Base(loc, type), callable(callable), numArgs(numArgs) {}
CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs,
bool isNegated)
: Base(loc, type), callable(callable), numArgs(numArgs),
isNegated(isNegated) {}

/// The callable of this call.
Expr *callable;
Expand All @@ -415,6 +420,8 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,

/// TrailingObject utilities.
friend llvm::TrailingObjects<CallExpr, Expr *>;

bool isNegated;
};

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
auto *cstQuestion = cast<ConstraintQuestion>(question);
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
success, failure);
cstQuestion->getIsNegated(), success, failure);
// Replace the generated placeholders with the results of the constraint and
// erase them
for (auto result : llvm::enumerate(applyConstraintOp.getResults())) {
Expand Down
15 changes: 10 additions & 5 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ struct AttributeQuestion
struct ConstraintQuestion
: public PredicateBase<
ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>>,
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
Predicates::ConstraintQuestion> {
using Base::Base;

Expand All @@ -485,13 +485,18 @@ struct ConstraintQuestion
/// Return the result types of the constraint.
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }

bool getIsNegated() const { return std::get<3>(key); }

/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
alloc.copyInto(std::get<1>(key)),
alloc.copyInto(std::get<2>(key))});
alloc.copyInto(std::get<2>(key)),
std::get<3>(key)});
}

static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); }
};

/// Compare the equality of two values.
Expand Down Expand Up @@ -698,9 +703,9 @@ class PredicateBuilder {

/// Create a predicate that applies a generic constraint.
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
ArrayRef<Type> resultTypes) {
return {ConstraintQuestion::get(uniquer,
std::make_tuple(name, args, resultTypes)),
ArrayRef<Type> resultTypes, bool isNegated) {
return {ConstraintQuestion::get(
uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
TrueAnswer::get(uniquer)};
}

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
comparePosDepth);
ResultRange results = op.getResults();
PredicateBuilder::Predicate pred = builder.getConstraint(
op.getName(), allPositions, SmallVector<Type>(results.getTypes()));
op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
op.getIsNegated());

// for each result register a position so it can be used later
for (auto result : llvm::enumerate(results)) {
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/Rewrite/ByteCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ void Generator::generate(pdl_interp::ApplyConstraintOp op,
// TODO: Handle result ranges
writer.append(result);
}
writer.append(ByteCodeField(op.getIsNegated()));
writer.append(op.getSuccessors());
}
void Generator::generate(pdl_interp::ApplyRewriteOp op,
Expand Down Expand Up @@ -1447,7 +1448,8 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
const PDLConstraintFunction &constraintFn = constraintFunctions[fun_idx];
LogicalResult rewriteResult = constraintFn(rewriter, args);
// Depending on the constraint jump to the proper destination.
selectJump(succeeded(rewriteResult));
ByteCodeField isNegated = read();
selectJump(isNegated != succeeded(rewriteResult));
} else {
const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx];
ByteCodeRewriteResultList results(numResults);
Expand All @@ -1474,7 +1476,8 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
: 0;
}
// Depending on the constraint jump to the proper destination.
selectJump(succeeded(rewriteResult));
ByteCodeField isNegated = read();
selectJump(isNegated != succeeded(rewriteResult));
}
}

Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ void NodePrinter::printImpl(const AttributeExpr *expr) {
void NodePrinter::printImpl(const CallExpr *expr) {
os << "CallExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
os << ">";
if (expr->getIsNegated())
os << " negated";
os << "\n";
printChildren(expr->getCallableExpr());
printChildren("Arguments", expr->getArguments());
}
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Tools/PDLL/AST/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,13 @@ AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
//===----------------------------------------------------------------------===//

CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
ArrayRef<Expr *> arguments, Type resultType) {
ArrayRef<Expr *> arguments, Type resultType,
bool isNegated) {
unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));

CallExpr *expr =
new (rawData) CallExpr(loc, resultType, callable, arguments.size());
CallExpr *expr = new (rawData)
CallExpr(loc, resultType, callable, arguments.size(), isNegated);
std::uninitialized_copy(arguments.begin(), arguments.end(),
expr->getArguments().begin());
return expr;
Expand Down
22 changes: 13 additions & 9 deletions mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,14 @@ class CodeGen {
Value genExprImpl(const ast::TypeExpr *expr);

SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
Location loc, ValueRange inputs);
Location loc, ValueRange inputs,
bool isNegated = false);
SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
Location loc, ValueRange inputs);
template <typename PDLOpT, typename T>
SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
ValueRange inputs);
ValueRange inputs,
bool isNegated = false);

//===--------------------------------------------------------------------===//
// Fields
Expand Down Expand Up @@ -419,7 +421,7 @@ SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
// Generate the PDL based on the type of callable.
const ast::Decl *callable = callableExpr->getDecl();
if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
return genConstraintCall(decl, loc, arguments);
return genConstraintCall(decl, loc, arguments, expr->getIsNegated());
if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
return genRewriteCall(decl, loc, arguments);
llvm_unreachable("unhandled CallExpr callable");
Expand Down Expand Up @@ -553,15 +555,15 @@ Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {

SmallVector<Value>
CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
ValueRange inputs) {
ValueRange inputs, bool isNegated) {
// Apply any constraints defined on the arguments to the input values.
for (auto it : llvm::zip(decl->getInputs(), inputs))
applyVarConstraints(std::get<0>(it), std::get<1>(it));

// Generate the constraint call.
SmallVector<Value> results =
genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc,
inputs);
genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
decl, loc, inputs, isNegated);

// Apply any constraints defined on the results of the constraint.
for (auto it : llvm::zip(decl->getResults(), results))
Expand All @@ -576,9 +578,9 @@ SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
}

template <typename PDLOpT, typename T>
SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
Location loc,
ValueRange inputs) {
SmallVector<Value>
CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
ValueRange inputs, bool isNegated) {
const ast::CompoundStmt *cstBody = decl->getBody();

// If the decl doesn't have a statement body, it is a native decl.
Expand All @@ -593,6 +595,8 @@ SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
}
Operation *pdlOp = builder.create<PDLOpT>(
loc, resultTypes, decl->getName().getName(), inputs);
if (isNegated)
pdlOp->setAttr("isNegated", builder.getBoolAttr(true));
return pdlOp->getResults();
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ Token Lexer::lexToken() {
return formToken(Token::l_paren, tokStart);
case ')':
return formToken(Token::r_paren, tokStart);
case '!':
return formToken(Token::exclam, tokStart);
case '/':
if (*curPtr == '/') {
lexComment();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class Token {
equal,
equal_arrow,
semicolon,
exclam,
/// Paired punctuation.
less,
greater,
Expand Down
36 changes: 27 additions & 9 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Parser.h"
#include <string>
#include <optional>
#include <string>

using namespace mlir;
using namespace mlir::pdll;
Expand Down Expand Up @@ -316,13 +316,15 @@ class Parser {
/// Identifier expressions.
FailureOr<ast::Expr *> parseArrayAttrExpr();
FailureOr<ast::Expr *> parseAttributeExpr();
FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
bool isNegated = false);
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
FailureOr<ast::Expr *> parseDictAttrExpr();
FailureOr<ast::Expr *> parseIdentifierExpr();
FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
FailureOr<ast::Expr *> parseNegatedExpr();
FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
FailureOr<ast::Expr *>
Expand Down Expand Up @@ -406,7 +408,7 @@ class Parser {

FailureOr<ast::CallExpr *>
createCallExpr(SMRange loc, ast::Expr *parentExpr,
MutableArrayRef<ast::Expr *> arguments);
MutableArrayRef<ast::Expr *> arguments, bool isNegated);
FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
FailureOr<ast::DeclRefExpr *>
createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
Expand Down Expand Up @@ -1829,6 +1831,9 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
case Token::l_square:
lhsExpr = parseArrayAttrExpr();
break;
case Token::exclam:
lhsExpr = parseNegatedExpr();
break;
case Token::string_block:
return emitError("expected expression. If you are trying to create an "
"ArrayAttr, use a space between `[` and `{`.");
Expand Down Expand Up @@ -1912,7 +1917,8 @@ FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
return ast::AttributeExpr::create(ctx, loc, attrExpr);
}

FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
bool isNegated) {
consumeToken(Token::l_paren);

// Parse the arguments of the call.
Expand All @@ -1936,7 +1942,7 @@ FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
return failure();

return createCallExpr(loc, parentExpr, arguments);
return createCallExpr(loc, parentExpr, arguments, isNegated);
}

FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
Expand Down Expand Up @@ -2061,6 +2067,16 @@ FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
return createMemberAccessExpr(parentExpr, memberName, loc);
}

FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
consumeToken(Token::exclam);
if (!curToken.is(Token::identifier))
return emitError("expected native constraint");
FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
if (failed(identifierExpr))
return failure();
return parseCallExpr(*identifierExpr, /*isNegated = */ true);
}

FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
SMRange loc = curToken.getLoc();

Expand Down Expand Up @@ -2789,7 +2805,8 @@ Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {

FailureOr<ast::CallExpr *>
Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
MutableArrayRef<ast::Expr *> arguments) {
MutableArrayRef<ast::Expr *> arguments,
bool isNegated = false) {
ast::Type parentType = parentExpr->getType();

ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
Expand All @@ -2803,6 +2820,8 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
if (isa<ast::UserConstraintDecl>(callableDecl))
return emitError(
loc, "unable to invoke `Constraint` within a rewrite section");
if (isNegated)
return emitError(loc, "negation of Rewrites is not supported");
} else if (isa<ast::UserRewriteDecl>(callableDecl)) {
return emitError(loc, "unable to invoke `Rewrite` within a match section");
}
Expand Down Expand Up @@ -2835,7 +2854,7 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
}

return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
callableDecl->getResultType());
callableDecl->getResultType(), isNegated);
}

FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
Expand Down Expand Up @@ -2959,8 +2978,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
OpResultTypeContext resultTypeContext,
SmallVectorImpl<ast::Expr *> &operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
SmallVectorImpl<ast::Expr *> &results,
unsigned numRegions) {
SmallVectorImpl<ast::Expr *> &results, unsigned numRegions) {
std::optional<StringRef> opNameRef = name->getName();
const ods::Operation *odsOp = lookupODSOperation(opNameRef);

Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ module @constraint_with_result_multiple {

// -----

// CHECK-LABEL: module @negated_constraint
module @negated_constraint {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: pdl_interp.apply_constraint "constraint"(%[[ROOT]] : !pdl.operation) {isNegated = true}
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
pdl.pattern : benefit(1) {
%root = operation
pdl.apply_native_constraint "constraint"(%root : !pdl.operation) {isNegated = true}
rewrite %root with "rewriter"
}
}

// -----

// CHECK-LABEL: module @inputs
module @inputs {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
Expand Down
Loading