-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][SPIRV] Add definition and (de)serialization for cache controls #115461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[SPV_INTEL_cache_controls](https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.html) defines decorations for load and store cache control. Add support for this extension in the SPIR-V dialect. As several `CacheControlLoadINTEL` and `CacheControlStoreINTEL` may be applied to the same value, these are represented as array attributes. (De)Serialization takes care of this representation. Signed-off-by: Victor Perez <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Victor Perez (victor-eds) ChangesSPV_INTEL_cache_controls defines decorations for load and store cache control. Add support for this extension in the SPIR-V dialect. As several Full diff: https://github.com/llvm/llvm-project/pull/115461.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index f2a12f68d481b8..1bc3c63646fdd6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -97,6 +97,20 @@ def SPIRV_CooperativeMatrixPropertiesNVAttr :
let assemblyFormat = "`<` struct(params) `>`";
}
+def SPIRV_CacheControlLoadINTELAttr :
+ SPIRV_Attr<"CacheControlLoadINTEL", "cache_control_load_intel"> {
+ let parameters = (ins "unsigned":$cache_level,
+ "mlir::spirv::LoadCacheControl":$load_cache_control);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def SPIRV_CacheControlStoreINTELAttr :
+ SPIRV_Attr<"CacheControlStoreINTEL", "cache_control_store_intel"> {
+ let parameters = (ins "unsigned":$cache_level,
+ "mlir::spirv::StoreCacheControl":$store_cache_control);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
def SPIRV_CooperativeMatrixPropertiesNVArrayAttr :
TypedArrayAttrBase<SPIRV_CooperativeMatrixPropertiesNVAttr,
"CooperativeMatrixPropertiesNV array attribute">;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 3b7da9b44a08fb..252d9319fccc5a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -400,6 +400,7 @@ def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp
def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
+def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -459,7 +460,8 @@ def SPIRV_ExtensionAttr :
SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
- SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
+ SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
+ SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
@@ -1415,6 +1417,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B
];
}
+def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
+ list<Availability> availability = [
+ Extension<[SPV_INTEL_cache_controls]>
+ ];
+}
+
def SPIRV_CapabilityAttr :
SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16,
@@ -1507,7 +1515,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_UniformTexelBufferArrayNonUniformIndexing,
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
- SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL
+ SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
+ SPIRV_C_CacheControlsINTEL
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -2623,6 +2632,16 @@ def SPIRV_D_MediaBlockIOINTEL : I32EnumAttrCase<"MediaBlockIOIN
Capability<[SPIRV_C_VectorComputeINTEL]>
];
}
+def SPIRV_D_CacheControlLoadINTEL : I32EnumAttrCase<"CacheControlLoadINTEL", 6442> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_CacheControlsINTEL]>
+ ];
+}
+def SPIRV_D_CacheControlStoreINTEL : I32EnumAttrCase<"CacheControlStoreINTEL", 6443> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_CacheControlsINTEL]>
+ ];
+}
def SPIRV_DecorationAttr :
SPIRV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", "decoration", [
@@ -2658,7 +2677,8 @@ def SPIRV_DecorationAttr :
SPIRV_D_FuseLoopsInFunctionINTEL, SPIRV_D_AliasScopeINTEL, SPIRV_D_NoAliasINTEL,
SPIRV_D_BufferLocationINTEL, SPIRV_D_IOPipeStorageINTEL,
SPIRV_D_FunctionFloatingPointModeINTEL, SPIRV_D_SingleElementVectorINTEL,
- SPIRV_D_VectorComputeCallableFunctionINTEL, SPIRV_D_MediaBlockIOINTEL
+ SPIRV_D_VectorComputeCallableFunctionINTEL, SPIRV_D_MediaBlockIOINTEL,
+ SPIRV_D_CacheControlLoadINTEL, SPIRV_D_CacheControlStoreINTEL
]>;
def SPIRV_D_1D : I32EnumAttrCase<"Dim1D", 0> {
@@ -4092,6 +4112,32 @@ def SPIRV_KHR_CooperativeMatrixOperandsAttr :
SPIRV_KHR_CMO_Result_Signed, SPIRV_KHR_CMO_AccSat
]>;
+def SPIRV_INTEL_LCC_Uncached : I32EnumAttrCase<"Uncached", 0>;
+def SPIRV_INTEL_LCC_Cached : I32EnumAttrCase<"Cached", 1>;
+def SPIRV_INTEL_LCC_Streaming : I32EnumAttrCase<"Streaming", 2>;
+def SPIRV_INTEL_LCC_InvalidateAfterRead : I32EnumAttrCase<"InvalidateAfterR", 3>;
+def SPIRV_INTEL_LCC_ConstCached : I32EnumAttrCase<"ConstCached", 4>;
+
+def SPIRV_INTEL_LoadCacheControlAttr :
+ SPIRV_I32EnumAttr<"LoadCacheControl", "valid SPIR-V LoadCacheControl",
+ "load_cache_control", [
+ SPIRV_INTEL_LCC_Uncached, SPIRV_INTEL_LCC_Cached,
+ SPIRV_INTEL_LCC_Streaming, SPIRV_INTEL_LCC_InvalidateAfterRead,
+ SPIRV_INTEL_LCC_ConstCached
+ ]>;
+
+def SPIRV_INTEL_SCC_Uncached : I32EnumAttrCase<"Uncached", 0>;
+def SPIRV_INTEL_SCC_WriteThrough : I32EnumAttrCase<"WriteThrough", 1>;
+def SPIRV_INTEL_SCC_WriteBack : I32EnumAttrCase<"WriteBack", 2>;
+def SPIRV_INTEL_SCC_Streaming : I32EnumAttrCase<"Streaming", 3>;
+
+def SPIRV_INTEL_StoreCacheControlAttr :
+ SPIRV_I32EnumAttr<"StoreCacheControl", "valid SPIR-V StoreCacheControl",
+ "store_cache_control", [
+ SPIRV_INTEL_SCC_Uncached, SPIRV_INTEL_SCC_WriteThrough,
+ SPIRV_INTEL_SCC_WriteBack, SPIRV_INTEL_SCC_Streaming
+ ]>;
+
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 462d3e326b6c27..e76e0595e75f03 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -226,6 +226,28 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
return success();
}
+template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
+LogicalResult deserializeCacheControlDecoration(
+ Location loc, OpBuilder &opBuilder,
+ DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words,
+ StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
+ if (words.size() != 4) {
+ return emitError(loc, "OpDecoration with ")
+ << decorationName << "needs a cache control integer literal and a "
+ << cacheControlKind << " cache control literal";
+ }
+ unsigned cacheLevel = words[2];
+ auto cacheControlAttr = static_cast<EnumTy>(words[3]);
+ auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
+ SmallVector<Attribute> attrs;
+ if (auto attrList =
+ llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
+ attrs.append(attrList.begin(), attrList.end());
+ attrs.push_back(value);
+ decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
+ return success();
+}
+
LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
// TODO: This function should also be auto-generated. For now, since only a
// few decorations are processed/handled in a meaningful manner, going with a
@@ -339,6 +361,24 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
decorations[words[0]].set(
symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
break;
+ case spirv::Decoration::CacheControlLoadINTEL: {
+ LogicalResult res = deserializeCacheControlDecoration<
+ CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
+ unknownLoc, opBuilder, decorations, words, symbol, decorationName,
+ "load");
+ if (failed(res))
+ return res;
+ break;
+ }
+ case spirv::Decoration::CacheControlStoreINTEL: {
+ LogicalResult res = deserializeCacheControlDecoration<
+ CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
+ unknownLoc, opBuilder, decorations, words, symbol, decorationName,
+ "store");
+ if (failed(res))
+ return res;
+ break;
+ }
default:
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index f355982e9ed884..57bb374278e9b2 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -217,10 +217,43 @@ static std::string getDecorationName(StringRef attrName) {
// similar here
if (attrName == "fp_rounding_mode")
return "FPRoundingMode";
+ // convertToCamelFromSnakeCase will not capitalize "INTEL".
+ if (attrName == "cache_control_load_intel")
+ return "CacheControlLoadINTEL";
+ if (attrName == "cache_control_store_intel")
+ return "CacheControlStoreINTEL";
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
}
+template <typename AttrTy, typename EmitF>
+LogicalResult processDecorationList(Location loc, Decoration decoration,
+ Attribute attrList, StringRef attrName,
+ EmitF emitter) {
+ auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
+ if (!arrayAttr) {
+ return emitError(loc, "expecting array attribute of ")
+ << attrName << " for " << stringifyDecoration(decoration);
+ }
+ if (arrayAttr.empty()) {
+ return emitError(loc, "expecting non-empty array attribute of ")
+ << attrName << " for " << stringifyDecoration(decoration);
+ }
+ for (Attribute attr : arrayAttr.getValue()) {
+ auto cacheControlAttr = dyn_cast<AttrTy>(attr);
+ if (!cacheControlAttr) {
+ return emitError(loc, "expecting array attribute of ")
+ << attrName << " for " << stringifyDecoration(decoration);
+ }
+ // This named attribute encodes several decorations. Emit one per
+ // element in the array.
+ LogicalResult res = emitter(cacheControlAttr);
+ if (failed(res))
+ return res;
+ }
+ return success();
+}
+
LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
Decoration decoration,
Attribute attr) {
@@ -294,6 +327,26 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
return emitError(loc,
"expected unit attribute or decoration attribute for ")
<< stringifyDecoration(decoration);
+ case spirv::Decoration::CacheControlLoadINTEL:
+ return processDecorationList<CacheControlLoadINTELAttr>(
+ loc, decoration, attr, "CacheControlLoadINTEL",
+ [&](CacheControlLoadINTELAttr attr) {
+ unsigned cacheLevel = attr.getCacheLevel();
+ LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
+ return emitDecoration(
+ resultID, decoration,
+ {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
+ });
+ case spirv::Decoration::CacheControlStoreINTEL:
+ return processDecorationList<CacheControlStoreINTELAttr>(
+ loc, decoration, attr, "CacheControlStoreINTEL",
+ [&](CacheControlStoreINTELAttr attr) {
+ unsigned cacheLevel = attr.getCacheLevel();
+ StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
+ return emitDecoration(
+ resultID, decoration,
+ {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
+ });
default:
return emitError(loc, "unhandled decoration ")
<< stringifyDecoration(decoration);
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index 53a1015de75bcc..66c70e816d4134 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -69,3 +69,21 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
spirv.Return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.CacheControls
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @foo() "None" {
+ // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 0a29290b6a6fab..ba84e2621d7a30 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -107,3 +107,17 @@ spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
spirv.ReturnValue %0 : f16
}
}
+
+// -----
+
+// CHECK-LABEL: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @foo() "None" {
+ // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. Could we add some tests to cover the failure cases? Not sure how easy that'd be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments, @kuhar. I added failure tests covering encoding. I couldn't see similar tests for decoding, so I assume that isn't covered and input is assumed to be valid SPIR-V.
I'd like to get this merged today if there aren't additional comments from reviewers. |
SPV_INTEL_cache_controls defines decorations for load and store cache control. Add support for this extension in the SPIR-V dialect.
As several
CacheControlLoadINTEL
andCacheControlStoreINTEL
may be applied to the same value, these are represented as array attributes. (De)Serialization takes care of this representation.