-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Add some op decorations #72809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv Author: Ivan Butygin (Hardcode84) ChangesNoSignedWrap, NoUnsignedWrap, FPFastMathMode. Full diff: https://github.com/llvm/llvm-project/pull/72809.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca562b7..8eaf2a98a58560e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4742,4 +4742,28 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
SPIRV_VendorOp<mnemonic, "NV", traits> {
}
+def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">;
+def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>;
+def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;
+def SPIRV_FPFMM_NSZ : I32BitEnumAttrCaseBit<"NSZ", 2>;
+def SPIRV_FPFMM_AllowRecip : I32BitEnumAttrCaseBit<"AllowRecip", 3>;
+def SPIRV_FPFMM_Fast : I32BitEnumAttrCaseBit<"Fast", 4>;
+def SPIRV_FPFMM_AllowContractFastINTEL : I32BitEnumAttrCaseBit<"AllowContractFastINTEL", 16> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_FPFastMathModeINTEL]>
+ ];
+}
+def SPIRV_FPFMM_AllowReassocINTEL : I32BitEnumAttrCaseBit<"AllowReassocINTEL", 17> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_FPFastMathModeINTEL]>
+ ];
+}
+
+def SPIRV_FPFastMathModeAttr :
+ SPIRV_BitEnumAttr<"FPFastMathMode", "Indicates a floating-point fast math flag", "fastmath_mode", [
+ SPIRV_FPFMM_None, SPIRV_FPFMM_NotNaN, SPIRV_FPFMM_NotInf, SPIRV_FPFMM_NSZ,
+ SPIRV_FPFMM_AllowRecip, SPIRV_FPFMM_Fast, SPIRV_FPFMM_AllowContractFastINTEL,
+ SPIRV_FPFMM_AllowReassocINTEL
+ ]>;
+
#endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index ce8b3ab3894606c..89e2e7ad52fa7d1 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -242,6 +242,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
auto symbol = opBuilder.getStringAttr(attrName);
switch (static_cast<spirv::Decoration>(words[1])) {
+ case spirv::Decoration::FPFastMathMode:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorate with ")
+ << decorationName << " needs a single integer literal";
+ }
+ decorations[words[0]].set(
+ symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
+ static_cast<FPFastMathMode>(words[2])));
+ break;
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Binding:
if (words.size() != 3) {
@@ -295,8 +304,10 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
- case spirv::Decoration::Restrict:
+ case spirv::Decoration::NoSignedWrap:
+ case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
+ case spirv::Decoration::Restrict:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 22fcc4939317be9..9e9a16456cc1022 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -206,11 +206,19 @@ void Serializer::processMemoryModel() {
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
}
+static std::string getDecorationName(StringRef attrName) {
+ // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
+ // expected FPFastMathMode.
+ if (attrName == "fp_fast_math_mode")
+ return "FPFastMathMode";
+
+ return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
+}
+
LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr) {
auto attrName = attr.getName().strref();
- auto decorationName =
- llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
+ auto decorationName = getDecorationName(attrName);
auto decoration = spirv::symbolizeDecoration(decorationName);
if (!decoration) {
return emitError(
@@ -232,6 +240,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
args.push_back(static_cast<uint32_t>(linkageType));
break;
}
+ case spirv::Decoration::FPFastMathMode:
+ if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+ args.push_back(static_cast<uint32_t>(intAttr.getValue()));
+ break;
+ }
+ return emitError(loc, "expected FPFastMathModeAttr attribute for ")
+ << attrName;
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
@@ -256,8 +271,10 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
- case spirv::Decoration::Restrict:
+ case spirv::Decoration::NoSignedWrap:
+ case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
+ case spirv::Decoration::Restrict:
// For unit attributes, the args list has no values so we do nothing
if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
break;
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index aadf64c340b3445..04cb059f931863d 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -55,6 +55,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: relaxed_precision
spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr<vector<4xf32>, Output>
}
+
// -----
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
@@ -66,3 +67,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
>
} : !spirv.ptr<f32, Private>
}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.func @iadd_scalar(%arg: i32) -> i32 "None" {
+ // CHECK: spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap}
+ %0 = spirv.IAdd %arg, %arg {no_signed_wrap, no_unsigned_wrap} : i32
+ spirv.ReturnValue %0 : i32
+}
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.func @iadd_scalar(%arg: f32) -> f32 "None" {
+ // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>}
+ %0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} : f32
+ spirv.ReturnValue %0 : f32
+}
+}
|
@llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesNoSignedWrap, NoUnsignedWrap, FPFastMathMode. Full diff: https://github.com/llvm/llvm-project/pull/72809.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca562b7..8eaf2a98a58560e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4742,4 +4742,28 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
SPIRV_VendorOp<mnemonic, "NV", traits> {
}
+def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">;
+def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>;
+def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;
+def SPIRV_FPFMM_NSZ : I32BitEnumAttrCaseBit<"NSZ", 2>;
+def SPIRV_FPFMM_AllowRecip : I32BitEnumAttrCaseBit<"AllowRecip", 3>;
+def SPIRV_FPFMM_Fast : I32BitEnumAttrCaseBit<"Fast", 4>;
+def SPIRV_FPFMM_AllowContractFastINTEL : I32BitEnumAttrCaseBit<"AllowContractFastINTEL", 16> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_FPFastMathModeINTEL]>
+ ];
+}
+def SPIRV_FPFMM_AllowReassocINTEL : I32BitEnumAttrCaseBit<"AllowReassocINTEL", 17> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_FPFastMathModeINTEL]>
+ ];
+}
+
+def SPIRV_FPFastMathModeAttr :
+ SPIRV_BitEnumAttr<"FPFastMathMode", "Indicates a floating-point fast math flag", "fastmath_mode", [
+ SPIRV_FPFMM_None, SPIRV_FPFMM_NotNaN, SPIRV_FPFMM_NotInf, SPIRV_FPFMM_NSZ,
+ SPIRV_FPFMM_AllowRecip, SPIRV_FPFMM_Fast, SPIRV_FPFMM_AllowContractFastINTEL,
+ SPIRV_FPFMM_AllowReassocINTEL
+ ]>;
+
#endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index ce8b3ab3894606c..89e2e7ad52fa7d1 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -242,6 +242,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
auto symbol = opBuilder.getStringAttr(attrName);
switch (static_cast<spirv::Decoration>(words[1])) {
+ case spirv::Decoration::FPFastMathMode:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorate with ")
+ << decorationName << " needs a single integer literal";
+ }
+ decorations[words[0]].set(
+ symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
+ static_cast<FPFastMathMode>(words[2])));
+ break;
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Binding:
if (words.size() != 3) {
@@ -295,8 +304,10 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
- case spirv::Decoration::Restrict:
+ case spirv::Decoration::NoSignedWrap:
+ case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
+ case spirv::Decoration::Restrict:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 22fcc4939317be9..9e9a16456cc1022 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -206,11 +206,19 @@ void Serializer::processMemoryModel() {
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
}
+static std::string getDecorationName(StringRef attrName) {
+ // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
+ // expected FPFastMathMode.
+ if (attrName == "fp_fast_math_mode")
+ return "FPFastMathMode";
+
+ return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
+}
+
LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr) {
auto attrName = attr.getName().strref();
- auto decorationName =
- llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
+ auto decorationName = getDecorationName(attrName);
auto decoration = spirv::symbolizeDecoration(decorationName);
if (!decoration) {
return emitError(
@@ -232,6 +240,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
args.push_back(static_cast<uint32_t>(linkageType));
break;
}
+ case spirv::Decoration::FPFastMathMode:
+ if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+ args.push_back(static_cast<uint32_t>(intAttr.getValue()));
+ break;
+ }
+ return emitError(loc, "expected FPFastMathModeAttr attribute for ")
+ << attrName;
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
@@ -256,8 +271,10 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
- case spirv::Decoration::Restrict:
+ case spirv::Decoration::NoSignedWrap:
+ case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
+ case spirv::Decoration::Restrict:
// For unit attributes, the args list has no values so we do nothing
if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
break;
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index aadf64c340b3445..04cb059f931863d 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -55,6 +55,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: relaxed_precision
spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr<vector<4xf32>, Output>
}
+
// -----
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
@@ -66,3 +67,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
>
} : !spirv.ptr<f32, Private>
}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.func @iadd_scalar(%arg: i32) -> i32 "None" {
+ // CHECK: spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap}
+ %0 = spirv.IAdd %arg, %arg {no_signed_wrap, no_unsigned_wrap} : i32
+ spirv.ReturnValue %0 : i32
+}
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.func @iadd_scalar(%arg: f32) -> f32 "None" {
+ // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>}
+ %0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} : f32
+ spirv.ReturnValue %0 : f32
+}
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
NoSignedWrap, NoUnsignedWrap, FPFastMathMode
cd91ba2
to
bf3a911
Compare
NoSignedWrap, NoUnsignedWrap, FPFastMathMode.