Skip to content

Commit e6b5f52

Browse files
committed
[Executorch][llama] bug fix for custom sdpa for attention bias
When using attention bias dont override seq length for causal attention Differential Revision: [D73222733](https://our.internmc.facebook.com/intern/diff/D73222733/) [ghstack-poisoned]
1 parent 97c6f04 commit e6b5f52

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,8 @@ Tensor& custom_sdpa_out_impl(
400400

401401
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
402402

403-
const int64_t num_keys_for_causal_attention = start_pos + seq_len;
403+
const int64_t num_keys_for_causal_attention =
404+
attn_mask.has_value() ? -1 : start_pos + seq_len;
404405

405406
ET_KERNEL_CHECK(
406407
ctx,

0 commit comments

Comments
 (0)