Skip to content

Commit f144e15

Browse files
committed
Extend
1 parent a7e82a7 commit f144e15

File tree

4 files changed

+127
-8
lines changed

4 files changed

+127
-8
lines changed

mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
3636
LLVM::IntegerOverflowFlagsAttr
3737
convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
3838

39+
/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
40+
/// mode enum value.
41+
LLVM::RoundingMode
42+
convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode);
43+
44+
/// Creates an LLVM rounding mode attribute from a given arithmetic rounding
45+
/// mode attribute.
46+
LLVM::RoundingModeAttr
47+
convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr);
48+
49+
/// Returns an attribute for the default LLVM FP exception behavior.
50+
LLVM::FPExceptionBehaviorAttr
51+
getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
52+
3953
// Attribute converter that populates a NamedAttrList by removing the fastmath
4054
// attribute from the source operation attributes, and replacing it with an
4155
// equivalent LLVM fastmath attribute.

mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,34 @@ LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
5555
return LLVM::IntegerOverflowFlagsAttr::get(
5656
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
5757
}
58+
59+
LLVM::RoundingMode
60+
mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
61+
switch (roundingMode) {
62+
case arith::RoundingMode::downward:
63+
return LLVM::RoundingMode::TowardNegative;
64+
case arith::RoundingMode::tonearestaway:
65+
return LLVM::RoundingMode::NearestTiesToAway;
66+
case arith::RoundingMode::tonearesteven:
67+
return LLVM::RoundingMode::NearestTiesToEven;
68+
case arith::RoundingMode::towardzero:
69+
return LLVM::RoundingMode::TowardZero;
70+
case arith::RoundingMode::upward:
71+
return LLVM::RoundingMode::TowardPositive;
72+
}
73+
llvm_unreachable("Unhandled rounding mode");
74+
}
75+
76+
LLVM::RoundingModeAttr mlir::arith::convertArithRoundingModeAttrToLLVM(
77+
arith::RoundingModeAttr roundingModeAttr) {
78+
assert(roundingModeAttr && "Expecting valid attribute");
79+
return LLVM::RoundingModeAttr::get(
80+
roundingModeAttr.getContext(),
81+
convertArithRoundingModeToLLVM(roundingModeAttr.getValue()));
82+
}
83+
84+
LLVM::FPExceptionBehaviorAttr
85+
mlir::arith::getLLVMDefaultFPExceptionBehavior(MLIRContext &context) {
86+
return LLVM::FPExceptionBehaviorAttr::get(&context,
87+
LLVM::FPExceptionBehavior::Ignore);
88+
}

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,29 @@ arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
9191
llvm_unreachable("unknown cmpi predicate kind");
9292
}
9393

94+
/// Equivalent to
95+
/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
96+
///
97+
/// Not possible to implement as chain of calls as this would introduce a
98+
/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
99+
/// on the LLVM dialect and on translation to LLVM.
100+
static llvm::RoundingMode
101+
convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
102+
switch (roundingMode) {
103+
case RoundingMode::downward:
104+
return llvm::RoundingMode::TowardNegative;
105+
case RoundingMode::tonearestaway:
106+
return llvm::RoundingMode::NearestTiesToAway;
107+
case RoundingMode::tonearesteven:
108+
return llvm::RoundingMode::NearestTiesToEven;
109+
case RoundingMode::towardzero:
110+
return llvm::RoundingMode::TowardZero;
111+
case RoundingMode::upward:
112+
return llvm::RoundingMode::TowardPositive;
113+
}
114+
llvm_unreachable("Unhandled rounding mode");
115+
}
116+
94117
static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
95118
return arith::CmpIPredicateAttr::get(pred.getContext(),
96119
invertPredicate(pred.getValue()));
@@ -1233,13 +1256,12 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
12331256
}
12341257

12351258
/// Attempts to convert `sourceValue` to an APFloat value with
1236-
/// `targetSemantics`, without any information loss or rounding.
1237-
static FailureOr<APFloat>
1238-
convertFloatValue(APFloat sourceValue,
1239-
const llvm::fltSemantics &targetSemantics) {
1259+
/// `targetSemantics` and `roundingMode`, without any information loss.
1260+
static FailureOr<APFloat> convertFloatValue(
1261+
APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1262+
llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
12401263
bool losesInfo = false;
1241-
auto status = sourceValue.convert(
1242-
targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
1264+
auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
12431265
if (losesInfo || status != APFloat::opOK)
12441266
return failure();
12451267

@@ -1398,8 +1420,15 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
13981420
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
13991421
return constFoldCastOp<FloatAttr, FloatAttr>(
14001422
adaptor.getOperands(), getType(),
1401-
[&targetSemantics](const APFloat &a, bool &castStatus) {
1402-
FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1423+
[this, &targetSemantics](const APFloat &a, bool &castStatus) {
1424+
FailureOr<APFloat> result;
1425+
if (std::optional<RoundingMode> roundingMode = getRoundingmode()) {
1426+
llvm::RoundingMode llvmRoundingMode =
1427+
convertArithRoundingModeToLLVMIR(*roundingMode);
1428+
result = convertFloatValue(a, targetSemantics, llvmRoundingMode);
1429+
} else {
1430+
result = convertFloatValue(a, targetSemantics);
1431+
}
14031432
if (failed(result)) {
14041433
castStatus = false;
14051434
return a;

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,51 @@ func.func @truncFPConstant() -> bf16 {
757757
return %0 : bf16
758758
}
759759

760+
// CHECK-LABEL: @truncFPToNearestEvenConstant
761+
// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
762+
// CHECK: return %[[cres]]
763+
func.func @truncFPToNearestEvenConstant() -> bf16 {
764+
%cst = arith.constant 1.000000e+00 : f32
765+
%0 = arith.truncf %cst tonearesteven : f32 to bf16
766+
return %0 : bf16
767+
}
768+
769+
// CHECK-LABEL: @truncFPDownwardConstant
770+
// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
771+
// CHECK: return %[[cres]]
772+
func.func @truncFPDownwardConstant() -> bf16 {
773+
%cst = arith.constant 1.000000e+00 : f32
774+
%0 = arith.truncf %cst downward : f32 to bf16
775+
return %0 : bf16
776+
}
777+
778+
// CHECK-LABEL: @truncFPUpwardConstant
779+
// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
780+
// CHECK: return %[[cres]]
781+
func.func @truncFPUpwardConstant() -> bf16 {
782+
%cst = arith.constant 1.000000e+00 : f32
783+
%0 = arith.truncf %cst upward : f32 to bf16
784+
return %0 : bf16
785+
}
786+
787+
// CHECK-LABEL: @truncFPTowardZeroConstant
788+
// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
789+
// CHECK: return %[[cres]]
790+
func.func @truncFPTowardZeroConstant() -> bf16 {
791+
%cst = arith.constant 1.000000e+00 : f32
792+
%0 = arith.truncf %cst towardzero : f32 to bf16
793+
return %0 : bf16
794+
}
795+
796+
// CHECK-LABEL: @truncFPToNearestAwayConstant
797+
// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
798+
// CHECK: return %[[cres]]
799+
func.func @truncFPToNearestAwayConstant() -> bf16 {
800+
%cst = arith.constant 1.000000e+00 : f32
801+
%0 = arith.truncf %cst tonearestaway : f32 to bf16
802+
return %0 : bf16
803+
}
804+
760805
// CHECK-LABEL: @truncFPVectorConstant
761806
// CHECK: %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xbf16>
762807
// CHECK: return %[[cres]]

0 commit comments

Comments
 (0)