Skip to content

Commit a6c940b

Browse files
committed
vulkan: remove PV matrix, helps with register usage
1 parent 876e661 commit a6c940b

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,27 +276,22 @@ void main() {
276276
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
277277
}
278278

279-
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
280-
281-
vec4 PVf[Br][D_per_thread / 4];
282279
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
283280
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
284-
PVf[r][d] = vec4(0.0);
281+
Of[r][d] = eMf[r] * Of[r][d];
285282
}
286283
}
284+
285+
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
286+
287287
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
288288
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
289289
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
290290
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
291-
PVf[r][d] += Pf[r][c] * Vf;
291+
Of[r][d] += Pf[r][c] * Vf;
292292
}
293293
}
294294
}
295-
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
296-
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
297-
Of[r][d] = eMf[r] * Of[r][d] + PVf[r][d];
298-
}
299-
}
300295

301296
barrier();
302297
}

0 commit comments

Comments
 (0)