Skip to content

[mlir][nvgpu] Fix 'warpgroup.mma.store' index calculation #78413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 22, 2024

Conversation

grypp
Copy link
Member

@grypp grypp commented Jan 17, 2024

This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range.

This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range.
@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2024

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

Changes

This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range.


Full diff: https://github.com/llvm/llvm-project/pull/78413.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4-3)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+28)
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5..9e4ae219eefd60 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1554,6 +1554,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
     Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
     Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
 
+    auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
     auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
                                    TypedValue<::mlir::MemRefType> memref) {
       Type it = b.getIndexType();
@@ -1570,11 +1571,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering
     Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
     if (offset)
       ti = makeAdd(ti, makeConst(offset));
-    for (int i = 0; i < 2; ++i) {
+    for (size_t i = 0; i < 2; ++i) {
       Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
-      for (int j = 0; j < 16; ++j) {
+      for (size_t j = 0; j < (structType.getBody().size() / 8); ++j) {
         Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
-        int sIndex = i * 2 + j * 4;
+        size_t sIndex = i * 2 + j * 4;
         makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
       }
     }
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bd..ce81fd859fd02a 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1055,6 +1055,34 @@ func.func @warpgroup_mma_store(
   return 
 }
 
+// CHECK-LABEL: @warpgroup_mma_store_multiplie(  
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x128xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: memref<64x32xf32, 3>, %[[arg4:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, %[[arg5:[a-zA-Z0-9_]+]]: memref<64x64xf32, 3>)
+func.func @warpgroup_mma_store_multiplie(
+    %result128 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, 
+    %matrixD128: memref<64x128xf32,3>,
+    %result32 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, 
+    %matrixD32: memref<64x32xf32,3>,
+    %result64 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, 
+    %matrixD64: memref<64x64xf32,3>) {
+  
+  // CHECK-COUNT-32:  memref.store %{{.*}}, %[[arg1]][%{{.*}}, %{{.*}}] : memref<64x128xf32, 3>
+  nvgpu.warpgroup.mma.store %result128, %matrixD128 : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> 
+    to memref<64x128xf32,3>
+
+
+  // CHECK-COUNT-8: memref.store %{{.*}}, %[[arg3]][%{{.*}}, %{{.*}}] : memref<64x32xf32, 3>
+  nvgpu.warpgroup.mma.store %result32, %matrixD32 : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x32xf32>> 
+    to memref<64x32xf32,3>
+
+  // CHECK-COUNT-16: memref.store %{{.*}}, %[[arg5]][%{{.*}}, %{{.*}}] : memref<64x64xf32, 3>
+  nvgpu.warpgroup.mma.store %result64, %matrixD64 : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x64xf32>> 
+    to memref<64x64xf32,3>
+  return 
+}
+
 func.func @warpgroup_mma_init() {
   //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
   //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>

@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2024

@llvm/pr-subscribers-mlir-gpu

Author: Guray Ozen (grypp)

Changes

This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range.


Full diff: https://github.com/llvm/llvm-project/pull/78413.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4-3)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+28)
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5f..9e4ae219eefd60b 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1554,6 +1554,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
     Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
     Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
 
+    auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
     auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
                                    TypedValue<::mlir::MemRefType> memref) {
       Type it = b.getIndexType();
@@ -1570,11 +1571,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering
     Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
     if (offset)
       ti = makeAdd(ti, makeConst(offset));
-    for (int i = 0; i < 2; ++i) {
+    for (size_t i = 0; i < 2; ++i) {
       Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
-      for (int j = 0; j < 16; ++j) {
+      for (size_t j = 0; j < (structType.getBody().size() / 8); ++j) {
         Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
-        int sIndex = i * 2 + j * 4;
+        size_t sIndex = i * 2 + j * 4;
         makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
       }
     }
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bdd..ce81fd859fd02ae 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1055,6 +1055,34 @@ func.func @warpgroup_mma_store(
   return 
 }
 
+// CHECK-LABEL: @warpgroup_mma_store_multiplie(  
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x128xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: memref<64x32xf32, 3>, %[[arg4:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, %[[arg5:[a-zA-Z0-9_]+]]: memref<64x64xf32, 3>)
+func.func @warpgroup_mma_store_multiplie(
+    %result128 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, 
+    %matrixD128: memref<64x128xf32,3>,
+    %result32 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, 
+    %matrixD32: memref<64x32xf32,3>,
+    %result64 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, 
+    %matrixD64: memref<64x64xf32,3>) {
+  
+  // CHECK-COUNT-32:  memref.store %{{.*}}, %[[arg1]][%{{.*}}, %{{.*}}] : memref<64x128xf32, 3>
+  nvgpu.warpgroup.mma.store %result128, %matrixD128 : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> 
+    to memref<64x128xf32,3>
+
+
+  // CHECK-COUNT-8: memref.store %{{.*}}, %[[arg3]][%{{.*}}, %{{.*}}] : memref<64x32xf32, 3>
+  nvgpu.warpgroup.mma.store %result32, %matrixD32 : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x32xf32>> 
+    to memref<64x32xf32,3>
+
+  // CHECK-COUNT-16: memref.store %{{.*}}, %[[arg5]][%{{.*}}, %{{.*}}] : memref<64x64xf32, 3>
+  nvgpu.warpgroup.mma.store %result64, %matrixD64 : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x64xf32>> 
+    to memref<64x64xf32,3>
+  return 
+}
+
 func.func @warpgroup_mma_init() {
   //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
   //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>

@grypp grypp requested a review from jpienaar January 19, 2024 16:03
@grypp grypp merged commit 21830c9 into llvm:main Jan 22, 2024
@grypp grypp deleted the fix-epilogue branch January 22, 2024 07:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants