@@ -202,7 +202,7 @@ void main() {
202
202
#endif
203
203
204
204
#ifdef COOPMAT
205
- coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
205
+ coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row] ;
206
206
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
207
207
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
208
208
@@ -725,12 +725,12 @@ void main() {
725
725
[[unroll]] for (uint i = 0; i < BK; i += TK) {
726
726
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
727
727
// Load from shared into cache
728
- coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
729
-
730
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
731
- coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
732
-
733
- sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
728
+ coopMatLoad(cache_a[cm_row] , buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
729
+ }
730
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
731
+ coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
732
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
733
+ sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row] , cache_b, sums[cm_col * cms_per_row + cm_row]);
734
734
}
735
735
}
736
736
}
0 commit comments