Skip to content

Commit a8f04ae

Browse files
committed
move mask as sdpa input instead of attribute
sdpa (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) input is taking attention mask as input, refactor the sdpa module input closer to the sdpa input Differential Revision: [D56119739](https://our.internmc.facebook.com/intern/diff/D56119739/) ghstack-source-id: 222465699 Pull Request resolved: #3036
1 parent 21fdc4e commit a8f04ae

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,10 @@ class SDPACustom(torch.nn.Module):
9696
def __init__(
9797
self,
9898
kv_cache: KVCache,
99-
mask,
10099
dim: int,
101100
):
102101
super().__init__()
103102
self.kv_cache = kv_cache
104-
self.mask = mask
105103
self.dim = dim
106104

107105
def forward(
@@ -112,6 +110,7 @@ def forward(
112110
v: torch.Tensor,
113111
bsz,
114112
seqlen,
113+
mask,
115114
):
116115
output = torch.ops.llama.sdpa_with_kv_cache(
117116
q,
@@ -131,7 +130,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
131130
setattr(
132131
module,
133132
name,
134-
SDPACustom(child.kv_cache, child.mask, child.dim),
133+
SDPACustom(child.kv_cache, child.dim),
135134
)
136135
else:
137136
_replace_sdpa_with_custom_op(child)

examples/models/llama2/llama_transformer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,11 @@ class SDPA(nn.Module):
213213
def __init__(
214214
self,
215215
kv_cache: KVCache,
216-
mask,
217216
dim: int,
218217
n_rep: int,
219218
):
220219
super().__init__()
221220
self.kv_cache = kv_cache
222-
self.mask = mask
223221
self.dim = dim
224222
self.n_rep = n_rep
225223

@@ -231,17 +229,18 @@ def forward(
231229
v: torch.Tensor,
232230
bsz,
233231
seqlen,
232+
mask: torch.Tensor,
234233
) -> torch.Tensor:
235234
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
236235
k = k.transpose(1, 2)
237236
v = v.transpose(1, 2)
238237

239238
k, v = self.kv_cache.update(input_pos, k, v)
240-
mask = self.mask[None, None, input_pos]
239+
attn_mask = self.mask[None, None, input_pos]
241240

242241
k = k.repeat_interleave(self.n_rep, dim=1)
243242
v = v.repeat_interleave(self.n_rep, dim=1)
244-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
243+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
245244

246245
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
247246

@@ -288,7 +287,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
288287
)
289288
self.SDPA = SDPA(
290289
self.kv_cache,
291-
self.mask,
292290
self.dim,
293291
self.n_rep,
294292
)
@@ -314,7 +312,7 @@ def forward(
314312

315313
if self.use_kv_cache:
316314
assert input_pos is not None
317-
output = self.SDPA(input_pos, q, k, v, bsz, seqlen)
315+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
318316
return self.wo(output)
319317

320318
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

0 commit comments

Comments
 (0)