Skip to content

[mlir][spirv] Add some op decorations #72809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2023
Merged

Conversation

Hardcode84
Copy link
Contributor

NoSignedWrap, NoUnsignedWrap, FPFastMathMode.

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2023

@llvm/pr-subscribers-mlir-spirv

Author: Ivan Butygin (Hardcode84)

Changes

NoSignedWrap, NoUnsignedWrap, FPFastMathMode.


Full diff: https://github.com/llvm/llvm-project/pull/72809.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+24)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+12-1)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+20-3)
  • (modified) mlir/test/Target/SPIRV/decorations.mlir (+21)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca562b7..8eaf2a98a58560e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4742,4 +4742,28 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
   SPIRV_VendorOp<mnemonic, "NV", traits> {
 }
 
+def SPIRV_FPFMM_None         : I32BitEnumAttrCaseNone<"None">;
+def SPIRV_FPFMM_NotNaN       : I32BitEnumAttrCaseBit<"NotNaN", 0>;
+def SPIRV_FPFMM_NotInf       : I32BitEnumAttrCaseBit<"NotInf", 1>;
+def SPIRV_FPFMM_NSZ          : I32BitEnumAttrCaseBit<"NSZ", 2>;
+def SPIRV_FPFMM_AllowRecip   : I32BitEnumAttrCaseBit<"AllowRecip", 3>;
+def SPIRV_FPFMM_Fast         : I32BitEnumAttrCaseBit<"Fast", 4>;
+def SPIRV_FPFMM_AllowContractFastINTEL : I32BitEnumAttrCaseBit<"AllowContractFastINTEL", 16> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_FPFastMathModeINTEL]>
+  ];
+}
+def SPIRV_FPFMM_AllowReassocINTEL : I32BitEnumAttrCaseBit<"AllowReassocINTEL", 17> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_FPFastMathModeINTEL]>
+  ];
+}
+
+def SPIRV_FPFastMathModeAttr :
+    SPIRV_BitEnumAttr<"FPFastMathMode", "Indicates a floating-point fast math flag", "fastmath_mode", [
+      SPIRV_FPFMM_None, SPIRV_FPFMM_NotNaN, SPIRV_FPFMM_NotInf, SPIRV_FPFMM_NSZ,
+      SPIRV_FPFMM_AllowRecip, SPIRV_FPFMM_Fast, SPIRV_FPFMM_AllowContractFastINTEL,
+      SPIRV_FPFMM_AllowReassocINTEL
+    ]>;
+
 #endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index ce8b3ab3894606c..89e2e7ad52fa7d1 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -242,6 +242,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
   auto symbol = opBuilder.getStringAttr(attrName);
   switch (static_cast<spirv::Decoration>(words[1])) {
+  case spirv::Decoration::FPFastMathMode:
+    if (words.size() != 3) {
+      return emitError(unknownLoc, "OpDecorate with ")
+             << decorationName << " needs a single integer literal";
+    }
+    decorations[words[0]].set(
+        symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
+                                        static_cast<FPFastMathMode>(words[2])));
+    break;
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Binding:
     if (words.size() != 3) {
@@ -295,8 +304,10 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   case spirv::Decoration::NonReadable:
   case spirv::Decoration::NonWritable:
   case spirv::Decoration::NoPerspective:
-  case spirv::Decoration::Restrict:
+  case spirv::Decoration::NoSignedWrap:
+  case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
+  case spirv::Decoration::Restrict:
     if (words.size() != 2) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 22fcc4939317be9..9e9a16456cc1022 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -206,11 +206,19 @@ void Serializer::processMemoryModel() {
   encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
 }
 
+static std::string getDecorationName(StringRef attrName) {
+  // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
+  // expected FPFastMathMode.
+  if (attrName == "fp_fast_math_mode")
+    return "FPFastMathMode";
+
+  return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
+}
+
 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
                                             NamedAttribute attr) {
   auto attrName = attr.getName().strref();
-  auto decorationName =
-      llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
+  auto decorationName = getDecorationName(attrName);
   auto decoration = spirv::symbolizeDecoration(decorationName);
   if (!decoration) {
     return emitError(
@@ -232,6 +240,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
     args.push_back(static_cast<uint32_t>(linkageType));
     break;
   }
+  case spirv::Decoration::FPFastMathMode:
+    if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+      args.push_back(static_cast<uint32_t>(intAttr.getValue()));
+      break;
+    }
+    return emitError(loc, "expected FPFastMathModeAttr attribute for ")
+           << attrName;
   case spirv::Decoration::Binding:
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Location:
@@ -256,8 +271,10 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
   case spirv::Decoration::NonReadable:
   case spirv::Decoration::NonWritable:
   case spirv::Decoration::NoPerspective:
-  case spirv::Decoration::Restrict:
+  case spirv::Decoration::NoSignedWrap:
+  case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
+  case spirv::Decoration::Restrict:
     // For unit attributes, the args list has no values so we do nothing
     if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
       break;
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index aadf64c340b3445..04cb059f931863d 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -55,6 +55,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   // CHECK: relaxed_precision
   spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr<vector<4xf32>, Output>
 }
+
 // -----
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
@@ -66,3 +67,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
     >
   } : !spirv.ptr<f32, Private>
 }
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.func @iadd_scalar(%arg: i32) -> i32 "None" {
+  // CHECK: spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap}
+  %0 = spirv.IAdd %arg, %arg {no_signed_wrap, no_unsigned_wrap} : i32
+  spirv.ReturnValue %0 : i32
+}
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.func @iadd_scalar(%arg: f32) -> f32 "None" {
+  // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>}
+  %0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} : f32
+  spirv.ReturnValue %0 : f32
+}
+}

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2023

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

NoSignedWrap, NoUnsignedWrap, FPFastMathMode.


Full diff: https://github.com/llvm/llvm-project/pull/72809.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+24)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+12-1)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+20-3)
  • (modified) mlir/test/Target/SPIRV/decorations.mlir (+21)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca562b7..8eaf2a98a58560e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4742,4 +4742,28 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
   SPIRV_VendorOp<mnemonic, "NV", traits> {
 }
 
+def SPIRV_FPFMM_None         : I32BitEnumAttrCaseNone<"None">;
+def SPIRV_FPFMM_NotNaN       : I32BitEnumAttrCaseBit<"NotNaN", 0>;
+def SPIRV_FPFMM_NotInf       : I32BitEnumAttrCaseBit<"NotInf", 1>;
+def SPIRV_FPFMM_NSZ          : I32BitEnumAttrCaseBit<"NSZ", 2>;
+def SPIRV_FPFMM_AllowRecip   : I32BitEnumAttrCaseBit<"AllowRecip", 3>;
+def SPIRV_FPFMM_Fast         : I32BitEnumAttrCaseBit<"Fast", 4>;
+def SPIRV_FPFMM_AllowContractFastINTEL : I32BitEnumAttrCaseBit<"AllowContractFastINTEL", 16> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_FPFastMathModeINTEL]>
+  ];
+}
+def SPIRV_FPFMM_AllowReassocINTEL : I32BitEnumAttrCaseBit<"AllowReassocINTEL", 17> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_FPFastMathModeINTEL]>
+  ];
+}
+
+def SPIRV_FPFastMathModeAttr :
+    SPIRV_BitEnumAttr<"FPFastMathMode", "Indicates a floating-point fast math flag", "fastmath_mode", [
+      SPIRV_FPFMM_None, SPIRV_FPFMM_NotNaN, SPIRV_FPFMM_NotInf, SPIRV_FPFMM_NSZ,
+      SPIRV_FPFMM_AllowRecip, SPIRV_FPFMM_Fast, SPIRV_FPFMM_AllowContractFastINTEL,
+      SPIRV_FPFMM_AllowReassocINTEL
+    ]>;
+
 #endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index ce8b3ab3894606c..89e2e7ad52fa7d1 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -242,6 +242,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
   auto symbol = opBuilder.getStringAttr(attrName);
   switch (static_cast<spirv::Decoration>(words[1])) {
+  case spirv::Decoration::FPFastMathMode:
+    if (words.size() != 3) {
+      return emitError(unknownLoc, "OpDecorate with ")
+             << decorationName << " needs a single integer literal";
+    }
+    decorations[words[0]].set(
+        symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
+                                        static_cast<FPFastMathMode>(words[2])));
+    break;
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Binding:
     if (words.size() != 3) {
@@ -295,8 +304,10 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   case spirv::Decoration::NonReadable:
   case spirv::Decoration::NonWritable:
   case spirv::Decoration::NoPerspective:
-  case spirv::Decoration::Restrict:
+  case spirv::Decoration::NoSignedWrap:
+  case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
+  case spirv::Decoration::Restrict:
     if (words.size() != 2) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 22fcc4939317be9..9e9a16456cc1022 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -206,11 +206,19 @@ void Serializer::processMemoryModel() {
   encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
 }
 
+static std::string getDecorationName(StringRef attrName) {
+  // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
+  // expected FPFastMathMode.
+  if (attrName == "fp_fast_math_mode")
+    return "FPFastMathMode";
+
+  return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
+}
+
 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
                                             NamedAttribute attr) {
   auto attrName = attr.getName().strref();
-  auto decorationName =
-      llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
+  auto decorationName = getDecorationName(attrName);
   auto decoration = spirv::symbolizeDecoration(decorationName);
   if (!decoration) {
     return emitError(
@@ -232,6 +240,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
     args.push_back(static_cast<uint32_t>(linkageType));
     break;
   }
+  case spirv::Decoration::FPFastMathMode:
+    if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+      args.push_back(static_cast<uint32_t>(intAttr.getValue()));
+      break;
+    }
+    return emitError(loc, "expected FPFastMathModeAttr attribute for ")
+           << attrName;
   case spirv::Decoration::Binding:
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Location:
@@ -256,8 +271,10 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
   case spirv::Decoration::NonReadable:
   case spirv::Decoration::NonWritable:
   case spirv::Decoration::NoPerspective:
-  case spirv::Decoration::Restrict:
+  case spirv::Decoration::NoSignedWrap:
+  case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
+  case spirv::Decoration::Restrict:
     // For unit attributes, the args list has no values so we do nothing
     if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
       break;
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index aadf64c340b3445..04cb059f931863d 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -55,6 +55,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   // CHECK: relaxed_precision
   spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr<vector<4xf32>, Output>
 }
+
 // -----
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
@@ -66,3 +67,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
     >
   } : !spirv.ptr<f32, Private>
 }
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.func @iadd_scalar(%arg: i32) -> i32 "None" {
+  // CHECK: spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap}
+  %0 = spirv.IAdd %arg, %arg {no_signed_wrap, no_unsigned_wrap} : i32
+  spirv.ReturnValue %0 : i32
+}
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.func @iadd_scalar(%arg: f32) -> f32 "None" {
+  // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>}
+  %0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} : f32
+  spirv.ReturnValue %0 : f32
+}
+}

@Hardcode84
Copy link
Contributor Author

@mshahneo @silee2 there are some intersections with private patches you aren't upstreamed yet, so FYI

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

NoSignedWrap, NoUnsignedWrap, FPFastMathMode
@Hardcode84 Hardcode84 merged commit b84fe8f into llvm:main Nov 21, 2023
@Hardcode84 Hardcode84 deleted the spirv-fmf-decor branch November 21, 2023 12:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants