Skip to content

[mlir][spirv][gpu] Clean up wmma to coop matrix NV conversion. NFC. #66278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ class MMAMatrixType;
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);

/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV.
void populateGpuWMMAToSPIRVConversionPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);

spirv::CooperativeMatrixNVType convertMMAToSPIRVType(gpu::MMAMatrixType type);
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
/// using the NV Cooperative Matrix extension.
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);

/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
/// `type`.
spirv::CooperativeMatrixNVType
convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type);
} // namespace mlir

#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
11 changes: 6 additions & 5 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,14 @@ namespace {
/// replace it).
///
/// 2) Lower the body of the spirv::ModuleOp.
class GPUToSPIRVPass : public impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
public:
struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
explicit GPUToSPIRVPass(bool mapMemorySpace)
: mapMemorySpace(mapMemorySpace) {}
void runOnOperation() override;

private:
bool mapMemorySpace;
};
} // namespace

void GPUToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
Expand Down Expand Up @@ -89,11 +87,12 @@ void GPUToSPIRVPass::runOnOperation() {
options.use64bitIndex = this->use64bitIndex;
SPIRVTypeConverter typeConverter(targetAttr, options);
typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
return convertMMAToSPIRVType(type);
return convertMMAToSPIRVCoopMatrixNVType(type);
});
RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
populateGpuWMMAToSPIRVConversionPatterns(typeConverter, patterns);
populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
patterns);
// TODO: Change SPIR-V conversion to be progressive and remove the following
// patterns.
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
Expand All @@ -105,6 +104,8 @@ void GPUToSPIRVPass::runOnOperation() {
}
}

} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertGPUToSPIRVPass(bool mapMemorySpace) {
return std::make_unique<GPUToSPIRVPass>(mapMemorySpace);
Expand Down
75 changes: 39 additions & 36 deletions mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===//
//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
// SPIRV Dialect ops.
// SPIRV Cooperative Matrix ops.
//
//===----------------------------------------------------------------------===//

Expand All @@ -22,7 +22,8 @@
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/TypeUtilities.h"

using namespace mlir;
namespace mlir::nv {
namespace {

/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
/// when the elementwise op directly supports with cooperative matrix type.
Expand Down Expand Up @@ -70,12 +71,10 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
return false;
}

namespace {

/// This class implements the conversion of GPU MMA loadOp to
/// CooperativeMatrixLoad op in the SPIRV dialect.
struct WmmaLoadOpToSPIRVLowering
: public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
/// dialect.
struct WmmaLoadOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
Expand All @@ -90,7 +89,7 @@ struct WmmaLoadOpToSPIRVLowering
Value bufferPtr = spirv::getElementPtr(
*getTypeConverter<const SPIRVTypeConverter>(), memrefType,
adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
auto coopType = convertMMAToSPIRVType(retType);
auto coopType = convertMMAToSPIRVCoopMatrixNVType(retType);
int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
auto i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
Expand All @@ -105,10 +104,10 @@ struct WmmaLoadOpToSPIRVLowering
}
};

/// This class implements the conversion of GPU MMA StoreOp to
/// CooperativeMatrixStore op in the SPIRV dialect.
struct WmmaStoreOpToSPIRVLowering
: public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
/// Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV
/// dialect.
struct WmmaStoreOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
Expand Down Expand Up @@ -136,10 +135,10 @@ struct WmmaStoreOpToSPIRVLowering
}
};

/// This class implements the conversion of GPU MMA Compute to
/// CooperativeMatrixMulAdd op in the SPIRV dialect.
struct WmmaMmaOpToSPIRVLowering
: public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
/// Converts GPU MMA Compute to
/// NVCooperativeMatrixMulAdd op in the SPIRV dialect.
struct WmmaMmaOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaComputeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
Expand All @@ -153,17 +152,18 @@ struct WmmaMmaOpToSPIRVLowering
}
};

/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops.
struct WmmaConstantOpToSPIRVLowering
: public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
/// ops.
struct WmmaConstantOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value cst = adaptor.getOperands()[0];
auto coopType = convertMMAToSPIRVType(
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
subgroupMmaConstantMatrixOp, coopType, cst);
Expand All @@ -173,8 +173,8 @@ struct WmmaConstantOpToSPIRVLowering

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
struct WmmaElementwiseOpToSPIRVDefaultLowering
: public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
struct WmmaElementwiseOpToSPIRVDefaultLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
Expand All @@ -186,7 +186,7 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
return failure();
}
auto coopType = convertMMAToSPIRVType(
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
adaptor.getOperands()));
Expand All @@ -195,8 +195,8 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// matrix times scalar case.
struct WmmaElementwiseOpToSPIRVScalarMulLowering
: public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
struct WmmaElementwiseOpToSPIRVScalarMulLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
Expand Down Expand Up @@ -238,7 +238,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
assert(cc.getConstituents().size() == 1);
scalar = cc.getConstituents().front();

auto coopType = convertMMAToSPIRVType(
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
elementwiseOp, coopType, ValueRange{matrix, scalar});
Expand All @@ -247,23 +247,26 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
};

} // namespace
} // namespace mlir::nv

/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
mlir::spirv::CooperativeMatrixNVType
mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
ArrayRef<int64_t> retTypeShape = type.getShape();
Type elementType = type.getElementType();
return spirv::CooperativeMatrixNVType::get(
elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
}

void mlir::populateGpuWMMAToSPIRVConversionPatterns(
void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
using namespace mlir;
MLIRContext *context = patterns.getContext();
patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
patterns
.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
/*benefit=*/2);
patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
context,
/*benefit=*/2);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -split-input-file -verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s

module attributes {
gpu.container_module,
Expand Down