Skip to content

[mlir][spirv][gpu] Clean up wmma to coop matrix NV conversion. NFC. #66278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 13, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Sep 13, 2023

This is a cleanup in preparation for adding a second conversion path using the KHR cooperative matrix extension.

Make the existing lowering explicit about emitting ops from the NV coop matrix extension. Clean up surrounding code.

This is a cleanup in preparation for adding a second conversion path
using the KHR cooperative matrix extension.

Make the existing lowering explicit about emitting ops from the NV coop
matrix extension. Clean up surrounding code.
@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Changes This is a cleanup in preparation for adding a second conversion path using the KHR cooperative matrix extension.

Make the existing lowering explicit about emitting ops from the NV coop matrix extension. Clean up surrounding code.

Full diff: https://github.com/llvm/llvm-project/pull/66278.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h (+9-5)
  • (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp (+6-5)
  • (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+39-36)
  • (renamed) mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir (+1-1)
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index 3c3281513f60d89..6c4643da1884900 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -30,11 +30,15 @@ class MMAMatrixType;
 void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                 RewritePatternSet &patterns);
 
-/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV.
-void populateGpuWMMAToSPIRVConversionPatterns(SPIRVTypeConverter &typeConverter,
-                                              RewritePatternSet &patterns);
-
-spirv::CooperativeMatrixNVType convertMMAToSPIRVType(gpu::MMAMatrixType type);
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
+/// using the NV Cooperative Matrix extension.
+void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
+    SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
+
+/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
+/// `type`.
+spirv::CooperativeMatrixNVType
+convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index f37c70a771f5916..d0ce58597f980d4 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -39,8 +39,7 @@ namespace {
 /// replace it).
 ///
 /// 2) Lower the body of the spirv::ModuleOp.
-class GPUToSPIRVPass : public impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
-public:
+struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
   explicit GPUToSPIRVPass(bool mapMemorySpace)
       : mapMemorySpace(mapMemorySpace) {}
   void runOnOperation() override;
@@ -48,7 +47,6 @@ class GPUToSPIRVPass : public impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
 private:
   bool mapMemorySpace;
 };
-} // namespace
 
 void GPUToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
@@ -89,11 +87,12 @@ void GPUToSPIRVPass::runOnOperation() {
     options.use64bitIndex = this->use64bitIndex;
     SPIRVTypeConverter typeConverter(targetAttr, options);
     typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
-      return convertMMAToSPIRVType(type);
+      return convertMMAToSPIRVCoopMatrixNVType(type);
     });
     RewritePatternSet patterns(context);
     populateGPUToSPIRVPatterns(typeConverter, patterns);
-    populateGpuWMMAToSPIRVConversionPatterns(typeConverter, patterns);
+    populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
+                                                         patterns);
     // TODO: Change SPIR-V conversion to be progressive and remove the following
     // patterns.
     mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
@@ -105,6 +104,8 @@ void GPUToSPIRVPass::runOnOperation() {
   }
 }
 
+} // namespace
+
 std::unique_ptr<OperationPass<ModuleOp>>
 mlir::createConvertGPUToSPIRVPass(bool mapMemorySpace) {
   return std::make_unique<GPUToSPIRVPass>(mapMemorySpace);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 3851fb728b6654b..bf3fff027fe384a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -1,4 +1,4 @@
-//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===//
+//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
-// SPIRV Dialect ops.
+// SPIRV Cooperative Matrix ops.
 //
 //===----------------------------------------------------------------------===//
 
@@ -22,7 +22,8 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/TypeUtilities.h"
 
-using namespace mlir;
+namespace mlir::nv {
+namespace {
 
 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
 /// when the elementwise op directly supports with cooperative matrix type.
@@ -70,12 +71,10 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
   return false;
 }
 
-namespace {
-
-/// This class implements the conversion of GPU MMA loadOp to
-/// CooperativeMatrixLoad op in the SPIRV dialect.
-struct WmmaLoadOpToSPIRVLowering
-    : public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
+/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
+/// dialect.
+struct WmmaLoadOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -90,7 +89,7 @@ struct WmmaLoadOpToSPIRVLowering
     Value bufferPtr = spirv::getElementPtr(
         *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
         adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
-    auto coopType = convertMMAToSPIRVType(retType);
+    auto coopType = convertMMAToSPIRVCoopMatrixNVType(retType);
     int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
     auto i32Type = rewriter.getI32Type();
     auto strideValue = rewriter.create<spirv::ConstantOp>(
@@ -105,10 +104,10 @@ struct WmmaLoadOpToSPIRVLowering
   }
 };
 
-/// This class implements the conversion of GPU MMA StoreOp to
-/// CooperativeMatrixStore op in the SPIRV dialect.
-struct WmmaStoreOpToSPIRVLowering
-    : public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
+/// Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV
+/// dialect.
+struct WmmaStoreOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -136,10 +135,10 @@ struct WmmaStoreOpToSPIRVLowering
   }
 };
 
-/// This class implements the conversion of GPU MMA Compute to
-/// CooperativeMatrixMulAdd op in the SPIRV dialect.
-struct WmmaMmaOpToSPIRVLowering
-    : public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+/// Converts GPU MMA Compute to
+/// NVCooperativeMatrixMulAdd op in the SPIRV dialect.
+struct WmmaMmaOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -153,9 +152,10 @@ struct WmmaMmaOpToSPIRVLowering
   }
 };
 
-/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops.
-struct WmmaConstantOpToSPIRVLowering
-    : public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
+/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
+/// ops.
+struct WmmaConstantOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -163,7 +163,7 @@ struct WmmaConstantOpToSPIRVLowering
                   OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Value cst = adaptor.getOperands()[0];
-    auto coopType = convertMMAToSPIRVType(
+    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
         cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
         subgroupMmaConstantMatrixOp, coopType, cst);
@@ -173,8 +173,8 @@ struct WmmaConstantOpToSPIRVLowering
 
 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
 /// the default case.
-struct WmmaElementwiseOpToSPIRVDefaultLowering
-    : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+struct WmmaElementwiseOpToSPIRVDefaultLowering final
+    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -186,7 +186,7 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
       if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
         return failure();
     }
-    auto coopType = convertMMAToSPIRVType(
+    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
         cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
     return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
                                        adaptor.getOperands()));
@@ -195,8 +195,8 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
 
 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
 /// matrix times scalar case.
-struct WmmaElementwiseOpToSPIRVScalarMulLowering
-    : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+struct WmmaElementwiseOpToSPIRVScalarMulLowering final
+    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -238,7 +238,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
     assert(cc.getConstituents().size() == 1);
     scalar = cc.getConstituents().front();
 
-    auto coopType = convertMMAToSPIRVType(
+    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
         cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
     rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
         elementwiseOp, coopType, ValueRange{matrix, scalar});
@@ -247,23 +247,26 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
 };
 
 } // namespace
+} // namespace mlir::nv
 
-/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
 mlir::spirv::CooperativeMatrixNVType
-mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
+mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
   ArrayRef<int64_t> retTypeShape = type.getShape();
   Type elementType = type.getElementType();
   return spirv::CooperativeMatrixNVType::get(
       elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
 }
 
-void mlir::populateGpuWMMAToSPIRVConversionPatterns(
+void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
     SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+  using namespace mlir;
   MLIRContext *context = patterns.getContext();
-  patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
-               WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
-               WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
+  patterns
+      .add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
+           nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
+           nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
   // Give the following patterns higher benefit to prevail over the default one.
-  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
-                                                          /*benefit=*/2);
+  patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
+                                                              context,
+                                                              /*benefit=*/2);
 }
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
similarity index 98%
rename from mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
rename to mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
index a53eca65fc98699..5811c791f308d1e 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s
 
 module attributes {
   gpu.container_module,

@kuhar kuhar merged commit d6d4a52 into llvm:main Sep 13, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
…lvm#66278)

This is a cleanup in preparation for adding a second conversion path
using the KHR cooperative matrix extension.

Make the existing lowering explicit about emitting ops from the NV coop
matrix extension. Clean up surrounding code.
Guzhu-AMD pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Sep 21, 2023
Local branch amd-gfx b396737 Merged main:25e8105bff4d into amd-gfx:0d570d01ad3e
Remote branch main d6d4a52 [mlir][spirv][gpu] Clean up wmma to coop matrix NV conversion. NFC. (llvm#66278)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants