Skip to content

[MLIR][Arith] Add rounding mode attribute to truncf #86152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 1, 2024

Conversation

victor-eds
Copy link
Contributor

Add rounding mode attribute to arith. This attribute can be used in different FP arith operations to control rounding mode. Rounding modes correspond to IEEE 754-specified rounding modes.

As this is not supported in other dialects, conversion should fail for now in case this attribute is present.

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir-arith

Author: Victor Perez (victor-eds)

Changes

Add rounding mode attribute to arith. This attribute can be used in different FP arith operations to control rounding mode. Rounding modes correspond to IEEE 754-specified rounding modes.

As this is not supported in other dialects, conversion should fail for now in case this attribute is present.


Full diff: https://github.com/llvm/llvm-project/pull/86152.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithBase.td (+25)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+18-3)
  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td (+33)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+3)
  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+27-1)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+9)
  • (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+6)
  • (modified) mlir/test/Dialect/Arith/ops.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index c8a42c43c880b0..a9d976d9e4e28c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -156,4 +156,29 @@ def Arith_IntegerOverflowAttr :
   let assemblyFormat = "`<` $value `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// Arith_RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These correspond to LLVM's values defined in:
+// llvm/include/llvm/ADT/FloatingPointMode.h
+
+def Arith_RToNearestTiesToEven       // Round to nearest, ties to even
+    : I32EnumAttrCase<"tonearesteven", 0>;
+def Arith_RDownward                  // Round toward -inf
+    : I32EnumAttrCase<"downward", 1>;
+def Arith_RUpward                    // Round toward +inf
+    : I32EnumAttrCase<"upward", 2>;
+def Arith_RTowardZero                // Round toward 0
+    : I32EnumAttrCase<"towardzero", 3>;
+def Arith_RToNearestTiesAwayFromZero // Round to nearest, ties away from zero
+    : I32EnumAttrCase<"tonearestaway", 4>;
+
+def Arith_RoundingModeAttr : I32EnumAttr<
+    "RoundingMode", "Floating point rounding mode",
+    [Arith_RToNearestTiesToEven, Arith_RDownward, Arith_RUpward,
+     Arith_RTowardZero, Arith_RToNearestTiesAwayFromZero]> {
+  let cppNamespace = "::mlir::arith";
+}
+
 #endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c9df50d0395d1f..ead19c69a0831c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1227,17 +1227,32 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
 // TruncFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> {
+def Arith_TruncFOp :
+    Arith_Op<"truncf",
+      [Pure, SameOperandsAndResultShape,
+       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+       DeclareOpInterfaceMethods<CastOpInterface>]>,
+    Arguments<(ins FloatLike:$in,
+                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+    Results<(outs FloatLike:$out)> {
   let summary = "cast from floating-point to narrower floating-point";
   let description = [{
     Truncate a floating-point value to a smaller floating-point-typed value.
     The destination type must be strictly narrower than the source type.
-    If the value cannot be exactly represented, it is rounded using the default
-    rounding mode. When operating on vectors, casts elementwise.
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+    When operating on vectors, casts elementwise.
   }];
+  let builders = [
+    OpBuilder<(ins "Type":$out, "Value":$in), [{
+      $_state.addOperands(in);
+      $_state.addTypes(out);
+    }]>
+  ];
 
   let hasFolder = 1;
   let hasVerifier = 1;
+  let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 73a5d9c32ef205..82d6c9ad6b03da 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -106,4 +106,37 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
   ];
 }
 
+def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
+  let description = [{
+    Access to op rounding mode.
+  }];
+
+  let cppNamespace = "::mlir::arith";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a RoundingModeAttr attribute for the operation",
+      /*returnType=*/  "RoundingModeAttr",
+      /*methodName=*/  "getRoundingModeAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getRoundingmodeAttr();
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the RoundingModeAttr attribute for
+                         the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getRoundingModeAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "roundingmode";
+      }]
+    >
+  ];
+}
+
 #endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b51a13ae362e92..0113a3df0b8e3d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -175,6 +175,9 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
 }
 
 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
+  // Only supporting default rounding mode as of now.
+  if (op.getRoundingmodeAttr())
+    return failure();
   Type outType = op.getOut().getType();
   if (auto outVecType = outType.dyn_cast<VectorType>()) {
     if (outVecType.isScalable())
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 1f01f4a75c5b3e..9d1961486303a2 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -28,6 +28,31 @@ using namespace mlir;
 
 namespace {
 
+/// Operations whose conversion will depend on whether they are passed a
+/// rounding mode attribute or not.
+///
+/// \tparam SourceOp is the source operation; \tparam TargetOp, the operation it
+/// will lower to; \tparam AttrConvert is the attribute conversion to convert
+/// the rounding mode attribute.
+template <typename SourceOp, typename TargetOp, bool Constrained,
+          template <typename, typename> typename AttrConvert =
+              AttrConvertPassThrough>
+struct ConstrainedVectorConvertToLLVMPattern
+    : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
+  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
+                                   AttrConvert>::VectorConvertToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
+      return failure();
+    return VectorConvertToLLVMPattern<SourceOp, TargetOp,
+                                      AttrConvert>::matchAndRewrite(op, adaptor,
+                                                                    rewriter);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
@@ -112,7 +137,8 @@ using SubIOpLowering =
     VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
                                arith::AttrConvertOverflowToLLVM>;
 using TruncFOpLowering =
-    VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
+    ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
+                                          false>;
 using TruncIOpLowering =
     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
 using UIToFPOpLowering =
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f396..843d9c0afaadc1 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -805,6 +805,15 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
     } else {
       rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
                                                     adaptor.getOperands());
+      if (auto roundingModeOp =
+              dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
+        if (arith::RoundingModeAttr roundingMode =
+                roundingModeOp.getRoundingModeAttr()) {
+          // TODO: Perform rounding mode attribute conversion and attach to new
+          // operation when defined in the dialect.
+          return failure();
+        }
+      }
     }
     return success();
   }
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 7f246daf99ff3c..b45be8b6bd8a4c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -261,6 +261,12 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
       return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
     }
 
+    if (op.getRoundingmodeAttr()) {
+      return rewriter.notifyMatchFailure(
+          op, "only applicable to default rounding mode.");
+    }
+
+    Type i1Ty = b.getI1Type();
     Type i16Ty = b.getI16Type();
     Type i32Ty = b.getI32Type();
     Type f32Ty = b.getF32Type();
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index e499573e324b5f..e4b23f073117e6 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -703,6 +703,16 @@ func.func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf
   return %0 : vector<[8]xbf16>
 }
 
+// CHECK-LABEL: test_truncf_rounding_mode
+func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
+  %0 = arith.truncf %arg0 tonearesteven : f64 to f32
+  %1 = arith.truncf %arg0 downward : f64 to f32
+  %2 = arith.truncf %arg0 upward : f64 to f32
+  %3 = arith.truncf %arg0 towardzero : f64 to f32
+  %4 = arith.truncf %arg0 tonearestaway : f64 to f32
+  return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
+}
+
 // CHECK-LABEL: test_uitofp
 func.func @test_uitofp(%arg0 : i32) -> f32 {
   %0 = arith.uitofp %arg0 : i32 to f32

@victor-eds victor-eds requested a review from kuhar March 22, 2024 09:41
Add rounding mode attribute to `arith`. This attribute can be used in
different FP `arith` operations to control rounding mode. Rounding
modes correspond to IEEE 754-specified rounding modes.

As this is not supported in other dialects, conversion should fail for
now in case this attribute is present.

Signed-off-by: Victor Perez <[email protected]>
@victor-eds victor-eds force-pushed the arith-trunc-constrained branch from 485c360 to 3b756e4 Compare March 28, 2024 17:11
Copy link

github-actions bot commented Mar 28, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@victor-eds
Copy link
Contributor Author

@kuhar that should address all your comments. Thanks!

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@victor-eds
Copy link
Contributor Author

CI error caused by:

C:\BuildTools\VC\Tools\MSVC\14.29.30133\include\variant(177): fatal error C1060: compiler is out of heap space

only happening in Windows.

@victor-eds victor-eds merged commit 8827ff9 into llvm:main Apr 1, 2024
@victor-eds victor-eds deleted the arith-trunc-constrained branch April 1, 2024 09:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants