Skip to content

Commit 3bf9d0e

Browse files
committed
[mlir][spirv]: Add Image to Vulkan Storage Class Map
Extend the "storage class" <-> "memory space" map for the Vulkan SPIR-V environment to include the Image class. 12 is chosen as the next available value in the MemRef memory space list. Extend the pass testing to include missing memory scopes and add a new test file for the memory address indices which only support a mapping in the Vuklan environment. It appears that previously there was a missing CHECK line for the default pass behavior so that has been added. Signed-off-by: Jack Frankland <[email protected]>
1 parent c309abd commit 3bf9d0e

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ using namespace mlir;
5959
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
6060
MAP_FN(spirv::StorageClass::Input, 9) \
6161
MAP_FN(spirv::StorageClass::Output, 10) \
62-
MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
62+
MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11) \
63+
MAP_FN(spirv::StorageClass::Image, 12)
6364

6465
std::optional<spirv::StorageClass>
6566
spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {

mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
11
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=vulkan' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=VULKAN
22
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=opencl' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=OPENCL
3+
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class -verify-diagnostics %s -o - | FileCheck %s
34

45
// Vulkan Mappings:
56
// 0 -> StorageBuffer
67
// 1 -> Generic
78
// 2 -> [null]
89
// 3 -> Workgroup
910
// 4 -> Uniform
11+
// 5 -> Private
12+
// 6 -> Function
13+
// 7 -> PushConstant
14+
// 8 -> UniformConstant
15+
// 9 -> Input
16+
// 10 -> Output
17+
// 11 -> PhysicalStorageBuffer
18+
// 12 -> Image
1019

1120
// OpenCL Mappings:
1221
// 0 -> CrossWorkgroup
1322
// 1 -> Generic
1423
// 2 -> [null]
1524
// 3 -> Workgroup
1625
// 4 -> UniformConstant
26+
// 5 -> Private
27+
// 6 -> Function
28+
// 7 -> Image
1729

1830
// VULKAN-LABEL: func @operand_result
1931
// OPENCL-LABEL: func @operand_result
@@ -30,6 +42,15 @@ func.func @operand_result() {
3042
// VULKAN: memref<*xf16, #spirv.storage_class<Uniform>>
3143
// OPENCL: memref<*xf16, #spirv.storage_class<UniformConstant>>
3244
%3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)
45+
// VULKAN: memref<*xf16, #spirv.storage_class<Private>>
46+
// OPENCL: memref<*xf16, #spirv.storage_class<Private>>
47+
%4 = "dialect.memref_producer"() : () -> (memref<*xf16, 5>)
48+
// VULKAN: memref<*xf16, #spirv.storage_class<Function>>
49+
// OPENCL: memref<*xf16, #spirv.storage_class<Function>>
50+
%5 = "dialect.memref_producer"() : () -> (memref<*xf16, 6>)
51+
// VULKAN: memref<*xf16, #spirv.storage_class<PushConstant>>
52+
// OPENCL: memref<*xf16, #spirv.storage_class<Image>>
53+
%6 = "dialect.memref_producer"() : () -> (memref<*xf16, 7>)
3354

3455

3556
"dialect.memref_consumer"(%0) : (memref<f32>) -> ()
@@ -42,6 +63,15 @@ func.func @operand_result() {
4263
// VULKAN: memref<*xf16, #spirv.storage_class<Uniform>>
4364
// OPENCL: memref<*xf16, #spirv.storage_class<UniformConstant>>
4465
"dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
66+
// VULKAN: memref<*xf16, #spirv.storage_class<Private>>
67+
// OPENCL: memref<*xf16, #spirv.storage_class<Private>>
68+
"dialect.memref_consumer"(%4) : (memref<*xf16, 5>) -> ()
69+
// VULKAN: memref<*xf16, #spirv.storage_class<Function>>
70+
// OPENCL: memref<*xf16, #spirv.storage_class<Function>>
71+
"dialect.memref_consumer"(%5) : (memref<*xf16, 6>) -> ()
72+
// VULKAN: memref<*xf16, #spirv.storage_class<PushConstant>>
73+
// OPENCL: memref<*xf16, #spirv.storage_class<Image>>
74+
"dialect.memref_consumer"(%6) : (memref<*xf16, 7>) -> ()
4575

4676
return
4777
}
@@ -166,4 +196,4 @@ func.func @operand_result() {
166196
"dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
167197
return
168198
}
169-
}
199+
}

0 commit comments

Comments
 (0)