16
16
#include " mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
17
17
#include " mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
18
18
#include " mlir/Dialect/Arith/Transforms/Passes.h"
19
+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
19
20
#include " mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
20
21
#include " mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
21
22
#include " mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -40,6 +41,35 @@ using namespace mlir;
40
41
41
42
namespace {
42
43
44
+ // / Map memRef memory space to SPIR-V storage class.
45
+ void mapToMemRef (Operation *op, spirv::TargetEnvAttr &targetAttr) {
46
+ spirv::TargetEnv targetEnv (targetAttr);
47
+ bool targetEnvSupportsKernelCapability =
48
+ targetEnv.allows (spirv::Capability::Kernel);
49
+ spirv::MemorySpaceToStorageClassMap memorySpaceMap =
50
+ targetEnvSupportsKernelCapability
51
+ ? spirv::mapMemorySpaceToOpenCLStorageClass
52
+ : spirv::mapMemorySpaceToVulkanStorageClass;
53
+ spirv::MemorySpaceToStorageClassConverter converter (memorySpaceMap);
54
+ spirv::convertMemRefTypesAndAttrs (op, converter);
55
+ }
56
+
57
+ // / Populate patterns for each dialect.
58
+ void populateConvertToSPIRVPatterns (SPIRVTypeConverter &typeConverter,
59
+ ScfToSPIRVContext &scfToSPIRVContext,
60
+ RewritePatternSet &patterns) {
61
+ arith::populateCeilFloorDivExpandOpsPatterns (patterns);
62
+ arith::populateArithToSPIRVPatterns (typeConverter, patterns);
63
+ populateBuiltinFuncToSPIRVPatterns (typeConverter, patterns);
64
+ populateFuncToSPIRVPatterns (typeConverter, patterns);
65
+ populateGPUToSPIRVPatterns (typeConverter, patterns);
66
+ index::populateIndexToSPIRVPatterns (typeConverter, patterns);
67
+ populateMemRefToSPIRVPatterns (typeConverter, patterns);
68
+ populateVectorToSPIRVPatterns (typeConverter, patterns);
69
+ populateSCFToSPIRVPatterns (typeConverter, scfToSPIRVContext, patterns);
70
+ ub::populateUBToSPIRVConversionPatterns (typeConverter, patterns);
71
+ }
72
+
43
73
// / A pass to perform the SPIR-V conversion.
44
74
struct ConvertToSPIRVPass final
45
75
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
@@ -57,38 +87,46 @@ struct ConvertToSPIRVPass final
57
87
if (runVectorUnrolling && failed (spirv::unrollVectorsInFuncBodies (op)))
58
88
return signalPassFailure ();
59
89
60
- spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault (op);
61
- std::unique_ptr<ConversionTarget> target =
62
- SPIRVConversionTarget::get (targetAttr);
63
- SPIRVTypeConverter typeConverter (targetAttr);
64
- RewritePatternSet patterns (context);
65
- ScfToSPIRVContext scfToSPIRVContext;
66
-
67
- // Map MemRef memory space to SPIR-V storage class.
68
- spirv::TargetEnv targetEnv (targetAttr);
69
- bool targetEnvSupportsKernelCapability =
70
- targetEnv.allows (spirv::Capability::Kernel);
71
- spirv::MemorySpaceToStorageClassMap memorySpaceMap =
72
- targetEnvSupportsKernelCapability
73
- ? spirv::mapMemorySpaceToOpenCLStorageClass
74
- : spirv::mapMemorySpaceToVulkanStorageClass;
75
- spirv::MemorySpaceToStorageClassConverter converter (memorySpaceMap);
76
- spirv::convertMemRefTypesAndAttrs (op, converter);
90
+ // Generic conversion.
91
+ if (!convertGPUModules) {
92
+ spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault (op);
93
+ std::unique_ptr<ConversionTarget> target =
94
+ SPIRVConversionTarget::get (targetAttr);
95
+ SPIRVTypeConverter typeConverter (targetAttr);
96
+ RewritePatternSet patterns (context);
97
+ ScfToSPIRVContext scfToSPIRVContext;
98
+ mapToMemRef (op, targetAttr);
99
+ populateConvertToSPIRVPatterns (typeConverter, scfToSPIRVContext,
100
+ patterns);
101
+ if (failed (applyPartialConversion (op, *target, std::move (patterns))))
102
+ return signalPassFailure ();
103
+ return ;
104
+ }
77
105
78
- // Populate patterns for each dialect.
79
- arith::populateCeilFloorDivExpandOpsPatterns (patterns);
80
- arith::populateArithToSPIRVPatterns (typeConverter, patterns);
81
- populateBuiltinFuncToSPIRVPatterns (typeConverter, patterns);
82
- populateFuncToSPIRVPatterns (typeConverter, patterns);
83
- populateGPUToSPIRVPatterns (typeConverter, patterns);
84
- index::populateIndexToSPIRVPatterns (typeConverter, patterns);
85
- populateMemRefToSPIRVPatterns (typeConverter, patterns);
86
- populateVectorToSPIRVPatterns (typeConverter, patterns);
87
- populateSCFToSPIRVPatterns (typeConverter, scfToSPIRVContext, patterns);
88
- ub::populateUBToSPIRVConversionPatterns (typeConverter, patterns);
89
-
90
- if (failed (applyPartialConversion (op, *target, std::move (patterns))))
91
- return signalPassFailure ();
106
+ // Clone each GPU kernel module for conversion, given that the GPU
107
+ // launch op still needs the original GPU kernel module.
108
+ SmallVector<Operation *, 1 > gpuModules;
109
+ OpBuilder builder (context);
110
+ op->walk ([&](gpu::GPUModuleOp gpuModule) {
111
+ builder.setInsertionPoint (gpuModule);
112
+ gpuModules.push_back (builder.clone (*gpuModule));
113
+ });
114
+ // Run conversion for each module independently as they can have
115
+ // different TargetEnv attributes.
116
+ for (Operation *gpuModule : gpuModules) {
117
+ spirv::TargetEnvAttr targetAttr =
118
+ spirv::lookupTargetEnvOrDefault (gpuModule);
119
+ std::unique_ptr<ConversionTarget> target =
120
+ SPIRVConversionTarget::get (targetAttr);
121
+ SPIRVTypeConverter typeConverter (targetAttr);
122
+ RewritePatternSet patterns (context);
123
+ ScfToSPIRVContext scfToSPIRVContext;
124
+ mapToMemRef (gpuModule, targetAttr);
125
+ populateConvertToSPIRVPatterns (typeConverter, scfToSPIRVContext,
126
+ patterns);
127
+ if (failed (applyFullConversion (gpuModule, *target, std::move (patterns))))
128
+ return signalPassFailure ();
129
+ }
92
130
}
93
131
};
94
132
0 commit comments