Skip to content

Commit cab6335

Browse files
swolchokfacebook-github-bot
authored andcommitted
Allow using custom SDPA for non-float32 dtypes in llama demo (#5548)
Summary: Pull Request resolved: #5548 Converting the input to and from float32 is faster than not using the op. h/t to torchchat, which does this already (though it had a bug, which I sent a patch for). Reviewed By: kimishpatel Differential Revision: D63158951 fbshipit-source-id: 58c90d141ee403536c03a3b731f8547790fc9440
1 parent f68a138 commit cab6335

File tree

1 file changed

+10
-2
lines changed
  • examples/models/llama2/source_transformation

1 file changed

+10
-2
lines changed

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def __init__(
2323
dim: int,
2424
):
2525
super().__init__()
26-
self.kv_cache = kv_cache
26+
# Custom op only supports float32 currently. Converting to/from float32 is
27+
# faster than not having the op.
28+
self.kv_cache = kv_cache.to(torch.float)
2729
self.dim = dim
2830

2931
def forward(
@@ -36,6 +38,12 @@ def forward(
3638
seqlen,
3739
mask,
3840
):
41+
# Custom op only supports float32 currently. Converting to/from float32 is
42+
# faster than not having the op.
43+
input_dtype = q.dtype
44+
q = q.to(dtype=torch.float)
45+
k = k.to(dtype=torch.float)
46+
v = v.to(dtype=torch.float)
3947
output = torch.ops.llama.sdpa_with_kv_cache(
4048
q,
4149
k,
@@ -48,7 +56,7 @@ def forward(
4856
0, # dropout probability. Ignored by the code
4957
True, # is_causal
5058
)
51-
return output.view(bsz, seqlen, self.dim)
59+
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
5260

5361

5462
def _replace_sdpa_with_custom_op(module: torch.nn.Module):

0 commit comments

Comments
 (0)