Skip to content

Commit 2bf2468

Browse files
angelz913kuhar
andauthored
[mlir][spirv] Integrate convert-to-spirv into mlir-vulkan-runner (#106082)
**Description** This PR adds a new option for `convert-to-spirv` pass to clone and convert only GPU kernel modules for integration testing. The reason for using pass options instead of two separate passes is that they both consist of `memref` types conversion and individual dialect patterns, except they run on different scopes. The PR also replaces the `gpu-to-spirv` pass with the `convert-to-spirv` pass (with the new option) in `mlir-vulkan-runner`. **Future Plan** Use nesting pass pipelines in `mlir-vulkan-runner` instead of adding this option. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent d0a6434 commit 2bf2468

File tree

6 files changed

+189
-33
lines changed

6 files changed

+189
-33
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
5050
"Run function signature conversion to convert vector types">,
5151
Option<"runVectorUnrolling", "run-vector-unrolling", "bool",
5252
/*default=*/"true",
53-
"Run vector unrolling to convert vector types in function bodies">
53+
"Run vector unrolling to convert vector types in function bodies">,
54+
Option<"convertGPUModules", "convert-gpu-modules", "bool",
55+
/*default=*/"false",
56+
"Clone and convert GPU modules">
5457
];
5558
}
5659

mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
1515
MLIRArithToSPIRV
1616
MLIRArithTransforms
1717
MLIRFuncToSPIRV
18+
MLIRGPUDialect
1819
MLIRGPUToSPIRV
1920
MLIRIndexToSPIRV
2021
MLIRIR

mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
1717
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
1818
#include "mlir/Dialect/Arith/Transforms/Passes.h"
19+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1920
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
2021
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
2122
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -40,6 +41,35 @@ using namespace mlir;
4041

4142
namespace {
4243

44+
/// Map memRef memory space to SPIR-V storage class.
45+
void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) {
46+
spirv::TargetEnv targetEnv(targetAttr);
47+
bool targetEnvSupportsKernelCapability =
48+
targetEnv.allows(spirv::Capability::Kernel);
49+
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
50+
targetEnvSupportsKernelCapability
51+
? spirv::mapMemorySpaceToOpenCLStorageClass
52+
: spirv::mapMemorySpaceToVulkanStorageClass;
53+
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
54+
spirv::convertMemRefTypesAndAttrs(op, converter);
55+
}
56+
57+
/// Populate patterns for each dialect.
58+
void populateConvertToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
59+
ScfToSPIRVContext &scfToSPIRVContext,
60+
RewritePatternSet &patterns) {
61+
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
62+
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
63+
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
64+
populateFuncToSPIRVPatterns(typeConverter, patterns);
65+
populateGPUToSPIRVPatterns(typeConverter, patterns);
66+
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
67+
populateMemRefToSPIRVPatterns(typeConverter, patterns);
68+
populateVectorToSPIRVPatterns(typeConverter, patterns);
69+
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
70+
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
71+
}
72+
4373
/// A pass to perform the SPIR-V conversion.
4474
struct ConvertToSPIRVPass final
4575
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
@@ -57,38 +87,46 @@ struct ConvertToSPIRVPass final
5787
if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
5888
return signalPassFailure();
5989

60-
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
61-
std::unique_ptr<ConversionTarget> target =
62-
SPIRVConversionTarget::get(targetAttr);
63-
SPIRVTypeConverter typeConverter(targetAttr);
64-
RewritePatternSet patterns(context);
65-
ScfToSPIRVContext scfToSPIRVContext;
66-
67-
// Map MemRef memory space to SPIR-V storage class.
68-
spirv::TargetEnv targetEnv(targetAttr);
69-
bool targetEnvSupportsKernelCapability =
70-
targetEnv.allows(spirv::Capability::Kernel);
71-
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
72-
targetEnvSupportsKernelCapability
73-
? spirv::mapMemorySpaceToOpenCLStorageClass
74-
: spirv::mapMemorySpaceToVulkanStorageClass;
75-
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
76-
spirv::convertMemRefTypesAndAttrs(op, converter);
90+
// Generic conversion.
91+
if (!convertGPUModules) {
92+
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
93+
std::unique_ptr<ConversionTarget> target =
94+
SPIRVConversionTarget::get(targetAttr);
95+
SPIRVTypeConverter typeConverter(targetAttr);
96+
RewritePatternSet patterns(context);
97+
ScfToSPIRVContext scfToSPIRVContext;
98+
mapToMemRef(op, targetAttr);
99+
populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
100+
patterns);
101+
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
102+
return signalPassFailure();
103+
return;
104+
}
77105

78-
// Populate patterns for each dialect.
79-
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
80-
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
81-
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
82-
populateFuncToSPIRVPatterns(typeConverter, patterns);
83-
populateGPUToSPIRVPatterns(typeConverter, patterns);
84-
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
85-
populateMemRefToSPIRVPatterns(typeConverter, patterns);
86-
populateVectorToSPIRVPatterns(typeConverter, patterns);
87-
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
88-
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
89-
90-
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
91-
return signalPassFailure();
106+
// Clone each GPU kernel module for conversion, given that the GPU
107+
// launch op still needs the original GPU kernel module.
108+
SmallVector<Operation *, 1> gpuModules;
109+
OpBuilder builder(context);
110+
op->walk([&](gpu::GPUModuleOp gpuModule) {
111+
builder.setInsertionPoint(gpuModule);
112+
gpuModules.push_back(builder.clone(*gpuModule));
113+
});
114+
// Run conversion for each module independently as they can have
115+
// different TargetEnv attributes.
116+
for (Operation *gpuModule : gpuModules) {
117+
spirv::TargetEnvAttr targetAttr =
118+
spirv::lookupTargetEnvOrDefault(gpuModule);
119+
std::unique_ptr<ConversionTarget> target =
120+
SPIRVConversionTarget::get(targetAttr);
121+
SPIRVTypeConverter typeConverter(targetAttr);
122+
RewritePatternSet patterns(context);
123+
ScfToSPIRVContext scfToSPIRVContext;
124+
mapToMemRef(gpuModule, targetAttr);
125+
populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
126+
patterns);
127+
if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
128+
return signalPassFailure();
129+
}
92130
}
93131
};
94132

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// RUN: mlir-opt -convert-to-spirv="convert-gpu-modules=true run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
2+
3+
module attributes {
4+
gpu.container_module,
5+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
6+
} {
7+
// CHECK-LABEL: func.func @main
8+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
9+
// CHECK: gpu.launch_func @[[$KERNELS_1:.*]]::@[[$BUILTIN_WG_ID_X:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]])
10+
// CHECK: gpu.launch_func @[[$KERNELS_2:.*]]::@[[$BUILTIN_WG_ID_Y:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]])
11+
func.func @main() {
12+
%c1 = arith.constant 1 : index
13+
gpu.launch_func @kernels_1::@builtin_workgroup_id_x
14+
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
15+
gpu.launch_func @KERNELS_2::@builtin_workgroup_id_y
16+
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
17+
return
18+
}
19+
20+
// CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
21+
// CHECK: spirv.func @[[$BUILTIN_WG_ID_X]]
22+
// CHECK: spirv.mlir.addressof
23+
// CHECK: spirv.Load "Input"
24+
// CHECK: spirv.CompositeExtract
25+
gpu.module @kernels_1 {
26+
gpu.func @builtin_workgroup_id_x() kernel
27+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
28+
%0 = gpu.block_id x
29+
gpu.return
30+
}
31+
}
32+
// CHECK: gpu.module @[[$KERNELS_1]]
33+
// CHECK: gpu.func @[[$BUILTIN_WG_ID_X]]
34+
// CHECK gpu.block_id x
35+
// CHECK: gpu.return
36+
37+
// CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
38+
// CHECK: spirv.func @[[$BUILTIN_WG_ID_Y]]
39+
// CHECK: spirv.mlir.addressof
40+
// CHECK: spirv.Load "Input"
41+
// CHECK: spirv.CompositeExtract
42+
gpu.module @KERNELS_2 {
43+
gpu.func @builtin_workgroup_id_y() kernel
44+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
45+
%0 = gpu.block_id y
46+
gpu.return
47+
}
48+
}
49+
// CHECK: gpu.module @[[$KERNELS_2]]
50+
// CHECK: gpu.func @[[$BUILTIN_WG_ID_Y]]
51+
// CHECK gpu.block_id y
52+
// CHECK: gpu.return
53+
}
54+
55+
// -----
56+
57+
module attributes {
58+
gpu.container_module,
59+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
60+
} {
61+
// CHECK-LABEL: func.func @main
62+
// CHECK-SAME: %[[ARG0:.*]]: memref<2xi32>, %[[ARG1:.*]]: memref<4xi32>
63+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
64+
// CHECK: gpu.launch_func @[[$KERNEL_MODULE:.*]]::@[[$KERNEL_FUNC:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]]) args(%[[ARG0]] : memref<2xi32>, %[[ARG1]] : memref<4xi32>)
65+
func.func @main(%arg0 : memref<2xi32>, %arg2 : memref<4xi32>) {
66+
%c1 = arith.constant 1 : index
67+
gpu.launch_func @kernels::@kernel_foo
68+
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
69+
args(%arg0 : memref<2xi32>, %arg2 : memref<4xi32>)
70+
return
71+
}
72+
73+
// CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
74+
// CHECK: spirv.func @[[$KERNEL_FUNC]]
75+
// CHECK-SAME: %{{.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<2 x i32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
76+
// CHECK-SAME: %{{.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
77+
gpu.module @kernels {
78+
gpu.func @kernel_foo(%arg0 : memref<2xi32>, %arg1 : memref<4xi32>)
79+
kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
80+
// CHECK: spirv.Constant
81+
// CHECK: spirv.Constant dense<0>
82+
%idx0 = arith.constant 0 : index
83+
%vec0 = arith.constant dense<[0, 0]> : vector<2xi32>
84+
// CHECK: spirv.AccessChain
85+
// CHECK: spirv.Load "StorageBuffer"
86+
%val = memref.load %arg0[%idx0] : memref<2xi32>
87+
// CHECK: spirv.CompositeInsert
88+
%vec = vector.insertelement %val, %vec0[%idx0 : index] : vector<2xi32>
89+
// CHECK: spirv.VectorShuffle
90+
%shuffle = vector.shuffle %vec, %vec[3, 2, 1, 0] : vector<2xi32>, vector<2xi32>
91+
// CHECK: spirv.CompositeExtract
92+
%res = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32>
93+
// CHECK: spirv.AccessChain
94+
// CHECK: spirv.Store "StorageBuffer"
95+
memref.store %res, %arg1[%idx0]: memref<4xi32>
96+
// CHECK: spirv.Return
97+
gpu.return
98+
}
99+
}
100+
// CHECK: gpu.module @[[$KERNEL_MODULE]]
101+
// CHECK: gpu.func @[[$KERNEL_FUNC]]
102+
// CHECK-SAME: %{{.*}}: memref<2xi32>, %{{.*}}: memref<4xi32>
103+
// CHECK: arith.constant
104+
// CHECK: memref.load
105+
// CHECK: vector.insertelement
106+
// CHECK: vector.shuffle
107+
// CHECK: vector.extractelement
108+
// CHECK: memref.store
109+
// CHECK: gpu.return
110+
}

mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
1516
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
1617
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
1718
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
@@ -64,7 +65,9 @@ static LogicalResult runMLIRPasses(Operation *op,
6465
passManager.addPass(createGpuKernelOutliningPass());
6566
passManager.addPass(memref::createFoldMemRefAliasOpsPass());
6667

67-
passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
68+
ConvertToSPIRVPassOptions convertToSPIRVOptions{};
69+
convertToSPIRVOptions.convertGPUModules = true;
70+
passManager.addPass(createConvertToSPIRVPass(convertToSPIRVOptions));
6871
OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
6972
modulePM.addPass(spirv::createSPIRVLowerABIAttributesPass());
7073
modulePM.addPass(spirv::createSPIRVUpdateVCEPass());

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8410,6 +8410,7 @@ cc_library(
84108410
":ArithTransforms",
84118411
":ConversionPassIncGen",
84128412
":FuncToSPIRV",
8413+
":GPUDialect",
84138414
":GPUToSPIRV",
84148415
":IR",
84158416
":IndexToSPIRV",

0 commit comments

Comments
 (0)