@@ -64,8 +64,11 @@ layout (push_constant) uniform parameter {
64
64
} p;
65
65
66
66
layout (binding = 0) readonly buffer Q {float data_q[];};
67
+ layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
67
68
layout (binding = 1) readonly buffer K {float16_t data_k[];};
69
+ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
68
70
layout (binding = 2) readonly buffer V {float16_t data_v[];};
71
+ layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
69
72
layout (binding = 3) readonly buffer M {float16_t data_m[];};
70
73
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
71
74
@@ -161,19 +164,19 @@ void main() {
161
164
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
162
165
163
166
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
164
- float Qf[Br][D_per_thread];
167
+ vec4 Qf[Br][D_per_thread / 4 ];
165
168
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
166
169
if (i * Br + r < N) {
167
- [[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
168
- Qf[r][d] = float(data_q [q_offset + (i * Br + r) * q_stride + d * D_split + d_tid]) * p.scale;
170
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4 ; ++d) {
171
+ Qf[r][d] = vec4(data_qv4 [q_offset / 4 + (i * Br + r) * q_stride / 4 + d * D_split + d_tid]) * p.scale;
169
172
}
170
173
}
171
174
}
172
175
173
- float Of[Br][D_per_thread];
174
- [[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
176
+ vec4 Of[Br][D_per_thread / 4 ];
177
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4 ; ++d) {
175
178
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
176
- Of[r][d] = 0.0;
179
+ Of[r][d] = vec4( 0.0) ;
177
180
}
178
181
}
179
182
@@ -212,10 +215,10 @@ void main() {
212
215
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
213
216
214
217
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
215
- [[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
216
- float K_Tf = float(data_k [k_offset + (j * Bc + c * cols_per_iter + col_tid) * k_stride + d * D_split + d_tid]);
218
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4 ; ++d) {
219
+ vec4 K_Tf = vec4(data_kv4 [k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
217
220
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
218
- Sf[r][c] += Qf[r][d] * K_Tf;
221
+ Sf[r][c] += dot( Qf[r][d], K_Tf) ;
219
222
}
220
223
}
221
224
}
@@ -275,21 +278,21 @@ void main() {
275
278
276
279
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
277
280
278
- float PVf[Br][D_per_thread];
279
- [[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
281
+ vec4 PVf[Br][D_per_thread / 4 ];
282
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4 ; ++d) {
280
283
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
281
- PVf[r][d] = 0.0;
284
+ PVf[r][d] = vec4( 0.0) ;
282
285
}
283
286
}
284
287
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
285
- [[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
286
- float Vf = float(data_v [v_offset + (j * Bc + c * cols_per_iter + col_tid) * v_stride + d * D_split + d_tid]);
288
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4 ; ++d) {
289
+ vec4 Vf = vec4(data_vv4 [v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
287
290
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
288
291
PVf[r][d] += Pf[r][c] * Vf;
289
292
}
290
293
}
291
294
}
292
- [[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
295
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4 ; ++d) {
293
296
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
294
297
Of[r][d] = eMf[r] * Of[r][d] + PVf[r][d];
295
298
}
@@ -337,21 +340,23 @@ void main() {
337
340
Lf[r] = tmpsh[d_tid];
338
341
barrier();
339
342
340
- [[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
343
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4 ; ++d) {
341
344
342
345
Of[r][d] = eMf * Of[r][d];
343
- tmpsh[tid] = Of[r][d];
346
+ [[unroll]] for (uint32_t c = 0; c < 4; ++c) {
347
+ tmpsh[tid] = Of[r][d][c];
344
348
345
- barrier();
346
- [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
347
- if (tid < s) {
348
- Of[r][d] += tmpsh[tid + s];
349
- tmpsh[tid] = Of[r][d];
349
+ barrier();
350
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
351
+ if (tid < s) {
352
+ Of[r][d][c] += tmpsh[tid + s];
353
+ tmpsh[tid] = Of[r][d][c];
354
+ }
355
+ barrier();
350
356
}
357
+ Of[r][d][c] = tmpsh[d_tid];
351
358
barrier();
352
359
}
353
- Of[r][d] = tmpsh[d_tid];
354
- barrier();
355
360
}
356
361
}
357
362
@@ -363,8 +368,10 @@ void main() {
363
368
364
369
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
365
370
if (r < N) {
366
- for (uint32_t d = 0; d < D_per_thread; ++d) {
367
- perElemOpGqaStore(r, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
371
+ for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
372
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
373
+ perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
374
+ }
368
375
}
369
376
}
370
377
}
@@ -385,7 +392,7 @@ void main() {
385
392
Lfrcp[r] = 1.0 / Lf[r];
386
393
}
387
394
388
- [[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
395
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4 ; ++d) {
389
396
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
390
397
Of[r][d] *= Lfrcp[r];
391
398
}
@@ -396,16 +403,20 @@ void main() {
396
403
if (p.gqa_ratio > 1) {
397
404
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
398
405
if (r < N) {
399
- for (uint32_t d = 0; d < D_per_thread; ++d) {
400
- perElemOpGqaStore(r, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
406
+ for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
407
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
408
+ perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
409
+ }
401
410
}
402
411
}
403
412
}
404
413
} else {
405
414
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
406
415
if (i * Br + r < N) {
407
- for (uint32_t d = 0; d < D_per_thread; ++d) {
408
- data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + d * D_split + d_tid] = D_TYPE(Of[r][d]);
416
+ for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
417
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
418
+ data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
419
+ }
409
420
}
410
421
}
411
422
}
0 commit comments