Skip to content

Commit 2d9ba14

Browse files
committed
Update on "[Executorch][llama] Allow custom sdpa op replacement pass to leverage attention mask"
Previously we assumed that the custom sdpa always does causal attention. This diff adds option to this module swap pass to make custom sdpa leverage attention mask instead of causal. Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/) [ghstack-poisoned]
2 parents e7109f4 + 9435ba7 commit 2d9ba14

File tree

1 file changed

+0
-1
lines changed
  • examples/models/llama/source_transformation

1 file changed

+0
-1
lines changed

examples/models/llama/source_transformation/sdpa.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def forward(
4747
torch._check_is_size(start_pos)
4848
torch._check(start_pos < self.max_context_len)
4949
seq_length = q.size(2)
50-
# pyre-ignore: Incompatible parameter type [6]
5150
mask = mask.narrow(0, start_pos, seq_length)
5251
else:
5352
mask = mask[input_pos]

0 commit comments

Comments
 (0)