Skip to content

Commit 8827ff9

Browse files
authored
[MLIR][Arith] Add rounding mode attribute to truncf (#86152)
Add rounding mode attribute to `arith`. This attribute can be used in different FP `arith` operations to control rounding mode. Rounding modes correspond to IEEE 754-specified rounding modes. Use in `arith.truncf` folding. As this is not supported in dialects other than LLVM, conversion should fail for now in case this attribute is present. --------- Signed-off-by: Victor Perez <[email protected]>
1 parent da1d3d8 commit 8827ff9

File tree

13 files changed

+309
-14
lines changed

13 files changed

+309
-14
lines changed

mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
3636
LLVM::IntegerOverflowFlagsAttr
3737
convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
3838

39+
/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
40+
/// mode enum value.
41+
LLVM::RoundingMode
42+
convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode);
43+
44+
/// Creates an LLVM rounding mode attribute from a given arithmetic rounding
45+
/// mode attribute.
46+
LLVM::RoundingModeAttr
47+
convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr);
48+
49+
/// Returns an attribute for the default LLVM FP exception behavior.
50+
LLVM::FPExceptionBehaviorAttr
51+
getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
52+
3953
// Attribute converter that populates a NamedAttrList by removing the fastmath
4054
// attribute from the source operation attributes, and replacing it with an
4155
// equivalent LLVM fastmath attribute.
@@ -89,6 +103,40 @@ class AttrConvertOverflowToLLVM {
89103
private:
90104
NamedAttrList convertedAttr;
91105
};
106+
107+
template <typename SourceOp, typename TargetOp>
108+
class AttrConverterConstrainedFPToLLVM {
109+
static_assert(TargetOp::template hasTrait<
110+
LLVM::FPExceptionBehaviorOpInterface::Trait>(),
111+
"Target constrained FP operations must implement "
112+
"LLVM::FPExceptionBehaviorOpInterface");
113+
114+
public:
115+
AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
116+
// Copy the source attributes.
117+
convertedAttr = NamedAttrList{srcOp->getAttrs()};
118+
119+
if constexpr (TargetOp::template hasTrait<
120+
LLVM::RoundingModeOpInterface::Trait>()) {
121+
// Get the name of the rounding mode attribute.
122+
StringRef arithAttrName = srcOp.getRoundingModeAttrName();
123+
// Remove the source attribute.
124+
auto arithAttr =
125+
cast<arith::RoundingModeAttr>(convertedAttr.erase(arithAttrName));
126+
// Set the target attribute.
127+
convertedAttr.set(TargetOp::getRoundingModeAttrName(),
128+
convertArithRoundingModeAttrToLLVM(arithAttr));
129+
}
130+
convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
131+
getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
132+
}
133+
134+
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
135+
136+
private:
137+
NamedAttrList convertedAttr;
138+
};
139+
92140
} // namespace arith
93141
} // namespace mlir
94142

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,29 @@ def Arith_IntegerOverflowAttr :
156156
let assemblyFormat = "`<` $value `>`";
157157
}
158158

159+
//===----------------------------------------------------------------------===//
160+
// Arith_RoundingMode
161+
//===----------------------------------------------------------------------===//
162+
163+
// These correspond to LLVM's values defined in:
164+
// llvm/include/llvm/ADT/FloatingPointMode.h
165+
166+
def Arith_RToNearestTiesToEven // Round to nearest, ties to even
167+
: I32EnumAttrCase<"to_nearest_even", 0>;
168+
def Arith_RDownward // Round toward -inf
169+
: I32EnumAttrCase<"downward", 1>;
170+
def Arith_RUpward // Round toward +inf
171+
: I32EnumAttrCase<"upward", 2>;
172+
def Arith_RTowardZero // Round toward 0
173+
: I32EnumAttrCase<"toward_zero", 3>;
174+
def Arith_RToNearestTiesAwayFromZero // Round to nearest, ties away from zero
175+
: I32EnumAttrCase<"to_nearest_away", 4>;
176+
177+
def Arith_RoundingModeAttr : I32EnumAttr<
178+
"RoundingMode", "Floating point rounding mode",
179+
[Arith_RToNearestTiesToEven, Arith_RDownward, Arith_RUpward,
180+
Arith_RTowardZero, Arith_RToNearestTiesAwayFromZero]> {
181+
let cppNamespace = "::mlir::arith";
182+
}
183+
159184
#endif // ARITH_BASE

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,17 +1227,32 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
12271227
// TruncFOp
12281228
//===----------------------------------------------------------------------===//
12291229

1230-
def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> {
1230+
def Arith_TruncFOp :
1231+
Arith_Op<"truncf",
1232+
[Pure, SameOperandsAndResultShape,
1233+
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
1234+
DeclareOpInterfaceMethods<CastOpInterface>]>,
1235+
Arguments<(ins FloatLike:$in,
1236+
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
1237+
Results<(outs FloatLike:$out)> {
12311238
let summary = "cast from floating-point to narrower floating-point";
12321239
let description = [{
12331240
Truncate a floating-point value to a smaller floating-point-typed value.
12341241
The destination type must be strictly narrower than the source type.
1235-
If the value cannot be exactly represented, it is rounded using the default
1236-
rounding mode. When operating on vectors, casts elementwise.
1242+
If the value cannot be exactly represented, it is rounded using the
1243+
provided rounding mode or the default one if no rounding mode is provided.
1244+
When operating on vectors, casts elementwise.
12371245
}];
1246+
let builders = [
1247+
OpBuilder<(ins "Type":$out, "Value":$in), [{
1248+
$_state.addOperands(in);
1249+
$_state.addTypes(out);
1250+
}]>
1251+
];
12381252

12391253
let hasFolder = 1;
12401254
let hasVerifier = 1;
1255+
let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
12411256
}
12421257

12431258
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,37 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
106106
];
107107
}
108108

109+
def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
110+
let description = [{
111+
Access to op rounding mode.
112+
}];
113+
114+
let cppNamespace = "::mlir::arith";
115+
116+
let methods = [
117+
InterfaceMethod<
118+
/*desc=*/ "Returns a RoundingModeAttr attribute for the operation",
119+
/*returnType=*/ "RoundingModeAttr",
120+
/*methodName=*/ "getRoundingModeAttr",
121+
/*args=*/ (ins),
122+
/*methodBody=*/ [{}],
123+
/*defaultImpl=*/ [{
124+
auto op = cast<ConcreteOp>(this->getOperation());
125+
return op.getRoundingmodeAttr();
126+
}]
127+
>,
128+
StaticInterfaceMethod<
129+
/*desc=*/ [{Returns the name of the RoundingModeAttr attribute for
130+
the operation}],
131+
/*returnType=*/ "StringRef",
132+
/*methodName=*/ "getRoundingModeAttrName",
133+
/*args=*/ (ins),
134+
/*methodBody=*/ [{}],
135+
/*defaultImpl=*/ [{
136+
return "roundingmode";
137+
}]
138+
>
139+
];
140+
}
141+
109142
#endif // ARITH_OPS_INTERFACES

mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,34 @@ LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
5555
return LLVM::IntegerOverflowFlagsAttr::get(
5656
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
5757
}
58+
59+
LLVM::RoundingMode
60+
mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
61+
switch (roundingMode) {
62+
case arith::RoundingMode::downward:
63+
return LLVM::RoundingMode::TowardNegative;
64+
case arith::RoundingMode::to_nearest_away:
65+
return LLVM::RoundingMode::NearestTiesToAway;
66+
case arith::RoundingMode::to_nearest_even:
67+
return LLVM::RoundingMode::NearestTiesToEven;
68+
case arith::RoundingMode::toward_zero:
69+
return LLVM::RoundingMode::TowardZero;
70+
case arith::RoundingMode::upward:
71+
return LLVM::RoundingMode::TowardPositive;
72+
}
73+
llvm_unreachable("Unhandled rounding mode");
74+
}
75+
76+
LLVM::RoundingModeAttr mlir::arith::convertArithRoundingModeAttrToLLVM(
77+
arith::RoundingModeAttr roundingModeAttr) {
78+
assert(roundingModeAttr && "Expecting valid attribute");
79+
return LLVM::RoundingModeAttr::get(
80+
roundingModeAttr.getContext(),
81+
convertArithRoundingModeToLLVM(roundingModeAttr.getValue()));
82+
}
83+
84+
LLVM::FPExceptionBehaviorAttr
85+
mlir::arith::getLLVMDefaultFPExceptionBehavior(MLIRContext &context) {
86+
return LLVM::FPExceptionBehaviorAttr::get(&context,
87+
LLVM::FPExceptionBehavior::Ignore);
88+
}

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
175175
}
176176

177177
LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
178+
// Only supporting default rounding mode as of now.
179+
if (op.getRoundingmodeAttr())
180+
return failure();
178181
Type outType = op.getOut().getType();
179182
if (auto outVecType = outType.dyn_cast<VectorType>()) {
180183
if (outVecType.isScalable())

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,31 @@ using namespace mlir;
2828

2929
namespace {
3030

31+
/// Operations whose conversion will depend on whether they are passed a
32+
/// rounding mode attribute or not.
33+
///
34+
/// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
35+
/// to; `AttrConvert` is the attribute conversion to convert the rounding mode
36+
/// attribute.
37+
template <typename SourceOp, typename TargetOp, bool Constrained,
38+
template <typename, typename> typename AttrConvert =
39+
AttrConvertPassThrough>
40+
struct ConstrainedVectorConvertToLLVMPattern
41+
: public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
42+
using VectorConvertToLLVMPattern<SourceOp, TargetOp,
43+
AttrConvert>::VectorConvertToLLVMPattern;
44+
45+
LogicalResult
46+
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
47+
ConversionPatternRewriter &rewriter) const override {
48+
if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
49+
return failure();
50+
return VectorConvertToLLVMPattern<SourceOp, TargetOp,
51+
AttrConvert>::matchAndRewrite(op, adaptor,
52+
rewriter);
53+
}
54+
};
55+
3156
//===----------------------------------------------------------------------===//
3257
// Straightforward Op Lowerings
3358
//===----------------------------------------------------------------------===//
@@ -112,7 +137,11 @@ using SubIOpLowering =
112137
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
113138
arith::AttrConvertOverflowToLLVM>;
114139
using TruncFOpLowering =
115-
VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
140+
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
141+
false>;
142+
using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
143+
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
144+
arith::AttrConverterConstrainedFPToLLVM>;
116145
using TruncIOpLowering =
117146
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
118147
using UIToFPOpLowering =
@@ -537,6 +566,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
537566
SubFOpLowering,
538567
SubIOpLowering,
539568
TruncFOpLowering,
569+
ConstrainedTruncFOpLowering,
540570
TruncIOpLowering,
541571
UIToFPOpLowering,
542572
XOrIOpLowering

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,15 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
805805
} else {
806806
rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
807807
adaptor.getOperands());
808+
if (auto roundingModeOp =
809+
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
810+
if (arith::RoundingModeAttr roundingMode =
811+
roundingModeOp.getRoundingModeAttr()) {
812+
// TODO: Perform rounding mode attribute conversion and attach to new
813+
// operation when defined in the dialect.
814+
return failure();
815+
}
816+
}
808817
}
809818
return success();
810819
}

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,29 @@ arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
9191
llvm_unreachable("unknown cmpi predicate kind");
9292
}
9393

94+
/// Equivalent to
95+
/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
96+
///
97+
/// Not possible to implement as chain of calls as this would introduce a
98+
/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
99+
/// on the LLVM dialect and on translation to LLVM.
100+
static llvm::RoundingMode
101+
convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
102+
switch (roundingMode) {
103+
case RoundingMode::downward:
104+
return llvm::RoundingMode::TowardNegative;
105+
case RoundingMode::to_nearest_away:
106+
return llvm::RoundingMode::NearestTiesToAway;
107+
case RoundingMode::to_nearest_even:
108+
return llvm::RoundingMode::NearestTiesToEven;
109+
case RoundingMode::toward_zero:
110+
return llvm::RoundingMode::TowardZero;
111+
case RoundingMode::upward:
112+
return llvm::RoundingMode::TowardPositive;
113+
}
114+
llvm_unreachable("Unhandled rounding mode");
115+
}
116+
94117
static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
95118
return arith::CmpIPredicateAttr::get(pred.getContext(),
96119
invertPredicate(pred.getValue()));
@@ -1233,13 +1256,12 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
12331256
}
12341257

12351258
/// Attempts to convert `sourceValue` to an APFloat value with
1236-
/// `targetSemantics`, without any information loss or rounding.
1237-
static FailureOr<APFloat>
1238-
convertFloatValue(APFloat sourceValue,
1239-
const llvm::fltSemantics &targetSemantics) {
1259+
/// `targetSemantics` and `roundingMode`, without any information loss.
1260+
static FailureOr<APFloat> convertFloatValue(
1261+
APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1262+
llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
12401263
bool losesInfo = false;
1241-
auto status = sourceValue.convert(
1242-
targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
1264+
auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
12431265
if (losesInfo || status != APFloat::opOK)
12441266
return failure();
12451267

@@ -1391,15 +1413,19 @@ LogicalResult arith::TruncIOp::verify() {
13911413
//===----------------------------------------------------------------------===//
13921414

13931415
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1394-
/// can be represented without precision loss or rounding. This is because the
1395-
/// semantics of `arith.truncf` do not assume a specific rounding mode.
1416+
/// can be represented without precision loss.
13961417
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
13971418
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
13981419
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
13991420
return constFoldCastOp<FloatAttr, FloatAttr>(
14001421
adaptor.getOperands(), getType(),
1401-
[&targetSemantics](const APFloat &a, bool &castStatus) {
1402-
FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1422+
[this, &targetSemantics](const APFloat &a, bool &castStatus) {
1423+
RoundingMode roundingMode =
1424+
getRoundingmode().value_or(RoundingMode::to_nearest_even);
1425+
llvm::RoundingMode llvmRoundingMode =
1426+
convertArithRoundingModeToLLVMIR(roundingMode);
1427+
FailureOr<APFloat> result =
1428+
convertFloatValue(a, targetSemantics, llvmRoundingMode);
14031429
if (failed(result)) {
14041430
castStatus = false;
14051431
return a;

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,11 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
253253
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
254254
}
255255

256+
if (op.getRoundingmodeAttr()) {
257+
return rewriter.notifyMatchFailure(
258+
op, "only applicable to default rounding mode.");
259+
}
260+
256261
Type i16Ty = b.getI16Type();
257262
Type i32Ty = b.getI32Type();
258263
Type f32Ty = b.getF32Type();

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,21 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
289289
return
290290
}
291291

292+
// CHECK-LABEL: experimental_constrained_fptrunc
293+
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
294+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
295+
%0 = arith.truncf %arg0 to_nearest_even : f64 to f32
296+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32
297+
%1 = arith.truncf %arg0 downward : f64 to f32
298+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32
299+
%2 = arith.truncf %arg0 upward : f64 to f32
300+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32
301+
%3 = arith.truncf %arg0 toward_zero : f64 to f32
302+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32
303+
%4 = arith.truncf %arg0 to_nearest_away : f64 to f32
304+
return
305+
}
306+
292307
// Check sign and zero extension and truncation of integers.
293308
// CHECK-LABEL: @integer_extension_and_truncation
294309
func.func @integer_extension_and_truncation(%arg0 : i3) {

0 commit comments

Comments
 (0)