You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: examples/models/llama2/custom_ops/op_sdpa.cpp
+11-2Lines changed: 11 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -219,13 +219,21 @@ void cpu_flash_attention(
219
219
int64_t qSize = query.size(2);
220
220
int64_t headSize = query.size(3);
221
221
int64_t kvSize = value.size(2);
222
+
int64_t num_heads_kv = key.size(1);
222
223
223
224
if (is_with_kv_cache) {
224
225
num_head = query.size(2);
226
+
num_heads_kv = key.size(2);
225
227
qSize = query.size(1);
226
228
kvSize = value.size(1);
227
229
}
228
230
231
+
ET_CHECK_MSG(
232
+
num_heads_kv <= num_head, "FlashAttention does not support num kv heads > num query heads.Got num query heads=%" PRId64 " num key heads:%" PRId64, num_head, num_heads_kv);
233
+
ET_CHECK_MSG(
234
+
num_head % num_heads_kv == 0, "FlashAttention: num qyery heads must be divisible by num kv heads but got num query heads=%" PRId64 " and num kv heads=%" PRId64, num_head, num_heads_kv);
0 commit comments