Skip to content

Commit 3ea88cd

Browse files
committed
Update on "move mask as sdpa input instead of attribute"
sdpa (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) input is taking attention mask as input, refactor the sdpa module input closer to the sdpa input Differential Revision: [D56119739](https://our.internmc.facebook.com/intern/diff/D56119739/) [ghstack-poisoned]
1 parent fa5197a commit 3ea88cd

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,13 @@ def __init__(
214214
self,
215215
kv_cache: KVCache,
216216
dim: int,
217+
head_dim: int,
217218
n_rep: int,
218219
):
219220
super().__init__()
220221
self.kv_cache = kv_cache
221222
self.dim = dim
223+
self.head_dim = head_dim
222224
self.n_rep = n_rep
223225

224226
def forward(
@@ -236,7 +238,7 @@ def forward(
236238
v = v.transpose(1, 2)
237239

238240
k, v = self.kv_cache.update(input_pos, k, v)
239-
attn_mask = self.mask[None, None, input_pos]
241+
attn_mask = mask[None, None, input_pos]
240242

241243
k = k.repeat_interleave(self.n_rep, dim=1)
242244
v = v.repeat_interleave(self.n_rep, dim=1)
@@ -286,9 +288,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
286288
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
287289
)
288290
self.SDPA = SDPA(
289-
self.kv_cache,
290-
self.dim,
291-
self.n_rep,
291+
kv_cache=self.kv_cache,
292+
dim=self.dim,
293+
head_dim=self.head_dim,
294+
n_rep=self.n_rep,
292295
)
293296

294297
def forward(

0 commit comments

Comments
 (0)