Skip to content

Commit 029f6f5

Browse files
FranklandJackmiguelcsx
authored andcommitted
[mlir][spirv]: Add Image to Vulkan Storage Class Map (#144899)
Extend the "storage class" <-> "memory space" map for the Vulkan SPIR-V environment to include the Image class. 12 is chosen as the next available value in the MemRef memory space list. Signed-off-by: Jack Frankland <[email protected]>
1 parent f7c3b68 commit 029f6f5

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ using namespace mlir;
5959
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
6060
MAP_FN(spirv::StorageClass::Input, 9) \
6161
MAP_FN(spirv::StorageClass::Output, 10) \
62-
MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
62+
MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11) \
63+
MAP_FN(spirv::StorageClass::Image, 12)
6364

6465
std::optional<spirv::StorageClass>
6566
spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt --allow-unregistered-dialect --map-memref-spirv-storage-class='client-api=vulkan' %s | FileCheck %s
2+
3+
// Vulkan Specific Mappings:
4+
// 8 -> UniformConstant
5+
// 9 -> Input
6+
// 10 -> Output
7+
// 11 -> PhysicalStorageBuffer
8+
// 12 -> Image
9+
10+
/// Check that Vulkan specific memory space indices get converted into the correct
11+
/// SPIR-V storage class. If mappings to OpenCL address spaces are added for these
12+
/// indices then those test case should be moved into the common test file.
13+
14+
// CHECK-LABEL: func @test_vk_specific_memory_spaces
15+
func.func @test_vk_specific_memory_spaces() {
16+
// CHECK: memref<4xi32, #spirv.storage_class<UniformConstant>>
17+
%1 = "dialect.memref_producer"() : () -> (memref<4xi32, 8>)
18+
// CHECK: memref<4xi32, #spirv.storage_class<Input>>
19+
%2 = "dialect.memref_producer"() : () -> (memref<4xi32, 9>)
20+
// CHECK: memref<4xi32, #spirv.storage_class<Output>>
21+
%3 = "dialect.memref_producer"() : () -> (memref<4xi32, 10>)
22+
// CHECK: memref<4xi32, #spirv.storage_class<PhysicalStorageBuffer>>
23+
%4 = "dialect.memref_producer"() : () -> (memref<4xi32, 11>)
24+
// CHECK: memref<4xi32, #spirv.storage_class<Image>>
25+
%5 = "dialect.memref_producer"() : () -> (memref<4xi32, 12>)
26+
return
27+
}

mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
11
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=vulkan' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=VULKAN
22
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=opencl' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=OPENCL
3+
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class -verify-diagnostics %s -o - | FileCheck %s
34

45
// Vulkan Mappings:
56
// 0 -> StorageBuffer
67
// 1 -> Generic
78
// 2 -> [null]
89
// 3 -> Workgroup
910
// 4 -> Uniform
11+
// 5 -> Private
12+
// 6 -> Function
13+
// 7 -> PushConstant
14+
// 8 -> UniformConstant
15+
// 9 -> Input
16+
// 10 -> Output
17+
// 11 -> PhysicalStorageBuffer
18+
// 12 -> Image
1019

1120
// OpenCL Mappings:
1221
// 0 -> CrossWorkgroup
1322
// 1 -> Generic
1423
// 2 -> [null]
1524
// 3 -> Workgroup
1625
// 4 -> UniformConstant
26+
// 5 -> Private
27+
// 6 -> Function
28+
// 7 -> Image
1729

1830
// VULKAN-LABEL: func @operand_result
1931
// OPENCL-LABEL: func @operand_result
@@ -30,6 +42,15 @@ func.func @operand_result() {
3042
// VULKAN: memref<*xf16, #spirv.storage_class<Uniform>>
3143
// OPENCL: memref<*xf16, #spirv.storage_class<UniformConstant>>
3244
%3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)
45+
// VULKAN: memref<*xf16, #spirv.storage_class<Private>>
46+
// OPENCL: memref<*xf16, #spirv.storage_class<Private>>
47+
%4 = "dialect.memref_producer"() : () -> (memref<*xf16, 5>)
48+
// VULKAN: memref<*xf16, #spirv.storage_class<Function>>
49+
// OPENCL: memref<*xf16, #spirv.storage_class<Function>>
50+
%5 = "dialect.memref_producer"() : () -> (memref<*xf16, 6>)
51+
// VULKAN: memref<*xf16, #spirv.storage_class<PushConstant>>
52+
// OPENCL: memref<*xf16, #spirv.storage_class<Image>>
53+
%6 = "dialect.memref_producer"() : () -> (memref<*xf16, 7>)
3354

3455

3556
"dialect.memref_consumer"(%0) : (memref<f32>) -> ()
@@ -42,6 +63,15 @@ func.func @operand_result() {
4263
// VULKAN: memref<*xf16, #spirv.storage_class<Uniform>>
4364
// OPENCL: memref<*xf16, #spirv.storage_class<UniformConstant>>
4465
"dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
66+
// VULKAN: memref<*xf16, #spirv.storage_class<Private>>
67+
// OPENCL: memref<*xf16, #spirv.storage_class<Private>>
68+
"dialect.memref_consumer"(%4) : (memref<*xf16, 5>) -> ()
69+
// VULKAN: memref<*xf16, #spirv.storage_class<Function>>
70+
// OPENCL: memref<*xf16, #spirv.storage_class<Function>>
71+
"dialect.memref_consumer"(%5) : (memref<*xf16, 6>) -> ()
72+
// VULKAN: memref<*xf16, #spirv.storage_class<PushConstant>>
73+
// OPENCL: memref<*xf16, #spirv.storage_class<Image>>
74+
"dialect.memref_consumer"(%6) : (memref<*xf16, 7>) -> ()
4575

4676
return
4777
}
@@ -166,4 +196,4 @@ func.func @operand_result() {
166196
"dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
167197
return
168198
}
169-
}
199+
}

0 commit comments

Comments
 (0)