1
- // ===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering- -----===//
1
+ // ===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
9
// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10
- // SPIRV Dialect ops.
10
+ // SPIRV Cooperative Matrix ops.
11
11
//
12
12
// ===----------------------------------------------------------------------===//
13
13
22
22
#include " mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
23
23
#include " mlir/IR/TypeUtilities.h"
24
24
25
- using namespace mlir ;
25
+ namespace mlir ::nv {
26
+ namespace {
26
27
27
28
// / Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
28
29
// / when the elementwise op directly supports with cooperative matrix type.
@@ -70,12 +71,10 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
70
71
return false ;
71
72
}
72
73
73
- namespace {
74
-
75
- // / This class implements the conversion of GPU MMA loadOp to
76
- // / CooperativeMatrixLoad op in the SPIRV dialect.
77
- struct WmmaLoadOpToSPIRVLowering
78
- : public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
74
+ // / Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
75
+ // / dialect.
76
+ struct WmmaLoadOpToSPIRVLowering final
77
+ : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
79
78
using OpConversionPattern::OpConversionPattern;
80
79
81
80
LogicalResult
@@ -90,7 +89,7 @@ struct WmmaLoadOpToSPIRVLowering
90
89
Value bufferPtr = spirv::getElementPtr (
91
90
*getTypeConverter<const SPIRVTypeConverter>(), memrefType,
92
91
adaptor.getSrcMemref (), adaptor.getIndices (), loc, rewriter);
93
- auto coopType = convertMMAToSPIRVType (retType);
92
+ auto coopType = convertMMAToSPIRVCoopMatrixNVType (retType);
94
93
int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension ().getSExtValue ();
95
94
auto i32Type = rewriter.getI32Type ();
96
95
auto strideValue = rewriter.create <spirv::ConstantOp>(
@@ -105,10 +104,10 @@ struct WmmaLoadOpToSPIRVLowering
105
104
}
106
105
};
107
106
108
- // / This class implements the conversion of GPU MMA StoreOp to
109
- // / CooperativeMatrixStore op in the SPIRV dialect.
110
- struct WmmaStoreOpToSPIRVLowering
111
- : public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
107
+ // / Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV
108
+ // / dialect.
109
+ struct WmmaStoreOpToSPIRVLowering final
110
+ : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
112
111
using OpConversionPattern::OpConversionPattern;
113
112
114
113
LogicalResult
@@ -136,10 +135,10 @@ struct WmmaStoreOpToSPIRVLowering
136
135
}
137
136
};
138
137
139
- // / This class implements the conversion of GPU MMA Compute to
140
- // / CooperativeMatrixMulAdd op in the SPIRV dialect.
141
- struct WmmaMmaOpToSPIRVLowering
142
- : public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
138
+ // / Converts GPU MMA Compute to
139
+ // / NVCooperativeMatrixMulAdd op in the SPIRV dialect.
140
+ struct WmmaMmaOpToSPIRVLowering final
141
+ : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
143
142
using OpConversionPattern::OpConversionPattern;
144
143
145
144
LogicalResult
@@ -153,17 +152,18 @@ struct WmmaMmaOpToSPIRVLowering
153
152
}
154
153
};
155
154
156
- // / Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops.
157
- struct WmmaConstantOpToSPIRVLowering
158
- : public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
155
+ // / Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
156
+ // / ops.
157
+ struct WmmaConstantOpToSPIRVLowering final
158
+ : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
159
159
using OpConversionPattern::OpConversionPattern;
160
160
161
161
LogicalResult
162
162
matchAndRewrite (gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
163
163
OpAdaptor adaptor,
164
164
ConversionPatternRewriter &rewriter) const override {
165
165
Value cst = adaptor.getOperands ()[0 ];
166
- auto coopType = convertMMAToSPIRVType (
166
+ auto coopType = convertMMAToSPIRVCoopMatrixNVType (
167
167
cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType ()));
168
168
rewriter.replaceOpWithNewOp <spirv::CompositeConstructOp>(
169
169
subgroupMmaConstantMatrixOp, coopType, cst);
@@ -173,8 +173,8 @@ struct WmmaConstantOpToSPIRVLowering
173
173
174
174
// / Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
175
175
// / the default case.
176
- struct WmmaElementwiseOpToSPIRVDefaultLowering
177
- : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
176
+ struct WmmaElementwiseOpToSPIRVDefaultLowering final
177
+ : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
178
178
using OpConversionPattern::OpConversionPattern;
179
179
180
180
LogicalResult
@@ -186,7 +186,7 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
186
186
if (!isa<spirv::CooperativeMatrixNVType>(operand.getType ()))
187
187
return failure ();
188
188
}
189
- auto coopType = convertMMAToSPIRVType (
189
+ auto coopType = convertMMAToSPIRVCoopMatrixNVType (
190
190
cast<gpu::MMAMatrixType>(elementwiseOp.getType ()));
191
191
return success (createElementwiseOp (rewriter, elementwiseOp, coopType,
192
192
adaptor.getOperands ()));
@@ -195,8 +195,8 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
195
195
196
196
// / Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
197
197
// / matrix times scalar case.
198
- struct WmmaElementwiseOpToSPIRVScalarMulLowering
199
- : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
198
+ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
199
+ : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
200
200
using OpConversionPattern::OpConversionPattern;
201
201
202
202
LogicalResult
@@ -238,7 +238,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
238
238
assert (cc.getConstituents ().size () == 1 );
239
239
scalar = cc.getConstituents ().front ();
240
240
241
- auto coopType = convertMMAToSPIRVType (
241
+ auto coopType = convertMMAToSPIRVCoopMatrixNVType (
242
242
cast<gpu::MMAMatrixType>(elementwiseOp.getType ()));
243
243
rewriter.replaceOpWithNewOp <spirv::MatrixTimesScalarOp>(
244
244
elementwiseOp, coopType, ValueRange{matrix, scalar});
@@ -247,23 +247,26 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
247
247
};
248
248
249
249
} // namespace
250
+ } // namespace mlir::nv
250
251
251
- // / Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
252
252
mlir::spirv::CooperativeMatrixNVType
253
- mlir::convertMMAToSPIRVType (gpu::MMAMatrixType type) {
253
+ mlir::convertMMAToSPIRVCoopMatrixNVType (gpu::MMAMatrixType type) {
254
254
ArrayRef<int64_t > retTypeShape = type.getShape ();
255
255
Type elementType = type.getElementType ();
256
256
return spirv::CooperativeMatrixNVType::get (
257
257
elementType, spirv::Scope::Subgroup, retTypeShape[0 ], retTypeShape[1 ]);
258
258
}
259
259
260
- void mlir::populateGpuWMMAToSPIRVConversionPatterns (
260
+ void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns (
261
261
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
262
+ using namespace mlir ;
262
263
MLIRContext *context = patterns.getContext ();
263
- patterns.add <WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
264
- WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
265
- WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
264
+ patterns
265
+ .add <nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
266
+ nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
267
+ nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
266
268
// Give the following patterns higher benefit to prevail over the default one.
267
- patterns.add <WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
268
- /* benefit=*/ 2 );
269
+ patterns.add <nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
270
+ context,
271
+ /* benefit=*/ 2 );
269
272
}
0 commit comments