Skip to content

Commit c4ba84d

Browse files
authored
[mlir][nvgpu] Fix packing accumlator matrix (#69316)
The #68728 significantly simplified the accumulator matrix type, making it easier to work with the nvgpu dialect without worrying about the number of required structs, as this information is abstracted away in the nvgpu-to-nvvm transformation. However, we forgot packing the structs after initialization, causing the accumulator matrix to hold undefined values, which is wrong. This PR addresses that.
1 parent 6338932 commit c4ba84d

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,27 +1578,34 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
15781578
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
15791579
ConversionPatternRewriter &rewriter) const override {
15801580
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1581-
LLVM::LLVMStructType structType =
1581+
LLVM::LLVMStructType packStructType =
15821582
getTypeConverter()
15831583
->convertType(op.getMatrixC().getType())
15841584
.cast<LLVM::LLVMStructType>();
1585-
Type elemType = structType.getBody()
1585+
Type elemType = packStructType.getBody()
15861586
.front()
15871587
.cast<LLVM::LLVMStructType>()
15881588
.getBody()
15891589
.front();
15901590
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1591-
Value structValue = b.create<LLVM::UndefOp>(structType);
1592-
for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
1593-
auto innerStructType = s.cast<LLVM::LLVMStructType>();
1594-
int ii = idx;
1595-
Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
1596-
for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
1597-
innerStructValue = b.create<LLVM::InsertValueOp>(
1598-
innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
1591+
Value packStruct = b.create<LLVM::UndefOp>(packStructType);
1592+
SmallVector<Value> innerStructs;
1593+
// Unpack the structs and set all values to zero
1594+
for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1595+
auto structType = s.cast<LLVM::LLVMStructType>();
1596+
Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1597+
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1598+
structValue = b.create<LLVM::InsertValueOp>(
1599+
structType, structValue, zero, ArrayRef<int64_t>({i}));
15991600
}
1601+
innerStructs.push_back(structValue);
16001602
}
1601-
rewriter.replaceOp(op, structValue);
1603+
// Pack the inner structs into a single struct
1604+
for (auto [idx, matrix] : llvm::enumerate(innerStructs)) {
1605+
packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1606+
packStruct, matrix, idx);
1607+
}
1608+
rewriter.replaceOp(op, packStruct);
16021609
return success();
16031610
}
16041611
};

0 commit comments

Comments
 (0)