Skip to content

Commit 750e7da

Browse files
committed
Changes to split kv cache and sdpa
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6356acb Pull Request resolved: #7413
1 parent f94dda6 commit 750e7da

File tree

3 files changed

+66
-149
lines changed

3 files changed

+66
-149
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,16 @@ def __init__(
232232
max_seq_length: int,
233233
n_heads: int,
234234
head_dim: int,
235-
transpose_cache: bool,
236235
enable_dynamic_shape: bool,
237236
dtype=torch.float32,
238237
):
239238
super().__init__()
240239
self.max_seq_length = max_seq_length
241-
self.is_transposed = transpose_cache
242-
if transpose_cache:
243-
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
244-
else:
245-
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
240+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
246241

247242
self.max_batch_size = max_batch_size
248243
self.n_heads = n_heads
249244
self.head_dim = head_dim
250-
self.transpose_cache = transpose_cache
251245
self.enable_dynamic_shape = enable_dynamic_shape
252246
self.register_buffer(
253247
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
@@ -259,12 +253,12 @@ def __init__(
259253
def update(
260254
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
261255
) -> Tuple[torch.Tensor, torch.Tensor]:
262-
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
256+
# input_pos: [S], k_val: [B, H, S, D]
263257
if self.enable_dynamic_shape:
264258
start_pos = input_pos[0].item()
265259
torch._check_is_size(start_pos)
266260
torch._check(start_pos < self.max_seq_length)
267-
dim_to_slice = 2 if self.transpose_cache else 1
261+
dim_to_slice = 2
268262
seq_length = k_val.size(dim_to_slice)
269263
# Replace the entry in the cache for this token
270264
# The following lines are equivalent to:
@@ -283,28 +277,22 @@ def update(
283277
else:
284278
k_out = self.k_cache
285279
v_out = self.v_cache
286-
if self.transpose_cache:
287-
k_out[:, :, input_pos] = k_val
288-
v_out[:, :, input_pos] = v_val
289-
else:
290-
k_out[:, input_pos] = k_val
291-
v_out[:, input_pos] = v_val
280+
k_out[:, :, input_pos] = k_val
281+
v_out[:, :, input_pos] = v_val
292282

293283
return k_out, v_out
294284

295285

296286
class SDPA(nn.Module):
297287
def __init__(
298288
self,
299-
kv_cache: KVCache,
300289
dim: int,
301290
head_dim: int,
302291
n_rep: int,
303292
max_seq_len: int,
304293
enable_dynamic_shape: bool,
305294
):
306295
super().__init__()
307-
self.kv_cache = kv_cache
308296
self.dim = dim
309297
self.head_dim = head_dim
310298
self.n_rep = n_rep
@@ -314,18 +302,16 @@ def __init__(
314302
def forward(
315303
self,
316304
input_pos: torch.Tensor,
317-
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
318-
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
319-
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
305+
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
306+
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
307+
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
320308
bsz,
321309
seqlen,
322310
mask: torch.Tensor,
323311
) -> torch.Tensor:
324-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
325-
k = k.transpose(1, 2)
326-
v = v.transpose(1, 2)
327312

328-
k, v = self.kv_cache.update(input_pos, k, v)
313+
# TODO(kimishpatel): Move this slicing logic to Attention block so that
314+
# SDPA does not have to take input_pos as arg
329315
if self.enable_dynamic_shape:
330316
start_pos = input_pos[-1].item()
331317
torch._check_is_size(start_pos)
@@ -336,6 +322,8 @@ def forward(
336322
else:
337323
attn_mask = mask[None, None, input_pos]
338324

325+
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
326+
# can natively support GQA now. But needs enable_gqa=True
339327
k = k.repeat_interleave(self.n_rep, dim=1)
340328
v = v.repeat_interleave(self.n_rep, dim=1)
341329
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
@@ -383,11 +371,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
383371
args.max_seq_len,
384372
self.n_kv_heads,
385373
self.head_dim,
386-
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
387374
args.enable_dynamic_shape,
388375
)
389376
self.SDPA = SDPA(
390-
kv_cache=self.kv_cache,
391377
dim=self.n_local_heads * self.head_dim,
392378
head_dim=self.head_dim,
393379
n_rep=self.n_rep,
@@ -414,15 +400,16 @@ def forward(
414400
# RoPE relative positional embeddings
415401
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
416402

403+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
404+
k = k.transpose(1, 2)
405+
v = v.transpose(1, 2)
406+
417407
if self.use_kv_cache:
418408
assert input_pos is not None
409+
k, v = self.kv_cache.update(input_pos, k, v)
419410
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
420411
return self.wo(output)
421412

422-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
423-
k = k.transpose(1, 2)
424-
v = v.transpose(1, 2)
425-
426413
# grouped multiquery attention: expand out keys and values
427414
k = k.repeat_interleave(self.n_rep, dim=1)
428415
v = v.repeat_interleave(self.n_rep, dim=1)

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 42 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def __init__(
3737
n_heads,
3838
head_dim,
3939
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
40-
tranposed=False,
41-
enable_dynamic_shape=False,
4240
):
4341
super().__init__()
4442
if cache_type not in (
@@ -52,14 +50,8 @@ def __init__(
5250
# For now supporting int8 only
5351
self.quantized_cache_dtype = torch.int8
5452
self.cache_fp_type = torch.float32
55-
self.is_transposed = tranposed
56-
self.enable_dynamic_shape = enable_dynamic_shape
57-
if self.is_transposed:
58-
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
59-
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
60-
else:
61-
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
62-
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
53+
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
54+
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
6355
self.register_buffer(
6456
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
6557
)
@@ -98,71 +90,37 @@ def _quantize(self, value):
9890
return quantized_value, scales, zero_points
9991

10092
def update(self, input_pos, k_val, v_val):
93+
"""
94+
k_val, v_val: [B, H, S, D]
95+
return: [B, H, S, D]
96+
However the storage is [B, S, H, D] so we incur transpose in, transpose out
97+
This shall be removed by subsequent post-export graph pass
98+
"""
99+
k_val = k_val.transpose(1, 2)
100+
v_val = v_val.transpose(1, 2)
101101
# quantize current k_val and store it in the cache
102102
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
103103

104104
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
105105

106-
if self.is_transposed:
107-
# We cannot use update_cache op at the moment
108-
# if the cache is transposed
109-
# Also note that we shold not need separate paths
110-
# for dynamic shape vs !
111-
# Only reason it is done this way is to accommodate
112-
# for lowering pains of backends that work better
113-
# with index_put op.
114-
if self.enable_dynamic_shape:
115-
start_pos = input_pos[0].item()
116-
torch._check_is_size(start_pos)
117-
dim_to_slice = 2 if self.is_transposed else 1
118-
torch._check(start_pos < self.k_cache.size(dim_to_slice))
119-
seq_length = k_val.size(dim_to_slice)
120-
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
121-
narrowed_k_scales = self.k_cache_scales.narrow(
122-
dim_to_slice, start_pos, seq_length
123-
)
124-
narrowed_k_zp = self.k_cache_zero_points.narrow(
125-
dim_to_slice, start_pos, seq_length
126-
)
127-
narrowed_k.copy_(quantized_k_val)
128-
narrowed_k_scales.copy_(k_scales)
129-
narrowed_k_zp.copy_(k_zero_points)
130-
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
131-
narrowed_v_scales = self.v_cache_scales.narrow(
132-
dim_to_slice, start_pos, seq_length
133-
)
134-
narrowed_v_zp = self.v_cache_zero_points.narrow(
135-
dim_to_slice, start_pos, seq_length
136-
)
137-
narrowed_v.copy_(quantized_v_val)
138-
narrowed_v_scales.copy_(v_scales)
139-
narrowed_v_zp.copy_(v_zero_points)
140-
else:
141-
self.k_cache[:, :, input_pos] = quantized_k_val
142-
self.k_cache_scales[:, :, input_pos] = k_scales
143-
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
144-
self.v_cache[:, :, input_pos] = quantized_v_val
145-
self.v_cache_scales[:, :, input_pos] = v_scales
146-
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
147-
else:
148-
# Right now using custom ops on this path.
149-
# In future we can update custom op to handle transposed cache
150-
# as well.
151-
# Note that we may have to revert this change if other ET
152-
# backends such as QNN want to use quantized cache, with dynamic shape,
153-
# instead of quantizing on their own.
154-
# But until this opting for code simplicity
155-
start_pos = input_pos[0].item()
156-
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
157-
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
158-
_ = torch.ops.llama.update_cache(
159-
k_zero_points, self.k_cache_zero_points, start_pos
160-
)
161-
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
162-
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
163-
_ = torch.ops.llama.update_cache(
164-
v_zero_points, self.v_cache_zero_points, start_pos
165-
)
106+
# Right now using custom ops on this path.
107+
# In future we can update custom op to handle transposed cache
108+
# as well.
109+
# Note that we may have to revert this change if other ET
110+
# backends such as QNN want to use quantized cache, with dynamic shape,
111+
# instead of quantizing on their own.
112+
# But until this opting for code simplicity
113+
start_pos = input_pos[0].item()
114+
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
115+
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
116+
_ = torch.ops.llama.update_cache(
117+
k_zero_points, self.k_cache_zero_points, start_pos
118+
)
119+
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
120+
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
121+
_ = torch.ops.llama.update_cache(
122+
v_zero_points, self.v_cache_zero_points, start_pos
123+
)
166124

167125
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
168126
self.k_cache,
@@ -183,42 +141,24 @@ def update(self, input_pos, k_val, v_val):
183141
self.cache_fp_type,
184142
)
185143

186-
if self.is_transposed:
187-
if self.enable_dynamic_shape:
188-
start_pos = input_pos[0].item()
189-
torch._check_is_size(start_pos)
190-
dim_to_slice = 2 if self.is_transposed else 1
191-
torch._check(start_pos < self.k_cache.size(dim_to_slice))
192-
seq_length = k_val.size(dim_to_slice)
193-
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
194-
narrowed_k.copy_(k_val)
195-
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
196-
narrowed_v.copy_(v_val)
197-
else:
198-
k_out[:, :, input_pos] = k_val
199-
v_out[:, :, input_pos] = v_val
200-
else:
201-
start_pos = input_pos[0].item()
202-
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
203-
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
144+
start_pos = input_pos[0].item()
145+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
146+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
204147

205-
return k_out, v_out
148+
return k_out.transpose(1, 2), v_out.transpose(1, 2)
206149

207150
@classmethod
208151
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
209-
cache_shape = kv_cache.k_cache.shape
210-
if kv_cache.is_transposed:
211-
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
212-
else:
213-
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
152+
max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape
153+
if isinstance(kv_cache, CustomKVCache):
154+
# If replacing custom kv cache, then the shape is [B, S, H, D]
155+
max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape
214156
return cls(
215157
max_batch_size,
216158
max_seq_length,
217159
n_heads,
218160
head_dim,
219161
cache_type,
220-
kv_cache.is_transposed,
221-
kv_cache.enable_dynamic_shape,
222162
)
223163

224164

@@ -254,7 +194,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
254194
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
255195
)
256196
for name, child in module.named_children():
257-
if isinstance(child, KVCache):
197+
if isinstance(child, KVCache) or isinstance(child, CustomKVCache):
258198
setattr(
259199
module,
260200
name,
@@ -291,11 +231,13 @@ def __init__(
291231
def update(
292232
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
293233
) -> Tuple[torch.Tensor, torch.Tensor]:
294-
# input_pos: [S], k_val: [B, S, H, D]
234+
# input_pos: [S], k_val: [B, H, S, D]
235+
k_val = k_val.transpose(1, 2)
236+
v_val = v_val.transpose(1, 2)
295237
start_pos = input_pos[0].item()
296238
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
297239
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
298-
return self.k_cache, self.v_cache
240+
return self.k_cache.transpose(1, 2), self.v_cache.transpose(1, 2)
299241

300242

301243
def replace_kv_cache_with_custom_kv_cache(module):
@@ -313,10 +255,7 @@ def replace_kv_cache_with_custom_kv_cache(module):
313255
if isinstance(child, KVCache):
314256
cache_shape = child.k_cache.shape
315257
cache_dtype = child.k_cache.dtype
316-
assert (
317-
child.is_transposed is False
318-
), "CustomKVCache does not support transposed cache"
319-
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
258+
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
320259
setattr(
321260
module,
322261
name,

0 commit comments

Comments
 (0)