Skip to content

Commit e923b21

Browse files
authored
[mlir][spirv][gpu] Default to KHR coop matrix. Clean up type conversion. (llvm#67485)
- Now that the KHR coop matrix implementation is robust, switch the gpu conversion pass to default to it. - Use a populate function for MMA to coop matrix type conversions. This makes the API surface area smaller.
1 parent 6e608dc commit e923b21

File tree

4 files changed

+47
-50
lines changed

4 files changed

+47
-50
lines changed

mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@
2020
namespace mlir {
2121
class SPIRVTypeConverter;
2222

23-
namespace gpu {
24-
class MMAMatrixType;
25-
} // namespace gpu
26-
2723
/// Appends to a pattern list additional patterns for translating GPU Ops to
2824
/// SPIR-V ops. For a gpu.func to be converted, it should have a
2925
/// spirv.entry_point_abi attribute.
@@ -40,15 +36,11 @@ void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
4036
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
4137
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
4238

43-
/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
44-
/// `type`.
45-
spirv::CooperativeMatrixType
46-
convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);
47-
48-
/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
49-
/// `type`.
50-
spirv::CooperativeMatrixNVType
51-
convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type);
39+
/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix type
40+
/// conversion to the type converter. Defaults to KHR cooperative matrix types.
41+
/// When `useNVTypes` is `true`, uses the NV cooperative matrix types.
42+
void populateMMAToSPIRVCoopMatrixTypeConversion(
43+
SPIRVTypeConverter &typeConverter, bool useNVTypes = false);
5244
} // namespace mlir
5345

5446
#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H

mlir/include/mlir/Conversion/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
569569
"bool", /*default=*/"false",
570570
"Use 64-bit integers to convert index types">,
571571
Option<"useCoopMatrixNV", "use-coop-matrix-nv",
572-
"bool", /*default=*/"true",
572+
"bool", /*default=*/"false",
573573
"Use the NV cooperative matrix extension insted of the KHR extension"
574574
" to lower GPU WMMA ops">,
575575
];

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,8 @@ void GPUToSPIRVPass::runOnOperation() {
8686
SPIRVConversionOptions options;
8787
options.use64bitIndex = this->use64bitIndex;
8888
SPIRVTypeConverter typeConverter(targetAttr, options);
89-
90-
typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
91-
gpu::MMAMatrixType type) -> Type {
92-
if (useNV)
93-
return convertMMAToSPIRVCoopMatrixNVType(type);
94-
95-
return convertMMAToSPIRVCoopMatrixType(type);
96-
});
89+
populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter,
90+
this->useCoopMatrixNV);
9791

9892
RewritePatternSet patterns(context);
9993
populateGPUToSPIRVPatterns(typeConverter, patterns);

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,21 @@ struct WmmaLoadOpToSPIRVLowering final
311311
OpAdaptor adaptor,
312312
ConversionPatternRewriter &rewriter) const override {
313313
Location loc = subgroupMmaLoadMatrixOp->getLoc();
314+
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
315+
314316
gpu::MMAMatrixType retType =
315317
cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
316318
auto memrefType =
317319
cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType());
318-
Value bufferPtr = spirv::getElementPtr(
319-
*getTypeConverter<const SPIRVTypeConverter>(), memrefType,
320-
adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
321-
auto coopType = convertMMAToSPIRVCoopMatrixNVType(retType);
320+
Value bufferPtr =
321+
spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
322+
adaptor.getIndices(), loc, rewriter);
323+
auto coopType =
324+
typeConverter.convertType<spirv::CooperativeMatrixNVType>(retType);
325+
if (!coopType)
326+
return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp,
327+
"type conversion failed");
328+
322329
int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
323330
auto i32Type = rewriter.getI32Type();
324331
auto strideValue = rewriter.create<spirv::ConstantOp>(
@@ -385,30 +392,6 @@ struct WmmaMmaOpToSPIRVLowering final
385392
} // namespace nv
386393
} // namespace mlir
387394

388-
mlir::spirv::CooperativeMatrixNVType
389-
mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
390-
ArrayRef<int64_t> retTypeShape = type.getShape();
391-
Type elementType = type.getElementType();
392-
return spirv::CooperativeMatrixNVType::get(
393-
elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
394-
}
395-
396-
mlir::spirv::CooperativeMatrixType
397-
mlir::convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type) {
398-
ArrayRef<int64_t> retTypeShape = type.getShape();
399-
Type elementType = type.getElementType();
400-
401-
auto use =
402-
llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
403-
.Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
404-
.Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
405-
.Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
406-
407-
return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
408-
retTypeShape[1],
409-
spirv::Scope::Subgroup, use);
410-
}
411-
412395
void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
413396
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
414397
using namespace mlir;
@@ -432,3 +415,31 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
432415
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
433416
/*benefit=*/2);
434417
}
418+
419+
void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
420+
mlir::SPIRVTypeConverter &typeConverter, bool useNVTypes) {
421+
if (useNVTypes) {
422+
typeConverter.addConversion([](gpu::MMAMatrixType type) {
423+
ArrayRef<int64_t> retTypeShape = type.getShape();
424+
Type elementType = type.getElementType();
425+
return spirv::CooperativeMatrixNVType::get(
426+
elementType, spirv::Scope::Subgroup, retTypeShape[0],
427+
retTypeShape[1]);
428+
});
429+
return;
430+
}
431+
432+
typeConverter.addConversion([](gpu::MMAMatrixType type) {
433+
ArrayRef<int64_t> retTypeShape = type.getShape();
434+
Type elementType = type.getElementType();
435+
auto use =
436+
llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
437+
.Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
438+
.Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
439+
.Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
440+
441+
return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
442+
retTypeShape[1],
443+
spirv::Scope::Subgroup, use);
444+
});
445+
}

0 commit comments

Comments
 (0)