Skip to content

[ExecuTorch] Allow using custom SDPA for non-float32 dtypes in llama demo #5548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/models/llama2/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def __init__(
dim: int,
):
super().__init__()
self.kv_cache = kv_cache
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
self.kv_cache = kv_cache.to(torch.float)
self.dim = dim

def forward(
Expand All @@ -36,6 +38,12 @@ def forward(
seqlen,
mask,
):
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
input_dtype = q.dtype
q = q.to(dtype=torch.float)
k = k.to(dtype=torch.float)
v = v.to(dtype=torch.float)
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
Expand All @@ -48,7 +56,7 @@ def forward(
0, # dropout probability. Ignored by the code
True, # is_causal
)
return output.view(bsz, seqlen, self.dim)
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)


def _replace_sdpa_with_custom_op(module: torch.nn.Module):
Expand Down
Loading