Skip to content

Commit 3145bde

Browse files
committed
Revert portion to move to next PR
1 parent 4587852 commit 3145bde

File tree

3 files changed

+5
-166
lines changed

3 files changed

+5
-166
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
)
5555
from .source_transformation.quantized_kv_cache import (
5656
replace_kv_cache_with_quantized_kv_cache,
57-
replace_torchtune_kv_cache_with_quantized_kv_cache,
5857
)
5958
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
6059

@@ -66,15 +65,10 @@
6665
replace_sdpa_with_coreml_sdpa,
6766
replace_sdpa_with_custom_op,
6867
replace_sdpa_with_flex_sdpa,
69-
replace_sdpa_with_sdpa_only_custom_op,
7068
replace_sdpa_with_simple_sdpa,
7169
)
72-
73-
from .source_transformation.torchtune.attention import replace_mha_with_inference_mha
74-
7570
from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb
7671

77-
7872
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
7973
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
8074
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -243,7 +237,7 @@ def build_args_parser() -> argparse.ArgumentParser:
243237
"--use_sdpa_with_kv_cache",
244238
default=False,
245239
action="store_true",
246-
help="Whether to use a custom sdpa + kv_cache update when kv cache is enabled.",
240+
help="Whether to use sdpa_with_kv_cache update op when using kv cache",
247241
)
248242
parser.add_argument(
249243
"--disable_dynamic_shape",
@@ -595,18 +589,6 @@ def _validate_args(args):
595589
if args.num_sharding > 0 and not args.qnn:
596590
raise ValueError("Model shard is only supported with qnn backend now.")
597591

598-
if args.model in TORCHTUNE_DEFINED_MODELS:
599-
if args.use_sdpa_with_kv_cache:
600-
if not args.use_kv_cache and not args.quantize_kv_cache:
601-
raise ValueError(
602-
f"TorchTune-defined {args.model} only works with custom SDPA op + quantized KV cache at the moment. Please enable use_kv_cache and quantize_kv_cache when use_sdpa_with_kv_cache is enabled."
603-
)
604-
if args.use_kv_cache:
605-
if not args.quantize_kv_cache:
606-
raise ValueError(
607-
f"TorchTune-defined {args.model} only works with quantized KV cache at the moment. Please enable quantize_kv_cache when use_kv_cache is enabled."
608-
)
609-
610592

611593
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
612594
_validate_args(args)
@@ -910,7 +892,6 @@ def _load_llama_model(
910892
def _get_source_transforms( # noqa
911893
modelname: str, dtype_override: Optional[DType], args
912894
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
913-
is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS
914895
transforms = []
915896

916897
if args.use_spin_quant:
@@ -962,29 +943,12 @@ def _get_source_transforms( # noqa
962943
if args.expand_rope_table:
963944
transforms.append(materialze_broadcast_of_rope_freq_cis)
964945

965-
transforms.append(replace_mha_with_inference_mha)
966946
if args.use_sdpa_with_kv_cache:
967-
if is_torchtune_model:
968-
assert (
969-
args.use_kv_cache and args.quantize_kv_cache
970-
), "use_sdpa_with_kv_cache requires use_kv_cache=True and quantize_kv_cache=True for TorchTune at the moment."
971-
transforms.append(replace_mha_with_inference_mha)
972-
transforms.append(replace_sdpa_with_sdpa_only_custom_op)
973-
else:
974-
transforms.append(replace_sdpa_with_custom_op)
947+
transforms.append(replace_sdpa_with_custom_op)
975948

976949
if args.quantize_kv_cache:
977950
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
978-
if is_torchtune_model:
979-
transforms.append(
980-
lambda module: replace_torchtune_kv_cache_with_quantized_kv_cache(
981-
module,
982-
is_transposed=not args.use_sdpa_with_kv_cache,
983-
enable_dynamic_shape=args.enable_dynamic_shape,
984-
)
985-
)
986-
else:
987-
transforms.append(replace_kv_cache_with_quantized_kv_cache)
951+
transforms.append(replace_kv_cache_with_quantized_kv_cache)
988952

989953
if args.use_kv_cache:
990954
if args.qnn:
@@ -1019,8 +983,4 @@ def _get_source_transforms( # noqa
1019983
if args.vulkan:
1020984
transforms.append(replace_with_vulkan_rotary_emb)
1021985

1022-
print(
1023-
f"Performing the following source transformations: {[transform.__name__ for transform in transforms]}"
1024-
)
1025-
1026986
return transforms

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch.nn as nn
1212
from executorch.examples.models.llama.llama_transformer import KVCache
1313
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
14-
from torchtune.modules.kv_cache import KVCache as TorchTuneKVCache
1514

1615

1716
"""
@@ -208,31 +207,8 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
208207
kv_cache.enable_dynamic_shape,
209208
)
210209

211-
@classmethod
212-
def from_torchtune_float(
213-
cls,
214-
kv_cache,
215-
cache_type: QuantizedCacheType,
216-
is_transposed: bool,
217-
enable_dynamic_shape: bool,
218-
):
219-
cache_shape = kv_cache.k_cache.shape
220-
if kv_cache.is_tranposed:
221-
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
222-
else:
223-
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
224-
return cls(
225-
max_batch_size,
226-
max_seq_length,
227-
n_heads,
228-
head_dim,
229-
cache_type,
230-
is_transposed,
231-
enable_dynamic_shape,
232-
)
233-
234210

235-
def replace_kv_cache_with_quantized_kv_cache(module: nn.Module) -> nn.Module:
211+
def replace_kv_cache_with_quantized_kv_cache(module):
236212
logging.warning(
237213
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
238214
)
@@ -246,41 +222,3 @@ def replace_kv_cache_with_quantized_kv_cache(module: nn.Module) -> nn.Module:
246222
else:
247223
replace_kv_cache_with_quantized_kv_cache(child)
248224
return module
249-
250-
251-
def replace_torchtune_kv_cache_with_quantized_kv_cache(
252-
module: nn.Module, is_transposed: bool, enable_dynamic_shape: bool
253-
) -> nn.Module:
254-
"""
255-
Replace TorchTune KVCache with Executorch's quantized KVCache.
256-
257-
Args:
258-
is_transposed: whether q, k, and v are transposed. Should set to false when sdpa custom op source transform is enabled.
259-
enable_dynamic_shape: whether dynamic shapes are enabled.
260-
261-
Returns:
262-
The passed in model.
263-
"""
264-
logging.warning(
265-
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
266-
)
267-
for name, child in module.named_children():
268-
if isinstance(child, TorchTuneKVCache):
269-
cache_shape = child.k_cache.shape
270-
if is_transposed:
271-
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
272-
else:
273-
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
274-
setattr(
275-
module,
276-
name,
277-
QuantizedKVCache.from_torchtune_float(
278-
child,
279-
QuantizedCacheType.AffineAsymmetric,
280-
is_transposed,
281-
enable_dynamic_shape,
282-
),
283-
)
284-
else:
285-
replace_kv_cache_with_quantized_kv_cache(child)
286-
return module

examples/models/llama/source_transformation/sdpa.py

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def forward(
8080
input_pos[0].item(),
8181
seqlen,
8282
None, # Attention mask
83-
0, # Dropout probability, ignored by the code
83+
0, # dropout probability. Ignored by the code
8484
True, # is_causal
8585
)
8686
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
@@ -105,65 +105,6 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
105105
return module
106106

107107

108-
class SDPAOnlyCustom(torch.nn.Module):
109-
"""
110-
Just the custom SDPA op, no KV cache update included. Can only be used
111-
in conjunction with a quantized KV cache.
112-
"""
113-
114-
def __init__(
115-
self,
116-
):
117-
super().__init__()
118-
119-
def forward(
120-
self,
121-
input_pos: torch.Tensor,
122-
q: torch.Tensor,
123-
k: torch.Tensor,
124-
v: torch.Tensor,
125-
bsz: int,
126-
seqlen: int,
127-
mask: torch.Tensor = None,
128-
):
129-
# Custom op only supports float32 currently. Converting to/from float32 is
130-
# faster than not having the op.
131-
input_dtype = q.dtype
132-
q = q.to(dtype=torch.float)
133-
k = k.to(dtype=torch.float)
134-
v = v.to(dtype=torch.float)
135-
output = torch.ops.llama.custom_sdpa(
136-
q,
137-
k,
138-
v,
139-
input_pos[0].item(),
140-
None, # Attention mask
141-
0, # Dropout probability, ignored by the code.
142-
True, # is_causal
143-
)
144-
return output.view(bsz, seqlen, -1).to(dtype=input_dtype)
145-
146-
147-
def _replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module):
148-
for name, child in module.named_children():
149-
if isinstance(child, SDPA):
150-
assert (
151-
child.kv_cache.cache_fp_type == torch.float32
152-
), "Only float32 is supported for custom SDPA"
153-
setattr(
154-
module,
155-
name,
156-
SDPAOnlyCustom(),
157-
)
158-
else:
159-
_replace_sdpa_with_sdpa_only_custom_op(child)
160-
161-
162-
def replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module) -> torch.nn.Module:
163-
_replace_sdpa_with_sdpa_only_custom_op(module)
164-
return module
165-
166-
167108
class SDPASimple(torch.nn.Module):
168109

169110
def __init__(

0 commit comments

Comments
 (0)