Skip to content

Commit eeb49de

Browse files
committed
[MLIR][Arith] Add rounding mode attribute to truncf
Add rounding mode attribute to `arith`. This attribute can be used in different FP `arith` operations to control rounding mode. Rounding modes correspond to IEEE 754-specified rounding modes. As this is not supported in other dialects, conversion should fail for now in case this attribute is present. Signed-off-by: Victor Perez <[email protected]>
1 parent 77cbc9b commit eeb49de

File tree

8 files changed

+131
-4
lines changed

8 files changed

+131
-4
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,29 @@ def Arith_IntegerOverflowAttr :
156156
let assemblyFormat = "`<` $value `>`";
157157
}
158158

159+
//===----------------------------------------------------------------------===//
160+
// Arith_RoundingMode
161+
//===----------------------------------------------------------------------===//
162+
163+
// These correspond to LLVM's values defined in:
164+
// llvm/include/llvm/ADT/FloatingPointMode.h
165+
166+
def Arith_RToNearestTiesToEven // Round to nearest, ties to even
167+
: I32EnumAttrCase<"tonearesteven", 0>;
168+
def Arith_RDownward // Round toward -inf
169+
: I32EnumAttrCase<"downward", 1>;
170+
def Arith_RUpward // Round toward +inf
171+
: I32EnumAttrCase<"upward", 2>;
172+
def Arith_RTowardZero // Round toward 0
173+
: I32EnumAttrCase<"towardzero", 3>;
174+
def Arith_RToNearestTiesAwayFromZero // Round to nearest, ties away from zero
175+
: I32EnumAttrCase<"tonearestaway", 4>;
176+
177+
def Arith_RoundingModeAttr : I32EnumAttr<
178+
"RoundingMode", "Floating point rounding mode",
179+
[Arith_RToNearestTiesToEven, Arith_RDownward, Arith_RUpward,
180+
Arith_RTowardZero, Arith_RToNearestTiesAwayFromZero]> {
181+
let cppNamespace = "::mlir::arith";
182+
}
183+
159184
#endif // ARITH_BASE

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,17 +1227,32 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
12271227
// TruncFOp
12281228
//===----------------------------------------------------------------------===//
12291229

1230-
def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> {
1230+
def Arith_TruncFOp :
1231+
Arith_Op<"truncf",
1232+
[Pure, SameOperandsAndResultShape,
1233+
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
1234+
DeclareOpInterfaceMethods<CastOpInterface>]>,
1235+
Arguments<(ins FloatLike:$in,
1236+
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
1237+
Results<(outs FloatLike:$out)> {
12311238
let summary = "cast from floating-point to narrower floating-point";
12321239
let description = [{
12331240
Truncate a floating-point value to a smaller floating-point-typed value.
12341241
The destination type must be strictly narrower than the source type.
1235-
If the value cannot be exactly represented, it is rounded using the default
1236-
rounding mode. When operating on vectors, casts elementwise.
1242+
If the value cannot be exactly represented, it is rounded using the
1243+
provided rounding mode or the default one if no rounding mode is provided.
1244+
When operating on vectors, casts elementwise.
12371245
}];
1246+
let builders = [
1247+
OpBuilder<(ins "Type":$out, "Value":$in), [{
1248+
$_state.addOperands(in);
1249+
$_state.addTypes(out);
1250+
}]>
1251+
];
12381252

12391253
let hasFolder = 1;
12401254
let hasVerifier = 1;
1255+
let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
12411256
}
12421257

12431258
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,37 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
106106
];
107107
}
108108

109+
def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
110+
let description = [{
111+
Access to op rounding mode.
112+
}];
113+
114+
let cppNamespace = "::mlir::arith";
115+
116+
let methods = [
117+
InterfaceMethod<
118+
/*desc=*/ "Returns a RoundingModeAttr attribute for the operation",
119+
/*returnType=*/ "RoundingModeAttr",
120+
/*methodName=*/ "getRoundingModeAttr",
121+
/*args=*/ (ins),
122+
/*methodBody=*/ [{}],
123+
/*defaultImpl=*/ [{
124+
auto op = cast<ConcreteOp>(this->getOperation());
125+
return op.getRoundingmodeAttr();
126+
}]
127+
>,
128+
StaticInterfaceMethod<
129+
/*desc=*/ [{Returns the name of the RoundingModeAttr attribute for
130+
the operation}],
131+
/*returnType=*/ "StringRef",
132+
/*methodName=*/ "getRoundingModeAttrName",
133+
/*args=*/ (ins),
134+
/*methodBody=*/ [{}],
135+
/*defaultImpl=*/ [{
136+
return "roundingmode";
137+
}]
138+
>
139+
];
140+
}
141+
109142
#endif // ARITH_OPS_INTERFACES

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
175175
}
176176

177177
LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
178+
// Only supporting default rounding mode as of now.
179+
if (op.getRoundingmodeAttr())
180+
return failure();
178181
Type outType = op.getOut().getType();
179182
if (auto outVecType = outType.dyn_cast<VectorType>()) {
180183
if (outVecType.isScalable())

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,31 @@ using namespace mlir;
2828

2929
namespace {
3030

31+
/// Operations whose conversion will depend on whether they are passed a
32+
/// rounding mode attribute or not.
33+
///
34+
/// \tparam SourceOp is the source operation; \tparam TargetOp, the operation it
35+
/// will lower to; \tparam AttrConvert is the attribute conversion to convert
36+
/// the rounding mode attribute.
37+
template <typename SourceOp, typename TargetOp, bool Constrained,
38+
template <typename, typename> typename AttrConvert =
39+
AttrConvertPassThrough>
40+
struct ConstrainedVectorConvertToLLVMPattern
41+
: public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
42+
using VectorConvertToLLVMPattern<SourceOp, TargetOp,
43+
AttrConvert>::VectorConvertToLLVMPattern;
44+
45+
LogicalResult
46+
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
47+
ConversionPatternRewriter &rewriter) const override {
48+
if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
49+
return failure();
50+
return VectorConvertToLLVMPattern<SourceOp, TargetOp,
51+
AttrConvert>::matchAndRewrite(op, adaptor,
52+
rewriter);
53+
}
54+
};
55+
3156
//===----------------------------------------------------------------------===//
3257
// Straightforward Op Lowerings
3358
//===----------------------------------------------------------------------===//
@@ -112,7 +137,8 @@ using SubIOpLowering =
112137
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
113138
arith::AttrConvertOverflowToLLVM>;
114139
using TruncFOpLowering =
115-
VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
140+
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
141+
false>;
116142
using TruncIOpLowering =
117143
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
118144
using UIToFPOpLowering =

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,15 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
805805
} else {
806806
rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
807807
adaptor.getOperands());
808+
if (auto roundingModeOp =
809+
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
810+
if (arith::RoundingModeAttr roundingMode =
811+
roundingModeOp.getRoundingModeAttr()) {
812+
// TODO: Perform rounding mode attribute conversion and attach to new
813+
// operation when defined in the dialect.
814+
return failure();
815+
}
816+
}
808817
}
809818
return success();
810819
}

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
253253
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
254254
}
255255

256+
if (op.getRoundingmodeAttr()) {
257+
return rewriter.notifyMatchFailure(
258+
op, "only applicable to default rounding mode.");
259+
}
260+
261+
Type i1Ty = b.getI1Type();
256262
Type i16Ty = b.getI16Type();
257263
Type i32Ty = b.getI32Type();
258264
Type f32Ty = b.getF32Type();

mlir/test/Dialect/Arith/ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,16 @@ func.func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf
703703
return %0 : vector<[8]xbf16>
704704
}
705705

706+
// CHECK-LABEL: test_truncf_rounding_mode
707+
func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
708+
%0 = arith.truncf %arg0 tonearesteven : f64 to f32
709+
%1 = arith.truncf %arg0 downward : f64 to f32
710+
%2 = arith.truncf %arg0 upward : f64 to f32
711+
%3 = arith.truncf %arg0 towardzero : f64 to f32
712+
%4 = arith.truncf %arg0 tonearestaway : f64 to f32
713+
return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
714+
}
715+
706716
// CHECK-LABEL: test_uitofp
707717
func.func @test_uitofp(%arg0 : i32) -> f32 {
708718
%0 = arith.uitofp %arg0 : i32 to f32

0 commit comments

Comments
 (0)