Skip to content

Commit a296376

Browse files
committed
feat: pdll native operator for abs
1 parent 5f808cd commit a296376

File tree

10 files changed

+214
-34
lines changed

10 files changed

+214
-34
lines changed

mlir/include/mlir/Dialect/PDL/IR/Builtins.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ enum class BinaryOpKind {
3737
enum class UnaryOpKind {
3838
log2,
3939
exp2,
40+
abs,
4041
};
4142

4243
LogicalResult createDictionaryAttr(PatternRewriter &rewriter,
@@ -48,9 +49,6 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
4849
Attribute createArrayAttr(PatternRewriter &rewriter);
4950
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
5051
Attribute element);
51-
template <BinaryOpKind T>
52-
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
53-
llvm::ArrayRef<PDLValue> args);
5452
LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results,
5553
llvm::ArrayRef<PDLValue> args);
5654
LogicalResult div(PatternRewriter &rewriter, PDLResultList &results,
@@ -65,10 +63,8 @@ LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
6563
llvm::ArrayRef<PDLValue> args);
6664
LogicalResult exp2(PatternRewriter &rewriter, PDLResultList &results,
6765
llvm::ArrayRef<PDLValue> args);
68-
69-
template <BinaryOpKind T>
70-
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
71-
llvm::ArrayRef<PDLValue> args);
66+
LogicalResult abs(PatternRewriter &rewriter, PDLResultList &results,
67+
llvm::ArrayRef<PDLValue> args);
7268
} // namespace builtin
7369
} // namespace pdl
7470
} // namespace mlir

mlir/lib/Dialect/PDL/IR/Builtins.cpp

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cassert>
22
#include <cstdint>
3+
#include <iostream>
34
#include <llvm/ADT/APFloat.h>
45
#include <llvm/ADT/APInt.h>
56
#include <llvm/ADT/APSInt.h>
@@ -59,8 +60,8 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
5960
}
6061

6162
template <UnaryOpKind T>
62-
LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
63-
ArrayRef<PDLValue> args) {
63+
LogicalResult static unaryOp(PatternRewriter &rewriter, PDLResultList &results,
64+
ArrayRef<PDLValue> args) {
6465
assert(args.size() == 1 && "Expected one operand for unary operation");
6566
auto operandAttr = args[0].cast<Attribute>();
6667

@@ -99,6 +100,33 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
99100
getIntegerAsAttr(APSInt(operandIntAttr.getValue(), false)));
100101
else
101102
results.push_back(getIntegerAsAttr(operandIntAttr.getAPSInt()));
103+
} else if constexpr (T == UnaryOpKind::abs) {
104+
if (integerType.isSigned()) {
105+
// check overflow
106+
if (operandIntAttr.getAPSInt() ==
107+
APSInt::getMinValue(integerType.getIntOrFloatBitWidth(), false))
108+
return failure();
109+
110+
results.push_back(rewriter.getIntegerAttr(
111+
integerType, std::abs(operandIntAttr.getSInt())));
112+
return success();
113+
}
114+
if (integerType.isSignless()) {
115+
auto resultVal = rewriter.getIntegerAttr(
116+
integerType, std::abs(operandIntAttr.getInt()));
117+
results.push_back(rewriter.getIntegerAttr(
118+
integerType, std::abs(operandIntAttr.getInt())));
119+
120+
std::cout << "Input: "
121+
<< (uint8_t)operandIntAttr.getValue().getZExtValue()
122+
<< std::endl;
123+
std::cout << "Result store in IntegerAttr: " << resultVal.getInt()
124+
<< std::endl;
125+
return success();
126+
}
127+
// If unsigned, don't do anything
128+
results.push_back(operandIntAttr);
129+
return success();
102130
} else {
103131
llvm::llvm_unreachable_internal(
104132
"encountered an unsupported unary operator");
@@ -140,20 +168,25 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
140168
})
141169
.Default([](Type /*type*/) { return failure(); });
142170
} else if constexpr (T == UnaryOpKind::log2) {
143-
auto minF32 = APFloat::getSmallest(llvm::APFloat::IEEEsingle());
144-
145-
APFloat resultFloat((float)operandFloatAttr.getValue().getExactLog2());
146-
results.push_back(
147-
rewriter.getFloatAttr(operandFloatAttr.getType(), resultFloat));
171+
results.push_back(rewriter.getFloatAttr(
172+
operandFloatAttr.getType(),
173+
(double)operandFloatAttr.getValue().getExactLog2()));
174+
} else if constexpr (T == UnaryOpKind::abs) {
175+
results.push_back(rewriter.getFloatAttr(
176+
operandFloatAttr.getType(),
177+
std::abs(operandFloatAttr.getValue().convertToFloat())));
178+
} else {
179+
llvm::llvm_unreachable_internal(
180+
"encountered an unsupported unary operator");
148181
}
149182
return success();
150183
}
151184
return failure();
152185
}
153186

154187
template <BinaryOpKind T>
155-
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
156-
llvm::ArrayRef<PDLValue> args) {
188+
LogicalResult static binaryOp(PatternRewriter &rewriter, PDLResultList &results,
189+
llvm::ArrayRef<PDLValue> args) {
157190
assert(args.size() == 2 && "Expected two operands for binary operation");
158191
auto lhsAttr = args[0].cast<Attribute>();
159192
auto rhsAttr = args[1].cast<Attribute>();
@@ -294,6 +327,10 @@ LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
294327
llvm::ArrayRef<PDLValue> args) {
295328
return unaryOp<UnaryOpKind::log2>(rewriter, results, args);
296329
}
330+
LogicalResult abs(PatternRewriter &rewriter, PDLResultList &results,
331+
llvm::ArrayRef<PDLValue> args) {
332+
return unaryOp<UnaryOpKind::abs>(rewriter, results, args);
333+
}
297334
} // namespace builtin
298335

299336
void registerBuiltins(PDLPatternModule &pdlPattern) {
@@ -319,7 +356,7 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
319356
pdlPattern.registerRewriteFunction("__builtin_subRewrite", sub);
320357
pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2);
321358
pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2);
322-
359+
pdlPattern.registerRewriteFunction("__builtin_absRewrite", abs);
323360
pdlPattern.registerConstraintFunctionWithResults("__builtin_mulConstraint",
324361
mul);
325362
pdlPattern.registerConstraintFunctionWithResults("__builtin_divConstraint",
@@ -334,5 +371,7 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
334371
log2);
335372
pdlPattern.registerConstraintFunctionWithResults("__builtin_exp2Constraint",
336373
exp2);
374+
pdlPattern.registerConstraintFunctionWithResults("__builtin_absConstraint",
375+
abs);
337376
}
338377
} // namespace mlir::pdl

mlir/lib/Tools/PDLL/Parser/Lexer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
377377
.Case("_", Token::underscore)
378378
.Case("log2", Token::log2)
379379
.Case("exp2", Token::exp2)
380+
.Case("abs", Token::abs)
380381
.Default(Token::identifier);
381382
return Token(kind, str);
382383
}

mlir/lib/Tools/PDLL/Parser/Lexer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class Token {
8888
sub,
8989
log2,
9090
exp2,
91+
abs,
9192
less,
9293
greater,
9394
l_brace,

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ class Parser {
334334
FailureOr<ast::Expr *> parseLogicalAndExpr();
335335
FailureOr<ast::Expr *> parseEqualityExpr();
336336
FailureOr<ast::Expr *> parseRelationExpr();
337-
FailureOr<ast::Expr *> parseExp2Log2Expr();
337+
FailureOr<ast::Expr *> parseExp2Log2AbsExpr();
338338
FailureOr<ast::Expr *> parseAddSubExpr();
339339
FailureOr<ast::Expr *> parseMulDivModExpr();
340340
FailureOr<ast::Expr *> parseLogicalNotExpr();
@@ -624,13 +624,15 @@ class Parser {
624624
ast::UserRewriteDecl *subRewrite;
625625
ast::UserRewriteDecl *log2Rewrite;
626626
ast::UserRewriteDecl *exp2Rewrite;
627+
ast::UserRewriteDecl *absRewrite;
627628
ast::UserConstraintDecl *mulConstraint;
628629
ast::UserConstraintDecl *divConstraint;
629630
ast::UserConstraintDecl *modConstraint;
630631
ast::UserConstraintDecl *addConstraint;
631632
ast::UserConstraintDecl *subConstraint;
632633
ast::UserConstraintDecl *log2Constraint;
633634
ast::UserConstraintDecl *exp2Constraint;
635+
ast::UserConstraintDecl *absConstraint;
634636
} builtins{};
635637
};
636638
} // namespace
@@ -701,6 +703,8 @@ void Parser::declareBuiltins() {
701703
"__builtin_log2Rewrite", {"Attr"}, true);
702704
builtins.exp2Rewrite = declareBuiltin<ast::UserRewriteDecl>(
703705
"__builtin_exp2Rewrite", {"Attr"}, true);
706+
builtins.absRewrite = declareBuiltin<ast::UserRewriteDecl>(
707+
"__builtin_absRewrite", {"Attr"}, true);
704708
builtins.mulConstraint = declareBuiltin<ast::UserConstraintDecl>(
705709
"__builtin_mulConstraint", {"lhs", "rhs"}, true);
706710
builtins.divConstraint = declareBuiltin<ast::UserConstraintDecl>(
@@ -715,6 +719,8 @@ void Parser::declareBuiltins() {
715719
"__builtin_log2Constraint", {"Attr"}, true);
716720
builtins.exp2Constraint = declareBuiltin<ast::UserConstraintDecl>(
717721
"__builtin_exp2Constraint", {"Attr"}, true);
722+
builtins.absConstraint = declareBuiltin<ast::UserConstraintDecl>(
723+
"__builtin_absConstraint", {"Attr"}, true);
718724
}
719725

720726
FailureOr<ast::Module *> Parser::parseModule() {
@@ -2030,15 +2036,15 @@ FailureOr<ast::Expr *> Parser::parseAddSubExpr() {
20302036
}
20312037

20322038
FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
2033-
auto lhs = parseExp2Log2Expr();
2039+
auto lhs = parseExp2Log2AbsExpr();
20342040
if (failed(lhs))
20352041
return failure();
20362042

20372043
for (;;) {
20382044
switch (curToken.getKind()) {
20392045
case Token::mul: {
20402046
consumeToken();
2041-
auto rhs = parseExp2Log2Expr();
2047+
auto rhs = parseExp2Log2AbsExpr();
20422048
if (failed(rhs))
20432049
return failure();
20442050
SmallVector<ast::Expr *> args{*lhs, *rhs};
@@ -2058,7 +2064,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
20582064
}
20592065
case Token::div: {
20602066
consumeToken();
2061-
auto rhs = parseExp2Log2Expr();
2067+
auto rhs = parseExp2Log2AbsExpr();
20622068
if (failed(rhs))
20632069
return failure();
20642070
SmallVector<ast::Expr *> args{*lhs, *rhs};
@@ -2078,7 +2084,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
20782084
}
20792085
case Token::mod: {
20802086
consumeToken();
2081-
auto rhs = parseExp2Log2Expr();
2087+
auto rhs = parseExp2Log2AbsExpr();
20822088
if (failed(rhs))
20832089
return failure();
20842090
SmallVector<ast::Expr *> args{*lhs, *rhs};
@@ -2100,7 +2106,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
21002106
}
21012107
}
21022108

2103-
FailureOr<ast::Expr *> Parser::parseExp2Log2Expr() {
2109+
FailureOr<ast::Expr *> Parser::parseExp2Log2AbsExpr() {
21042110
FailureOr<ast::Expr *> expr = nullptr;
21052111

21062112
switch (curToken.getKind()) {
@@ -2144,6 +2150,26 @@ FailureOr<ast::Expr *> Parser::parseExp2Log2Expr() {
21442150
: createBuiltinCall(curToken.getLoc(), builtins.exp2Constraint,
21452151
{*expr});
21462152
}
2153+
case Token::abs: {
2154+
consumeToken();
2155+
consumeToken(Token::l_paren);
2156+
expr = parseAddSubExpr();
2157+
if (failed(expr))
2158+
return failure();
2159+
2160+
// Check if it is in rewrite section but not in the let statement
2161+
bool inRewriteSection = parserContext == ParserContext::Rewrite;
2162+
if (inRewriteSection && nativeOperatorContext != NativeOperatorContext::Let)
2163+
return emitError("cannot evaluate abs operator in rewrite section. "
2164+
"Assign to a variable with `let`");
2165+
2166+
consumeToken(Token::r_paren);
2167+
return inRewriteSection
2168+
? createBuiltinCall(curToken.getLoc(), builtins.absRewrite,
2169+
{*expr})
2170+
: createBuiltinCall(curToken.getLoc(), builtins.absConstraint,
2171+
{*expr});
2172+
}
21472173
default:
21482174
return parseLogicalNotExpr();
21492175
}

mlir/test/mlir-pdll-lsp-server/completion.test

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@
208208
// CHECK-NEXT: "kind": 8,
209209
// CHECK-NEXT: "label": "__builtin_log2Constraint",
210210
// CHECK-NEXT: "sortText": "2___builtin_log2Constraint"
211+
// CHECK-NEXT: },
212+
// CHECK-NEXT: {
213+
// CHECK-NEXT: "detail": "(Attr: Attr) -> Attr",
214+
// CHECK-NEXT: "kind": 8,
215+
// CHECK-NEXT: "label": "__builtin_absConstraint",
216+
// CHECK-NEXT: "sortText": "2___builtin_absConstraint"
211217
// CHECK-NEXT: }
212218
// CHECK-NEXT: ]
213219
// CHECK-NEXT: }

mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ Pattern TestAdd {
309309
// CHECK: apply_native_constraint "__builtin_modConstraint"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute
310310
// CHECK: apply_native_constraint "__builtin_log2Constraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
311311
// CHECK: apply_native_constraint "__builtin_exp2Constraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
312+
// CHECK: apply_native_constraint "__builtin_absConstraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
312313

313314
Pattern TestOperatorsNotInRewriteSection {
314315
let a : Attr = attr<"4 : i32">;
@@ -320,6 +321,7 @@ Pattern TestOperatorsNotInRewriteSection {
320321
let modConstraint : Attr = a % b;
321322
let log2Constraint : Attr = log2(a);
322323
let exp2Constraint : Attr = exp2(a);
324+
let absConstraint : Attr = abs(a);
323325
replace op<test.simple> with op<test.success>;
324326
}
325327

@@ -336,6 +338,7 @@ Pattern TestOperatorsNotInRewriteSection {
336338
// CHECK: apply_native_rewrite "__builtin_modRewrite"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute
337339
// CHECK: apply_native_rewrite "__builtin_log2Rewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
338340
// CHECK: apply_native_rewrite "__builtin_exp2Rewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
341+
// CHECK: apply_native_rewrite "__builtin_absRewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
339342
Pattern TestOperatorsInRewriteSection {
340343
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
341344
rewrite root with {
@@ -348,6 +351,7 @@ Pattern TestOperatorsInRewriteSection {
348351
let modRewrite : Attr = a % b;
349352
let log2Rewrite : Attr = log2(a);
350353
let exp2Rewrite : Attr = exp2(a);
354+
let absRewrite : Attr = abs(a);
351355
erase root;
352356
};
353357
}

mlir/test/mlir-pdll/Parser/expr-failure.pdll

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,17 @@ Pattern {
581581

582582
// -----
583583

584+
Pattern {
585+
// CHECK: cannot evaluate abs operator in rewrite section. Assign to a variable with `let`
586+
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
587+
rewrite root with {
588+
abs(attr<"-4 : si32">);
589+
erase root;
590+
};
591+
}
592+
593+
// -----
594+
584595
// check llvm::saveAndRestore works
585596
Pattern {
586597
// CHECK: cannot evaluate exp2 operator in rewrite section. Assign to a variable with `let`
@@ -590,4 +601,30 @@ Pattern {
590601
exp2(attr<"4 : i32">);
591602
erase root;
592603
};
604+
}
605+
606+
// -----
607+
608+
// check llvm::saveAndRestore works
609+
Pattern {
610+
// CHECK: cannot evaluate log2 operator in rewrite section. Assign to a variable with `let`
611+
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
612+
rewrite root with {
613+
let a : Attr = attr<"4 : i32"> + attr<"5 : i32">;
614+
log2(attr<"4 : i32">);
615+
erase root;
616+
};
617+
}
618+
619+
// -----
620+
621+
// check llvm::saveAndRestore works
622+
Pattern {
623+
// CHECK: cannot evaluate abs operator in rewrite section. Assign to a variable with `let`
624+
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
625+
rewrite root with {
626+
let a : Attr = attr<"4 : i32"> + attr<"5 : i32">;
627+
abs(attr<"4 : i32">);
628+
erase root;
629+
};
593630
}

0 commit comments

Comments
 (0)