Skip to content

Commit 876e661

Browse files
committed
vulkan: use vector loads in scalar flash attention shader
1 parent 3a8d954 commit 876e661

File tree

2 files changed

+45
-32
lines changed

2 files changed

+45
-32
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
19111911
auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
19121912

19131913
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
1914-
const uint32_t D_split = std::min(device->subgroup_size, 16u);
1914+
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
1915+
const uint32_t D_lsb = D ^ (D & (D-1));
1916+
uint32_t D_split = std::min(std::min(device->subgroup_size, 16u), D_lsb / 4);
19151917

19161918
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
19171919
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,11 @@ layout (push_constant) uniform parameter {
6464
} p;
6565

6666
layout (binding = 0) readonly buffer Q {float data_q[];};
67+
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
6768
layout (binding = 1) readonly buffer K {float16_t data_k[];};
69+
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
6870
layout (binding = 2) readonly buffer V {float16_t data_v[];};
71+
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
6972
layout (binding = 3) readonly buffer M {float16_t data_m[];};
7073
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
7174

@@ -161,19 +164,19 @@ void main() {
161164
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
162165

163166
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
164-
float Qf[Br][D_per_thread];
167+
vec4 Qf[Br][D_per_thread / 4];
165168
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
166169
if (i * Br + r < N) {
167-
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
168-
Qf[r][d] = float(data_q[q_offset + (i * Br + r) * q_stride + d * D_split + d_tid]) * p.scale;
170+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
171+
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d * D_split + d_tid]) * p.scale;
169172
}
170173
}
171174
}
172175

173-
float Of[Br][D_per_thread];
174-
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
176+
vec4 Of[Br][D_per_thread / 4];
177+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
175178
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
176-
Of[r][d] = 0.0;
179+
Of[r][d] = vec4(0.0);
177180
}
178181
}
179182

@@ -212,10 +215,10 @@ void main() {
212215
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
213216

214217
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
215-
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
216-
float K_Tf = float(data_k[k_offset + (j * Bc + c * cols_per_iter + col_tid) * k_stride + d * D_split + d_tid]);
218+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
219+
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
217220
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
218-
Sf[r][c] += Qf[r][d] * K_Tf;
221+
Sf[r][c] += dot(Qf[r][d], K_Tf);
219222
}
220223
}
221224
}
@@ -275,21 +278,21 @@ void main() {
275278

276279
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
277280

278-
float PVf[Br][D_per_thread];
279-
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
281+
vec4 PVf[Br][D_per_thread / 4];
282+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
280283
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
281-
PVf[r][d] = 0.0;
284+
PVf[r][d] = vec4(0.0);
282285
}
283286
}
284287
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
285-
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
286-
float Vf = float(data_v[v_offset + (j * Bc + c * cols_per_iter + col_tid) * v_stride + d * D_split + d_tid]);
288+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
289+
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
287290
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
288291
PVf[r][d] += Pf[r][c] * Vf;
289292
}
290293
}
291294
}
292-
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
295+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
293296
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
294297
Of[r][d] = eMf[r] * Of[r][d] + PVf[r][d];
295298
}
@@ -337,21 +340,23 @@ void main() {
337340
Lf[r] = tmpsh[d_tid];
338341
barrier();
339342

340-
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
343+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
341344

342345
Of[r][d] = eMf * Of[r][d];
343-
tmpsh[tid] = Of[r][d];
346+
[[unroll]] for (uint32_t c = 0; c < 4; ++c) {
347+
tmpsh[tid] = Of[r][d][c];
344348

345-
barrier();
346-
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
347-
if (tid < s) {
348-
Of[r][d] += tmpsh[tid + s];
349-
tmpsh[tid] = Of[r][d];
349+
barrier();
350+
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
351+
if (tid < s) {
352+
Of[r][d][c] += tmpsh[tid + s];
353+
tmpsh[tid] = Of[r][d][c];
354+
}
355+
barrier();
350356
}
357+
Of[r][d][c] = tmpsh[d_tid];
351358
barrier();
352359
}
353-
Of[r][d] = tmpsh[d_tid];
354-
barrier();
355360
}
356361
}
357362

@@ -363,8 +368,10 @@ void main() {
363368

364369
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
365370
if (r < N) {
366-
for (uint32_t d = 0; d < D_per_thread; ++d) {
367-
perElemOpGqaStore(r, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
371+
for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
372+
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
373+
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
374+
}
368375
}
369376
}
370377
}
@@ -385,7 +392,7 @@ void main() {
385392
Lfrcp[r] = 1.0 / Lf[r];
386393
}
387394

388-
[[unroll]] for (uint32_t d = 0; d < D_per_thread; ++d) {
395+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
389396
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
390397
Of[r][d] *= Lfrcp[r];
391398
}
@@ -396,16 +403,20 @@ void main() {
396403
if (p.gqa_ratio > 1) {
397404
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
398405
if (r < N) {
399-
for (uint32_t d = 0; d < D_per_thread; ++d) {
400-
perElemOpGqaStore(r, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
406+
for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
407+
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
408+
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
409+
}
401410
}
402411
}
403412
}
404413
} else {
405414
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
406415
if (i * Br + r < N) {
407-
for (uint32_t d = 0; d < D_per_thread; ++d) {
408-
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + d * D_split + d_tid] = D_TYPE(Of[r][d]);
416+
for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
417+
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
418+
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
419+
}
409420
}
410421
}
411422
}

0 commit comments

Comments
 (0)