Skip to content

Commit 679a6e7

Browse files
committed
Add et version of TorchTune MHA for swapping with custom op
1 parent af45b02 commit 679a6e7

File tree

4 files changed

+477
-0
lines changed

4 files changed

+477
-0
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
replace_sdpa_with_flex_sdpa,
6767
replace_sdpa_with_simple_sdpa,
6868
)
69+
from .source_transformation.torchtune.attention import replace_mha_with_inference_mha
6970

7071
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
7172
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -934,6 +935,7 @@ def _get_source_transforms( # noqa
934935

935936
if args.use_sdpa_with_kv_cache:
936937
transforms.append(replace_sdpa_with_custom_op)
938+
transforms.append(replace_mha_with_inference_mha)
937939

938940
if args.quantize_kv_cache:
939941
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import torch
2+
import torchtune.modules.attention as TorchTuneAttention
3+
from executorch.examples.models.llama2.source_transformation.torchtune.modules.mha import MultiHeadAttention
4+
from executorch.examples.models.llama2.source_transformation.torchtune.modules.sdpa import SDPA
5+
6+
def _replace_mha_with_inference_mha(module: torch.nn.Module):
7+
for name, child in module.named_children():
8+
if isinstance(child, TorchTuneAttention.MultiHeadAttention):
9+
setattr(
10+
module,
11+
name,
12+
MultiHeadAttention(
13+
embed_dim=child.embed_dim,
14+
num_heads=child.num_heads,
15+
num_kv_heads=child.num_kv_heads,
16+
head_dim=child.head_dim,
17+
q_proj=child.q_proj,
18+
k_proj=child.k_proj,
19+
v_proj=child.v_proj,
20+
output_proj=child.output_proj,
21+
pos_embeddings=child.pos_embedding,
22+
q_norm=child.q_norm,
23+
k_norm=child.k_norm,
24+
kv_cache=child.kv_cache,
25+
max_seq_len=child.max_seq_len,
26+
is_causal=child.is_causal,
27+
attn_dropout=child.attn_dropout,
28+
),
29+
)
30+
else:
31+
replace_mha_with_inference_mha(child)
32+
33+
def replace_mha_with_inference_mha(module: torch.nn.Module):
34+
"""
35+
Replace TorchTune's MHA with an inference friendly version of MHA that
36+
separates out the inference-related parts for further optimization.
37+
"""
38+
_replace_mha_with_inference_mha(module)
39+
return module
40+
41+
# class SDPACustom(torch.nn.Module):
42+
# def __init__(
43+
# self,
44+
# kv_cache: KVCache,
45+
# dim: int,
46+
# ):
47+
# super().__init__()
48+
# # Custom op only supports float32 currently. Converting to/from float32 is
49+
# # faster than not having the op.
50+
# self.kv_cache = kv_cache.to(torch.float)
51+
# self.dim = dim
52+
53+
# def forward(
54+
# self,
55+
# input_pos: torch.Tensor,
56+
# q: torch.Tensor,
57+
# k: torch.Tensor,
58+
# v: torch.Tensor,
59+
# bsz,
60+
# seqlen,
61+
# mask,
62+
# ):
63+
# # Custom op only supports float32 currently. Converting to/from float32 is
64+
# # faster than not having the op.
65+
# input_dtype = q.dtype
66+
# q = q.to(dtype=torch.float)
67+
# k = k.to(dtype=torch.float)
68+
# v = v.to(dtype=torch.float)
69+
# output = torch.ops.llama.sdpa_with_kv_cache(
70+
# q,
71+
# k,
72+
# v,
73+
# self.kv_cache.k_cache,
74+
# self.kv_cache.v_cache,
75+
# input_pos[-1].item(),
76+
# seqlen,
77+
# None, # Attention mask
78+
# 0, # dropout probability. Ignored by the code
79+
# True, # is_causal
80+
# )
81+
# return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
82+
83+
84+
# def _replace_sdpa_with_custom_op(module: torch.nn.Module):
85+
# for name, child in module.named_children():
86+
# if isinstance(child, SDPA):
87+
# setattr(
88+
# module,
89+
# name,
90+
# SDPACustom(child.kv_cache, child.dim),
91+
# )
92+
# else:
93+
# _replace_sdpa_with_custom_op(child)
94+
95+
96+
# def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
97+
# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa
98+
99+
# _replace_sdpa_with_custom_op(module)
100+
# return module
101+

0 commit comments

Comments
 (0)