Skip to content

[mlir][nvgpu] Fix packing accumlator matrix #69316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 17, 2023
Merged

Conversation

grypp
Copy link
Member

@grypp grypp commented Oct 17, 2023

The #68728 significantly simplified the accumulator matrix type, making it easier to work with the nvgpu dialect without worrying about the number of required structs, as this information is abstracted away in the nvgpu-to-nvvm transformation.

However, we forgot packing the structs after initialization, causing the accumulator matrix to hold undefined values, which is wrong. This PR addresses that.

llvm#68728 simplified significantly the accumulator matrix. But we forget packing the struct after initilization. This PR fixes that.
@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Guray Ozen (grypp)

Changes

The #68728 significantly simplified the accumulator matrix type, making it easier to work with the nvgpu dialect without worrying about the number of required structs, as this information is abstracted away in the nvgpu-to-nvvm transformation.

However, we forgot packing the structs after initialization, causing the accumulator matrix to hold undefined values, which is wrong. This PR addresses that.


Full diff: https://github.com/llvm/llvm-project/pull/69316.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+18-11)
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 2d43230938526b9..91b6a25c6dfc03b 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1576,27 +1576,34 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
   matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder b(op->getLoc(), rewriter);
-    LLVM::LLVMStructType structType =
+    LLVM::LLVMStructType packStructType =
         getTypeConverter()
             ->convertType(op.getMatrixC().getType())
             .cast<LLVM::LLVMStructType>();
-    Type elemType = structType.getBody()
+    Type elemType = packStructType.getBody()
                         .front()
                         .cast<LLVM::LLVMStructType>()
                         .getBody()
                         .front();
     Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
-    Value structValue = b.create<LLVM::UndefOp>(structType);
-    for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
-      auto innerStructType = s.cast<LLVM::LLVMStructType>();
-      int ii = idx;
-      Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
-      for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
-        innerStructValue = b.create<LLVM::InsertValueOp>(
-            innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
+    Value packStruct = b.create<LLVM::UndefOp>(packStructType);
+    SmallVector<Value> innerStructs;
+    // Unpack the structs and set all values to zero
+    for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
+      auto structType = s.cast<LLVM::LLVMStructType>();
+      Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
+      for (unsigned i = 0; i < structType.getBody().size(); ++i) {
+        structValue = b.create<LLVM::InsertValueOp>(
+            structType, structValue, zero, ArrayRef<int64_t>({i}));
       }
+      innerStructs.push_back(structValue);
     }
-    rewriter.replaceOp(op, structValue);
+    // Pack the inner structs into a single struct
+    for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
+      packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
+                                                 packStruct, matrix, idx);
+    }
+    rewriter.replaceOp(op, packStruct);
     return success();
   }
 };

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants