Skip to content

Commit d6d4a52

Browse files
authored
[mlir][spirv][gpu] Clean up wmma to coop matrix NV conversion. NFC. (#66278)
This is a cleanup in preparation for adding a second conversion path using the KHR cooperative matrix extension. Make the existing lowering explicit about emitting ops from the NV coop matrix extension. Clean up surrounding code.
1 parent 089b811 commit d6d4a52

File tree

4 files changed

+55
-47
lines changed

4 files changed

+55
-47
lines changed

mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,15 @@ class MMAMatrixType;
3030
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
3131
RewritePatternSet &patterns);
3232

33-
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV.
34-
void populateGpuWMMAToSPIRVConversionPatterns(SPIRVTypeConverter &typeConverter,
35-
RewritePatternSet &patterns);
36-
37-
spirv::CooperativeMatrixNVType convertMMAToSPIRVType(gpu::MMAMatrixType type);
33+
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
34+
/// using the NV Cooperative Matrix extension.
35+
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
36+
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
37+
38+
/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
39+
/// `type`.
40+
spirv::CooperativeMatrixNVType
41+
convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type);
3842
} // namespace mlir
3943

4044
#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,14 @@ namespace {
3939
/// replace it).
4040
///
4141
/// 2) Lower the body of the spirv::ModuleOp.
42-
class GPUToSPIRVPass : public impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
43-
public:
42+
struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
4443
explicit GPUToSPIRVPass(bool mapMemorySpace)
4544
: mapMemorySpace(mapMemorySpace) {}
4645
void runOnOperation() override;
4746

4847
private:
4948
bool mapMemorySpace;
5049
};
51-
} // namespace
5250

5351
void GPUToSPIRVPass::runOnOperation() {
5452
MLIRContext *context = &getContext();
@@ -89,11 +87,12 @@ void GPUToSPIRVPass::runOnOperation() {
8987
options.use64bitIndex = this->use64bitIndex;
9088
SPIRVTypeConverter typeConverter(targetAttr, options);
9189
typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
92-
return convertMMAToSPIRVType(type);
90+
return convertMMAToSPIRVCoopMatrixNVType(type);
9391
});
9492
RewritePatternSet patterns(context);
9593
populateGPUToSPIRVPatterns(typeConverter, patterns);
96-
populateGpuWMMAToSPIRVConversionPatterns(typeConverter, patterns);
94+
populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
95+
patterns);
9796
// TODO: Change SPIR-V conversion to be progressive and remove the following
9897
// patterns.
9998
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
@@ -105,6 +104,8 @@ void GPUToSPIRVPass::runOnOperation() {
105104
}
106105
}
107106

107+
} // namespace
108+
108109
std::unique_ptr<OperationPass<ModuleOp>>
109110
mlir::createConvertGPUToSPIRVPass(bool mapMemorySpace) {
110111
return std::make_unique<GPUToSPIRVPass>(mapMemorySpace);

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===//
1+
//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10-
// SPIRV Dialect ops.
10+
// SPIRV Cooperative Matrix ops.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

@@ -22,7 +22,8 @@
2222
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
2323
#include "mlir/IR/TypeUtilities.h"
2424

25-
using namespace mlir;
25+
namespace mlir::nv {
26+
namespace {
2627

2728
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
2829
/// when the elementwise op directly supports with cooperative matrix type.
@@ -70,12 +71,10 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
7071
return false;
7172
}
7273

73-
namespace {
74-
75-
/// This class implements the conversion of GPU MMA loadOp to
76-
/// CooperativeMatrixLoad op in the SPIRV dialect.
77-
struct WmmaLoadOpToSPIRVLowering
78-
: public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
74+
/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
75+
/// dialect.
76+
struct WmmaLoadOpToSPIRVLowering final
77+
: OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
7978
using OpConversionPattern::OpConversionPattern;
8079

8180
LogicalResult
@@ -90,7 +89,7 @@ struct WmmaLoadOpToSPIRVLowering
9089
Value bufferPtr = spirv::getElementPtr(
9190
*getTypeConverter<const SPIRVTypeConverter>(), memrefType,
9291
adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
93-
auto coopType = convertMMAToSPIRVType(retType);
92+
auto coopType = convertMMAToSPIRVCoopMatrixNVType(retType);
9493
int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
9594
auto i32Type = rewriter.getI32Type();
9695
auto strideValue = rewriter.create<spirv::ConstantOp>(
@@ -105,10 +104,10 @@ struct WmmaLoadOpToSPIRVLowering
105104
}
106105
};
107106

108-
/// This class implements the conversion of GPU MMA StoreOp to
109-
/// CooperativeMatrixStore op in the SPIRV dialect.
110-
struct WmmaStoreOpToSPIRVLowering
111-
: public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
107+
/// Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV
108+
/// dialect.
109+
struct WmmaStoreOpToSPIRVLowering final
110+
: OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
112111
using OpConversionPattern::OpConversionPattern;
113112

114113
LogicalResult
@@ -136,10 +135,10 @@ struct WmmaStoreOpToSPIRVLowering
136135
}
137136
};
138137

139-
/// This class implements the conversion of GPU MMA Compute to
140-
/// CooperativeMatrixMulAdd op in the SPIRV dialect.
141-
struct WmmaMmaOpToSPIRVLowering
142-
: public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
138+
/// Converts GPU MMA Compute to
139+
/// NVCooperativeMatrixMulAdd op in the SPIRV dialect.
140+
struct WmmaMmaOpToSPIRVLowering final
141+
: OpConversionPattern<gpu::SubgroupMmaComputeOp> {
143142
using OpConversionPattern::OpConversionPattern;
144143

145144
LogicalResult
@@ -153,17 +152,18 @@ struct WmmaMmaOpToSPIRVLowering
153152
}
154153
};
155154

156-
/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops.
157-
struct WmmaConstantOpToSPIRVLowering
158-
: public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
155+
/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
156+
/// ops.
157+
struct WmmaConstantOpToSPIRVLowering final
158+
: OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
159159
using OpConversionPattern::OpConversionPattern;
160160

161161
LogicalResult
162162
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
163163
OpAdaptor adaptor,
164164
ConversionPatternRewriter &rewriter) const override {
165165
Value cst = adaptor.getOperands()[0];
166-
auto coopType = convertMMAToSPIRVType(
166+
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
167167
cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
168168
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
169169
subgroupMmaConstantMatrixOp, coopType, cst);
@@ -173,8 +173,8 @@ struct WmmaConstantOpToSPIRVLowering
173173

174174
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
175175
/// the default case.
176-
struct WmmaElementwiseOpToSPIRVDefaultLowering
177-
: public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
176+
struct WmmaElementwiseOpToSPIRVDefaultLowering final
177+
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
178178
using OpConversionPattern::OpConversionPattern;
179179

180180
LogicalResult
@@ -186,7 +186,7 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
186186
if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
187187
return failure();
188188
}
189-
auto coopType = convertMMAToSPIRVType(
189+
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
190190
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
191191
return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
192192
adaptor.getOperands()));
@@ -195,8 +195,8 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
195195

196196
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
197197
/// matrix times scalar case.
198-
struct WmmaElementwiseOpToSPIRVScalarMulLowering
199-
: public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
198+
struct WmmaElementwiseOpToSPIRVScalarMulLowering final
199+
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
200200
using OpConversionPattern::OpConversionPattern;
201201

202202
LogicalResult
@@ -238,7 +238,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
238238
assert(cc.getConstituents().size() == 1);
239239
scalar = cc.getConstituents().front();
240240

241-
auto coopType = convertMMAToSPIRVType(
241+
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
242242
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
243243
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
244244
elementwiseOp, coopType, ValueRange{matrix, scalar});
@@ -247,23 +247,26 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
247247
};
248248

249249
} // namespace
250+
} // namespace mlir::nv
250251

251-
/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
252252
mlir::spirv::CooperativeMatrixNVType
253-
mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
253+
mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
254254
ArrayRef<int64_t> retTypeShape = type.getShape();
255255
Type elementType = type.getElementType();
256256
return spirv::CooperativeMatrixNVType::get(
257257
elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
258258
}
259259

260-
void mlir::populateGpuWMMAToSPIRVConversionPatterns(
260+
void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
261261
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
262+
using namespace mlir;
262263
MLIRContext *context = patterns.getContext();
263-
patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
264-
WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
265-
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
264+
patterns
265+
.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
266+
nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
267+
nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
266268
// Give the following patterns higher benefit to prevail over the default one.
267-
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
268-
/*benefit=*/2);
269+
patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
270+
context,
271+
/*benefit=*/2);
269272
}

mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir renamed to mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -split-input-file -verify-diagnostics %s | FileCheck %s
1+
// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s
22

33
module attributes {
44
gpu.container_module,

0 commit comments

Comments
 (0)