@@ -196,7 +196,7 @@ void main() {
196
196
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
197
197
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
198
198
199
- for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
199
+ [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
200
200
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
201
201
}
202
202
#else
@@ -209,7 +209,7 @@ void main() {
209
209
}
210
210
#endif
211
211
212
- [[dont_unroll]] for (uint block = start_k; block < end_k; block += BK) {
212
+ for (uint block = start_k; block < end_k; block += BK) {
213
213
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
214
214
215
215
#if defined(DATA_A_F32) || defined(DATA_A_F16)
@@ -506,20 +506,20 @@ void main() {
506
506
pos_b += BK / LOAD_VEC_B;
507
507
508
508
#ifdef COOPMAT
509
- for (uint i = 0; i < BK; i += TK) {
510
- for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
509
+ [[unroll]] for (uint i = 0; i < BK; i += TK) {
510
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
511
511
// Load from shared into cache
512
512
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
513
513
514
- for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
514
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
515
515
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
516
516
517
517
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
518
518
}
519
519
}
520
520
}
521
521
#else
522
- for (uint i = 0; i < BK; i++) {
522
+ [[unroll]] for (uint i = 0; i < BK; i++) {
523
523
// Load from shared into cache
524
524
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
525
525
[[unroll]] for (uint j = 0; j < TM; j++) {
0 commit comments