-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][arith][spirv] Convert arith.truncf rounding mode to SPIR-V #101547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kuhar
merged 1 commit into
llvm:main
from
andfau-amd:87050-arith-truncf-rounding-mode-SPIR-V
Aug 2, 2024
Merged
[mlir][arith][spirv] Convert arith.truncf rounding mode to SPIR-V #101547
kuhar
merged 1 commit into
llvm:main
from
andfau-amd:87050-arith-truncf-rounding-mode-SPIR-V
Aug 2, 2024
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The first commit is from #101546, please review/merge that first. |
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Andrea Faulds (andfau-amd) ChangesResolves #87050. Full diff: https://github.com/llvm/llvm-project/pull/101547.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6ec97e17c5dcc..b38978272c5bd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3249,6 +3249,19 @@ def SPIRV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
];
}
+def SPIRV_FPRM_RTE : I32EnumAttrCase<"RTE", 0>;
+def SPIRV_FPRM_RTZ : I32EnumAttrCase<"RTZ", 1>;
+def SPIRV_FPRM_RTP : I32EnumAttrCase<"RTP", 2>;
+def SPIRV_FPRM_RTN : I32EnumAttrCase<"RTN", 3>;
+
+// TODO: Enforce SPIR-V spec validation rule for Shader capability: only permit
+// FPRoundingMode on a value stored to certain storage classes?
+// (The OpenCL environment also has FPRoundingMode rules, but different.)
+def SPIRV_FPRoundingModeAttr :
+ SPIRV_I32EnumAttr<"FPRoundingMode", "valid SPIR-V FPRoundingMode", "fp_rounding_mode", [
+ SPIRV_FPRM_RTE, SPIRV_FPRM_RTZ, SPIRV_FPRM_RTP, SPIRV_FPRM_RTN
+ ]>;
+
def SPIRV_FunctionControlAttr :
SPIRV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [
SPIRV_FC_None, SPIRV_FC_Inline, SPIRV_FC_DontInline, SPIRV_FC_Pure, SPIRV_FC_Const,
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 4c3237b24b786..f2b9a18f60eca 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -807,6 +807,25 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
// TypeCastingOp
//===----------------------------------------------------------------------===//
+static std::optional<spirv::FPRoundingMode>
+convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
+ switch (roundingMode) {
+ case arith::RoundingMode::downward:
+ return spirv::FPRoundingMode::RTN;
+ case arith::RoundingMode::to_nearest_even:
+ return spirv::FPRoundingMode::RTE;
+ case arith::RoundingMode::toward_zero:
+ return spirv::FPRoundingMode::RTZ;
+ case arith::RoundingMode::upward:
+ return spirv::FPRoundingMode::RTP;
+ case arith::RoundingMode::to_nearest_away:
+ // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
+ // (as of SPIR-V 1.6)
+ return {};
+ }
+ llvm_unreachable("Unhandled rounding mode");
+}
+
/// Converts type-casting standard operations to SPIR-V operations.
template <typename Op, typename SPIRVOp>
struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
@@ -829,15 +848,20 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
// Then we can just erase this operation by forwarding its operand.
rewriter.replaceOp(op, adaptor.getOperands().front());
} else {
- rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
- adaptor.getOperands());
+ auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+ op, dstType, adaptor.getOperands());
if (auto roundingModeOp =
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
if (arith::RoundingModeAttr roundingMode =
roundingModeOp.getRoundingModeAttr()) {
- // TODO: Perform rounding mode attribute conversion and attach to new
- // operation when defined in the dialect.
- return failure();
+ if (auto rm =
+ convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
+ newOp->setAttr(
+ getDecorationString(spirv::Decoration::FPRoundingMode),
+ spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+ } else {
+ return failure(); // unsupported rounding mode
+ }
}
}
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d7a308548cf4d..12980879b20ab 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -250,6 +250,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
static_cast<FPFastMathMode>(words[2])));
break;
+ case spirv::Decoration::FPRoundingMode:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorate with ")
+ << decorationName << " needs a single integer literal";
+ }
+ decorations[words[0]].set(
+ symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
+ static_cast<FPRoundingMode>(words[2])));
+ break;
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Binding:
if (words.size() != 3) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 4c4fef177317e..714a3edfb5657 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -214,6 +214,9 @@ static std::string getDecorationName(StringRef attrName) {
// expected FPFastMathMode.
if (attrName == "fp_fast_math_mode")
return "FPFastMathMode";
+ // similar here
+ if (attrName == "fp_rounding_mode")
+ return "FPRoundingMode";
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
}
@@ -242,6 +245,13 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
}
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
<< stringifyDecoration(decoration);
+ case spirv::Decoration::FPRoundingMode:
+ if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
+ args.push_back(static_cast<uint32_t>(intAttr.getValue()));
+ break;
+ }
+ return emitError(loc, "expected FPRoundingModeAttr attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index beb2c8d2d242c..4c5b7664bb1aa 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -754,6 +754,21 @@ func.func @fptrunc2(%arg0: f32) -> f16 {
return %0 : f16
}
+
+// CHECK-LABEL: @experimental_constrained_fptrunc
+func.func @experimental_constrained_fptrunc(%arg0 : f64) {
+ // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTE>} : f64 to f32
+ %0 = arith.truncf %arg0 to_nearest_even : f64 to f32
+ // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f64 to f32
+ %1 = arith.truncf %arg0 downward : f64 to f32
+ // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTP>} : f64 to f32
+ %2 = arith.truncf %arg0 upward : f64 to f32
+ // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTZ>} : f64 to f32
+ %3 = arith.truncf %arg0 toward_zero : f64 to f32
+ return
+}
+
+
// CHECK-LABEL: @sitofp1
func.func @sitofp1(%arg0 : i32) -> f32 {
// CHECK: spirv.ConvertSToF %{{.*}} : i32 to f32
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 195773735431e..0a29290b6a6fa 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -97,3 +97,13 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
spirv.ReturnValue %0 : f32
}
}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel, Float16], []> {
+spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
+ // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
+ %0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
+ spirv.ReturnValue %0 : f16
+}
+}
|
kuhar
reviewed
Aug 1, 2024
46a8013
to
a4effd4
Compare
kuhar
reviewed
Aug 2, 2024
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
Outdated
Show resolved
Hide resolved
a4effd4
to
b6c6148
Compare
kuhar
approved these changes
Aug 2, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Resolves #87050.