Skip to content

PDLL native operator abs #173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions mlir/include/mlir/Dialect/PDL/IR/Builtins.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ void registerBuiltins(PDLPatternModule &pdlPattern);
namespace builtin {
enum class BinaryOpKind {
add,
sub,
mul,
div,
mod,
mul,
sub,
};

enum class UnaryOpKind {
log2,
abs,
exp2,
log2,
};

LogicalResult createDictionaryAttr(PatternRewriter &rewriter,
Expand All @@ -48,9 +49,6 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
Attribute createArrayAttr(PatternRewriter &rewriter);
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
Attribute element);
template <BinaryOpKind T>
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
LogicalResult div(PatternRewriter &rewriter, PDLResultList &results,
Expand All @@ -65,10 +63,8 @@ LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
LogicalResult exp2(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);

template <BinaryOpKind T>
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
LogicalResult abs(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
} // namespace builtin
} // namespace pdl
} // namespace mlir
Expand Down
51 changes: 42 additions & 9 deletions mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
}

template <UnaryOpKind T>
LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
LogicalResult static unaryOp(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
assert(args.size() == 1 && "Expected one operand for unary operation");
auto operandAttr = args[0].cast<Attribute>();

Expand Down Expand Up @@ -99,6 +99,27 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
getIntegerAsAttr(APSInt(operandIntAttr.getValue(), false)));
else
results.push_back(getIntegerAsAttr(operandIntAttr.getAPSInt()));
} else if constexpr (T == UnaryOpKind::abs) {
if (integerType.isSigned()) {
// check overflow
if (operandIntAttr.getAPSInt() ==
APSInt::getMinValue(integerType.getIntOrFloatBitWidth(), false))
return failure();

results.push_back(rewriter.getIntegerAttr(
integerType, operandIntAttr.getValue().abs()));
return success();
}
if (integerType.isSignless()) {
// Overflow should not be checked.
// Otherwise the purpose of signless integer is meaningless.
results.push_back(rewriter.getIntegerAttr(
integerType, operandIntAttr.getValue().abs()));
return success();
}
// If unsigned, do nothing
results.push_back(operandIntAttr);
return success();
} else {
llvm::llvm_unreachable_internal(
"encountered an unsupported unary operator");
Expand Down Expand Up @@ -140,20 +161,26 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
})
.Default([](Type /*type*/) { return failure(); });
} else if constexpr (T == UnaryOpKind::log2) {
auto minF32 = APFloat::getSmallest(llvm::APFloat::IEEEsingle());

APFloat resultFloat((float)operandFloatAttr.getValue().getExactLog2());
results.push_back(rewriter.getFloatAttr(
operandFloatAttr.getType(),
(double)operandFloatAttr.getValue().getExactLog2()));
} else if constexpr (T == UnaryOpKind::abs) {
auto resultVal = operandFloatAttr.getValue();
resultVal.clearSign();
results.push_back(
rewriter.getFloatAttr(operandFloatAttr.getType(), resultFloat));
rewriter.getFloatAttr(operandFloatAttr.getType(), resultVal));
} else {
llvm::llvm_unreachable_internal(
"encountered an unsupported unary operator");
}
return success();
}
return failure();
}

template <BinaryOpKind T>
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args) {
LogicalResult static binaryOp(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args) {
assert(args.size() == 2 && "Expected two operands for binary operation");
auto lhsAttr = args[0].cast<Attribute>();
auto rhsAttr = args[1].cast<Attribute>();
Expand Down Expand Up @@ -294,6 +321,10 @@ LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args) {
return unaryOp<UnaryOpKind::log2>(rewriter, results, args);
}
LogicalResult abs(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args) {
return unaryOp<UnaryOpKind::abs>(rewriter, results, args);
}
} // namespace builtin

void registerBuiltins(PDLPatternModule &pdlPattern) {
Expand All @@ -319,7 +350,7 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
pdlPattern.registerRewriteFunction("__builtin_subRewrite", sub);
pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2);
pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2);

pdlPattern.registerRewriteFunction("__builtin_absRewrite", abs);
pdlPattern.registerConstraintFunctionWithResults("__builtin_mulConstraint",
mul);
pdlPattern.registerConstraintFunctionWithResults("__builtin_divConstraint",
Expand All @@ -334,5 +365,7 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
log2);
pdlPattern.registerConstraintFunctionWithResults("__builtin_exp2Constraint",
exp2);
pdlPattern.registerConstraintFunctionWithResults("__builtin_absConstraint",
abs);
}
} // namespace mlir::pdl
1 change: 1 addition & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
.Case("_", Token::underscore)
.Case("log2", Token::log2)
.Case("exp2", Token::exp2)
.Case("abs", Token::abs)
.Default(Token::identifier);
return Token(kind, str);
}
Expand Down
14 changes: 9 additions & 5 deletions mlir/lib/Tools/PDLL/Parser/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,18 @@ class Token {
equal,
equal_arrow,
semicolon,
/// Paired punctuation.
mul,

/// Arithmetic.
abs,
add,
div,
exp2,
log2,
mod,
add,
mul,
sub,
log2,
exp2,

/// Paired punctuation.
less,
greater,
l_brace,
Expand Down
38 changes: 32 additions & 6 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class Parser {
FailureOr<ast::Expr *> parseLogicalAndExpr();
FailureOr<ast::Expr *> parseEqualityExpr();
FailureOr<ast::Expr *> parseRelationExpr();
FailureOr<ast::Expr *> parseExp2Log2Expr();
FailureOr<ast::Expr *> parseExp2Log2AbsExpr();
FailureOr<ast::Expr *> parseAddSubExpr();
FailureOr<ast::Expr *> parseMulDivModExpr();
FailureOr<ast::Expr *> parseLogicalNotExpr();
Expand Down Expand Up @@ -624,13 +624,15 @@ class Parser {
ast::UserRewriteDecl *subRewrite;
ast::UserRewriteDecl *log2Rewrite;
ast::UserRewriteDecl *exp2Rewrite;
ast::UserRewriteDecl *absRewrite;
ast::UserConstraintDecl *mulConstraint;
ast::UserConstraintDecl *divConstraint;
ast::UserConstraintDecl *modConstraint;
ast::UserConstraintDecl *addConstraint;
ast::UserConstraintDecl *subConstraint;
ast::UserConstraintDecl *log2Constraint;
ast::UserConstraintDecl *exp2Constraint;
ast::UserConstraintDecl *absConstraint;
} builtins{};
};
} // namespace
Expand Down Expand Up @@ -701,6 +703,8 @@ void Parser::declareBuiltins() {
"__builtin_log2Rewrite", {"Attr"}, true);
builtins.exp2Rewrite = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_exp2Rewrite", {"Attr"}, true);
builtins.absRewrite = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_absRewrite", {"Attr"}, true);
builtins.mulConstraint = declareBuiltin<ast::UserConstraintDecl>(
"__builtin_mulConstraint", {"lhs", "rhs"}, true);
builtins.divConstraint = declareBuiltin<ast::UserConstraintDecl>(
Expand All @@ -715,6 +719,8 @@ void Parser::declareBuiltins() {
"__builtin_log2Constraint", {"Attr"}, true);
builtins.exp2Constraint = declareBuiltin<ast::UserConstraintDecl>(
"__builtin_exp2Constraint", {"Attr"}, true);
builtins.absConstraint = declareBuiltin<ast::UserConstraintDecl>(
"__builtin_absConstraint", {"Attr"}, true);
}

FailureOr<ast::Module *> Parser::parseModule() {
Expand Down Expand Up @@ -2030,15 +2036,15 @@ FailureOr<ast::Expr *> Parser::parseAddSubExpr() {
}

FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
auto lhs = parseExp2Log2Expr();
auto lhs = parseExp2Log2AbsExpr();
if (failed(lhs))
return failure();

for (;;) {
switch (curToken.getKind()) {
case Token::mul: {
consumeToken();
auto rhs = parseExp2Log2Expr();
auto rhs = parseExp2Log2AbsExpr();
if (failed(rhs))
return failure();
SmallVector<ast::Expr *> args{*lhs, *rhs};
Expand All @@ -2058,7 +2064,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
}
case Token::div: {
consumeToken();
auto rhs = parseExp2Log2Expr();
auto rhs = parseExp2Log2AbsExpr();
if (failed(rhs))
return failure();
SmallVector<ast::Expr *> args{*lhs, *rhs};
Expand All @@ -2078,7 +2084,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
}
case Token::mod: {
consumeToken();
auto rhs = parseExp2Log2Expr();
auto rhs = parseExp2Log2AbsExpr();
if (failed(rhs))
return failure();
SmallVector<ast::Expr *> args{*lhs, *rhs};
Expand All @@ -2100,7 +2106,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
}
}

FailureOr<ast::Expr *> Parser::parseExp2Log2Expr() {
FailureOr<ast::Expr *> Parser::parseExp2Log2AbsExpr() {
FailureOr<ast::Expr *> expr = nullptr;

switch (curToken.getKind()) {
Expand Down Expand Up @@ -2144,6 +2150,26 @@ FailureOr<ast::Expr *> Parser::parseExp2Log2Expr() {
: createBuiltinCall(curToken.getLoc(), builtins.exp2Constraint,
{*expr});
}
case Token::abs: {
consumeToken();
consumeToken(Token::l_paren);
expr = parseAddSubExpr();
if (failed(expr))
return failure();

// Check if it is in rewrite section but not in the let statement
bool inRewriteSection = parserContext == ParserContext::Rewrite;
if (inRewriteSection && nativeOperatorContext != NativeOperatorContext::Let)
return emitError("cannot evaluate abs operator in rewrite section. "
"Assign to a variable with `let`");

consumeToken(Token::r_paren);
return inRewriteSection
? createBuiltinCall(curToken.getLoc(), builtins.absRewrite,
{*expr})
: createBuiltinCall(curToken.getLoc(), builtins.absConstraint,
{*expr});
}
default:
return parseLogicalNotExpr();
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/mlir-pdll-lsp-server/completion.test
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@
// CHECK-NEXT: "kind": 8,
// CHECK-NEXT: "label": "__builtin_log2Constraint",
// CHECK-NEXT: "sortText": "2___builtin_log2Constraint"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "(Attr: Attr) -> Attr",
// CHECK-NEXT: "kind": 8,
// CHECK-NEXT: "label": "__builtin_absConstraint",
// CHECK-NEXT: "sortText": "2___builtin_absConstraint"
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: }
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ Pattern TestAdd {
// CHECK: apply_native_constraint "__builtin_modConstraint"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_constraint "__builtin_log2Constraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_constraint "__builtin_exp2Constraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_constraint "__builtin_absConstraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute

Pattern TestOperatorsNotInRewriteSection {
let a : Attr = attr<"4 : i32">;
Expand All @@ -320,6 +321,7 @@ Pattern TestOperatorsNotInRewriteSection {
let modConstraint : Attr = a % b;
let log2Constraint : Attr = log2(a);
let exp2Constraint : Attr = exp2(a);
let absConstraint : Attr = abs(a);
replace op<test.simple> with op<test.success>;
}

Expand All @@ -336,6 +338,7 @@ Pattern TestOperatorsNotInRewriteSection {
// CHECK: apply_native_rewrite "__builtin_modRewrite"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_rewrite "__builtin_log2Rewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_rewrite "__builtin_exp2Rewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_rewrite "__builtin_absRewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
Pattern TestOperatorsInRewriteSection {
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
rewrite root with {
Expand All @@ -348,6 +351,7 @@ Pattern TestOperatorsInRewriteSection {
let modRewrite : Attr = a % b;
let log2Rewrite : Attr = log2(a);
let exp2Rewrite : Attr = exp2(a);
let absRewrite : Attr = abs(a);
erase root;
};
}
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,17 @@ Pattern {

// -----

Pattern {
// CHECK: cannot evaluate abs operator in rewrite section. Assign to a variable with `let`
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
rewrite root with {
abs(attr<"-4 : si32">);
erase root;
};
}

// -----

// check llvm::saveAndRestore works
Pattern {
// CHECK: cannot evaluate exp2 operator in rewrite section. Assign to a variable with `let`
Expand All @@ -590,4 +601,30 @@ Pattern {
exp2(attr<"4 : i32">);
erase root;
};
}

// -----

// check llvm::saveAndRestore works
Pattern {
// CHECK: cannot evaluate log2 operator in rewrite section. Assign to a variable with `let`
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
rewrite root with {
let a : Attr = attr<"4 : i32"> + attr<"5 : i32">;
log2(attr<"4 : i32">);
erase root;
};
}

// -----

// check llvm::saveAndRestore works
Pattern {
// CHECK: cannot evaluate abs operator in rewrite section. Assign to a variable with `let`
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
rewrite root with {
let a : Attr = attr<"4 : i32"> + attr<"5 : i32">;
abs(attr<"4 : i32">);
erase root;
};
}
Loading