Skip to content

Commit 4f78f85

Browse files
authored
[MLIR][SPIRV] Add definition and (de)serialization for cache controls (#115461)
[SPV_INTEL_cache_controls](https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.html) defines decorations for load and store cache control. Add support for this extension in the SPIR-V dialect. As several `CacheControlLoadINTEL` and `CacheControlStoreINTEL` may be applied to the same value, these are represented as array attributes. (De)Serialization takes care of this representation. --------- Signed-off-by: Victor Perez <[email protected]>
1 parent 4548bff commit 4f78f85

File tree

6 files changed

+218
-4
lines changed

6 files changed

+218
-4
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ def SPIRV_CooperativeMatrixPropertiesNVAttr :
9797
let assemblyFormat = "`<` struct(params) `>`";
9898
}
9999

100+
def SPIRV_CacheControlLoadINTELAttr :
101+
SPIRV_Attr<"CacheControlLoadINTEL", "cache_control_load_intel"> {
102+
let parameters = (ins "unsigned":$cache_level,
103+
"mlir::spirv::LoadCacheControl":$load_cache_control);
104+
let assemblyFormat = "`<` struct(params) `>`";
105+
}
106+
107+
def SPIRV_CacheControlStoreINTELAttr :
108+
SPIRV_Attr<"CacheControlStoreINTEL", "cache_control_store_intel"> {
109+
let parameters = (ins "unsigned":$cache_level,
110+
"mlir::spirv::StoreCacheControl":$store_cache_control);
111+
let assemblyFormat = "`<` struct(params) `>`";
112+
}
113+
100114
def SPIRV_CooperativeMatrixPropertiesNVArrayAttr :
101115
TypedArrayAttrBase<SPIRV_CooperativeMatrixPropertiesNVAttr,
102116
"CooperativeMatrixPropertiesNV array attribute">;

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp
400400
def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
401401
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
402402
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
403+
def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
403404

404405
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
405406
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -459,7 +460,8 @@ def SPIRV_ExtensionAttr :
459460
SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
460461
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
461462
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
462-
SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
463+
SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
464+
SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
463465
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
464466
SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
465467
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
@@ -1415,6 +1417,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B
14151417
];
14161418
}
14171419

1420+
def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
1421+
list<Availability> availability = [
1422+
Extension<[SPV_INTEL_cache_controls]>
1423+
];
1424+
}
1425+
14181426
def SPIRV_CapabilityAttr :
14191427
SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
14201428
SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16,
@@ -1507,7 +1515,8 @@ def SPIRV_CapabilityAttr :
15071515
SPIRV_C_UniformTexelBufferArrayNonUniformIndexing,
15081516
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
15091517
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
1510-
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL
1518+
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
1519+
SPIRV_C_CacheControlsINTEL
15111520
]>;
15121521

15131522
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -2623,6 +2632,16 @@ def SPIRV_D_MediaBlockIOINTEL : I32EnumAttrCase<"MediaBlockIOIN
26232632
Capability<[SPIRV_C_VectorComputeINTEL]>
26242633
];
26252634
}
2635+
def SPIRV_D_CacheControlLoadINTEL : I32EnumAttrCase<"CacheControlLoadINTEL", 6442> {
2636+
list<Availability> availability = [
2637+
Capability<[SPIRV_C_CacheControlsINTEL]>
2638+
];
2639+
}
2640+
def SPIRV_D_CacheControlStoreINTEL : I32EnumAttrCase<"CacheControlStoreINTEL", 6443> {
2641+
list<Availability> availability = [
2642+
Capability<[SPIRV_C_CacheControlsINTEL]>
2643+
];
2644+
}
26262645

26272646
def SPIRV_DecorationAttr :
26282647
SPIRV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", "decoration", [
@@ -2658,7 +2677,8 @@ def SPIRV_DecorationAttr :
26582677
SPIRV_D_FuseLoopsInFunctionINTEL, SPIRV_D_AliasScopeINTEL, SPIRV_D_NoAliasINTEL,
26592678
SPIRV_D_BufferLocationINTEL, SPIRV_D_IOPipeStorageINTEL,
26602679
SPIRV_D_FunctionFloatingPointModeINTEL, SPIRV_D_SingleElementVectorINTEL,
2661-
SPIRV_D_VectorComputeCallableFunctionINTEL, SPIRV_D_MediaBlockIOINTEL
2680+
SPIRV_D_VectorComputeCallableFunctionINTEL, SPIRV_D_MediaBlockIOINTEL,
2681+
SPIRV_D_CacheControlLoadINTEL, SPIRV_D_CacheControlStoreINTEL
26622682
]>;
26632683

26642684
def SPIRV_D_1D : I32EnumAttrCase<"Dim1D", 0> {
@@ -4092,6 +4112,32 @@ def SPIRV_KHR_CooperativeMatrixOperandsAttr :
40924112
SPIRV_KHR_CMO_Result_Signed, SPIRV_KHR_CMO_AccSat
40934113
]>;
40944114

4115+
def SPIRV_INTEL_LCC_Uncached : I32EnumAttrCase<"Uncached", 0>;
4116+
def SPIRV_INTEL_LCC_Cached : I32EnumAttrCase<"Cached", 1>;
4117+
def SPIRV_INTEL_LCC_Streaming : I32EnumAttrCase<"Streaming", 2>;
4118+
def SPIRV_INTEL_LCC_InvalidateAfterRead : I32EnumAttrCase<"InvalidateAfterR", 3>;
4119+
def SPIRV_INTEL_LCC_ConstCached : I32EnumAttrCase<"ConstCached", 4>;
4120+
4121+
def SPIRV_INTEL_LoadCacheControlAttr :
4122+
SPIRV_I32EnumAttr<"LoadCacheControl", "valid SPIR-V LoadCacheControl",
4123+
"load_cache_control", [
4124+
SPIRV_INTEL_LCC_Uncached, SPIRV_INTEL_LCC_Cached,
4125+
SPIRV_INTEL_LCC_Streaming, SPIRV_INTEL_LCC_InvalidateAfterRead,
4126+
SPIRV_INTEL_LCC_ConstCached
4127+
]>;
4128+
4129+
def SPIRV_INTEL_SCC_Uncached : I32EnumAttrCase<"Uncached", 0>;
4130+
def SPIRV_INTEL_SCC_WriteThrough : I32EnumAttrCase<"WriteThrough", 1>;
4131+
def SPIRV_INTEL_SCC_WriteBack : I32EnumAttrCase<"WriteBack", 2>;
4132+
def SPIRV_INTEL_SCC_Streaming : I32EnumAttrCase<"Streaming", 3>;
4133+
4134+
def SPIRV_INTEL_StoreCacheControlAttr :
4135+
SPIRV_I32EnumAttr<"StoreCacheControl", "valid SPIR-V StoreCacheControl",
4136+
"store_cache_control", [
4137+
SPIRV_INTEL_SCC_Uncached, SPIRV_INTEL_SCC_WriteThrough,
4138+
SPIRV_INTEL_SCC_WriteBack, SPIRV_INTEL_SCC_Streaming
4139+
]>;
4140+
40954141
//===----------------------------------------------------------------------===//
40964142
// SPIR-V attribute definitions
40974143
//===----------------------------------------------------------------------===//

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,28 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
226226
return success();
227227
}
228228

229+
template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
230+
LogicalResult deserializeCacheControlDecoration(
231+
Location loc, OpBuilder &opBuilder,
232+
DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words,
233+
StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
234+
if (words.size() != 4) {
235+
return emitError(loc, "OpDecoration with ")
236+
<< decorationName << "needs a cache control integer literal and a "
237+
<< cacheControlKind << " cache control literal";
238+
}
239+
unsigned cacheLevel = words[2];
240+
auto cacheControlAttr = static_cast<EnumTy>(words[3]);
241+
auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
242+
SmallVector<Attribute> attrs;
243+
if (auto attrList =
244+
llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
245+
llvm::append_range(attrs, attrList);
246+
attrs.push_back(value);
247+
decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
248+
return success();
249+
}
250+
229251
LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
230252
// TODO: This function should also be auto-generated. For now, since only a
231253
// few decorations are processed/handled in a meaningful manner, going with a
@@ -339,6 +361,24 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
339361
decorations[words[0]].set(
340362
symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
341363
break;
364+
case spirv::Decoration::CacheControlLoadINTEL: {
365+
LogicalResult res = deserializeCacheControlDecoration<
366+
CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
367+
unknownLoc, opBuilder, decorations, words, symbol, decorationName,
368+
"load");
369+
if (failed(res))
370+
return res;
371+
break;
372+
}
373+
case spirv::Decoration::CacheControlStoreINTEL: {
374+
LogicalResult res = deserializeCacheControlDecoration<
375+
CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
376+
unknownLoc, opBuilder, decorations, words, symbol, decorationName,
377+
"store");
378+
if (failed(res))
379+
return res;
380+
break;
381+
}
342382
default:
343383
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
344384
}

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,42 @@ static std::string getDecorationName(StringRef attrName) {
217217
// similar here
218218
if (attrName == "fp_rounding_mode")
219219
return "FPRoundingMode";
220+
// convertToCamelFromSnakeCase will not capitalize "INTEL".
221+
if (attrName == "cache_control_load_intel")
222+
return "CacheControlLoadINTEL";
223+
if (attrName == "cache_control_store_intel")
224+
return "CacheControlStoreINTEL";
220225

221226
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
222227
}
223228

229+
template <typename AttrTy, typename EmitF>
230+
LogicalResult processDecorationList(Location loc, Decoration decoration,
231+
Attribute attrList, StringRef attrName,
232+
EmitF emitter) {
233+
auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
234+
if (!arrayAttr) {
235+
return emitError(loc, "expecting array attribute of ")
236+
<< attrName << " for " << stringifyDecoration(decoration);
237+
}
238+
if (arrayAttr.empty()) {
239+
return emitError(loc, "expecting non-empty array attribute of ")
240+
<< attrName << " for " << stringifyDecoration(decoration);
241+
}
242+
for (Attribute attr : arrayAttr.getValue()) {
243+
auto cacheControlAttr = dyn_cast<AttrTy>(attr);
244+
if (!cacheControlAttr) {
245+
return emitError(loc, "expecting array attribute of ")
246+
<< attrName << " for " << stringifyDecoration(decoration);
247+
}
248+
// This named attribute encodes several decorations. Emit one per
249+
// element in the array.
250+
if (failed(emitter(cacheControlAttr)))
251+
return failure();
252+
}
253+
return success();
254+
}
255+
224256
LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
225257
Decoration decoration,
226258
Attribute attr) {
@@ -294,6 +326,26 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
294326
return emitError(loc,
295327
"expected unit attribute or decoration attribute for ")
296328
<< stringifyDecoration(decoration);
329+
case spirv::Decoration::CacheControlLoadINTEL:
330+
return processDecorationList<CacheControlLoadINTELAttr>(
331+
loc, decoration, attr, "CacheControlLoadINTEL",
332+
[&](CacheControlLoadINTELAttr attr) {
333+
unsigned cacheLevel = attr.getCacheLevel();
334+
LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
335+
return emitDecoration(
336+
resultID, decoration,
337+
{cacheLevel, static_cast<uint32_t>(loadCacheControl)});
338+
});
339+
case spirv::Decoration::CacheControlStoreINTEL:
340+
return processDecorationList<CacheControlStoreINTELAttr>(
341+
loc, decoration, attr, "CacheControlStoreINTEL",
342+
[&](CacheControlStoreINTELAttr attr) {
343+
unsigned cacheLevel = attr.getCacheLevel();
344+
StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
345+
return emitDecoration(
346+
resultID, decoration,
347+
{cacheLevel, static_cast<uint32_t>(storeCacheControl)});
348+
});
297349
default:
298350
return emitError(loc, "unhandled decoration ")
299351
<< stringifyDecoration(decoration);

mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,21 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
6969
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
7070
spirv.Return
7171
}
72+
73+
// -----
74+
75+
//===----------------------------------------------------------------------===//
76+
// spirv.INTEL.CacheControls
77+
//===----------------------------------------------------------------------===//
78+
79+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
80+
spirv.func @foo() "None" {
81+
// CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
82+
%0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
83+
// CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
84+
%1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
85+
spirv.Return
86+
}
87+
}
88+
89+
// -----

mlir/test/Target/SPIRV/decorations.mlir

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip %s | FileCheck %s
1+
// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip -verify-diagnostics %s | FileCheck %s
22

33
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
44
// CHECK: location = 0 : i32
@@ -107,3 +107,47 @@ spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
107107
spirv.ReturnValue %0 : f16
108108
}
109109
}
110+
111+
// -----
112+
113+
// CHECK-LABEL: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
114+
115+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
116+
spirv.func @cache_controls() "None" {
117+
// CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
118+
%0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
119+
// CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
120+
%1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
121+
spirv.Return
122+
}
123+
}
124+
125+
// -----
126+
127+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
128+
spirv.func @cache_controls_invalid_type() "None" {
129+
// expected-error@below {{expecting array attribute of CacheControlLoadINTEL for CacheControlLoadINTEL}}
130+
%0 = spirv.Variable {cache_control_load_intel = #spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>} : !spirv.ptr<f32, Function>
131+
spirv.Return
132+
}
133+
}
134+
135+
// -----
136+
137+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
138+
spirv.func @cache_controls_invalid_type() "None" {
139+
// expected-error@below {{expecting array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
140+
%0 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, 0 : i32]} : !spirv.ptr<f32, Function>
141+
spirv.Return
142+
}
143+
}
144+
145+
// -----
146+
147+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
148+
spirv.func @cache_controls_invalid_type() "None" {
149+
// expected-error@below {{expecting non-empty array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
150+
%0 = spirv.Variable {cache_control_store_intel = []} : !spirv.ptr<f32, Function>
151+
spirv.Return
152+
}
153+
}

0 commit comments

Comments
 (0)