Skip to content

[mlir][arith][spirv] Convert arith.truncf rounding mode to SPIR-V #101547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

andfau-amd
Copy link
Contributor

Resolves #87050.

@andfau-amd
Copy link
Contributor Author

The first commit is from #101546, please review/merge that first.

@llvmbot
Copy link
Member

llvmbot commented Aug 1, 2024

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Andrea Faulds (andfau-amd)

Changes

Resolves #87050.


Full diff: https://github.com/llvm/llvm-project/pull/101547.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+13)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+29-5)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+9)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+10)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+15)
  • (modified) mlir/test/Target/SPIRV/decorations.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6ec97e17c5dcc..b38978272c5bd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3249,6 +3249,19 @@ def SPIRV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
   ];
 }
 
+def SPIRV_FPRM_RTE : I32EnumAttrCase<"RTE", 0>;
+def SPIRV_FPRM_RTZ : I32EnumAttrCase<"RTZ", 1>;
+def SPIRV_FPRM_RTP : I32EnumAttrCase<"RTP", 2>;
+def SPIRV_FPRM_RTN : I32EnumAttrCase<"RTN", 3>;
+
+// TODO: Enforce SPIR-V spec validation rule for Shader capability: only permit
+//       FPRoundingMode on a value stored to certain storage classes?
+//       (The OpenCL environment also has FPRoundingMode rules, but different.)
+def SPIRV_FPRoundingModeAttr :
+    SPIRV_I32EnumAttr<"FPRoundingMode", "valid SPIR-V FPRoundingMode", "fp_rounding_mode", [
+      SPIRV_FPRM_RTE, SPIRV_FPRM_RTZ, SPIRV_FPRM_RTP, SPIRV_FPRM_RTN
+    ]>;
+
 def SPIRV_FunctionControlAttr :
     SPIRV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [
       SPIRV_FC_None, SPIRV_FC_Inline, SPIRV_FC_DontInline, SPIRV_FC_Pure, SPIRV_FC_Const,
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 4c3237b24b786..f2b9a18f60eca 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -807,6 +807,25 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
 // TypeCastingOp
 //===----------------------------------------------------------------------===//
 
+static std::optional<spirv::FPRoundingMode>
+convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
+  switch (roundingMode) {
+  case arith::RoundingMode::downward:
+    return spirv::FPRoundingMode::RTN;
+  case arith::RoundingMode::to_nearest_even:
+    return spirv::FPRoundingMode::RTE;
+  case arith::RoundingMode::toward_zero:
+    return spirv::FPRoundingMode::RTZ;
+  case arith::RoundingMode::upward:
+    return spirv::FPRoundingMode::RTP;
+  case arith::RoundingMode::to_nearest_away:
+    // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
+    // (as of SPIR-V 1.6)
+    return {};
+  }
+  llvm_unreachable("Unhandled rounding mode");
+}
+
 /// Converts type-casting standard operations to SPIR-V operations.
 template <typename Op, typename SPIRVOp>
 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
@@ -829,15 +848,20 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
       // Then we can just erase this operation by forwarding its operand.
       rewriter.replaceOp(op, adaptor.getOperands().front());
     } else {
-      rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
-                                                    adaptor.getOperands());
+      auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+          op, dstType, adaptor.getOperands());
       if (auto roundingModeOp =
               dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
         if (arith::RoundingModeAttr roundingMode =
                 roundingModeOp.getRoundingModeAttr()) {
-          // TODO: Perform rounding mode attribute conversion and attach to new
-          // operation when defined in the dialect.
-          return failure();
+          if (auto rm =
+                  convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
+            newOp->setAttr(
+                getDecorationString(spirv::Decoration::FPRoundingMode),
+                spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+          } else {
+            return failure(); // unsupported rounding mode
+          }
         }
       }
     }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d7a308548cf4d..12980879b20ab 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -250,6 +250,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
         symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
                                         static_cast<FPFastMathMode>(words[2])));
     break;
+  case spirv::Decoration::FPRoundingMode:
+    if (words.size() != 3) {
+      return emitError(unknownLoc, "OpDecorate with ")
+             << decorationName << " needs a single integer literal";
+    }
+    decorations[words[0]].set(
+        symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
+                                        static_cast<FPRoundingMode>(words[2])));
+    break;
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Binding:
     if (words.size() != 3) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 4c4fef177317e..714a3edfb5657 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -214,6 +214,9 @@ static std::string getDecorationName(StringRef attrName) {
   // expected FPFastMathMode.
   if (attrName == "fp_fast_math_mode")
     return "FPFastMathMode";
+  // similar here
+  if (attrName == "fp_rounding_mode")
+    return "FPRoundingMode";
 
   return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
 }
@@ -242,6 +245,13 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
     }
     return emitError(loc, "expected FPFastMathModeAttr attribute for ")
            << stringifyDecoration(decoration);
+  case spirv::Decoration::FPRoundingMode:
+    if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
+      args.push_back(static_cast<uint32_t>(intAttr.getValue()));
+      break;
+    }
+    return emitError(loc, "expected FPRoundingModeAttr attribute for ")
+           << stringifyDecoration(decoration);
   case spirv::Decoration::Binding:
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Location:
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index beb2c8d2d242c..4c5b7664bb1aa 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -754,6 +754,21 @@ func.func @fptrunc2(%arg0: f32) -> f16 {
   return %0 : f16
 }
 
+
+// CHECK-LABEL: @experimental_constrained_fptrunc
+func.func @experimental_constrained_fptrunc(%arg0 : f64) {
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTE>} : f64 to f32
+  %0 = arith.truncf %arg0 to_nearest_even : f64 to f32
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f64 to f32
+  %1 = arith.truncf %arg0 downward : f64 to f32
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTP>} : f64 to f32
+  %2 = arith.truncf %arg0 upward : f64 to f32
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTZ>} : f64 to f32
+  %3 = arith.truncf %arg0 toward_zero : f64 to f32
+  return
+}
+
+
 // CHECK-LABEL: @sitofp1
 func.func @sitofp1(%arg0 : i32) -> f32 {
   // CHECK: spirv.ConvertSToF %{{.*}} : i32 to f32
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 195773735431e..0a29290b6a6fa 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -97,3 +97,13 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
   spirv.ReturnValue %0 : f32
 }
 }
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel, Float16], []> {
+spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
+  %0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
+  spirv.ReturnValue %0 : f16
+}
+}

@andfau-amd andfau-amd marked this pull request as draft August 1, 2024 19:07
@andfau-amd andfau-amd force-pushed the 87050-arith-truncf-rounding-mode-SPIR-V branch from 46a8013 to a4effd4 Compare August 1, 2024 22:25
@andfau-amd andfau-amd force-pushed the 87050-arith-truncf-rounding-mode-SPIR-V branch from a4effd4 to b6c6148 Compare August 2, 2024 12:40
@andfau-amd andfau-amd marked this pull request as ready for review August 2, 2024 12:40
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuhar kuhar merged commit b537df9 into llvm:main Aug 2, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][arith][spirv] Convert arith.truncf with rounding mode to SPIR-V
3 participants