Skip to content

Commit 056c0f2

Browse files
committed
[Executorch][llama] Update SDPA op to use quantized kv cache
Using quantized kv cache, we cannot rely on sdpa to update the original case. SO we insert cache update op Differential Revision: [D62301841](https://our.internmc.facebook.com/intern/diff/D62301841/) [ghstack-poisoned]
1 parent f38581d commit 056c0f2

File tree

6 files changed

+197
-41
lines changed

6 files changed

+197
-41
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -864,9 +864,7 @@ def _get_source_transforms( # noqa
864864
transforms.append(replace_sdpa_with_custom_op)
865865

866866
if args.quantize_kv_cache:
867-
assert (
868-
args.use_kv_cache and not args.use_sdpa_with_kv_cache
869-
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
867+
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
870868
transforms.append(replace_kv_cache_with_quantized_kv_cache)
871869

872870
if args.use_kv_cache:

examples/models/llama2/source_transformation/TARGETS

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,45 @@ runtime.python_library(
1515
],
1616
)
1717

18+
runtime.python_library(
19+
name = "sdpa",
20+
srcs = [
21+
"sdpa.py",
22+
],
23+
_is_external_target = True,
24+
base_module = "executorch.examples.models.llama2.source_transformation",
25+
visibility = ["//executorch/..."],
26+
deps = [
27+
"//caffe2:torch",
28+
],
29+
)
30+
1831
runtime.python_test(
1932
name = "quantized_kv_cache_test",
2033
srcs = [
2134
"test_quantized_kv_cache.py",
2235
],
36+
preload_deps = [
37+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
38+
],
39+
deps = [
40+
":quantized_kv_cache",
41+
"//caffe2:torch",
42+
"//executorch/examples/models/llama2:llama_transformer",
43+
],
44+
)
45+
46+
runtime.python_test(
47+
name = "quantized_sdpa_with_kv_cache_test",
48+
srcs = [
49+
"test_sdpa_with_quantized_kv_cache.py",
50+
],
51+
preload_deps = [
52+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
53+
],
2354
deps = [
2455
":quantized_kv_cache",
56+
":sdpa",
2557
"//caffe2:torch",
2658
"//executorch/examples/models/llama2:llama_transformer",
2759
],

examples/models/llama2/source_transformation/quantized_kv_cache.py

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
raise ValueError(
4848
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
4949
)
50+
5051
# For now supporting int8 only
5152
self.quantized_cache_dtype = torch.int8
5253
self.cache_fp_type = torch.float32
@@ -100,48 +101,75 @@ def update(self, input_pos, k_val, v_val):
100101

101102
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
102103

103-
if self.enable_dynamic_shape:
104-
start_pos = input_pos[0].item()
105-
torch._check_is_size(start_pos)
106-
dim_to_slice = 2 if self.is_transposed else 1
107-
torch._check(start_pos < self.k_cache.size(dim_to_slice))
108-
seq_length = k_val.size(dim_to_slice)
109-
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
110-
narrowed_k_scales = self.k_cache_scales.narrow(
111-
dim_to_slice, start_pos, seq_length
112-
)
113-
narrowed_k_zp = self.k_cache_zero_points.narrow(
114-
dim_to_slice, start_pos, seq_length
115-
)
116-
narrowed_k.copy_(quantized_k_val)
117-
narrowed_k_scales.copy_(k_scales)
118-
narrowed_k_zp.copy_(k_zero_points)
119-
# pyre-ignore: Incompatible parameter type [6]
120-
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
121-
narrowed_v_scales = self.v_cache_scales.narrow(
122-
dim_to_slice, start_pos, seq_length
123-
)
124-
narrowed_v_zp = self.v_cache_zero_points.narrow(
125-
dim_to_slice, start_pos, seq_length
126-
)
127-
narrowed_v.copy_(quantized_v_val)
128-
narrowed_v_scales.copy_(v_scales)
129-
narrowed_v_zp.copy_(v_zero_points)
130-
else:
131-
if self.is_transposed:
104+
if self.is_transposed:
105+
# We cannot use update_cache op at the moment
106+
# if the cache is transposed
107+
# Also note that we shold not need separate paths
108+
# for dynamic shape vs !
109+
# Only reason it is done this way is to accommodate
110+
# for lowering pains of backends that work better
111+
# with index_put op.
112+
if self.enable_dynamic_shape:
113+
start_pos = input_pos[0].item()
114+
torch._check_is_size(start_pos)
115+
dim_to_slice = 2 if self.is_transposed else 1
116+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
117+
seq_length = k_val.size(dim_to_slice)
118+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
119+
narrowed_k_scales = self.k_cache_scales.narrow(
120+
dim_to_slice, start_pos, seq_length
121+
)
122+
narrowed_k_zp = self.k_cache_zero_points.narrow(
123+
dim_to_slice, start_pos, seq_length
124+
)
125+
narrowed_k.copy_(quantized_k_val)
126+
narrowed_k_scales.copy_(k_scales)
127+
narrowed_k_zp.copy_(k_zero_points)
128+
# pyre-ignore: Incompatible parameter type [6]
129+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
130+
narrowed_v_scales = self.v_cache_scales.narrow(
131+
dim_to_slice, start_pos, seq_length
132+
)
133+
narrowed_v_zp = self.v_cache_zero_points.narrow(
134+
dim_to_slice, start_pos, seq_length
135+
)
136+
narrowed_v.copy_(quantized_v_val)
137+
narrowed_v_scales.copy_(v_scales)
138+
narrowed_v_zp.copy_(v_zero_points)
139+
else:
132140
self.k_cache[:, :, input_pos] = quantized_k_val
133141
self.k_cache_scales[:, :, input_pos] = k_scales
134142
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
135143
self.v_cache[:, :, input_pos] = quantized_v_val
136144
self.v_cache_scales[:, :, input_pos] = v_scales
137145
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
138-
else:
139-
self.k_cache[:, input_pos] = quantized_k_val
140-
self.k_cache_scales[:, input_pos] = k_scales
141-
self.k_cache_zero_points[:, input_pos] = k_zero_points
142-
self.v_cache[:, input_pos] = quantized_v_val
143-
self.v_cache_scales[:, input_pos] = v_scales
144-
self.v_cache_zero_points[:, input_pos] = v_zero_points
146+
else:
147+
# Right now using custom ops on this path.
148+
# In future we can update custom op to handle transposed cache
149+
# as well.
150+
# Note that we may have to revert this change if other ET
151+
# backends such as QNN want to use quantized cache, with dynamic shape,
152+
# instead of quantizing on their own.
153+
# But until this opting for code simplicity
154+
start_pos = input_pos[0].item()
155+
_ = torch.ops.llama.update_quantized_cache(
156+
quantized_k_val, self.k_cache, start_pos
157+
)
158+
_ = torch.ops.llama.update_quantized_cache(
159+
k_scales, self.k_cache_scales, start_pos
160+
)
161+
_ = torch.ops.llama.update_quantized_cache(
162+
k_zero_points, self.k_cache_zero_points, start_pos
163+
)
164+
_ = torch.ops.llama.update_quantized_cache(
165+
quantized_v_val, self.v_cache, start_pos
166+
)
167+
_ = torch.ops.llama.update_quantized_cache(
168+
v_scales, self.v_cache_scales, start_pos
169+
)
170+
_ = torch.ops.llama.update_quantized_cache(
171+
v_zero_points, self.v_cache_zero_points, start_pos
172+
)
145173

146174
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
147175
self.k_cache,

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import torch
1515

1616
from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA
17+
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
18+
QuantizedKVCache,
19+
)
1720

1821

1922
class SDPACustom(torch.nn.Module):
@@ -44,12 +47,27 @@ def forward(
4447
q = q.to(dtype=torch.float)
4548
k = k.to(dtype=torch.float)
4649
v = v.to(dtype=torch.float)
50+
51+
k_cache = self.kv_cache.k_cache
52+
v_cache = self.kv_cache.v_cache
53+
if isinstance(self.kv_cache, QuantizedKVCache):
54+
# updated quantize cache, scale and zero points
55+
# returns dequantized kv cache
56+
# Not most optimal. Optimizations to follow next
57+
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
58+
# Note that this path will still inplace mutate the k_cache, v_cache.
59+
# WHen we are not using quantized kv cache, this will just mutate
60+
# the original kv cache.
61+
# When we aer using quantized kv cache, this will mutate
62+
# k_cache, v_cache that is returned from cache update operation.
63+
# This operation just dequantized thee cache and returns that.
64+
# Future diffs will optimize this
4765
output = torch.ops.llama.sdpa_with_kv_cache(
4866
q,
4967
k,
5068
v,
51-
self.kv_cache.k_cache,
52-
self.kv_cache.v_cache,
69+
k_cache,
70+
v_cache,
5371
input_pos[-1].item(),
5472
seqlen,
5573
None, # Attention mask
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import unittest
2+
3+
import torch
4+
5+
from executorch.examples.models.llama2.llama_transformer import KVCache
6+
7+
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
8+
QuantizedCacheType,
9+
QuantizedKVCache,
10+
)
11+
12+
from executorch.examples.models.llama2.source_transformation.sdpa import SDPACustom
13+
14+
15+
class SDPAWithQuantizedKVCacheTest(unittest.TestCase):
16+
17+
def _init_cache(self):
18+
self.kv_cache = KVCache(
19+
self.max_batch_size,
20+
self.max_seq_len,
21+
self.n_kv_heads,
22+
self.head_dim,
23+
False,
24+
self.enable_dynamic_shape,
25+
dtype=self.dtype,
26+
)
27+
self.quantized_kv_cache = QuantizedKVCache.from_float(
28+
self.kv_cache, QuantizedCacheType.AffineAsymmetric
29+
)
30+
31+
def _init_kv(self):
32+
kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
33+
q_shape = (1, self.seq_len, self.n_heads, self.head_dim)
34+
q = torch.rand(q_shape, dtype=self.dtype)
35+
k = torch.rand(kv_shape, dtype=self.dtype)
36+
v = torch.rand(kv_shape, dtype=self.dtype)
37+
return q, k, v
38+
39+
def setUp(self):
40+
torch.manual_seed(42)
41+
self.max_batch_size = 1
42+
self.max_seq_len = 5
43+
self.n_kv_heads = 4
44+
self.n_heads = 8
45+
self.head_dim = 17
46+
self.dim = self.n_heads * self.head_dim
47+
self.enable_dynamic_shape = False
48+
self.dtype = torch.float32
49+
50+
def test_simple(self, is_dynamic_shape=False):
51+
self.enable_dynamic_shape = is_dynamic_shape
52+
input_pos = torch.tensor([0], dtype=torch.int64)
53+
self.seq_len = 3
54+
self._init_cache()
55+
q, k, v = self._init_kv()
56+
self.float_sdpa = SDPACustom(self.kv_cache, self.dim)
57+
self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
58+
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
59+
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
60+
self.assertTrue(
61+
torch.allclose(
62+
float_out,
63+
quantized_out,
64+
)
65+
)
66+
67+
input_pos = torch.tensor([3], dtype=torch.int64)
68+
self.seq_len = 1
69+
q, k, v = self._init_kv()
70+
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
71+
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
72+
self.assertTrue(
73+
torch.allclose(
74+
float_out,
75+
quantized_out,
76+
rtol=1e-03,
77+
atol=1e-03,
78+
)
79+
)

extension/llm/custom_ops/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def define_common_targets():
2020
"op_sdpa.h",
2121
],
2222
exported_deps = [
23+
":update_quantized_cache",
2324
"//executorch/runtime/kernel:kernel_includes",
2425
"//executorch/kernels/portable/cpu:scalar_utils",
2526
"//executorch/kernels/optimized:libblas{}".format(mkl_dep),

0 commit comments

Comments
 (0)