Skip to content

Commit 5294ad1

Browse files
authored
[mlir][arith] Improve extf folder (#80232)
* Use APFloat conversion function to avoid losing information by converting to `double`. This would be the case with large types like `f80` or `f128`. * Check for potential information loss. This is intended for small floating point types that may have values not present in larger ones (e.g., f8m2e5fnuz and f16). * Support folding vector constants.
1 parent 0d6ed83 commit 5294ad1

File tree

2 files changed

+49
-13
lines changed

2 files changed

+49
-13
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/OpImplementation.h"
2222
#include "mlir/IR/PatternMatch.h"
2323
#include "mlir/IR/TypeUtilities.h"
24+
#include "mlir/Support/LogicalResult.h"
2425

2526
#include "llvm/ADT/APFloat.h"
2627
#include "llvm/ADT/APInt.h"
@@ -1258,6 +1259,20 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
12581259
srcType.getIntOrFloatBitWidth());
12591260
}
12601261

1262+
/// Attempts to convert `sourceValue` to an APFloat value with
1263+
/// `targetSemantics`, without any information loss or rounding.
1264+
static FailureOr<APFloat>
1265+
convertFloatValue(APFloat sourceValue,
1266+
const llvm::fltSemantics &targetSemantics) {
1267+
bool losesInfo = false;
1268+
auto status = sourceValue.convert(
1269+
targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
1270+
if (losesInfo || status != APFloat::opOK)
1271+
return failure();
1272+
1273+
return sourceValue;
1274+
}
1275+
12611276
//===----------------------------------------------------------------------===//
12621277
// ExtUIOp
12631278
//===----------------------------------------------------------------------===//
@@ -1321,14 +1336,21 @@ LogicalResult arith::ExtSIOp::verify() {
13211336
// ExtFOp
13221337
//===----------------------------------------------------------------------===//
13231338

1324-
/// Always fold extension of FP constants.
1339+
/// Fold extension of float constants when there is no information loss due the
1340+
/// difference in fp semantics.
13251341
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1326-
auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
1327-
if (!constOperand)
1328-
return {};
1329-
1330-
// Convert to target type via 'double'.
1331-
return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
1342+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1343+
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1344+
return constFoldCastOp<FloatAttr, FloatAttr>(
1345+
adaptor.getOperands(), getType(),
1346+
[&targetSemantics](const APFloat &a, bool &castStatus) {
1347+
FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1348+
if (failed(result)) {
1349+
castStatus = false;
1350+
return a;
1351+
}
1352+
return *result;
1353+
});
13321354
}
13331355

13341356
bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -1403,12 +1425,13 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
14031425
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
14041426
return constFoldCastOp<FloatAttr, FloatAttr>(
14051427
adaptor.getOperands(), getType(),
1406-
[&targetSemantics](APFloat a, bool &castStatus) {
1407-
bool losesInfo = false;
1408-
auto status = a.convert(
1409-
targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
1410-
castStatus = !losesInfo && status == APFloat::opOK;
1411-
return a;
1428+
[&targetSemantics](const APFloat &a, bool &castStatus) {
1429+
FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1430+
if (failed(result)) {
1431+
castStatus = false;
1432+
return a;
1433+
}
1434+
return *result;
14121435
});
14131436
}
14141437

@@ -1496,6 +1519,7 @@ OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
14961519
return apf;
14971520
});
14981521
}
1522+
14991523
//===----------------------------------------------------------------------===//
15001524
// FPToUIOp
15011525
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,18 @@ func.func @extFPConstant() -> f64 {
701701
return %0 : f64
702702
}
703703

704+
// CHECK-LABEL: @extFPVectorConstant
705+
// CHECK: %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf128>
706+
// CHECK: return %[[cres]]
707+
func.func @extFPVectorConstant() -> vector<2xf128> {
708+
%cst = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf80>
709+
%0 = arith.extf %cst : vector<2xf80> to vector<2xf128>
710+
return %0 : vector<2xf128>
711+
}
712+
713+
// TODO: We should also add a test for not folding arith.extf on information loss.
714+
// This may happen when extending f8E5M2FNUZ to f16.
715+
704716
// CHECK-LABEL: @truncConstant
705717
// CHECK: %[[cres:.+]] = arith.constant -2 : i16
706718
// CHECK: return %[[cres]]

0 commit comments

Comments
 (0)