Skip to content

Commit 096d2db

Browse files
kimishpatelYIWENX14
authored andcommitted
[ExecuTorch][Llama] Split custom sdpa op and kv cache (#7412)
* [ExecuTorch][Llama] Split custom sdpa op and kv cache Summary: This enables us to do more easier module swap with model definitions from torchtune Test Plan: CI Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Update on "[ExecuTorch][Llama] Split custom sdpa op and kv cache" Summary: This enables us to do more easier module swap with model definitions from torchtune Test Plan: CI Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 17acc3b commit 096d2db

File tree

4 files changed

+101
-36
lines changed

4 files changed

+101
-36
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
get_quant_weight_transform,
5757
)
5858
from .source_transformation.quantized_kv_cache import (
59+
replace_kv_cache_with_custom_kv_cache,
5960
replace_kv_cache_with_quantized_kv_cache,
6061
)
6162
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
@@ -1082,6 +1083,7 @@ def _get_source_transforms( # noqa
10821083
transforms.append(materialze_broadcast_of_rope_freq_cis)
10831084

10841085
if args.use_sdpa_with_kv_cache:
1086+
transforms.append(replace_kv_cache_with_custom_kv_cache)
10851087
transforms.append(replace_sdpa_with_custom_op)
10861088

10871089
if args.quantize_kv_cache:

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import logging
88
from enum import Enum
9+
from typing import Tuple
910

1011
import torch
1112
import torch.nn as nn
@@ -44,7 +45,6 @@ def __init__(
4445
QuantizedCacheType.AffineSymmetric,
4546
QuantizedCacheType.AffineAsymmetric,
4647
):
47-
4848
raise ValueError(
4949
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
5050
)
@@ -81,10 +81,11 @@ def __init__(
8181
)
8282

8383
def _quantize(self, value):
84-
scales, zero_points = (
85-
torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
86-
value, self.quantized_cache_dtype
87-
)
84+
(
85+
scales,
86+
zero_points,
87+
) = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
88+
value, self.quantized_cache_dtype
8889
)
8990
quantized_value = torch.ops.quantized_decomposed.quantize_per_token(
9091
value,
@@ -262,3 +263,71 @@ def replace_kv_cache_with_quantized_kv_cache(module):
262263
else:
263264
replace_kv_cache_with_quantized_kv_cache(child)
264265
return module
266+
267+
268+
class CustomKVCache(nn.Module):
269+
def __init__(
270+
self,
271+
max_batch_size: int,
272+
max_seq_length: int,
273+
n_heads: int,
274+
head_dim: int,
275+
dtype=torch.float32,
276+
):
277+
super().__init__()
278+
self.max_seq_length = max_seq_length
279+
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
280+
281+
self.max_batch_size = max_batch_size
282+
self.n_heads = n_heads
283+
self.head_dim = head_dim
284+
self.register_buffer(
285+
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
286+
)
287+
self.register_buffer(
288+
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
289+
)
290+
291+
def update(
292+
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
293+
) -> Tuple[torch.Tensor, torch.Tensor]:
294+
# input_pos: [S], k_val: [B, S, H, D]
295+
start_pos = input_pos[0].item()
296+
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
297+
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
298+
return self.k_cache, self.v_cache
299+
300+
301+
def replace_kv_cache_with_custom_kv_cache(module):
302+
r"""
303+
Replace KVCache with CustomKVCache. This modifies the model in place.
304+
At the moment custom kv cache only supports cache with shape
305+
[B, S, H, D] as opposed to [B, H, S, D]
306+
This is because the custom op treats second dim as sequence dim.
307+
Future work: support [B, H, S, D]
308+
"""
309+
logging.warning(
310+
"Replacing KVCache with CustomKVCache. This modifies the model in place."
311+
)
312+
for name, child in module.named_children():
313+
if isinstance(child, KVCache):
314+
cache_shape = child.k_cache.shape
315+
cache_dtype = child.k_cache.dtype
316+
assert (
317+
child.is_transposed is False
318+
), "CustomKVCache does not support transposed cache"
319+
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
320+
setattr(
321+
module,
322+
name,
323+
CustomKVCache(
324+
max_batch_size,
325+
max_seq_length,
326+
n_heads,
327+
head_dim,
328+
dtype=cache_dtype,
329+
),
330+
)
331+
else:
332+
replace_kv_cache_with_custom_kv_cache(child)
333+
return module

examples/models/llama/source_transformation/sdpa.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -56,33 +56,16 @@ def forward(
5656

5757
k_cache = self.kv_cache.k_cache
5858
v_cache = self.kv_cache.v_cache
59-
if hasattr(self.kv_cache, "quantized_cache_dtype"):
60-
# updated quantize cache, scale and zero points
61-
# returns dequantized kv cache
62-
# Not most optimal. Optimizations to follow next
63-
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
64-
output = torch.ops.llama.custom_sdpa(
65-
q,
66-
k_cache,
67-
v_cache,
68-
input_pos[0].item(),
69-
None, # Attention mask
70-
0, # dropout probability. Ignored by the code
71-
True, # is_causal
72-
)
73-
else:
74-
output = torch.ops.llama.sdpa_with_kv_cache(
75-
q,
76-
k,
77-
v,
78-
k_cache,
79-
v_cache,
80-
input_pos[0].item(),
81-
seqlen,
82-
None, # Attention mask
83-
0, # dropout probability. Ignored by the code
84-
True, # is_causal
85-
)
59+
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
60+
output = torch.ops.llama.custom_sdpa(
61+
q,
62+
k_cache,
63+
v_cache,
64+
input_pos[0].item(),
65+
None, # Attention mask
66+
0, # dropout probability. Ignored by the code
67+
True, # is_causal
68+
)
8669
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
8770

8871

@@ -106,7 +89,6 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
10689

10790

10891
class SDPASimple(torch.nn.Module):
109-
11092
def __init__(
11193
self,
11294
kv_cache: KVCache,
@@ -166,7 +148,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
166148

167149

168150
class SDPAFlex(torch.nn.Module):
169-
170151
def __init__(
171152
self,
172153
kv_cache: KVCache,

examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.examples.models.llama.llama_transformer import KVCache
1212

1313
from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
14+
CustomKVCache,
1415
QuantizedCacheType,
1516
QuantizedKVCache,
1617
)
@@ -19,7 +20,6 @@
1920

2021

2122
class SDPAWithQuantizedKVCacheTest(unittest.TestCase):
22-
2323
def _init_cache(self):
2424
self.kv_cache = KVCache(
2525
self.max_batch_size,
@@ -33,6 +33,19 @@ def _init_cache(self):
3333
self.quantized_kv_cache = QuantizedKVCache.from_float(
3434
self.kv_cache, QuantizedCacheType.AffineAsymmetric
3535
)
36+
# Need this because first test actually has seq_len > 1
37+
# and vanilla kvcache cannot handle seq_len > 1, due to
38+
# how input_pos encoding works in the current stack.
39+
# This needs fixing by making sure rest of the stack including
40+
# custom ops or other backends can work with input_pos
41+
# as a sequence of token positions
42+
self.custom_kv_cache = CustomKVCache(
43+
self.max_batch_size,
44+
self.max_seq_len,
45+
self.n_kv_heads,
46+
self.head_dim,
47+
dtype=self.dtype,
48+
)
3649

3750
def _init_kv(self):
3851
kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
@@ -59,7 +72,7 @@ def test_simple(self, is_dynamic_shape=False):
5972
self.seq_len = 3
6073
self._init_cache()
6174
q, k, v = self._init_kv()
62-
self.float_sdpa = SDPACustom(self.kv_cache, self.dim)
75+
self.float_sdpa = SDPACustom(self.custom_kv_cache, self.dim)
6376
self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
6477
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
6578
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)

0 commit comments

Comments
 (0)