Skip to content

Commit bca3ad6

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Update SDPA op to use quantized kv cache (#5666)
Summary: Pull Request resolved: #5666 Using quantized kv cache, we cannot rely on sdpa to update the original case. SO we insert cache update op ghstack-source-id: 245751546 exported-using-ghexport //oss complaining of internal lint bypass-github-export-checks exported-using-ghexport Reviewed By: swolchok Differential Revision: D62301841 fbshipit-source-id: 4ca1c27e7bb8c13604b2a2aa57efa573482be7cd
1 parent 5f324ce commit bca3ad6

File tree

5 files changed

+207
-45
lines changed

5 files changed

+207
-45
lines changed

examples/models/llama2/TARGETS

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,44 @@ runtime.python_library(
168168
],
169169
)
170170

171+
runtime.python_library(
172+
name = "sdpa",
173+
srcs = [
174+
"source_transformation/sdpa.py",
175+
],
176+
_is_external_target = True,
177+
visibility = ["//executorch/..."],
178+
deps = [
179+
"//caffe2:torch",
180+
],
181+
)
182+
171183
runtime.python_test(
172184
name = "quantized_kv_cache_test",
173185
srcs = [
174186
"source_transformation/test_quantized_kv_cache.py",
175187
],
188+
preload_deps = [
189+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
190+
],
191+
deps = [
192+
":quantized_kv_cache",
193+
"//caffe2:torch",
194+
"//executorch/examples/models/llama2:llama_transformer",
195+
],
196+
)
197+
198+
runtime.python_test(
199+
name = "quantized_sdpa_with_kv_cache_test",
200+
srcs = [
201+
"source_transformation/test_sdpa_with_quantized_kv_cache.py",
202+
],
203+
preload_deps = [
204+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
205+
],
176206
deps = [
177207
":quantized_kv_cache",
208+
":sdpa",
178209
"//caffe2:torch",
179210
"//executorch/examples/models/llama2:llama_transformer",
180211
],

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -890,9 +890,7 @@ def _get_source_transforms( # noqa
890890
transforms.append(replace_sdpa_with_custom_op)
891891

892892
if args.quantize_kv_cache:
893-
assert (
894-
args.use_kv_cache and not args.use_sdpa_with_kv_cache
895-
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
893+
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
896894
transforms.append(replace_kv_cache_with_quantized_kv_cache)
897895

898896
if args.use_kv_cache:

examples/models/llama2/source_transformation/quantized_kv_cache.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
raise ValueError(
4848
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
4949
)
50+
5051
# For now supporting int8 only
5152
self.quantized_cache_dtype = torch.int8
5253
self.cache_fp_type = torch.float32
@@ -65,10 +66,10 @@ def __init__(
6566
"v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
6667
)
6768
self.register_buffer(
68-
"k_cache_scales", torch.ones(scale_shape, dtype=torch.double)
69+
"k_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
6970
)
7071
self.register_buffer(
71-
"v_cache_scales", torch.ones(scale_shape, dtype=torch.double)
72+
"v_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
7273
)
7374
if cache_type == QuantizedCacheType.AffineAsymmetric:
7475
self.register_buffer(
@@ -100,47 +101,74 @@ def update(self, input_pos, k_val, v_val):
100101

101102
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
102103

103-
if self.enable_dynamic_shape:
104-
start_pos = input_pos[0].item()
105-
torch._check_is_size(start_pos)
106-
dim_to_slice = 2 if self.is_transposed else 1
107-
torch._check(start_pos < self.k_cache.size(dim_to_slice))
108-
seq_length = k_val.size(dim_to_slice)
109-
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
110-
narrowed_k_scales = self.k_cache_scales.narrow(
111-
dim_to_slice, start_pos, seq_length
112-
)
113-
narrowed_k_zp = self.k_cache_zero_points.narrow(
114-
dim_to_slice, start_pos, seq_length
115-
)
116-
narrowed_k.copy_(quantized_k_val)
117-
narrowed_k_scales.copy_(k_scales)
118-
narrowed_k_zp.copy_(k_zero_points)
119-
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
120-
narrowed_v_scales = self.v_cache_scales.narrow(
121-
dim_to_slice, start_pos, seq_length
122-
)
123-
narrowed_v_zp = self.v_cache_zero_points.narrow(
124-
dim_to_slice, start_pos, seq_length
125-
)
126-
narrowed_v.copy_(quantized_v_val)
127-
narrowed_v_scales.copy_(v_scales)
128-
narrowed_v_zp.copy_(v_zero_points)
129-
else:
130-
if self.is_transposed:
104+
if self.is_transposed:
105+
# We cannot use update_cache op at the moment
106+
# if the cache is transposed
107+
# Also note that we shold not need separate paths
108+
# for dynamic shape vs !
109+
# Only reason it is done this way is to accommodate
110+
# for lowering pains of backends that work better
111+
# with index_put op.
112+
if self.enable_dynamic_shape:
113+
start_pos = input_pos[0].item()
114+
torch._check_is_size(start_pos)
115+
dim_to_slice = 2 if self.is_transposed else 1
116+
torch._check(start_pos < self.k_cache.size(dim_to_slice))
117+
seq_length = k_val.size(dim_to_slice)
118+
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
119+
narrowed_k_scales = self.k_cache_scales.narrow(
120+
dim_to_slice, start_pos, seq_length
121+
)
122+
narrowed_k_zp = self.k_cache_zero_points.narrow(
123+
dim_to_slice, start_pos, seq_length
124+
)
125+
narrowed_k.copy_(quantized_k_val)
126+
narrowed_k_scales.copy_(k_scales)
127+
narrowed_k_zp.copy_(k_zero_points)
128+
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
129+
narrowed_v_scales = self.v_cache_scales.narrow(
130+
dim_to_slice, start_pos, seq_length
131+
)
132+
narrowed_v_zp = self.v_cache_zero_points.narrow(
133+
dim_to_slice, start_pos, seq_length
134+
)
135+
narrowed_v.copy_(quantized_v_val)
136+
narrowed_v_scales.copy_(v_scales)
137+
narrowed_v_zp.copy_(v_zero_points)
138+
else:
131139
self.k_cache[:, :, input_pos] = quantized_k_val
132140
self.k_cache_scales[:, :, input_pos] = k_scales
133141
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
134142
self.v_cache[:, :, input_pos] = quantized_v_val
135143
self.v_cache_scales[:, :, input_pos] = v_scales
136144
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
137-
else:
138-
self.k_cache[:, input_pos] = quantized_k_val
139-
self.k_cache_scales[:, input_pos] = k_scales
140-
self.k_cache_zero_points[:, input_pos] = k_zero_points
141-
self.v_cache[:, input_pos] = quantized_v_val
142-
self.v_cache_scales[:, input_pos] = v_scales
143-
self.v_cache_zero_points[:, input_pos] = v_zero_points
145+
else:
146+
# Right now using custom ops on this path.
147+
# In future we can update custom op to handle transposed cache
148+
# as well.
149+
# Note that we may have to revert this change if other ET
150+
# backends such as QNN want to use quantized cache, with dynamic shape,
151+
# instead of quantizing on their own.
152+
# But until this opting for code simplicity
153+
start_pos = input_pos[0].item()
154+
_ = torch.ops.llama.update_quantized_cache(
155+
quantized_k_val, self.k_cache, start_pos
156+
)
157+
_ = torch.ops.llama.update_quantized_cache(
158+
k_scales, self.k_cache_scales, start_pos
159+
)
160+
_ = torch.ops.llama.update_quantized_cache(
161+
k_zero_points, self.k_cache_zero_points, start_pos
162+
)
163+
_ = torch.ops.llama.update_quantized_cache(
164+
quantized_v_val, self.v_cache, start_pos
165+
)
166+
_ = torch.ops.llama.update_quantized_cache(
167+
v_scales, self.v_cache_scales, start_pos
168+
)
169+
_ = torch.ops.llama.update_quantized_cache(
170+
v_zero_points, self.v_cache_zero_points, start_pos
171+
)
144172

145173
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
146174
self.k_cache,

examples/models/llama2/source_transformation/sdpa.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,32 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12-
from typing import Tuple
12+
from typing import Tuple, Union
1313

1414
import torch
1515

1616
from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA
17+
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
18+
QuantizedKVCache,
19+
)
1720

1821

1922
class SDPACustom(torch.nn.Module):
2023
def __init__(
2124
self,
22-
kv_cache: KVCache,
25+
kv_cache: Union[KVCache, QuantizedKVCache],
2326
dim: int,
2427
):
2528
super().__init__()
2629
# Custom op only supports float32 currently. Converting to/from float32 is
2730
# faster than not having the op.
28-
self.kv_cache = kv_cache.to(torch.float)
31+
self.kv_cache = kv_cache
32+
if not isinstance(kv_cache, QuantizedKVCache):
33+
self.kv_cache = kv_cache.to(torch.float)
34+
else:
35+
assert (
36+
kv_cache.cache_fp_type == torch.float32
37+
), "Only float32 is supported for custom SDPA"
2938
self.dim = dim
3039

3140
def forward(
@@ -44,12 +53,27 @@ def forward(
4453
q = q.to(dtype=torch.float)
4554
k = k.to(dtype=torch.float)
4655
v = v.to(dtype=torch.float)
56+
57+
k_cache = self.kv_cache.k_cache
58+
v_cache = self.kv_cache.v_cache
59+
if isinstance(self.kv_cache, QuantizedKVCache):
60+
# updated quantize cache, scale and zero points
61+
# returns dequantized kv cache
62+
# Not most optimal. Optimizations to follow next
63+
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
64+
# Note that this path will still inplace mutate the k_cache, v_cache.
65+
# WHen we are not using quantized kv cache, this will just mutate
66+
# the original kv cache.
67+
# When we aer using quantized kv cache, this will mutate
68+
# k_cache, v_cache that is returned from cache update operation.
69+
# This operation just dequantized thee cache and returns that.
70+
# Future diffs will optimize this
4771
output = torch.ops.llama.sdpa_with_kv_cache(
4872
q,
4973
k,
5074
v,
51-
self.kv_cache.k_cache,
52-
self.kv_cache.v_cache,
75+
k_cache,
76+
v_cache,
5377
input_pos[-1].item(),
5478
seqlen,
5579
None, # Attention mask
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.examples.models.llama2.llama_transformer import KVCache
12+
13+
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
14+
QuantizedCacheType,
15+
QuantizedKVCache,
16+
)
17+
18+
from executorch.examples.models.llama2.source_transformation.sdpa import SDPACustom
19+
20+
21+
class SDPAWithQuantizedKVCacheTest(unittest.TestCase):
22+
23+
def _init_cache(self):
24+
self.kv_cache = KVCache(
25+
self.max_batch_size,
26+
self.max_seq_len,
27+
self.n_kv_heads,
28+
self.head_dim,
29+
False,
30+
self.enable_dynamic_shape,
31+
dtype=self.dtype,
32+
)
33+
self.quantized_kv_cache = QuantizedKVCache.from_float(
34+
self.kv_cache, QuantizedCacheType.AffineAsymmetric
35+
)
36+
37+
def _init_kv(self):
38+
kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
39+
q_shape = (1, self.seq_len, self.n_heads, self.head_dim)
40+
q = torch.rand(q_shape, dtype=self.dtype)
41+
k = torch.rand(kv_shape, dtype=self.dtype)
42+
v = torch.rand(kv_shape, dtype=self.dtype)
43+
return q, k, v
44+
45+
def setUp(self):
46+
torch.manual_seed(42)
47+
self.max_batch_size = 1
48+
self.max_seq_len = 5
49+
self.n_kv_heads = 4
50+
self.n_heads = 8
51+
self.head_dim = 17
52+
self.dim = self.n_heads * self.head_dim
53+
self.enable_dynamic_shape = False
54+
self.dtype = torch.float32
55+
56+
def test_simple(self, is_dynamic_shape=False):
57+
self.enable_dynamic_shape = is_dynamic_shape
58+
input_pos = torch.tensor([0], dtype=torch.int64)
59+
self.seq_len = 3
60+
self._init_cache()
61+
q, k, v = self._init_kv()
62+
self.float_sdpa = SDPACustom(self.kv_cache, self.dim)
63+
self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
64+
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
65+
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
66+
torch.testing.assert_close(
67+
float_out,
68+
quantized_out,
69+
)
70+
71+
input_pos = torch.tensor([3], dtype=torch.int64)
72+
self.seq_len = 1
73+
q, k, v = self._init_kv()
74+
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
75+
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
76+
torch.testing.assert_close(
77+
float_out,
78+
quantized_out,
79+
rtol=1e-03,
80+
atol=1e-03,
81+
)

0 commit comments

Comments
 (0)