Skip to content

Commit 59ba811

Browse files
committed
[ExecuTorch] Allow using custom SDPA for non-float32 dtypes in llama demo
Converting the input to and from float32 is faster than not using the op. h/t to torchchat, which does this already (though it had a bug, which I sent a patch for). Differential Revision: [D63158951](https://our.internmc.facebook.com/intern/diff/D63158951/) ghstack-source-id: 244181863 Pull Request resolved: #5548
1 parent b2517d6 commit 59ba811

File tree

1 file changed

+10
-2
lines changed
  • examples/models/llama2/source_transformation

1 file changed

+10
-2
lines changed

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def __init__(
2323
dim: int,
2424
):
2525
super().__init__()
26-
self.kv_cache = kv_cache
26+
# Custom op only supports float32 currently. Converting to/from float32 is
27+
# faster than not having the op.
28+
self.kv_cache = kv_cache.to(torch.float)
2729
self.dim = dim
2830

2931
def forward(
@@ -36,6 +38,12 @@ def forward(
3638
seqlen,
3739
mask,
3840
):
41+
# Custom op only supports float32 currently. Converting to/from float32 is
42+
# faster than not having the op.
43+
input_dtype = q.dtype
44+
q = q.to(dtype=torch.float)
45+
k = k.to(dtype=torch.float)
46+
v = v.to(dtype=torch.float)
3947
output = torch.ops.llama.sdpa_with_kv_cache(
4048
q,
4149
k,
@@ -48,7 +56,7 @@ def forward(
4856
0, # dropout probability. Ignored by the code
4957
True, # is_causal
5058
)
51-
return output.view(bsz, seqlen, self.dim)
59+
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
5260

5361

5462
def _replace_sdpa_with_custom_op(module: torch.nn.Module):

0 commit comments

Comments
 (0)