Skip to content

Commit 8cb5a88

Browse files
committed
fix: float type for exp2 operator & Parser::NativeOperatorContext
1 parent cf79307 commit 8cb5a88

File tree

4 files changed

+98
-44
lines changed

4 files changed

+98
-44
lines changed

mlir/lib/Dialect/PDL/IR/Builtins.cpp

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <llvm/ADT/APInt.h>
55
#include <llvm/ADT/APSInt.h>
66
#include <llvm/ADT/ArrayRef.h>
7+
#include <llvm/ADT/TypeSwitch.h>
78
#include <llvm/Support/Casting.h>
89
#include <llvm/Support/ErrorHandling.h>
910
#include <mlir/Dialect/PDL/IR/Builtins.h>
@@ -102,21 +103,49 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
102103
}
103104

104105
if (auto operandFloatAttr = dyn_cast_or_null<FloatAttr>(operandAttr)) {
105-
auto floatType = operandFloatAttr.getType();
106+
// auto floatType = operandFloatAttr.getType();
106107

107108
if constexpr (T == UnaryOpKind::exp2) {
108-
auto maxVal = APFloat::getLargest(llvm::APFloat::IEEEhalf());
109-
auto minVal = APFloat::getSmallest(llvm::APFloat::IEEEhalf());
110-
111-
APFloat resultFloat(
112-
std::exp(operandFloatAttr.getValue().convertToFloat()));
113-
// check overflow
114-
if (resultFloat < minVal || resultFloat > maxVal)
115-
return failure();
116-
results.push_back(rewriter.getFloatAttr(floatType, resultFloat));
109+
// auto maxVal = APFloat::getLargest(llvm::APFloat::IEEEhalf());
110+
// auto minVal = APFloat::getSmallest(llvm::APFloat::IEEEhalf());
111+
112+
auto type = operandFloatAttr.getType();
113+
114+
return TypeSwitch<Type, LogicalResult>(type)
115+
.template Case<Float64Type>([&results, &rewriter,
116+
&operandFloatAttr](auto floatType) {
117+
APFloat resultAPFloat(
118+
std::exp2(operandFloatAttr.getValue().convertToDouble()));
119+
120+
// check overflow
121+
if (!resultAPFloat.isNormal())
122+
return failure();
123+
124+
results.push_back(rewriter.getFloatAttr(floatType, resultAPFloat));
125+
return success();
126+
})
127+
.template Case<Float32Type, Float16Type, BFloat16Type>(
128+
[&results, &rewriter, &operandFloatAttr](auto floatType) {
129+
APFloat resultAPFloat(
130+
std::exp2(operandFloatAttr.getValue().convertToFloat()));
131+
132+
// check overflow and underflow
133+
// If overflow happens, resultAPFloat is inf
134+
// If underflow happens, resultAPFloat is 0
135+
if (!resultAPFloat.isNormal())
136+
return failure();
137+
138+
results.push_back(
139+
rewriter.getFloatAttr(floatType, resultAPFloat));
140+
return success();
141+
})
142+
.Default([](Type /*type*/) { return failure(); });
117143
} else if constexpr (T == UnaryOpKind::log2) {
144+
auto minF32 = APFloat::getSmallest(llvm::APFloat::IEEEsingle());
145+
118146
APFloat resultFloat((float)operandFloatAttr.getValue().getExactLog2());
119-
results.push_back(rewriter.getFloatAttr(floatType, resultFloat));
147+
results.push_back(
148+
rewriter.getFloatAttr(operandFloatAttr.getType(), resultFloat));
120149
}
121150
return success();
122151
}

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1980,9 +1980,9 @@ FailureOr<ast::Expr *> Parser::parseAddSubExpr() {
19801980
// Check if it is in rewrite section but not in the let statement
19811981
bool inRewriteSection = parserContext == ParserContext::Rewrite;
19821982
if (inRewriteSection &&
1983-
nativeOperatorContext == NativeOperatorContext::Generic)
1984-
return emitError(
1985-
"nodiscard rule for add operator is applied in rewrite section");
1983+
nativeOperatorContext != NativeOperatorContext::Let)
1984+
return emitError("cannot evaluate add operator in rewrite section. "
1985+
"Assign to a variable with `let`");
19861986

19871987
lhs = inRewriteSection ? createBuiltinCall(curToken.getLoc(),
19881988
builtins.addRewrite, args)
@@ -2000,9 +2000,9 @@ FailureOr<ast::Expr *> Parser::parseAddSubExpr() {
20002000
// Check if it is in rewrite section but not in the let statement
20012001
bool inRewriteSection = parserContext == ParserContext::Rewrite;
20022002
if (inRewriteSection &&
2003-
nativeOperatorContext == NativeOperatorContext::Generic)
2004-
return emitError(
2005-
"nodiscard rule for sub operator is applied in rewrite section");
2003+
nativeOperatorContext != NativeOperatorContext::Let)
2004+
return emitError("cannot evaluate sub operator in rewrite section. "
2005+
"Assign to a variable with `let`");
20062006

20072007
lhs = inRewriteSection ? createBuiltinCall(curToken.getLoc(),
20082008
builtins.subRewrite, args)
@@ -2033,9 +2033,9 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
20332033
// Check if it is in rewrite section but not in the let statement
20342034
bool inRewriteSection = parserContext == ParserContext::Rewrite;
20352035
if (inRewriteSection &&
2036-
nativeOperatorContext == NativeOperatorContext::Generic)
2037-
return emitError(
2038-
"nodiscard rule for mul operator is applied in rewrite section");
2036+
nativeOperatorContext != NativeOperatorContext::Let)
2037+
return emitError("cannot evaluate mul operator in rewrite section. "
2038+
"Assign to a variable with `let`");
20392039

20402040
lhs = inRewriteSection ? createBuiltinCall(curToken.getLoc(),
20412041
builtins.mulRewrite, args)
@@ -2053,9 +2053,9 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
20532053
// Check if it is in rewrite section but not in the let statement
20542054
bool inRewriteSection = parserContext == ParserContext::Rewrite;
20552055
if (inRewriteSection &&
2056-
nativeOperatorContext == NativeOperatorContext::Generic)
2057-
return emitError(
2058-
"nodiscard rule for div operator is applied in rewrite section");
2056+
nativeOperatorContext != NativeOperatorContext::Let)
2057+
return emitError("cannot evaluate div operator in rewrite section. "
2058+
"Assign to a variable with `let`");
20592059

20602060
lhs = inRewriteSection ? createBuiltinCall(curToken.getLoc(),
20612061
builtins.divRewrite, args)
@@ -2071,9 +2071,9 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
20712071
SmallVector<ast::Expr *> args{*lhs, *rhs};
20722072
bool inRewriteSection = parserContext == ParserContext::Rewrite;
20732073
if (inRewriteSection &&
2074-
nativeOperatorContext == NativeOperatorContext::Generic)
2075-
return emitError(
2076-
"nodiscard rule for mod operator is applied in rewrite section");
2074+
nativeOperatorContext != NativeOperatorContext::Let)
2075+
return emitError("cannot evaluate mod operator in rewrite section. "
2076+
"Assign to a variable with `let`");
20772077

20782078
lhs = inRewriteSection ? createBuiltinCall(curToken.getLoc(),
20792079
builtins.modRewrite, args)
@@ -2100,10 +2100,9 @@ FailureOr<ast::Expr *> Parser::parseExp2Log2Expr() {
21002100

21012101
// Check if it is in rewrite section but not in the let statement
21022102
bool inRewriteSection = parserContext == ParserContext::Rewrite;
2103-
if (inRewriteSection &&
2104-
nativeOperatorContext == NativeOperatorContext::Generic)
2105-
return emitError(
2106-
"nodiscard rule for log2 operator is applied in rewrite section");
2103+
if (inRewriteSection && nativeOperatorContext != NativeOperatorContext::Let)
2104+
return emitError("cannot evaluate log2 operator in rewrite section. "
2105+
"Assign to a variable with `let`");
21072106

21082107
consumeToken(Token::r_paren);
21092108
return inRewriteSection
@@ -2121,10 +2120,9 @@ FailureOr<ast::Expr *> Parser::parseExp2Log2Expr() {
21212120

21222121
// Check if it is in rewrite section but not in the let statement
21232122
bool inRewriteSection = parserContext == ParserContext::Rewrite;
2124-
if (inRewriteSection &&
2125-
nativeOperatorContext == NativeOperatorContext::Generic)
2126-
return emitError(
2127-
"nodiscard rule for exp2 operator is applied in rewrite section");
2123+
if (inRewriteSection && nativeOperatorContext != NativeOperatorContext::Let)
2124+
return emitError("cannot evaluate exp2 operator in rewrite section. "
2125+
"Assign to a variable with `let`");
21282126

21292127
consumeToken(Token::r_paren);
21302128
return inRewriteSection
@@ -2762,8 +2760,6 @@ FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
27622760
break;
27632761
case Token::kw_let: {
27642762
stmt = parseLetStmt();
2765-
llvm::SaveAndRestore saveCtx(nativeOperatorContext,
2766-
NativeOperatorContext::Generic);
27672763
break;
27682764
}
27692765
case Token::kw_replace:

mlir/test/mlir-pdll/Parser/expr-failure.pdll

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ Pattern {
505505
// -----
506506

507507
Pattern {
508-
// CHECK: nodiscard rule for add operator is applied in rewrite section
508+
// CHECK: cannot evaluate add operator in rewrite section. Assign to a variable with `let`
509509
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
510510
rewrite root with {
511511
attr<"4 : i32"> + attr<"5 : i32">;
@@ -516,7 +516,7 @@ Pattern {
516516
// -----
517517

518518
Pattern {
519-
// CHECK: nodiscard rule for sub operator is applied in rewrite section
519+
// CHECK: cannot evaluate sub operator in rewrite section. Assign to a variable with `let`
520520
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
521521
rewrite root with {
522522
attr<"4 : i32"> - attr<"5 : i32">;
@@ -527,7 +527,7 @@ Pattern {
527527
// -----
528528

529529
Pattern {
530-
// CHECK: nodiscard rule for mul operator is applied in rewrite section
530+
// CHECK: cannot evaluate mul operator in rewrite section. Assign to a variable with `let`
531531
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
532532
rewrite root with {
533533
attr<"4 : i32"> * attr<"5 : i32">;
@@ -538,7 +538,7 @@ Pattern {
538538
// -----
539539

540540
Pattern {
541-
// CHECK: nodiscard rule for div operator is applied in rewrite section
541+
// CHECK: cannot evaluate div operator in rewrite section. Assign to a variable with `let`
542542
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
543543
rewrite root with {
544544
attr<"4 : i32"> / attr<"5 : i32">;
@@ -549,7 +549,7 @@ Pattern {
549549
// -----
550550

551551
Pattern {
552-
// CHECK: nodiscard rule for mod operator is applied in rewrite section
552+
// CHECK: cannot evaluate mod operator in rewrite section. Assign to a variable with `let`
553553
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
554554
rewrite root with {
555555
attr<"4 : i32"> % attr<"5 : i32">;
@@ -560,7 +560,7 @@ Pattern {
560560
// -----
561561

562562
Pattern {
563-
// CHECK: nodiscard rule for log2 operator is applied in rewrite section
563+
// CHECK: cannot evaluate log2 operator in rewrite section. Assign to a variable with `let`
564564
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
565565
rewrite root with {
566566
log2(attr<"4 : i32">);
@@ -571,10 +571,23 @@ Pattern {
571571
// -----
572572

573573
Pattern {
574-
// CHECK: nodiscard rule for exp2 operator is applied in rewrite section
574+
// CHECK: cannot evaluate exp2 operator in rewrite section. Assign to a variable with `let`
575575
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
576576
rewrite root with {
577577
exp2(attr<"4 : i32">);
578578
erase root;
579579
};
580+
}
581+
582+
// -----
583+
584+
// check llvm::saveAndRestore works
585+
Pattern {
586+
// CHECK: cannot evaluate exp2 operator in rewrite section. Assign to a variable with `let`
587+
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
588+
rewrite root with {
589+
let a : Attr = attr<"4 : i32"> + attr<"5 : i32">;
590+
exp2(attr<"4 : i32">);
591+
erase root;
592+
};
580593
}

mlir/unittests/Dialect/PDL/BuiltinTest.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,10 +678,26 @@ TEST_F(BuiltinTest, exp2) {
678678
(uint64_t)2);
679679
}
680680

681-
// Check unsigned integer overflow
681+
// unsigned integer: overflow
682682
{
683683
TestPDLResultList results(1);
684684
EXPECT_TRUE(builtin::exp2(rewriter, results, {eightUInt8}).failed());
685685
}
686+
687+
auto hundredFortyF32 = rewriter.getF32FloatAttr(140.0);
688+
689+
// Float: overflow
690+
{
691+
TestPDLResultList results(1);
692+
EXPECT_TRUE(builtin::exp2(rewriter, results, {hundredFortyF32}).failed());
693+
}
694+
695+
// Float: underflow
696+
auto minusHundredFiftyF32 = rewriter.getF32FloatAttr(-150.0);
697+
{
698+
TestPDLResultList results(1);
699+
EXPECT_TRUE(
700+
builtin::exp2(rewriter, results, {minusHundredFiftyF32}).failed());
701+
}
686702
}
687703
} // namespace

0 commit comments

Comments
 (0)