Skip to content

Commit b341223

Browse files
cccclaifacebook-github-bot
authored andcommitted
move mask as sdpa input instead of attribute (#3036)
Summary: Pull Request resolved: #3036 sdpa (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) input is taking attention mask as input, refactor the sdpa module input closer to the sdpa input ghstack-source-id: 222650466 exported-using-ghexport Reviewed By: mergennachin Differential Revision: D56119739 fbshipit-source-id: d9adda66e540abc518b7ffb6a5ebd2aab1626b3b
1 parent f729b2d commit b341223

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,10 @@ class SDPACustom(torch.nn.Module):
9696
def __init__(
9797
self,
9898
kv_cache: KVCache,
99-
mask,
10099
dim: int,
101100
):
102101
super().__init__()
103102
self.kv_cache = kv_cache
104-
self.mask = mask
105103
self.dim = dim
106104

107105
def forward(
@@ -112,6 +110,7 @@ def forward(
112110
v: torch.Tensor,
113111
bsz,
114112
seqlen,
113+
mask,
115114
):
116115
output = torch.ops.llama.sdpa_with_kv_cache(
117116
q,
@@ -131,7 +130,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
131130
setattr(
132131
module,
133132
name,
134-
SDPACustom(child.kv_cache, child.mask, child.dim),
133+
SDPACustom(child.kv_cache, child.dim),
135134
)
136135
else:
137136
_replace_sdpa_with_custom_op(child)

examples/models/llama2/llama_transformer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,14 @@ class SDPA(nn.Module):
213213
def __init__(
214214
self,
215215
kv_cache: KVCache,
216-
mask,
217216
dim: int,
217+
head_dim: int,
218218
n_rep: int,
219219
):
220220
super().__init__()
221221
self.kv_cache = kv_cache
222-
self.mask = mask
223222
self.dim = dim
223+
self.head_dim = head_dim
224224
self.n_rep = n_rep
225225

226226
def forward(
@@ -231,17 +231,18 @@ def forward(
231231
v: torch.Tensor,
232232
bsz,
233233
seqlen,
234+
mask: torch.Tensor,
234235
) -> torch.Tensor:
235236
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
236237
k = k.transpose(1, 2)
237238
v = v.transpose(1, 2)
238239

239240
k, v = self.kv_cache.update(input_pos, k, v)
240-
mask = self.mask[None, None, input_pos]
241+
attn_mask = mask[None, None, input_pos]
241242

242243
k = k.repeat_interleave(self.n_rep, dim=1)
243244
v = v.repeat_interleave(self.n_rep, dim=1)
244-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
245+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
245246

246247
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
247248

@@ -287,10 +288,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
287288
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
288289
)
289290
self.SDPA = SDPA(
290-
self.kv_cache,
291-
self.mask,
292-
self.dim,
293-
self.n_rep,
291+
kv_cache=self.kv_cache,
292+
dim=self.dim,
293+
head_dim=self.head_dim,
294+
n_rep=self.n_rep,
294295
)
295296

296297
def forward(
@@ -314,7 +315,7 @@ def forward(
314315

315316
if self.use_kv_cache:
316317
assert input_pos is not None
317-
output = self.SDPA(input_pos, q, k, v, bsz, seqlen)
318+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
318319
return self.wo(output)
319320

320321
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

0 commit comments

Comments
 (0)