-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Arith] Add rounding mode attribute to truncf
#86152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Victor Perez (victor-eds) ChangesAdd rounding mode attribute to As this is not supported in other dialects, conversion should fail for now in case this attribute is present. Full diff: https://github.com/llvm/llvm-project/pull/86152.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index c8a42c43c880b0..a9d976d9e4e28c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -156,4 +156,29 @@ def Arith_IntegerOverflowAttr :
let assemblyFormat = "`<` $value `>`";
}
+//===----------------------------------------------------------------------===//
+// Arith_RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These correspond to LLVM's values defined in:
+// llvm/include/llvm/ADT/FloatingPointMode.h
+
+def Arith_RToNearestTiesToEven // Round to nearest, ties to even
+ : I32EnumAttrCase<"tonearesteven", 0>;
+def Arith_RDownward // Round toward -inf
+ : I32EnumAttrCase<"downward", 1>;
+def Arith_RUpward // Round toward +inf
+ : I32EnumAttrCase<"upward", 2>;
+def Arith_RTowardZero // Round toward 0
+ : I32EnumAttrCase<"towardzero", 3>;
+def Arith_RToNearestTiesAwayFromZero // Round to nearest, ties away from zero
+ : I32EnumAttrCase<"tonearestaway", 4>;
+
+def Arith_RoundingModeAttr : I32EnumAttr<
+ "RoundingMode", "Floating point rounding mode",
+ [Arith_RToNearestTiesToEven, Arith_RDownward, Arith_RUpward,
+ Arith_RTowardZero, Arith_RToNearestTiesAwayFromZero]> {
+ let cppNamespace = "::mlir::arith";
+}
+
#endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c9df50d0395d1f..ead19c69a0831c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1227,17 +1227,32 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
// TruncFOp
//===----------------------------------------------------------------------===//
-def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> {
+def Arith_TruncFOp :
+ Arith_Op<"truncf",
+ [Pure, SameOperandsAndResultShape,
+ DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+ DeclareOpInterfaceMethods<CastOpInterface>]>,
+ Arguments<(ins FloatLike:$in,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+ Results<(outs FloatLike:$out)> {
let summary = "cast from floating-point to narrower floating-point";
let description = [{
Truncate a floating-point value to a smaller floating-point-typed value.
The destination type must be strictly narrower than the source type.
- If the value cannot be exactly represented, it is rounded using the default
- rounding mode. When operating on vectors, casts elementwise.
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+ When operating on vectors, casts elementwise.
}];
+ let builders = [
+ OpBuilder<(ins "Type":$out, "Value":$in), [{
+ $_state.addOperands(in);
+ $_state.addTypes(out);
+ }]>
+ ];
let hasFolder = 1;
let hasVerifier = 1;
+ let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 73a5d9c32ef205..82d6c9ad6b03da 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -106,4 +106,37 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
];
}
+def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
+ let description = [{
+ Access to op rounding mode.
+ }];
+
+ let cppNamespace = "::mlir::arith";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a RoundingModeAttr attribute for the operation",
+ /*returnType=*/ "RoundingModeAttr",
+ /*methodName=*/ "getRoundingModeAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getRoundingmodeAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the RoundingModeAttr attribute for
+ the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getRoundingModeAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "roundingmode";
+ }]
+ >
+ ];
+}
+
#endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b51a13ae362e92..0113a3df0b8e3d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -175,6 +175,9 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
}
LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
+ // Only supporting default rounding mode as of now.
+ if (op.getRoundingmodeAttr())
+ return failure();
Type outType = op.getOut().getType();
if (auto outVecType = outType.dyn_cast<VectorType>()) {
if (outVecType.isScalable())
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 1f01f4a75c5b3e..9d1961486303a2 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -28,6 +28,31 @@ using namespace mlir;
namespace {
+/// Operations whose conversion will depend on whether they are passed a
+/// rounding mode attribute or not.
+///
+/// \tparam SourceOp is the source operation; \tparam TargetOp, the operation it
+/// will lower to; \tparam AttrConvert is the attribute conversion to convert
+/// the rounding mode attribute.
+template <typename SourceOp, typename TargetOp, bool Constrained,
+ template <typename, typename> typename AttrConvert =
+ AttrConvertPassThrough>
+struct ConstrainedVectorConvertToLLVMPattern
+ : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
+ using VectorConvertToLLVMPattern<SourceOp, TargetOp,
+ AttrConvert>::VectorConvertToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
+ return failure();
+ return VectorConvertToLLVMPattern<SourceOp, TargetOp,
+ AttrConvert>::matchAndRewrite(op, adaptor,
+ rewriter);
+ }
+};
+
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
@@ -112,7 +137,8 @@ using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
- VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
+ ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
+ false>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
using UIToFPOpLowering =
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f396..843d9c0afaadc1 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -805,6 +805,15 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
} else {
rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
adaptor.getOperands());
+ if (auto roundingModeOp =
+ dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
+ if (arith::RoundingModeAttr roundingMode =
+ roundingModeOp.getRoundingModeAttr()) {
+ // TODO: Perform rounding mode attribute conversion and attach to new
+ // operation when defined in the dialect.
+ return failure();
+ }
+ }
}
return success();
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 7f246daf99ff3c..b45be8b6bd8a4c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -261,6 +261,12 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
}
+ if (op.getRoundingmodeAttr()) {
+ return rewriter.notifyMatchFailure(
+ op, "only applicable to default rounding mode.");
+ }
+
+ Type i1Ty = b.getI1Type();
Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index e499573e324b5f..e4b23f073117e6 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -703,6 +703,16 @@ func.func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf
return %0 : vector<[8]xbf16>
}
+// CHECK-LABEL: test_truncf_rounding_mode
+func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
+ %0 = arith.truncf %arg0 tonearesteven : f64 to f32
+ %1 = arith.truncf %arg0 downward : f64 to f32
+ %2 = arith.truncf %arg0 upward : f64 to f32
+ %3 = arith.truncf %arg0 towardzero : f64 to f32
+ %4 = arith.truncf %arg0 tonearestaway : f64 to f32
+ return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
+}
+
// CHECK-LABEL: test_uitofp
func.func @test_uitofp(%arg0 : i32) -> f32 {
%0 = arith.uitofp %arg0 : i32 to f32
|
Add rounding mode attribute to `arith`. This attribute can be used in different FP `arith` operations to control rounding mode. Rounding modes correspond to IEEE 754-specified rounding modes. As this is not supported in other dialects, conversion should fail for now in case this attribute is present. Signed-off-by: Victor Perez <[email protected]>
485c360
to
3b756e4
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
@kuhar that should address all your comments. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
CI error caused by:
only happening in Windows. |
Add rounding mode attribute to
arith
. This attribute can be used in different FParith
operations to control rounding mode. Rounding modes correspond to IEEE 754-specified rounding modes.As this is not supported in other dialects, conversion should fail for now in case this attribute is present.