Skip to content

Commit a91666d

Browse files
committed
Split sdpa into custom op and quantized kv cache
1 parent 8afb8e1 commit a91666d

File tree

6 files changed

+249
-168
lines changed

6 files changed

+249
-168
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
)
5555
from .source_transformation.quantized_kv_cache import (
5656
replace_kv_cache_with_quantized_kv_cache,
57+
replace_torchtune_kv_cache_with_quantized_kv_cache,
5758
)
5859
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
5960

@@ -65,6 +66,7 @@
6566
replace_sdpa_with_coreml_sdpa,
6667
replace_sdpa_with_custom_op,
6768
replace_sdpa_with_flex_sdpa,
69+
replace_sdpa_with_sdpa_only_custom_op,
6870
replace_sdpa_with_simple_sdpa,
6971
)
7072
from .source_transformation.torchtune.attention import replace_mha_with_inference_mha
@@ -237,7 +239,7 @@ def build_args_parser() -> argparse.ArgumentParser:
237239
"--use_sdpa_with_kv_cache",
238240
default=False,
239241
action="store_true",
240-
help="Whether to use sdpa_with_kv_cache update op when using kv cache",
242+
help="Whether to use a custom sdpa + kv_cache update when kv cache is enabled.",
241243
)
242244
parser.add_argument(
243245
"--disable_dynamic_shape",
@@ -582,6 +584,18 @@ def _validate_args(args):
582584
if args.num_sharding > 0 and not args.qnn:
583585
raise ValueError("Model shard is only supported with qnn backend now.")
584586

587+
if args.model in TORCHTUNE_DEFINED_MODELS:
588+
if args.use_sdpa_with_kv_cache:
589+
if not args.use_kv_cache and not args.quantize_kv_cache:
590+
raise ValueError(
591+
f"TorchTune-defined {args.model} only works with custom SDPA op + quantized KV cache at the moment. Please enable use_kv_cache and quantize_kv_cache when use_sdpa_with_kv_cache is enabled."
592+
)
593+
if args.use_kv_cache:
594+
if not args.quantize_kv_cache:
595+
raise ValueError(
596+
f"TorchTune-defined {args.model} only works with quantized KV cache at the moment. Please enable quantize_kv_cache when use_kv_cache is enabled."
597+
)
598+
585599

586600
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
587601
_validate_args(args)
@@ -884,6 +898,7 @@ def _load_llama_model(
884898
def _get_source_transforms( # noqa
885899
modelname: str, dtype_override: Optional[DType], args
886900
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
901+
is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS
887902
transforms = []
888903

889904
if args.use_spin_quant:
@@ -936,12 +951,27 @@ def _get_source_transforms( # noqa
936951
transforms.append(materialze_broadcast_of_rope_freq_cis)
937952

938953
if args.use_sdpa_with_kv_cache:
939-
transforms.append(replace_sdpa_with_custom_op)
940-
transforms.append(replace_mha_with_inference_mha)
954+
if is_torchtune_model:
955+
assert (
956+
args.use_kv_cache and args.quantize_kv_cache
957+
), "use_sdpa_with_kv_cache requires use_kv_cache=True and quantize_kv_cache=True for TorchTune at the moment."
958+
transforms.append(replace_mha_with_inference_mha)
959+
transforms.append(replace_sdpa_with_sdpa_only_custom_op)
960+
else:
961+
transforms.append(replace_sdpa_with_custom_op)
941962

942963
if args.quantize_kv_cache:
943964
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
944-
transforms.append(replace_kv_cache_with_quantized_kv_cache)
965+
if is_torchtune_model:
966+
transforms.append(
967+
lambda module: replace_torchtune_kv_cache_with_quantized_kv_cache(
968+
module,
969+
is_transposed=not args.use_sdpa_with_kv_cache,
970+
enable_dynamic_shape=args.enable_dynamic_shape,
971+
)
972+
)
973+
else:
974+
transforms.append(replace_kv_cache_with_quantized_kv_cache)
945975

946976
if args.use_kv_cache:
947977
if args.qnn:

examples/models/llama2/source_transformation/quantized_kv_cache.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn as nn
1212
from executorch.examples.models.llama2.llama_transformer import KVCache
1313
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
14+
from torchtune.modules.kv_cache import KVCache as TorchTuneKVCache
1415

1516

1617
"""
@@ -207,8 +208,31 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
207208
kv_cache.enable_dynamic_shape,
208209
)
209210

211+
@classmethod
212+
def from_torchtune_float(
213+
cls,
214+
kv_cache,
215+
cache_type: QuantizedCacheType,
216+
is_transposed: bool,
217+
enable_dynamic_shape: bool,
218+
):
219+
cache_shape = kv_cache.k_cache.shape
220+
if kv_cache.is_tranposed:
221+
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
222+
else:
223+
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
224+
return cls(
225+
max_batch_size,
226+
max_seq_length,
227+
n_heads,
228+
head_dim,
229+
cache_type,
230+
is_transposed,
231+
enable_dynamic_shape,
232+
)
210233

211-
def replace_kv_cache_with_quantized_kv_cache(module):
234+
235+
def replace_kv_cache_with_quantized_kv_cache(module: nn.Module) -> nn.Module:
212236
logging.warning(
213237
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
214238
)
@@ -222,3 +246,41 @@ def replace_kv_cache_with_quantized_kv_cache(module):
222246
else:
223247
replace_kv_cache_with_quantized_kv_cache(child)
224248
return module
249+
250+
251+
def replace_torchtune_kv_cache_with_quantized_kv_cache(
252+
module: nn.Module, is_transposed: bool, enable_dynamic_shape: bool
253+
) -> nn.Module:
254+
"""
255+
Replace TorchTune KVCache with Executorch's quantized KVCache.
256+
257+
Args:
258+
is_transposed: whether q, k, and v are transposed. Should set to false when sdpa custom op source transform is enabled.
259+
enable_dynamic_shape: whether dynamic shapes are enabled.
260+
261+
Returns:
262+
The passed in model.
263+
"""
264+
logging.warning(
265+
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
266+
)
267+
for name, child in module.named_children():
268+
if isinstance(child, TorchTuneKVCache):
269+
cache_shape = child.k_cache.shape
270+
if is_transposed:
271+
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
272+
else:
273+
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
274+
setattr(
275+
module,
276+
name,
277+
QuantizedKVCache.from_torchtune_float(
278+
child,
279+
QuantizedCacheType.AffineAsymmetric,
280+
is_transposed,
281+
enable_dynamic_shape,
282+
),
283+
)
284+
else:
285+
replace_kv_cache_with_quantized_kv_cache(child)
286+
return module

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def forward(
8080
input_pos[0].item(),
8181
seqlen,
8282
None, # Attention mask
83-
0, # dropout probability. Ignored by the code
83+
0, # Dropout probability, ignored by the code
8484
True, # is_causal
8585
)
8686
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
@@ -105,6 +105,65 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
105105
return module
106106

107107

108+
class SDPAOnlyCustom(torch.nn.Module):
109+
"""
110+
Just the custom SDPA op, no KV cache update included. Can only be used
111+
in conjunction with a quantized KV cache.
112+
"""
113+
114+
def __init__(
115+
self,
116+
):
117+
super().__init__()
118+
119+
def forward(
120+
self,
121+
input_pos: torch.Tensor,
122+
q: torch.Tensor,
123+
k: torch.Tensor,
124+
v: torch.Tensor,
125+
bsz: int,
126+
seqlen: int,
127+
mask: torch.Tensor = None,
128+
):
129+
# Custom op only supports float32 currently. Converting to/from float32 is
130+
# faster than not having the op.
131+
input_dtype = q.dtype
132+
q = q.to(dtype=torch.float)
133+
k = k.to(dtype=torch.float)
134+
v = v.to(dtype=torch.float)
135+
output = torch.ops.llama.custom_sdpa(
136+
q,
137+
k,
138+
v,
139+
input_pos[0].item(),
140+
None, # Attention mask
141+
0, # Dropout probability, ignored by the code.
142+
True, # is_causal
143+
)
144+
return output.view(bsz, seqlen, -1).to(dtype=input_dtype)
145+
146+
147+
def _replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module):
148+
for name, child in module.named_children():
149+
if isinstance(child, SDPA):
150+
assert (
151+
child.kv_cache.cache_fp_type == torch.float32
152+
), "Only float32 is supported for custom SDPA"
153+
setattr(
154+
module,
155+
name,
156+
SDPAOnlyCustom(),
157+
)
158+
else:
159+
_replace_sdpa_with_sdpa_only_custom_op(child)
160+
161+
162+
def replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module) -> torch.nn.Module:
163+
_replace_sdpa_with_sdpa_only_custom_op(module)
164+
return module
165+
166+
108167
class SDPASimple(torch.nn.Module):
109168

110169
def __init__(
Lines changed: 7 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import torch
22
import torchtune.modules.attention as TorchTuneAttention
3-
from executorch.examples.models.llama2.source_transformation.torchtune.modules.mha import MultiHeadAttention
4-
from executorch.examples.models.llama2.source_transformation.torchtune.modules.sdpa import SDPA
3+
from executorch.examples.models.llama2.source_transformation.torchtune.modules.mha import (
4+
MultiHeadAttention,
5+
)
56

6-
def _replace_mha_with_inference_mha(module: torch.nn.Module):
7+
8+
def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None:
79
for name, child in module.named_children():
810
if isinstance(child, TorchTuneAttention.MultiHeadAttention):
911
setattr(
@@ -18,7 +20,7 @@ def _replace_mha_with_inference_mha(module: torch.nn.Module):
1820
k_proj=child.k_proj,
1921
v_proj=child.v_proj,
2022
output_proj=child.output_proj,
21-
pos_embeddings=child.pos_embedding,
23+
pos_embeddings=child.pos_embeddings,
2224
q_norm=child.q_norm,
2325
k_norm=child.k_norm,
2426
kv_cache=child.kv_cache,
@@ -30,72 +32,10 @@ def _replace_mha_with_inference_mha(module: torch.nn.Module):
3032
else:
3133
replace_mha_with_inference_mha(child)
3234

33-
def replace_mha_with_inference_mha(module: torch.nn.Module):
35+
def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
3436
"""
3537
Replace TorchTune's MHA with an inference friendly version of MHA that
3638
separates out the inference-related parts for further optimization.
3739
"""
3840
_replace_mha_with_inference_mha(module)
3941
return module
40-
41-
# class SDPACustom(torch.nn.Module):
42-
# def __init__(
43-
# self,
44-
# kv_cache: KVCache,
45-
# dim: int,
46-
# ):
47-
# super().__init__()
48-
# # Custom op only supports float32 currently. Converting to/from float32 is
49-
# # faster than not having the op.
50-
# self.kv_cache = kv_cache.to(torch.float)
51-
# self.dim = dim
52-
53-
# def forward(
54-
# self,
55-
# input_pos: torch.Tensor,
56-
# q: torch.Tensor,
57-
# k: torch.Tensor,
58-
# v: torch.Tensor,
59-
# bsz,
60-
# seqlen,
61-
# mask,
62-
# ):
63-
# # Custom op only supports float32 currently. Converting to/from float32 is
64-
# # faster than not having the op.
65-
# input_dtype = q.dtype
66-
# q = q.to(dtype=torch.float)
67-
# k = k.to(dtype=torch.float)
68-
# v = v.to(dtype=torch.float)
69-
# output = torch.ops.llama.sdpa_with_kv_cache(
70-
# q,
71-
# k,
72-
# v,
73-
# self.kv_cache.k_cache,
74-
# self.kv_cache.v_cache,
75-
# input_pos[-1].item(),
76-
# seqlen,
77-
# None, # Attention mask
78-
# 0, # dropout probability. Ignored by the code
79-
# True, # is_causal
80-
# )
81-
# return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
82-
83-
84-
# def _replace_sdpa_with_custom_op(module: torch.nn.Module):
85-
# for name, child in module.named_children():
86-
# if isinstance(child, SDPA):
87-
# setattr(
88-
# module,
89-
# name,
90-
# SDPACustom(child.kv_cache, child.dim),
91-
# )
92-
# else:
93-
# _replace_sdpa_with_custom_op(child)
94-
95-
96-
# def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
97-
# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa
98-
99-
# _replace_sdpa_with_custom_op(module)
100-
# return module
101-

0 commit comments

Comments
 (0)