Skip to content

Commit 6388c87

Browse files
committed
[Executorch][llama] Allow custom sdpa op replacement pass to leverage attention mask
Pull Request resolved: #10285 Previously we assumed that the custom sdpa always does causal attention. This diff adds option to this module swap pass to make custom sdpa leverage attention mask instead of causal. ghstack-source-id: 278977227 @exported-using-ghexport Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/)
1 parent 7584a22 commit 6388c87

File tree

1 file changed

+49
-14
lines changed
  • examples/models/llama/source_transformation

1 file changed

+49
-14
lines changed

examples/models/llama/source_transformation/sdpa.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,15 @@ class SDPACustom(torch.nn.Module):
2222
def __init__(
2323
self,
2424
dim: int,
25+
max_context_len,
26+
enable_dynamic_shape,
27+
use_attention_mask: bool = False,
2528
):
2629
super().__init__()
2730
self.dim = dim
31+
self.max_context_len = max_context_len
32+
self.use_attention_mask = use_attention_mask
33+
self.enable_dynamic_shape = enable_dynamic_shape
2834

2935
def forward(
3036
self,
@@ -36,6 +42,15 @@ def forward(
3642
seqlen,
3743
mask,
3844
):
45+
if self.enable_dynamic_shape:
46+
start_pos = input_pos[-1].item()
47+
torch._check_is_size(start_pos)
48+
torch._check(start_pos < self.max_context_len)
49+
seq_length = q.size(2)
50+
mask = mask.narrow(0, start_pos, seq_length)
51+
else:
52+
mask = mask[input_pos]
53+
3954
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
4055
k = k.transpose(1, 2)
4156
v = v.transpose(1, 2)
@@ -47,34 +62,54 @@ def forward(
4762
k = k.to(dtype=torch.float)
4863
v = v.to(dtype=torch.float)
4964

50-
output = torch.ops.llama.custom_sdpa(
51-
q,
52-
k,
53-
v,
54-
input_pos[0].item(),
55-
None, # Attention mask
56-
0, # dropout probability. Ignored by the code
57-
True, # is_causal
58-
)
65+
if self.use_attention_mask:
66+
output = torch.ops.llama.custom_sdpa(
67+
q,
68+
k,
69+
v,
70+
input_pos[0].item(),
71+
mask, # Attention mask
72+
0, # dropout probability. Ignored by the code
73+
False, # is_causal
74+
)
75+
else:
76+
output = torch.ops.llama.custom_sdpa(
77+
q,
78+
k,
79+
v,
80+
input_pos[0].item(),
81+
None, # Attention mask
82+
0, # dropout probability. Ignored by the code
83+
True, # is_causal
84+
)
5985
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
6086

6187

62-
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
88+
def _replace_sdpa_with_custom_op(
89+
module: torch.nn.Module, use_attention_mask: bool = False
90+
):
6391
for name, child in module.named_children():
6492
if isinstance(child, SDPA):
6593
setattr(
6694
module,
6795
name,
68-
SDPACustom(child.dim),
96+
SDPACustom(
97+
child.dim,
98+
child.max_context_len,
99+
child.enable_dynamic_shape,
100+
use_attention_mask=use_attention_mask,
101+
),
69102
)
70103
else:
71-
_replace_sdpa_with_custom_op(child)
104+
_replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask)
72105

73106

74-
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
107+
def replace_sdpa_with_custom_op(
108+
module: torch.nn.Module, use_attention_mask: bool = False
109+
) -> torch.nn.Module:
75110
from executorch.extension.llm.custom_ops import custom_ops # noqa
76111

77-
_replace_sdpa_with_custom_op(module)
112+
_replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask)
78113
return module
79114

80115

0 commit comments

Comments
 (0)