Skip to content

Commit 47b9419

Browse files
committed
[mlir][spirv] Add definitions and (de)serialization for FPRoundingMode
1 parent b6b0a24 commit 47b9419

File tree

4 files changed

+42
-0
lines changed

4 files changed

+42
-0
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3249,6 +3249,19 @@ def SPIRV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
32493249
];
32503250
}
32513251

3252+
def SPIRV_FPRM_RTE : I32EnumAttrCase<"RTE", 0>;
3253+
def SPIRV_FPRM_RTZ : I32EnumAttrCase<"RTZ", 1>;
3254+
def SPIRV_FPRM_RTP : I32EnumAttrCase<"RTP", 2>;
3255+
def SPIRV_FPRM_RTN : I32EnumAttrCase<"RTN", 3>;
3256+
3257+
// TODO: Enforce SPIR-V spec validation rule for Shader capability: only permit
3258+
// FPRoundingMode on a value stored to certain storage classes?
3259+
// (The OpenCL environment also has FPRoundingMode rules, but different.)
3260+
def SPIRV_FPRoundingModeAttr :
3261+
SPIRV_I32EnumAttr<"FPRoundingMode", "valid SPIR-V FPRoundingMode", "fp_rounding_mode", [
3262+
SPIRV_FPRM_RTE, SPIRV_FPRM_RTZ, SPIRV_FPRM_RTP, SPIRV_FPRM_RTN
3263+
]>;
3264+
32523265
def SPIRV_FunctionControlAttr :
32533266
SPIRV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [
32543267
SPIRV_FC_None, SPIRV_FC_Inline, SPIRV_FC_DontInline, SPIRV_FC_Pure, SPIRV_FC_Const,

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
250250
symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
251251
static_cast<FPFastMathMode>(words[2])));
252252
break;
253+
case spirv::Decoration::FPRoundingMode:
254+
if (words.size() != 3) {
255+
return emitError(unknownLoc, "OpDecorate with ")
256+
<< decorationName << " needs a single integer literal";
257+
}
258+
decorations[words[0]].set(
259+
symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
260+
static_cast<FPRoundingMode>(words[2])));
261+
break;
253262
case spirv::Decoration::DescriptorSet:
254263
case spirv::Decoration::Binding:
255264
if (words.size() != 3) {

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ static std::string getDecorationName(StringRef attrName) {
214214
// expected FPFastMathMode.
215215
if (attrName == "fp_fast_math_mode")
216216
return "FPFastMathMode";
217+
// similar here
218+
if (attrName == "fp_rounding_mode")
219+
return "FPRoundingMode";
217220

218221
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
219222
}
@@ -242,6 +245,13 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
242245
}
243246
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
244247
<< stringifyDecoration(decoration);
248+
case spirv::Decoration::FPRoundingMode:
249+
if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
250+
args.push_back(static_cast<uint32_t>(intAttr.getValue()));
251+
break;
252+
}
253+
return emitError(loc, "expected FPRoundingModeAttr attribute for ")
254+
<< stringifyDecoration(decoration);
245255
case spirv::Decoration::Binding:
246256
case spirv::Decoration::DescriptorSet:
247257
case spirv::Decoration::Location:

mlir/test/Target/SPIRV/decorations.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,13 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
9797
spirv.ReturnValue %0 : f32
9898
}
9999
}
100+
101+
// -----
102+
103+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel, Float16], []> {
104+
spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
105+
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
106+
%0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
107+
spirv.ReturnValue %0 : f16
108+
}
109+
}

0 commit comments

Comments
 (0)