Skip to content

Commit 1d74652

Browse files
Michael Gschwindfacebook-github-bot
authored andcommitted
Resolve recurring errors where query is c10::Half and key and value float
Summary: Resolve recurring errors where query is c10::Half and key and value float. This should ideally work from first principles, but somehow it does not. We need to fix this but in the meantime this ugly have will enable us to proceed and allow others to debug other aspects of ET lowering. Reviewed By: mavlyutovr Differential Revision: D54167581 fbshipit-source-id: 6cc4e76e3abbf107014b5b9da00e817ee3b2ab03
1 parent 566528f commit 1d74652

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

examples/models/llama2/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ def forward(
262262
# tensor will be 2-dimensional, regarldess of the values of L & S
263263
mask = torch.squeeze(mask, [0, 1])
264264

265+
# FIXME: This should be so automatically! MKG
266+
keys = keys.to(dtype=xq.dtype)
267+
values = values.to(dtype=xq.dtype)
268+
265269
output = F.scaled_dot_product_attention(
266270
xq, keys, values, attn_mask=mask, dropout_p=0.0
267271
)

0 commit comments

Comments
 (0)