Skip to content

Commit 746aa23

Browse files
authored
Merge pull request #173 from Xilinx/liangta.abs
PDLL native operator abs
2 parents 55b8b59 + 10a643c commit 746aa23

File tree

10 files changed

+231
-41
lines changed

10 files changed

+231
-41
lines changed

mlir/include/mlir/Dialect/PDL/IR/Builtins.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,16 @@ void registerBuiltins(PDLPatternModule &pdlPattern);
2828
namespace builtin {
2929
enum class BinaryOpKind {
3030
add,
31-
sub,
32-
mul,
3331
div,
3432
mod,
33+
mul,
34+
sub,
3535
};
3636

3737
enum class UnaryOpKind {
38-
log2,
38+
abs,
3939
exp2,
40+
log2,
4041
};
4142

4243
LogicalResult createDictionaryAttr(PatternRewriter &rewriter,
@@ -48,9 +49,6 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
4849
Attribute createArrayAttr(PatternRewriter &rewriter);
4950
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
5051
Attribute element);
51-
template <BinaryOpKind T>
52-
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
53-
llvm::ArrayRef<PDLValue> args);
5452
LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results,
5553
llvm::ArrayRef<PDLValue> args);
5654
LogicalResult div(PatternRewriter &rewriter, PDLResultList &results,
@@ -65,10 +63,8 @@ LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
6563
llvm::ArrayRef<PDLValue> args);
6664
LogicalResult exp2(PatternRewriter &rewriter, PDLResultList &results,
6765
llvm::ArrayRef<PDLValue> args);
68-
69-
template <BinaryOpKind T>
70-
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
71-
llvm::ArrayRef<PDLValue> args);
66+
LogicalResult abs(PatternRewriter &rewriter, PDLResultList &results,
67+
llvm::ArrayRef<PDLValue> args);
7268
} // namespace builtin
7369
} // namespace pdl
7470
} // namespace mlir

mlir/lib/Dialect/PDL/IR/Builtins.cpp

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
5959
}
6060

6161
template <UnaryOpKind T>
62-
LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
63-
ArrayRef<PDLValue> args) {
62+
LogicalResult static unaryOp(PatternRewriter &rewriter, PDLResultList &results,
63+
ArrayRef<PDLValue> args) {
6464
assert(args.size() == 1 && "Expected one operand for unary operation");
6565
auto operandAttr = args[0].cast<Attribute>();
6666

@@ -99,6 +99,27 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
9999
getIntegerAsAttr(APSInt(operandIntAttr.getValue(), false)));
100100
else
101101
results.push_back(getIntegerAsAttr(operandIntAttr.getAPSInt()));
102+
} else if constexpr (T == UnaryOpKind::abs) {
103+
if (integerType.isSigned()) {
104+
// check overflow
105+
if (operandIntAttr.getAPSInt() ==
106+
APSInt::getMinValue(integerType.getIntOrFloatBitWidth(), false))
107+
return failure();
108+
109+
results.push_back(rewriter.getIntegerAttr(
110+
integerType, operandIntAttr.getValue().abs()));
111+
return success();
112+
}
113+
if (integerType.isSignless()) {
114+
// Overflow should not be checked.
115+
// Otherwise the purpose of signless integer is meaningless.
116+
results.push_back(rewriter.getIntegerAttr(
117+
integerType, operandIntAttr.getValue().abs()));
118+
return success();
119+
}
120+
// If unsigned, do nothing
121+
results.push_back(operandIntAttr);
122+
return success();
102123
} else {
103124
llvm::llvm_unreachable_internal(
104125
"encountered an unsupported unary operator");
@@ -140,20 +161,26 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
140161
})
141162
.Default([](Type /*type*/) { return failure(); });
142163
} else if constexpr (T == UnaryOpKind::log2) {
143-
auto minF32 = APFloat::getSmallest(llvm::APFloat::IEEEsingle());
144-
145-
APFloat resultFloat((float)operandFloatAttr.getValue().getExactLog2());
164+
results.push_back(rewriter.getFloatAttr(
165+
operandFloatAttr.getType(),
166+
(double)operandFloatAttr.getValue().getExactLog2()));
167+
} else if constexpr (T == UnaryOpKind::abs) {
168+
auto resultVal = operandFloatAttr.getValue();
169+
resultVal.clearSign();
146170
results.push_back(
147-
rewriter.getFloatAttr(operandFloatAttr.getType(), resultFloat));
171+
rewriter.getFloatAttr(operandFloatAttr.getType(), resultVal));
172+
} else {
173+
llvm::llvm_unreachable_internal(
174+
"encountered an unsupported unary operator");
148175
}
149176
return success();
150177
}
151178
return failure();
152179
}
153180

154181
template <BinaryOpKind T>
155-
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
156-
llvm::ArrayRef<PDLValue> args) {
182+
LogicalResult static binaryOp(PatternRewriter &rewriter, PDLResultList &results,
183+
llvm::ArrayRef<PDLValue> args) {
157184
assert(args.size() == 2 && "Expected two operands for binary operation");
158185
auto lhsAttr = args[0].cast<Attribute>();
159186
auto rhsAttr = args[1].cast<Attribute>();
@@ -294,6 +321,10 @@ LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
294321
llvm::ArrayRef<PDLValue> args) {
295322
return unaryOp<UnaryOpKind::log2>(rewriter, results, args);
296323
}
324+
LogicalResult abs(PatternRewriter &rewriter, PDLResultList &results,
325+
llvm::ArrayRef<PDLValue> args) {
326+
return unaryOp<UnaryOpKind::abs>(rewriter, results, args);
327+
}
297328
} // namespace builtin
298329

299330
void registerBuiltins(PDLPatternModule &pdlPattern) {
@@ -319,7 +350,7 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
319350
pdlPattern.registerRewriteFunction("__builtin_subRewrite", sub);
320351
pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2);
321352
pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2);
322-
353+
pdlPattern.registerRewriteFunction("__builtin_absRewrite", abs);
323354
pdlPattern.registerConstraintFunctionWithResults("__builtin_mulConstraint",
324355
mul);
325356
pdlPattern.registerConstraintFunctionWithResults("__builtin_divConstraint",
@@ -334,5 +365,7 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
334365
log2);
335366
pdlPattern.registerConstraintFunctionWithResults("__builtin_exp2Constraint",
336367
exp2);
368+
pdlPattern.registerConstraintFunctionWithResults("__builtin_absConstraint",
369+
abs);
337370
}
338371
} // namespace mlir::pdl

mlir/lib/Tools/PDLL/Parser/Lexer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
377377
.Case("_", Token::underscore)
378378
.Case("log2", Token::log2)
379379
.Case("exp2", Token::exp2)
380+
.Case("abs", Token::abs)
380381
.Default(Token::identifier);
381382
return Token(kind, str);
382383
}

mlir/lib/Tools/PDLL/Parser/Lexer.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,18 @@ class Token {
8080
equal,
8181
equal_arrow,
8282
semicolon,
83-
/// Paired punctuation.
84-
mul,
83+
84+
/// Arithmetic.
85+
abs,
86+
add,
8587
div,
88+
exp2,
89+
log2,
8690
mod,
87-
add,
91+
mul,
8892
sub,
89-
log2,
90-
exp2,
93+
94+
/// Paired punctuation.
9195
less,
9296
greater,
9397
l_brace,

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ class Parser {
334334
FailureOr<ast::Expr *> parseLogicalAndExpr();
335335
FailureOr<ast::Expr *> parseEqualityExpr();
336336
FailureOr<ast::Expr *> parseRelationExpr();
337-
FailureOr<ast::Expr *> parseExp2Log2Expr();
337+
FailureOr<ast::Expr *> parseExp2Log2AbsExpr();
338338
FailureOr<ast::Expr *> parseAddSubExpr();
339339
FailureOr<ast::Expr *> parseMulDivModExpr();
340340
FailureOr<ast::Expr *> parseLogicalNotExpr();
@@ -624,13 +624,15 @@ class Parser {
624624
ast::UserRewriteDecl *subRewrite;
625625
ast::UserRewriteDecl *log2Rewrite;
626626
ast::UserRewriteDecl *exp2Rewrite;
627+
ast::UserRewriteDecl *absRewrite;
627628
ast::UserConstraintDecl *mulConstraint;
628629
ast::UserConstraintDecl *divConstraint;
629630
ast::UserConstraintDecl *modConstraint;
630631
ast::UserConstraintDecl *addConstraint;
631632
ast::UserConstraintDecl *subConstraint;
632633
ast::UserConstraintDecl *log2Constraint;
633634
ast::UserConstraintDecl *exp2Constraint;
635+
ast::UserConstraintDecl *absConstraint;
634636
} builtins{};
635637
};
636638
} // namespace
@@ -701,6 +703,8 @@ void Parser::declareBuiltins() {
701703
"__builtin_log2Rewrite", {"Attr"}, true);
702704
builtins.exp2Rewrite = declareBuiltin<ast::UserRewriteDecl>(
703705
"__builtin_exp2Rewrite", {"Attr"}, true);
706+
builtins.absRewrite = declareBuiltin<ast::UserRewriteDecl>(
707+
"__builtin_absRewrite", {"Attr"}, true);
704708
builtins.mulConstraint = declareBuiltin<ast::UserConstraintDecl>(
705709
"__builtin_mulConstraint", {"lhs", "rhs"}, true);
706710
builtins.divConstraint = declareBuiltin<ast::UserConstraintDecl>(
@@ -715,6 +719,8 @@ void Parser::declareBuiltins() {
715719
"__builtin_log2Constraint", {"Attr"}, true);
716720
builtins.exp2Constraint = declareBuiltin<ast::UserConstraintDecl>(
717721
"__builtin_exp2Constraint", {"Attr"}, true);
722+
builtins.absConstraint = declareBuiltin<ast::UserConstraintDecl>(
723+
"__builtin_absConstraint", {"Attr"}, true);
718724
}
719725

720726
FailureOr<ast::Module *> Parser::parseModule() {
@@ -2030,15 +2036,15 @@ FailureOr<ast::Expr *> Parser::parseAddSubExpr() {
20302036
}
20312037

20322038
FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
2033-
auto lhs = parseExp2Log2Expr();
2039+
auto lhs = parseExp2Log2AbsExpr();
20342040
if (failed(lhs))
20352041
return failure();
20362042

20372043
for (;;) {
20382044
switch (curToken.getKind()) {
20392045
case Token::mul: {
20402046
consumeToken();
2041-
auto rhs = parseExp2Log2Expr();
2047+
auto rhs = parseExp2Log2AbsExpr();
20422048
if (failed(rhs))
20432049
return failure();
20442050
SmallVector<ast::Expr *> args{*lhs, *rhs};
@@ -2058,7 +2064,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
20582064
}
20592065
case Token::div: {
20602066
consumeToken();
2061-
auto rhs = parseExp2Log2Expr();
2067+
auto rhs = parseExp2Log2AbsExpr();
20622068
if (failed(rhs))
20632069
return failure();
20642070
SmallVector<ast::Expr *> args{*lhs, *rhs};
@@ -2078,7 +2084,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
20782084
}
20792085
case Token::mod: {
20802086
consumeToken();
2081-
auto rhs = parseExp2Log2Expr();
2087+
auto rhs = parseExp2Log2AbsExpr();
20822088
if (failed(rhs))
20832089
return failure();
20842090
SmallVector<ast::Expr *> args{*lhs, *rhs};
@@ -2100,7 +2106,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
21002106
}
21012107
}
21022108

2103-
FailureOr<ast::Expr *> Parser::parseExp2Log2Expr() {
2109+
FailureOr<ast::Expr *> Parser::parseExp2Log2AbsExpr() {
21042110
FailureOr<ast::Expr *> expr = nullptr;
21052111

21062112
switch (curToken.getKind()) {
@@ -2144,6 +2150,26 @@ FailureOr<ast::Expr *> Parser::parseExp2Log2Expr() {
21442150
: createBuiltinCall(curToken.getLoc(), builtins.exp2Constraint,
21452151
{*expr});
21462152
}
2153+
case Token::abs: {
2154+
consumeToken();
2155+
consumeToken(Token::l_paren);
2156+
expr = parseAddSubExpr();
2157+
if (failed(expr))
2158+
return failure();
2159+
2160+
// Check if it is in rewrite section but not in the let statement
2161+
bool inRewriteSection = parserContext == ParserContext::Rewrite;
2162+
if (inRewriteSection && nativeOperatorContext != NativeOperatorContext::Let)
2163+
return emitError("cannot evaluate abs operator in rewrite section. "
2164+
"Assign to a variable with `let`");
2165+
2166+
consumeToken(Token::r_paren);
2167+
return inRewriteSection
2168+
? createBuiltinCall(curToken.getLoc(), builtins.absRewrite,
2169+
{*expr})
2170+
: createBuiltinCall(curToken.getLoc(), builtins.absConstraint,
2171+
{*expr});
2172+
}
21472173
default:
21482174
return parseLogicalNotExpr();
21492175
}

mlir/test/mlir-pdll-lsp-server/completion.test

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@
208208
// CHECK-NEXT: "kind": 8,
209209
// CHECK-NEXT: "label": "__builtin_log2Constraint",
210210
// CHECK-NEXT: "sortText": "2___builtin_log2Constraint"
211+
// CHECK-NEXT: },
212+
// CHECK-NEXT: {
213+
// CHECK-NEXT: "detail": "(Attr: Attr) -> Attr",
214+
// CHECK-NEXT: "kind": 8,
215+
// CHECK-NEXT: "label": "__builtin_absConstraint",
216+
// CHECK-NEXT: "sortText": "2___builtin_absConstraint"
211217
// CHECK-NEXT: }
212218
// CHECK-NEXT: ]
213219
// CHECK-NEXT: }

mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ Pattern TestAdd {
309309
// CHECK: apply_native_constraint "__builtin_modConstraint"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute
310310
// CHECK: apply_native_constraint "__builtin_log2Constraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
311311
// CHECK: apply_native_constraint "__builtin_exp2Constraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
312+
// CHECK: apply_native_constraint "__builtin_absConstraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
312313

313314
Pattern TestOperatorsNotInRewriteSection {
314315
let a : Attr = attr<"4 : i32">;
@@ -320,6 +321,7 @@ Pattern TestOperatorsNotInRewriteSection {
320321
let modConstraint : Attr = a % b;
321322
let log2Constraint : Attr = log2(a);
322323
let exp2Constraint : Attr = exp2(a);
324+
let absConstraint : Attr = abs(a);
323325
replace op<test.simple> with op<test.success>;
324326
}
325327

@@ -336,6 +338,7 @@ Pattern TestOperatorsNotInRewriteSection {
336338
// CHECK: apply_native_rewrite "__builtin_modRewrite"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute
337339
// CHECK: apply_native_rewrite "__builtin_log2Rewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
338340
// CHECK: apply_native_rewrite "__builtin_exp2Rewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
341+
// CHECK: apply_native_rewrite "__builtin_absRewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
339342
Pattern TestOperatorsInRewriteSection {
340343
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
341344
rewrite root with {
@@ -348,6 +351,7 @@ Pattern TestOperatorsInRewriteSection {
348351
let modRewrite : Attr = a % b;
349352
let log2Rewrite : Attr = log2(a);
350353
let exp2Rewrite : Attr = exp2(a);
354+
let absRewrite : Attr = abs(a);
351355
erase root;
352356
};
353357
}

mlir/test/mlir-pdll/Parser/expr-failure.pdll

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,17 @@ Pattern {
581581

582582
// -----
583583

584+
Pattern {
585+
// CHECK: cannot evaluate abs operator in rewrite section. Assign to a variable with `let`
586+
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
587+
rewrite root with {
588+
abs(attr<"-4 : si32">);
589+
erase root;
590+
};
591+
}
592+
593+
// -----
594+
584595
// check llvm::saveAndRestore works
585596
Pattern {
586597
// CHECK: cannot evaluate exp2 operator in rewrite section. Assign to a variable with `let`
@@ -590,4 +601,30 @@ Pattern {
590601
exp2(attr<"4 : i32">);
591602
erase root;
592603
};
604+
}
605+
606+
// -----
607+
608+
// check llvm::saveAndRestore works
609+
Pattern {
610+
// CHECK: cannot evaluate log2 operator in rewrite section. Assign to a variable with `let`
611+
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
612+
rewrite root with {
613+
let a : Attr = attr<"4 : i32"> + attr<"5 : i32">;
614+
log2(attr<"4 : i32">);
615+
erase root;
616+
};
617+
}
618+
619+
// -----
620+
621+
// check llvm::saveAndRestore works
622+
Pattern {
623+
// CHECK: cannot evaluate abs operator in rewrite section. Assign to a variable with `let`
624+
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
625+
rewrite root with {
626+
let a : Attr = attr<"4 : i32"> + attr<"5 : i32">;
627+
abs(attr<"4 : i32">);
628+
erase root;
629+
};
593630
}

0 commit comments

Comments
 (0)