Skip to content

[mlir][spirv] Add some op decorations #72809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -4742,4 +4742,28 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
SPIRV_VendorOp<mnemonic, "NV", traits> {
}

def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>;
def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;
def SPIRV_FPFMM_NSZ : I32BitEnumAttrCaseBit<"NSZ", 2>;
def SPIRV_FPFMM_AllowRecip : I32BitEnumAttrCaseBit<"AllowRecip", 3>;
def SPIRV_FPFMM_Fast : I32BitEnumAttrCaseBit<"Fast", 4>;
def SPIRV_FPFMM_AllowContractFastINTEL : I32BitEnumAttrCaseBit<"AllowContractFastINTEL", 16> {
list<Availability> availability = [
Capability<[SPIRV_C_FPFastMathModeINTEL]>
];
}
def SPIRV_FPFMM_AllowReassocINTEL : I32BitEnumAttrCaseBit<"AllowReassocINTEL", 17> {
list<Availability> availability = [
Capability<[SPIRV_C_FPFastMathModeINTEL]>
];
}

def SPIRV_FPFastMathModeAttr :
SPIRV_BitEnumAttr<"FPFastMathMode", "Indicates a floating-point fast math flag", "fastmath_mode", [
SPIRV_FPFMM_None, SPIRV_FPFMM_NotNaN, SPIRV_FPFMM_NotInf, SPIRV_FPFMM_NSZ,
SPIRV_FPFMM_AllowRecip, SPIRV_FPFMM_Fast, SPIRV_FPFMM_AllowContractFastINTEL,
SPIRV_FPFMM_AllowReassocINTEL
]>;

#endif // MLIR_DIALECT_SPIRV_IR_BASE
13 changes: 12 additions & 1 deletion mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
auto symbol = opBuilder.getStringAttr(attrName);
switch (static_cast<spirv::Decoration>(words[1])) {
case spirv::Decoration::FPFastMathMode:
if (words.size() != 3) {
return emitError(unknownLoc, "OpDecorate with ")
<< decorationName << " needs a single integer literal";
}
decorations[words[0]].set(
symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
static_cast<FPFastMathMode>(words[2])));
break;
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Binding:
if (words.size() != 3) {
Expand Down Expand Up @@ -295,8 +304,10 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
case spirv::Decoration::Restrict:
case spirv::Decoration::NoSignedWrap:
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
Expand Down
23 changes: 20 additions & 3 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,19 @@ void Serializer::processMemoryModel() {
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
}

static std::string getDecorationName(StringRef attrName) {
// convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
// expected FPFastMathMode.
if (attrName == "fp_fast_math_mode")
return "FPFastMathMode";

return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
}

LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr) {
auto attrName = attr.getName().strref();
auto decorationName =
llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
auto decorationName = getDecorationName(attrName);
auto decoration = spirv::symbolizeDecoration(decorationName);
if (!decoration) {
return emitError(
Expand All @@ -232,6 +240,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
args.push_back(static_cast<uint32_t>(linkageType));
break;
}
case spirv::Decoration::FPFastMathMode:
if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
args.push_back(static_cast<uint32_t>(intAttr.getValue()));
break;
}
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
<< attrName;
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
Expand All @@ -256,8 +271,10 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
case spirv::Decoration::Restrict:
case spirv::Decoration::NoSignedWrap:
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
// For unit attributes, the args list has no values so we do nothing
if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
break;
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Target/SPIRV/decorations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: relaxed_precision
spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr<vector<4xf32>, Output>
}

// -----

spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
Expand All @@ -66,3 +67,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
>
} : !spirv.ptr<f32, Private>
}

// -----

spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
spirv.func @iadd_decorations(%arg: i32) -> i32 "None" {
// CHECK: spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap}
%0 = spirv.IAdd %arg, %arg {no_signed_wrap, no_unsigned_wrap} : i32
spirv.ReturnValue %0 : i32
}
}

// -----

spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
spirv.func @fadd_decorations(%arg: f32) -> f32 "None" {
// CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>}
%0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} : f32
spirv.ReturnValue %0 : f32
}
}