Skip to content

Commit 615958f

Browse files
committed
vulkan: for scalar FA, select between 1 and 8 rows
1 parent 00784e3 commit 615958f

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,7 +1590,8 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
15901590

15911591
// number of rows/cols for flash attention shader
15921592
static constexpr uint32_t flash_attention_num_small_rows = 32;
1593-
static constexpr uint32_t scalar_flash_attention_num_small_rows = 8;
1593+
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1594+
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
15941595

15951596
static uint32_t get_fa_num_small_rows(bool scalar) {
15961597
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
@@ -1599,8 +1600,16 @@ static uint32_t get_fa_num_small_rows(bool scalar) {
15991600
static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
16001601
GGML_UNUSED(clamp);
16011602

1603+
if (scalar) {
1604+
if (small_rows) {
1605+
return {scalar_flash_attention_num_small_rows, 64};
1606+
} else {
1607+
return {scalar_flash_attention_num_large_rows, 32};
1608+
}
1609+
}
1610+
16021611
// small rows, large cols
1603-
if (small_rows || scalar) {
1612+
if (small_rows) {
16041613
return {get_fa_num_small_rows(scalar), 32};
16051614
}
16061615

@@ -5729,8 +5738,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
57295738
assert(q->type == GGML_TYPE_F32);
57305739
assert(k->type == v->type);
57315740

5732-
vk_pipeline *pipelines;
57335741
bool scalar = !ctx->device->coopmat2;
5742+
5743+
uint32_t gqa_ratio = 1;
5744+
uint32_t qk_ratio = neq2 / nek2;
5745+
uint32_t workgroups_x = (uint32_t)neq1;
5746+
uint32_t workgroups_y = (uint32_t)neq2;
5747+
uint32_t workgroups_z = (uint32_t)neq3;
5748+
5749+
// For scalar FA, we can use the "large" size to accommodate qga.
5750+
// For coopmat FA, we always use the small size (which is still pretty large for gqa).
5751+
const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
5752+
5753+
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
5754+
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5755+
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
5756+
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5757+
// and change addressing calculations to index Q's dimension 2.
5758+
gqa_ratio = qk_ratio;
5759+
N = gqa_ratio;
5760+
workgroups_y /= N;
5761+
}
5762+
5763+
vk_pipeline *pipelines;
57345764
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
57355765
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
57365766
bool small_rows = N <= get_fa_num_small_rows(scalar);
@@ -5776,24 +5806,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
57765806
vk_pipeline pipeline = pipelines[aligned];
57775807
assert(pipeline);
57785808

5779-
uint32_t gqa_ratio = 1;
5780-
uint32_t qk_ratio = neq2 / nek2;
5781-
uint32_t workgroups_x = (uint32_t)neq1;
5782-
uint32_t workgroups_y = (uint32_t)neq2;
5783-
uint32_t workgroups_z = (uint32_t)neq3;
5784-
5785-
const uint32_t max_gqa = get_fa_num_small_rows(scalar);
5786-
5787-
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
5788-
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5789-
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
5790-
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5791-
// and change addressing calculations to index Q's dimension 2.
5792-
gqa_ratio = qk_ratio;
5793-
N = gqa_ratio;
5794-
workgroups_y /= N;
5795-
}
5796-
57975809
uint32_t split_kv = KV;
57985810
uint32_t split_k = 1;
57995811

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,9 @@ void main() {
295295
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
296296
uint32_t c = (idx + tid) % Bc;
297297
uint32_t r = (idx + tid) / Bc;
298-
masksh[c][r] = data_m[(i * Br + r) * m_stride + (j * Bc + c)];
298+
if (idx + tid < Bc * Br) {
299+
masksh[c][r] = data_m[(i * Br + r) * m_stride + (j * Bc + c)];
300+
}
299301
}
300302
barrier();
301303

0 commit comments

Comments
 (0)