Skip to content

Commit 3d2078d

Browse files
committed
use new type WarpgroupAccumulator
1 parent dc24a32 commit 3d2078d

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -733,12 +733,13 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
733733
The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
734734
in $matrixD to give memref.
735735

736-
[See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
736+
[See the details of register fragment layout for accumulator matrix D]
737+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
737738

738739
Note that, the op must be run with warp group.
739740
}];
740741

741-
let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD,
742+
let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
742743
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
743744

744745
let assemblyFormat = [{

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1517,7 +1517,6 @@ struct NVGPUWarpgroupMmaStoreOpLowering
15171517
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
15181518
RewritePatternSet &patterns) {
15191519
patterns.add<
1520-
NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store`
15211520
NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
15221521
NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
15231522
NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
@@ -1529,6 +1528,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
15291528
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
15301529
NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
15311530
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1531+
NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
15321532
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
15331533
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
15341534
NVGPUMmaSparseSyncLowering>(converter);

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -531,13 +531,16 @@ LogicalResult WarpgroupMmaOp::verify() {
531531
}
532532

533533
LogicalResult WarpgroupMmaStoreOp::verify() {
534-
Type stype =
535-
getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor();
534+
Type stype = getMatrixD()
535+
.front()
536+
.getType()
537+
.cast<WarpgroupAccumulatorType>()
538+
.getFragmented();
536539

537540
for (auto result : getMatrixD()) {
538541
auto resultStype = result.getType()
539-
.cast<WarpgroupResultType>()
540-
.getTensor()
542+
.cast<WarpgroupAccumulatorType>()
543+
.getFragmented()
541544
.dyn_cast<LLVM::LLVMStructType>();
542545
if (!resultStype)
543546
return emitOpError() << "result is " << result.getType()

0 commit comments

Comments
 (0)