Skip to content

Commit b84fe8f

Browse files
authored
[mlir][spirv] Add some op decorations (#72809)
NoSignedWrap, NoUnsignedWrap, FPFastMathMode.
1 parent f3e54f2 commit b84fe8f

File tree

4 files changed

+77
-4
lines changed

4 files changed

+77
-4
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4742,4 +4742,28 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
47424742
SPIRV_VendorOp<mnemonic, "NV", traits> {
47434743
}
47444744

4745+
def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">;
4746+
def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>;
4747+
def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;
4748+
def SPIRV_FPFMM_NSZ : I32BitEnumAttrCaseBit<"NSZ", 2>;
4749+
def SPIRV_FPFMM_AllowRecip : I32BitEnumAttrCaseBit<"AllowRecip", 3>;
4750+
def SPIRV_FPFMM_Fast : I32BitEnumAttrCaseBit<"Fast", 4>;
4751+
def SPIRV_FPFMM_AllowContractFastINTEL : I32BitEnumAttrCaseBit<"AllowContractFastINTEL", 16> {
4752+
list<Availability> availability = [
4753+
Capability<[SPIRV_C_FPFastMathModeINTEL]>
4754+
];
4755+
}
4756+
def SPIRV_FPFMM_AllowReassocINTEL : I32BitEnumAttrCaseBit<"AllowReassocINTEL", 17> {
4757+
list<Availability> availability = [
4758+
Capability<[SPIRV_C_FPFastMathModeINTEL]>
4759+
];
4760+
}
4761+
4762+
def SPIRV_FPFastMathModeAttr :
4763+
SPIRV_BitEnumAttr<"FPFastMathMode", "Indicates a floating-point fast math flag", "fastmath_mode", [
4764+
SPIRV_FPFMM_None, SPIRV_FPFMM_NotNaN, SPIRV_FPFMM_NotInf, SPIRV_FPFMM_NSZ,
4765+
SPIRV_FPFMM_AllowRecip, SPIRV_FPFMM_Fast, SPIRV_FPFMM_AllowContractFastINTEL,
4766+
SPIRV_FPFMM_AllowReassocINTEL
4767+
]>;
4768+
47454769
#endif // MLIR_DIALECT_SPIRV_IR_BASE

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
242242
auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
243243
auto symbol = opBuilder.getStringAttr(attrName);
244244
switch (static_cast<spirv::Decoration>(words[1])) {
245+
case spirv::Decoration::FPFastMathMode:
246+
if (words.size() != 3) {
247+
return emitError(unknownLoc, "OpDecorate with ")
248+
<< decorationName << " needs a single integer literal";
249+
}
250+
decorations[words[0]].set(
251+
symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
252+
static_cast<FPFastMathMode>(words[2])));
253+
break;
245254
case spirv::Decoration::DescriptorSet:
246255
case spirv::Decoration::Binding:
247256
if (words.size() != 3) {
@@ -295,8 +304,10 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
295304
case spirv::Decoration::NonReadable:
296305
case spirv::Decoration::NonWritable:
297306
case spirv::Decoration::NoPerspective:
298-
case spirv::Decoration::Restrict:
307+
case spirv::Decoration::NoSignedWrap:
308+
case spirv::Decoration::NoUnsignedWrap:
299309
case spirv::Decoration::RelaxedPrecision:
310+
case spirv::Decoration::Restrict:
300311
if (words.size() != 2) {
301312
return emitError(unknownLoc, "OpDecoration with ")
302313
<< decorationName << "needs a single target <id>";

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,19 @@ void Serializer::processMemoryModel() {
206206
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
207207
}
208208

209+
static std::string getDecorationName(StringRef attrName) {
210+
// convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
211+
// expected FPFastMathMode.
212+
if (attrName == "fp_fast_math_mode")
213+
return "FPFastMathMode";
214+
215+
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
216+
}
217+
209218
LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
210219
NamedAttribute attr) {
211220
auto attrName = attr.getName().strref();
212-
auto decorationName =
213-
llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
221+
auto decorationName = getDecorationName(attrName);
214222
auto decoration = spirv::symbolizeDecoration(decorationName);
215223
if (!decoration) {
216224
return emitError(
@@ -232,6 +240,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
232240
args.push_back(static_cast<uint32_t>(linkageType));
233241
break;
234242
}
243+
case spirv::Decoration::FPFastMathMode:
244+
if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
245+
args.push_back(static_cast<uint32_t>(intAttr.getValue()));
246+
break;
247+
}
248+
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
249+
<< attrName;
235250
case spirv::Decoration::Binding:
236251
case spirv::Decoration::DescriptorSet:
237252
case spirv::Decoration::Location:
@@ -256,8 +271,10 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
256271
case spirv::Decoration::NonReadable:
257272
case spirv::Decoration::NonWritable:
258273
case spirv::Decoration::NoPerspective:
259-
case spirv::Decoration::Restrict:
274+
case spirv::Decoration::NoSignedWrap:
275+
case spirv::Decoration::NoUnsignedWrap:
260276
case spirv::Decoration::RelaxedPrecision:
277+
case spirv::Decoration::Restrict:
261278
// For unit attributes, the args list has no values so we do nothing
262279
if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
263280
break;

mlir/test/Target/SPIRV/decorations.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
5555
// CHECK: relaxed_precision
5656
spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr<vector<4xf32>, Output>
5757
}
58+
5859
// -----
5960

6061
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
@@ -66,3 +67,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
6667
>
6768
} : !spirv.ptr<f32, Private>
6869
}
70+
71+
// -----
72+
73+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
74+
spirv.func @iadd_decorations(%arg: i32) -> i32 "None" {
75+
// CHECK: spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap}
76+
%0 = spirv.IAdd %arg, %arg {no_signed_wrap, no_unsigned_wrap} : i32
77+
spirv.ReturnValue %0 : i32
78+
}
79+
}
80+
81+
// -----
82+
83+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
84+
spirv.func @fadd_decorations(%arg: f32) -> f32 "None" {
85+
// CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>}
86+
%0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} : f32
87+
spirv.ReturnValue %0 : f32
88+
}
89+
}

0 commit comments

Comments
 (0)