Skip to content

Commit d8e1d04

Browse files
committed
[ExecuTorch][Llama] Split custom sdpa op and kv cache
Summary: This enables us to do more easier module swap with model definitions from torchtune Test Plan: CI Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: bdcb0fd Pull Request resolved: #7412
1 parent 49cc399 commit d8e1d04

File tree

3 files changed

+86
-34
lines changed

3 files changed

+86
-34
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
get_quant_weight_transform,
5656
)
5757
from .source_transformation.quantized_kv_cache import (
58+
replace_kv_cache_with_custom_kv_cache,
5859
replace_kv_cache_with_quantized_kv_cache,
5960
)
6061
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
@@ -1045,6 +1046,7 @@ def _get_source_transforms( # noqa
10451046
transforms.append(materialze_broadcast_of_rope_freq_cis)
10461047

10471048
if args.use_sdpa_with_kv_cache:
1049+
transforms.append(replace_kv_cache_with_custom_kv_cache)
10481050
transforms.append(replace_sdpa_with_custom_op)
10491051

10501052
if args.quantize_kv_cache:

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import logging
88
from enum import Enum
9+
from typing import Tuple
910

1011
import torch
1112
import torch.nn as nn
@@ -44,7 +45,6 @@ def __init__(
4445
QuantizedCacheType.AffineSymmetric,
4546
QuantizedCacheType.AffineAsymmetric,
4647
):
47-
4848
raise ValueError(
4949
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
5050
)
@@ -81,10 +81,11 @@ def __init__(
8181
)
8282

8383
def _quantize(self, value):
84-
scales, zero_points = (
85-
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
86-
value, self.quantized_cache_dtype
87-
)
84+
(
85+
scales,
86+
zero_points,
87+
) = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
88+
value, self.quantized_cache_dtype
8889
)
8990
quantized_value = torch.ops.quantized_decomposed.quantize_per_token(
9091
value,
@@ -262,3 +263,71 @@ def replace_kv_cache_with_quantized_kv_cache(module):
262263
else:
263264
replace_kv_cache_with_quantized_kv_cache(child)
264265
return module
266+
267+
268+
class CustomKVCache(nn.Module):
269+
def __init__(
270+
self,
271+
max_batch_size: int,
272+
max_seq_length: int,
273+
n_heads: int,
274+
head_dim: int,
275+
dtype=torch.float32,
276+
):
277+
super().__init__()
278+
self.max_seq_length = max_seq_length
279+
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
280+
281+
self.max_batch_size = max_batch_size
282+
self.n_heads = n_heads
283+
self.head_dim = head_dim
284+
self.register_buffer(
285+
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
286+
)
287+
self.register_buffer(
288+
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
289+
)
290+
291+
def update(
292+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
293+
) -> Tuple[torch.Tensor, torch.Tensor]:
294+
# input_pos: [S], k_val: [B, S, H, D]
295+
start_pos = input_pos[0].item()
296+
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
297+
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
298+
return self.k_cache, self.v_cache
299+
300+
301+
def replace_kv_cache_with_custom_kv_cache(module):
302+
r"""
303+
Replace KVCache with CustomKVCache. This modifies the model in place.
304+
At the moment custom kv cache only supports cache with shape
305+
[B, S, H, D] as opposed to [B, H, S, D]
306+
This is because the custom op treats second dim as sequence dim.
307+
Future work: support [B, H, S, D]
308+
"""
309+
logging.warning(
310+
"Replacing KVCache with CustomKVCache. This modifies the model in place."
311+
)
312+
for name, child in module.named_children():
313+
if isinstance(child, KVCache):
314+
cache_shape = child.k_cache.shape
315+
cache_dtype = child.k_cache.dtype
316+
assert (
317+
child.is_transposed is False
318+
), "CustomKVCache does not support transposed cache"
319+
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
320+
setattr(
321+
module,
322+
name,
323+
CustomKVCache(
324+
max_batch_size,
325+
max_seq_length,
326+
n_heads,
327+
head_dim,
328+
dtype=cache_dtype,
329+
),
330+
)
331+
else:
332+
replace_kv_cache_with_custom_kv_cache(child)
333+
return module

examples/models/llama/source_transformation/sdpa.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -56,33 +56,16 @@ def forward(
5656

5757
k_cache = self.kv_cache.k_cache
5858
v_cache = self.kv_cache.v_cache
59-
if hasattr(self.kv_cache, "quantized_cache_dtype"):
60-
# updated quantize cache, scale and zero points
61-
# returns dequantized kv cache
62-
# Not most optimal. Optimizations to follow next
63-
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
64-
output = torch.ops.llama.custom_sdpa(
65-
q,
66-
k_cache,
67-
v_cache,
68-
input_pos[0].item(),
69-
None, # Attention mask
70-
0, # dropout probability. Ignored by the code
71-
True, # is_causal
72-
)
73-
else:
74-
output = torch.ops.llama.sdpa_with_kv_cache(
75-
q,
76-
k,
77-
v,
78-
k_cache,
79-
v_cache,
80-
input_pos[0].item(),
81-
seqlen,
82-
None, # Attention mask
83-
0, # dropout probability. Ignored by the code
84-
True, # is_causal
85-
)
59+
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
60+
output = torch.ops.llama.custom_sdpa(
61+
q,
62+
k_cache,
63+
v_cache,
64+
input_pos[0].item(),
65+
None, # Attention mask
66+
0, # dropout probability. Ignored by the code
67+
True, # is_causal
68+
)
8669
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
8770

8871

@@ -106,7 +89,6 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
10689

10790

10891
class SDPASimple(torch.nn.Module):
109-
11092
def __init__(
11193
self,
11294
kv_cache: KVCache,
@@ -166,7 +148,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
166148

167149

168150
class SDPAFlex(torch.nn.Module):
169-
170151
def __init__(
171152
self,
172153
kv_cache: KVCache,

0 commit comments

Comments
 (0)