-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][nvgpu] Fix 'warpgroup.mma.store' index calculation #78413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range.
@llvm/pr-subscribers-mlir Author: Guray Ozen (grypp) ChangesThis PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range. Full diff: https://github.com/llvm/llvm-project/pull/78413.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5..9e4ae219eefd60 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1554,6 +1554,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
+ auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
TypedValue<::mlir::MemRefType> memref) {
Type it = b.getIndexType();
@@ -1570,11 +1571,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
if (offset)
ti = makeAdd(ti, makeConst(offset));
- for (int i = 0; i < 2; ++i) {
+ for (size_t i = 0; i < 2; ++i) {
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
- for (int j = 0; j < 16; ++j) {
+ for (size_t j = 0; j < (structType.getBody().size() / 8); ++j) {
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
- int sIndex = i * 2 + j * 4;
+ size_t sIndex = i * 2 + j * 4;
makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
}
}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bd..ce81fd859fd02a 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1055,6 +1055,34 @@ func.func @warpgroup_mma_store(
return
}
+// CHECK-LABEL: @warpgroup_mma_store_multiplie(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x128xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: memref<64x32xf32, 3>, %[[arg4:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, %[[arg5:[a-zA-Z0-9_]+]]: memref<64x64xf32, 3>)
+func.func @warpgroup_mma_store_multiplie(
+ %result128 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+ %matrixD128: memref<64x128xf32,3>,
+ %result32 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>,
+ %matrixD32: memref<64x32xf32,3>,
+ %result64 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>,
+ %matrixD64: memref<64x64xf32,3>) {
+
+ // CHECK-COUNT-32: memref.store %{{.*}}, %[[arg1]][%{{.*}}, %{{.*}}] : memref<64x128xf32, 3>
+ nvgpu.warpgroup.mma.store %result128, %matrixD128 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ to memref<64x128xf32,3>
+
+
+ // CHECK-COUNT-8: memref.store %{{.*}}, %[[arg3]][%{{.*}}, %{{.*}}] : memref<64x32xf32, 3>
+ nvgpu.warpgroup.mma.store %result32, %matrixD32 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x32xf32>>
+ to memref<64x32xf32,3>
+
+ // CHECK-COUNT-16: memref.store %{{.*}}, %[[arg5]][%{{.*}}, %{{.*}}] : memref<64x64xf32, 3>
+ nvgpu.warpgroup.mma.store %result64, %matrixD64 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x64xf32>>
+ to memref<64x64xf32,3>
+ return
+}
+
func.func @warpgroup_mma_init() {
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
|
@llvm/pr-subscribers-mlir-gpu Author: Guray Ozen (grypp) ChangesThis PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range. Full diff: https://github.com/llvm/llvm-project/pull/78413.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5f..9e4ae219eefd60b 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1554,6 +1554,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
+ auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
TypedValue<::mlir::MemRefType> memref) {
Type it = b.getIndexType();
@@ -1570,11 +1571,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
if (offset)
ti = makeAdd(ti, makeConst(offset));
- for (int i = 0; i < 2; ++i) {
+ for (size_t i = 0; i < 2; ++i) {
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
- for (int j = 0; j < 16; ++j) {
+ for (size_t j = 0; j < (structType.getBody().size() / 8); ++j) {
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
- int sIndex = i * 2 + j * 4;
+ size_t sIndex = i * 2 + j * 4;
makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
}
}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bdd..ce81fd859fd02ae 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1055,6 +1055,34 @@ func.func @warpgroup_mma_store(
return
}
+// CHECK-LABEL: @warpgroup_mma_store_multiplie(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x128xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: memref<64x32xf32, 3>, %[[arg4:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, %[[arg5:[a-zA-Z0-9_]+]]: memref<64x64xf32, 3>)
+func.func @warpgroup_mma_store_multiplie(
+ %result128 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+ %matrixD128: memref<64x128xf32,3>,
+ %result32 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>,
+ %matrixD32: memref<64x32xf32,3>,
+ %result64 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>,
+ %matrixD64: memref<64x64xf32,3>) {
+
+ // CHECK-COUNT-32: memref.store %{{.*}}, %[[arg1]][%{{.*}}, %{{.*}}] : memref<64x128xf32, 3>
+ nvgpu.warpgroup.mma.store %result128, %matrixD128 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ to memref<64x128xf32,3>
+
+
+ // CHECK-COUNT-8: memref.store %{{.*}}, %[[arg3]][%{{.*}}, %{{.*}}] : memref<64x32xf32, 3>
+ nvgpu.warpgroup.mma.store %result32, %matrixD32 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x32xf32>>
+ to memref<64x32xf32,3>
+
+ // CHECK-COUNT-16: memref.store %{{.*}}, %[[arg5]][%{{.*}}, %{{.*}}] : memref<64x64xf32, 3>
+ nvgpu.warpgroup.mma.store %result64, %matrixD64 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x64xf32>>
+ to memref<64x64xf32,3>
+ return
+}
+
func.func @warpgroup_mma_init() {
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
|
This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range.