Skip to content

Commit 21830c9

Browse files
authored
[mlir][nvgpu] Fix 'warpgroup.mma.store' index calculation (#78413)
This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range.
1 parent a31a600 commit 21830c9

File tree

2 files changed

+152
-10
lines changed

2 files changed

+152
-10
lines changed

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,12 +1548,6 @@ struct NVGPUWarpgroupMmaStoreOpLowering
15481548
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
15491549
};
15501550

1551-
Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1552-
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1553-
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1554-
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1555-
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1556-
15571551
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
15581552
TypedValue<::mlir::MemRefType> memref) {
15591553
Type it = b.getIndexType();
@@ -1566,16 +1560,34 @@ struct NVGPUWarpgroupMmaStoreOpLowering
15661560
b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
15671561
};
15681562

1563+
Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1564+
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1565+
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1566+
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1567+
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1568+
15691569
Value tj = makeMul(lane4modId, c2);
15701570
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
15711571
if (offset)
15721572
ti = makeAdd(ti, makeConst(offset));
1573-
for (int i = 0; i < 2; ++i) {
1573+
1574+
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
1575+
1576+
// Number of 32-bit registers owns per thread
1577+
constexpr unsigned numAdjacentRegisters = 2;
1578+
// Number of 8x8 matrices one below another per warp
1579+
constexpr unsigned numStackedMatrices = 2;
1580+
1581+
size_t storeCount = (structType.getBody().size() /
1582+
(numStackedMatrices * numAdjacentRegisters));
1583+
1584+
for (size_t i = 0; i < numStackedMatrices; ++i) {
15741585
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1575-
for (int j = 0; j < 16; ++j) {
1586+
for (size_t j = 0; j < storeCount; ++j) {
15761587
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1577-
int sIndex = i * 2 + j * 4;
1578-
makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
1588+
size_t structIndex = (i * numAdjacentRegisters) +
1589+
(j * (numStackedMatrices * numAdjacentRegisters));
1590+
makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
15791591
}
15801592
}
15811593
}

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,136 @@ func.func @warpgroup_mma_store(
10551055
return
10561056
}
10571057

1058+
// CHECK-LABEL: @warpgroup_mma_store_multiple
1059+
func.func @warpgroup_mma_store_multiple(
1060+
%shmem_m64n8k : memref<64x8xf32>,
1061+
%shmem_m64n16k : memref<64x16xf32>,
1062+
%shmem_m64n24k : memref<64x24xf32>,
1063+
%shmem_m64n32k : memref<64x32xf32>,
1064+
%shmem_m64n40k : memref<64x40xf32>,
1065+
%shmem_m64n48k : memref<64x48xf32>,
1066+
%shmem_m64n56k : memref<64x56xf32>,
1067+
%shmem_m64n64k : memref<64x64xf32>,
1068+
%shmem_m64n72k : memref<64x72xf32>,
1069+
%shmem_m64n80k : memref<64x80xf32>,
1070+
%shmem_m64n88k : memref<64x88xf32>,
1071+
%shmem_m64n96k : memref<64x96xf32>,
1072+
%shmem_m64n104k : memref<64x104xf32>,
1073+
%shmem_m64n112k : memref<64x112xf32>,
1074+
%shmem_m64n120k : memref<64x120xf32>,
1075+
%shmem_m64n128k : memref<64x128xf32>,
1076+
%shmem_m64n136k : memref<64x136xf32>,
1077+
%shmem_m64n144k : memref<64x144xf32>,
1078+
%shmem_m64n152k : memref<64x152xf32>,
1079+
%shmem_m64n160k : memref<64x160xf32>,
1080+
%shmem_m64n168k : memref<64x168xf32>,
1081+
%shmem_m64n176k : memref<64x176xf32>,
1082+
%shmem_m64n184k : memref<64x184xf32>,
1083+
%shmem_m64n192k : memref<64x192xf32>,
1084+
%shmem_m64n200k : memref<64x200xf32>,
1085+
%shmem_m64n208k : memref<64x208xf32>,
1086+
%shmem_m64n216k : memref<64x216xf32>,
1087+
%shmem_m64n224k : memref<64x224xf32>,
1088+
%shmem_m64n232k : memref<64x232xf32>,
1089+
%shmem_m64n240k : memref<64x240xf32>,
1090+
%shmem_m64n248k : memref<64x248xf32>,
1091+
%shmem_m64n256k : memref<64x256xf32>,
1092+
%res_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>,
1093+
%res_m64n24k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x24xf32>>,
1094+
%res_m64n32k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>,
1095+
%res_m64n40k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x40xf32>>,
1096+
%res_m64n48k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x48xf32>>,
1097+
%res_m64n56k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x56xf32>>,
1098+
%res_m64n64k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>,
1099+
%res_m64n72k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x72xf32>>,
1100+
%res_m64n80k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x80xf32>>,
1101+
%res_m64n88k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x88xf32>>,
1102+
%res_m64n96k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x96xf32>>,
1103+
%res_m64n104k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x104xf32>>,
1104+
%res_m64n112k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x112xf32>>,
1105+
%res_m64n120k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x120xf32>>,
1106+
%res_m64n128k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
1107+
%res_m64n136k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x136xf32>>,
1108+
%res_m64n144k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x144xf32>>,
1109+
%res_m64n152k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x152xf32>>,
1110+
%res_m64n160k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x160xf32>>,
1111+
%res_m64n168k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x168xf32>>,
1112+
%res_m64n176k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x176xf32>>,
1113+
%res_m64n184k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x184xf32>>,
1114+
%res_m64n192k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x192xf32>>,
1115+
%res_m64n200k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x200xf32>>,
1116+
%res_m64n208k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x208xf32>>,
1117+
%res_m64n216k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x216xf32>>,
1118+
%res_m64n224k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x224xf32>>,
1119+
%res_m64n232k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x232xf32>>,
1120+
%res_m64n240k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x240xf32>>,
1121+
%res_m64n248k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x248xf32>>,
1122+
%res_m64n256k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x256xf32>>) {
1123+
// CHECK-COUNT-8: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32>
1124+
// CHECK-COUNT-12: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x24xf32>
1125+
// CHECK-COUNT-16: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x32xf32>
1126+
// CHECK-COUNT-20: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x40xf32>
1127+
// CHECK-COUNT-24: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x48xf32>
1128+
// CHECK-COUNT-28: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x56xf32>
1129+
// CHECK-COUNT-32: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x64xf32>
1130+
// CHECK-COUNT-36: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x72xf32>
1131+
// CHECK-COUNT-40: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x80xf32>
1132+
// CHECK-COUNT-44: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x88xf32>
1133+
// CHECK-COUNT-48: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x96xf32>
1134+
// CHECK-COUNT-52: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x104xf32>
1135+
// CHECK-COUNT-56: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x112xf32>
1136+
// CHECK-COUNT-60: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x120xf32>
1137+
// CHECK-COUNT-64: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x128xf32>
1138+
// CHECK-COUNT-68: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x136xf32>
1139+
// CHECK-COUNT-72: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x144xf32>
1140+
// CHECK-COUNT-76: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x152xf32>
1141+
// CHECK-COUNT-80: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x160xf32>
1142+
// CHECK-COUNT-84: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x168xf32>
1143+
// CHECK-COUNT-88: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x176xf32>
1144+
// CHECK-COUNT-92: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x184xf32>
1145+
// CHECK-COUNT-96: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x192xf32>
1146+
// CHECK-COUNT-100: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x200xf32>
1147+
// CHECK-COUNT-104: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x208xf32>
1148+
// CHECK-COUNT-108: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x216xf32>
1149+
// CHECK-COUNT-112: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x224xf32>
1150+
// CHECK-COUNT-116: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x232xf32>
1151+
// CHECK-COUNT-120: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x240xf32>
1152+
// CHECK-COUNT-124: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x248xf32>
1153+
// CHECK-COUNT-128: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x256xf32>
1154+
nvgpu.warpgroup.mma.store %res_m64n16k, %shmem_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to memref<64x16xf32>
1155+
nvgpu.warpgroup.mma.store %res_m64n24k, %shmem_m64n24k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x24xf32>> to memref<64x24xf32>
1156+
nvgpu.warpgroup.mma.store %res_m64n32k, %shmem_m64n32k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>> to memref<64x32xf32>
1157+
nvgpu.warpgroup.mma.store %res_m64n40k, %shmem_m64n40k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x40xf32>> to memref<64x40xf32>
1158+
nvgpu.warpgroup.mma.store %res_m64n48k, %shmem_m64n48k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x48xf32>> to memref<64x48xf32>
1159+
nvgpu.warpgroup.mma.store %res_m64n56k, %shmem_m64n56k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x56xf32>> to memref<64x56xf32>
1160+
nvgpu.warpgroup.mma.store %res_m64n64k, %shmem_m64n64k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>> to memref<64x64xf32>
1161+
nvgpu.warpgroup.mma.store %res_m64n72k, %shmem_m64n72k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x72xf32>> to memref<64x72xf32>
1162+
nvgpu.warpgroup.mma.store %res_m64n80k, %shmem_m64n80k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x80xf32>> to memref<64x80xf32>
1163+
nvgpu.warpgroup.mma.store %res_m64n88k, %shmem_m64n88k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x88xf32>> to memref<64x88xf32>
1164+
nvgpu.warpgroup.mma.store %res_m64n96k, %shmem_m64n96k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x96xf32>> to memref<64x96xf32>
1165+
nvgpu.warpgroup.mma.store %res_m64n104k, %shmem_m64n104k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x104xf32>> to memref<64x104xf32>
1166+
nvgpu.warpgroup.mma.store %res_m64n112k, %shmem_m64n112k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x112xf32>> to memref<64x112xf32>
1167+
nvgpu.warpgroup.mma.store %res_m64n120k, %shmem_m64n120k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x120xf32>> to memref<64x120xf32>
1168+
nvgpu.warpgroup.mma.store %res_m64n128k, %shmem_m64n128k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to memref<64x128xf32>
1169+
nvgpu.warpgroup.mma.store %res_m64n136k, %shmem_m64n136k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x136xf32>> to memref<64x136xf32>
1170+
nvgpu.warpgroup.mma.store %res_m64n144k, %shmem_m64n144k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x144xf32>> to memref<64x144xf32>
1171+
nvgpu.warpgroup.mma.store %res_m64n152k, %shmem_m64n152k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x152xf32>> to memref<64x152xf32>
1172+
nvgpu.warpgroup.mma.store %res_m64n160k, %shmem_m64n160k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x160xf32>> to memref<64x160xf32>
1173+
nvgpu.warpgroup.mma.store %res_m64n168k, %shmem_m64n168k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x168xf32>> to memref<64x168xf32>
1174+
nvgpu.warpgroup.mma.store %res_m64n176k, %shmem_m64n176k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x176xf32>> to memref<64x176xf32>
1175+
nvgpu.warpgroup.mma.store %res_m64n184k, %shmem_m64n184k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x184xf32>> to memref<64x184xf32>
1176+
nvgpu.warpgroup.mma.store %res_m64n192k, %shmem_m64n192k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x192xf32>> to memref<64x192xf32>
1177+
nvgpu.warpgroup.mma.store %res_m64n200k, %shmem_m64n200k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x200xf32>> to memref<64x200xf32>
1178+
nvgpu.warpgroup.mma.store %res_m64n208k, %shmem_m64n208k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x208xf32>> to memref<64x208xf32>
1179+
nvgpu.warpgroup.mma.store %res_m64n216k, %shmem_m64n216k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x216xf32>> to memref<64x216xf32>
1180+
nvgpu.warpgroup.mma.store %res_m64n224k, %shmem_m64n224k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x224xf32>> to memref<64x224xf32>
1181+
nvgpu.warpgroup.mma.store %res_m64n232k, %shmem_m64n232k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x232xf32>> to memref<64x232xf32>
1182+
nvgpu.warpgroup.mma.store %res_m64n240k, %shmem_m64n240k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x240xf32>> to memref<64x240xf32>
1183+
nvgpu.warpgroup.mma.store %res_m64n248k, %shmem_m64n248k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x248xf32>> to memref<64x248xf32>
1184+
nvgpu.warpgroup.mma.store %res_m64n256k, %shmem_m64n256k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x256xf32>> to memref<64x256xf32>
1185+
return
1186+
}
1187+
10581188
func.func @warpgroup_mma_init() {
10591189
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
10601190
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>

0 commit comments

Comments
 (0)