Skip to content

Commit 9622fbe

Browse files
committed
Vulkan: Unroll more loops for more mul mat mat performance
1 parent a0deeee commit 9622fbe

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ void main() {
196196
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
197197
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
198198

199-
for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
199+
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
200200
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
201201
}
202202
#else
@@ -209,7 +209,7 @@ void main() {
209209
}
210210
#endif
211211

212-
[[dont_unroll]] for (uint block = start_k; block < end_k; block += BK) {
212+
for (uint block = start_k; block < end_k; block += BK) {
213213
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
214214

215215
#if defined(DATA_A_F32) || defined(DATA_A_F16)
@@ -506,20 +506,20 @@ void main() {
506506
pos_b += BK / LOAD_VEC_B;
507507

508508
#ifdef COOPMAT
509-
for (uint i = 0; i < BK; i += TK) {
510-
for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
509+
[[unroll]] for (uint i = 0; i < BK; i += TK) {
510+
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
511511
// Load from shared into cache
512512
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
513513

514-
for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
514+
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
515515
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
516516

517517
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
518518
}
519519
}
520520
}
521521
#else
522-
for (uint i = 0; i < BK; i++) {
522+
[[unroll]] for (uint i = 0; i < BK; i++) {
523523
// Load from shared into cache
524524
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
525525
[[unroll]] for (uint j = 0; j < TM; j++) {

0 commit comments

Comments
 (0)