You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[mlir][nvgpu] Improve WarpgroupAccumulator type to simplify IR (#68728)
`WarpgroupAccumulator` (or `!nvgpu.warpgroup.accumulator`) is a type
that keeps the accumulator matrix that is used by warp-group level
matrix multiplication. It is handy to have a special type for that as
the matrix is distributed among the threads of the warp-group. However,
current transformations requires to create and use multiple
`WarpgroupAccumulator` if the shape of GEMM is larger than the supported
shape of `wgmma.mma_async` instruction. This makes IR looks dense.
This PR improves the transformation of `WarpgroupAccumulator` type in
every nvgpu Op that uses it.
**Example: Current GEMM in NVGPU-IR**
```
// Init
%m1, %m2 = nvgpu.warpgroup.mma.init.accumulator ->
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
// GEMM
%r1, %r2 = nvgpu.warpgroup.mma %descA, %descB, %m1, %m2 {transposeB}:
!nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
->
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
// Epilogue
nvgpu.warpgroup.mma.store [%r1, %r2] to %sharedMemoryBuffer
: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
into memref<128x128xf32,3>
```
**Example: This PR simplifies the IR as below:**
```
// Init
%m = nvgpu.warpgroup.mma.init.accumulator ->
!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
// GEMM
%r1 = nvgpu.warpgroup.mma %descA, %descB, %m1 {transposeB}:
!nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
->
!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
// Epilogue
nvgpu.warpgroup.mma.store [%matrixD1, %matrixD2] to %sharedMemoryBuffer
: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
into memref<128x128xf32,3>
```
0 commit comments