Skip to content

Commit 5f808cd

Browse files
authored
Merge pull request #169 from Xilinx/liangta.float_mul_div
[PDLL] Emit warning for inexact result of floating point binary arithme…
2 parents cdc5e38 + 01f8072 commit 5f808cd

File tree

2 files changed

+31
-34
lines changed

2 files changed

+31
-34
lines changed

mlir/lib/Dialect/PDL/IR/Builtins.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,12 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
102102
} else {
103103
llvm::llvm_unreachable_internal(
104104
"encountered an unsupported unary operator");
105-
return failure();
106105
}
107106
return success();
108107
}
109108

110109
if (auto operandFloatAttr = dyn_cast_or_null<FloatAttr>(operandAttr)) {
111-
// auto floatType = operandFloatAttr.getType();
112-
113110
if constexpr (T == UnaryOpKind::exp2) {
114-
// auto maxVal = APFloat::getLargest(llvm::APFloat::IEEEhalf());
115-
// auto minVal = APFloat::getSmallest(llvm::APFloat::IEEEhalf());
116-
117111
auto type = operandFloatAttr.getType();
118112

119113
return TypeSwitch<Type, LogicalResult>(type)
@@ -166,9 +160,8 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
166160

167161
if (auto lhsIntAttr = dyn_cast_or_null<IntegerAttr>(lhsAttr)) {
168162
auto rhsIntAttr = dyn_cast_or_null<IntegerAttr>(rhsAttr);
169-
if (!rhsIntAttr || lhsIntAttr.getType() != rhsIntAttr.getType()) {
163+
if (!rhsIntAttr || lhsIntAttr.getType() != rhsIntAttr.getType())
170164
return failure();
171-
}
172165

173166
auto integerType = lhsIntAttr.getType();
174167
APInt resultAPInt;
@@ -211,7 +204,8 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
211204
resultAPInt = lhsIntAttr.getValue().srem(rhsIntAttr.getValue());
212205
}
213206
} else {
214-
assert(false && "Unsupported binary operator");
207+
llvm::llvm_unreachable_internal(
208+
"encounter an unsupported binary operator.");
215209
}
216210

217211
if (isOverflow)
@@ -223,9 +217,8 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
223217

224218
if (auto lhsFloatAttr = dyn_cast_or_null<FloatAttr>(lhsAttr)) {
225219
auto rhsFloatAttr = dyn_cast_or_null<FloatAttr>(rhsAttr);
226-
if (!rhsFloatAttr || lhsFloatAttr.getType() != rhsFloatAttr.getType()) {
220+
if (!rhsFloatAttr || lhsFloatAttr.getType() != rhsFloatAttr.getType())
227221
return failure();
228-
}
229222

230223
APFloat lhsVal = lhsFloatAttr.getValue();
231224
APFloat rhsVal = rhsFloatAttr.getValue();
@@ -248,13 +241,19 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
248241
} else if constexpr (T == BinaryOpKind::mod) {
249242
operationStatus = resultVal.mod(rhsVal);
250243
} else {
251-
assert(false && "Unsupported binary operator");
244+
llvm::llvm_unreachable_internal(
245+
"encounter an unsupported binary operator.");
252246
}
253247

254248
if (operationStatus != APFloat::opOK) {
255-
return failure();
256-
}
249+
if (operationStatus != APFloat::opInexact)
250+
return failure();
257251

252+
emitWarning(rewriter.getUnknownLoc())
253+
<< "Binary arithmetic operation between " << lhsVal.convertToFloat()
254+
<< " and " << rhsVal.convertToFloat()
255+
<< " produced an inexact result";
256+
}
258257
results.push_back(rewriter.getFloatAttr(floatType, resultVal));
259258
return success();
260259
}

mlir/unittests/Dialect/PDL/BuiltinTest.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,19 @@ TEST_F(BuiltinTest, div) {
253253
"Divide by zero?");
254254
}
255255

256-
auto smallF16 = rewriter.getF16FloatAttr(0.0001);
256+
auto BF16Type = rewriter.getBF16Type();
257+
auto oneBF16 = rewriter.getFloatAttr(BF16Type, 1.0);
258+
auto nineBF16 = rewriter.getFloatAttr(BF16Type, 9.0);
259+
260+
// float: inexact result
261+
// return success(), but warning is emitted.
262+
{
263+
TestPDLResultList results(1);
264+
EXPECT_TRUE(
265+
builtin::div(rewriter, results, {oneBF16, nineBF16}).succeeded());
266+
}
267+
257268
auto twoF16 = rewriter.getF16FloatAttr(2.0);
258-
auto maxValF16 = rewriter.getF16FloatAttr(
259-
llvm::APFloat::getLargest(llvm::APFloat::IEEEhalf()).convertToFloat());
260269
auto zeroF16 = rewriter.getF16FloatAttr(0.0);
261270
auto negzeroF16 = rewriter.getF16FloatAttr(-0.0);
262271

@@ -272,13 +281,6 @@ TEST_F(BuiltinTest, div) {
272281
EXPECT_TRUE(builtin::div(rewriter, results, {twoF16, negzeroF16}).failed());
273282
}
274283

275-
// float: overflow
276-
{
277-
TestPDLResultList results(1);
278-
EXPECT_TRUE(
279-
builtin::div(rewriter, results, {maxValF16, smallF16}).failed());
280-
}
281-
282284
// float: correctness
283285
{
284286
TestPDLResultList results(1);
@@ -456,19 +458,17 @@ TEST_F(BuiltinTest, add) {
456458
EXPECT_TRUE(builtin::add(rewriter, results, {oneI16, oneI32}).failed());
457459
}
458460

459-
auto oneF16 = rewriter.getF16FloatAttr(1.0);
460461
auto oneF32 = rewriter.getF32FloatAttr(1.0);
461462
auto zeroF32 = rewriter.getF32FloatAttr(0.0);
462463
auto negzeroF32 = rewriter.getF32FloatAttr(-0.0);
463464
auto zeroF64 = rewriter.getF64FloatAttr(0.0);
464-
465-
auto maxValF16 = rewriter.getF16FloatAttr(
466-
llvm::APFloat::getLargest(llvm::APFloat::IEEEhalf()).convertToFloat());
465+
auto overflowF16 = rewriter.getF16FloatAttr(32768);
467466

468467
// float: overflow
469468
{
470469
TestPDLResultList results(1);
471-
EXPECT_TRUE(builtin::add(rewriter, results, {oneF16, maxValF16}).failed());
470+
EXPECT_TRUE(
471+
builtin::add(rewriter, results, {overflowF16, overflowF16}).failed());
472472
}
473473

474474
// float: correctness
@@ -553,19 +553,17 @@ TEST_F(BuiltinTest, sub) {
553553
EXPECT_TRUE(builtin::sub(rewriter, results, {oneI16, oneI32}).failed());
554554
}
555555

556-
auto oneF16 = rewriter.getF16FloatAttr(1.0);
556+
auto oneF16 = rewriter.getF16FloatAttr(100.0);
557557
auto oneF32 = rewriter.getF32FloatAttr(1.0);
558558
auto zeroF32 = rewriter.getF32FloatAttr(0.0);
559559
auto negzeroF32 = rewriter.getF32FloatAttr(-0.0);
560560
auto zeroF64 = rewriter.getF64FloatAttr(0.0);
561-
562-
auto maxValF16 = rewriter.getF16FloatAttr(
563-
llvm::APFloat::getLargest(llvm::APFloat::IEEEhalf()).convertToFloat());
561+
auto minValF16 = rewriter.getF16FloatAttr(-65504);
564562

565563
// float: overflow
566564
{
567565
TestPDLResultList results(1);
568-
EXPECT_TRUE(builtin::sub(rewriter, results, {maxValF16, oneF16}).failed());
566+
EXPECT_TRUE(builtin::sub(rewriter, results, {oneF16, minValF16}).failed());
569567
}
570568

571569
// float: correctness

0 commit comments

Comments
 (0)