Skip to content

Commit 5fe2784

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
fix bug in mask slicing for 1 length sequence and no kvcache (#2742)
Summary: No reason to squeeze anymore since the mask is always 2d Reviewed By: mergennachin Differential Revision: D55465010
1 parent 45c2557 commit 5fe2784

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,8 @@ def forward(
301301
v = v.repeat_interleave(self.n_rep, dim=1)
302302

303303
assert hasattr(self, "mask")
304-
mask = self.mask[:seqlen, :seqlen]
305304

306-
# this is needed to support xnnpack which requires mask shape to be 2d.
307-
# this is a temporary workaround. once we update xnnpack we should be able to handle this.
308-
# shape before: [1, 1, l, s], after: [l, s]
309-
# we make sure to specify the dimensions to be squeezed [0, 1] to ensure that the output
310-
# tensor will be 2-dimensional, regarldess of the values of l & s
311-
mask = torch.squeeze(mask, [0, 1])
305+
mask = self.mask[:seqlen, :seqlen]
312306

313307
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
314308

0 commit comments

Comments
 (0)