Skip to content

Commit 3d7dcd5

Browse files
kimishpatellarryliu0820
authored andcommitted
{executorch][llama] support mqa
Summary: This diff adds support for multi query attention for sdpa with kv cache Reviewed By: iseeyuan Differential Revision: D56212419
1 parent 1f4b631 commit 3d7dcd5

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,21 @@ void cpu_flash_attention(
219219
int64_t qSize = query.size(2);
220220
int64_t headSize = query.size(3);
221221
int64_t kvSize = value.size(2);
222+
int64_t num_heads_kv = key.size(1);
222223

223224
if (is_with_kv_cache) {
224225
num_head = query.size(2);
226+
num_heads_kv = key.size(2);
225227
qSize = query.size(1);
226228
kvSize = value.size(1);
227229
}
228230

231+
ET_CHECK_MSG(
232+
num_heads_kv <= num_head, "FlashAttention does not support num kv heads > num query heads.Got num query heads=%" PRId64 " num key heads:%" PRId64, num_head, num_heads_kv);
233+
ET_CHECK_MSG(
234+
num_head % num_heads_kv == 0, "FlashAttention: num qyery heads must be divisible by num kv heads but got num query heads=%" PRId64 " and num kv heads=%" PRId64, num_head, num_heads_kv);
235+
int64_t num_reps = num_head / num_heads_kv;
236+
229237
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
230238
if (has_attn_mask) {
231239
/*
@@ -365,6 +373,7 @@ void cpu_flash_attention(
365373
fill_stub(
366374
qk_max_data, -std::numeric_limits<accum_t>::infinity(), qBlockSize);
367375
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
376+
auto j_kv = j / num_reps;
368377
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
369378
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
370379
// Calculate scale * q @ k.T
@@ -376,7 +385,7 @@ void cpu_flash_attention(
376385
qBlockSize,
377386
headSize,
378387
static_cast<accum_t>(1),
379-
k_data + i * kStrideB + j * kStrideH + n * kStrideN,
388+
k_data + i * kStrideB + j_kv * kStrideH + n * kStrideN,
380389
kStrideN,
381390
q_data + i * qStrideB + j * qStrideH + m * qStrideM,
382391
qStrideM,
@@ -460,7 +469,7 @@ void cpu_flash_attention(
460469
qBlockSize,
461470
kvBlockSize,
462471
static_cast<accum_t>(1),
463-
v_data + i * vStrideB + j * vStrideH + n * vStrideN,
472+
v_data + i * vStrideB + j_kv * vStrideH + n * vStrideN,
464473
vStrideN,
465474
conditional_data_ptr(qk_data, qk_reduced_data),
466475
kvBlockSize,

0 commit comments

Comments
 (0)