Skip to content

Commit 38c2d0d

Browse files
committed
[ExecuTorch][BE] Split kv cache and SDPA for better code sharing
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: 6289ce2 Pull Request resolved: #7413
1 parent d8e1d04 commit 38c2d0d

File tree

7 files changed

+335
-171
lines changed

7 files changed

+335
-171
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
657657
# export_to_edge
658658
builder_exported = _prepare_for_llama_export(args).export()
659659

660+
builder_exported.run_canonical_optimizations()
661+
660662
if args.export_only:
661663
exit()
662664

examples/models/llama/llama_transformer.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,16 @@ def __init__(
232232
max_seq_length: int,
233233
n_heads: int,
234234
head_dim: int,
235-
transpose_cache: bool,
236235
enable_dynamic_shape: bool,
237236
dtype=torch.float32,
238237
):
239238
super().__init__()
240239
self.max_seq_length = max_seq_length
241-
self.is_transposed = transpose_cache
242-
if transpose_cache:
243-
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
244-
else:
245-
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
240+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
246241

247242
self.max_batch_size = max_batch_size
248243
self.n_heads = n_heads
249244
self.head_dim = head_dim
250-
self.transpose_cache = transpose_cache
251245
self.enable_dynamic_shape = enable_dynamic_shape
252246
self.register_buffer(
253247
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
@@ -259,12 +253,12 @@ def __init__(
259253
def update(
260254
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
261255
) -> Tuple[torch.Tensor, torch.Tensor]:
262-
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
256+
# input_pos: [S], k_val: [B, H, S, D]
263257
if self.enable_dynamic_shape:
264258
start_pos = input_pos[0].item()
265259
torch._check_is_size(start_pos)
266260
torch._check(start_pos < self.max_seq_length)
267-
dim_to_slice = 2 if self.transpose_cache else 1
261+
dim_to_slice = 2
268262
seq_length = k_val.size(dim_to_slice)
269263
# Replace the entry in the cache for this token
270264
# The following lines are equivalent to:
@@ -283,28 +277,22 @@ def update(
283277
else:
284278
k_out = self.k_cache
285279
v_out = self.v_cache
286-
if self.transpose_cache:
287-
k_out[:, :, input_pos] = k_val
288-
v_out[:, :, input_pos] = v_val
289-
else:
290-
k_out[:, input_pos] = k_val
291-
v_out[:, input_pos] = v_val
280+
k_out[:, :, input_pos] = k_val
281+
v_out[:, :, input_pos] = v_val
292282

293283
return k_out, v_out
294284

295285

296286
class SDPA(nn.Module):
297287
def __init__(
298288
self,
299-
kv_cache: KVCache,
300289
dim: int,
301290
head_dim: int,
302291
n_rep: int,
303292
max_seq_len: int,
304293
enable_dynamic_shape: bool,
305294
):
306295
super().__init__()
307-
self.kv_cache = kv_cache
308296
self.dim = dim
309297
self.head_dim = head_dim
310298
self.n_rep = n_rep
@@ -314,18 +302,16 @@ def __init__(
314302
def forward(
315303
self,
316304
input_pos: torch.Tensor,
317-
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
318-
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
319-
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
305+
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
306+
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
307+
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
320308
bsz,
321309
seqlen,
322310
mask: torch.Tensor,
323311
) -> torch.Tensor:
324-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
325-
k = k.transpose(1, 2)
326-
v = v.transpose(1, 2)
327312

328-
k, v = self.kv_cache.update(input_pos, k, v)
313+
# TODO(kimishpatel): Move this slicing logic to Attention block so that
314+
# SDPA does not have to take input_pos as arg
329315
if self.enable_dynamic_shape:
330316
start_pos = input_pos[-1].item()
331317
torch._check_is_size(start_pos)
@@ -336,6 +322,8 @@ def forward(
336322
else:
337323
attn_mask = mask[None, None, input_pos]
338324

325+
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
326+
# can natively support GQA now. But needs enable_gqa=True
339327
k = k.repeat_interleave(self.n_rep, dim=1)
340328
v = v.repeat_interleave(self.n_rep, dim=1)
341329
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
@@ -383,11 +371,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
383371
args.max_seq_len,
384372
self.n_kv_heads,
385373
self.head_dim,
386-
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
387374
args.enable_dynamic_shape,
388375
)
389376
self.SDPA = SDPA(
390-
kv_cache=self.kv_cache,
391377
dim=self.n_local_heads * self.head_dim,
392378
head_dim=self.head_dim,
393379
n_rep=self.n_rep,
@@ -414,15 +400,16 @@ def forward(
414400
# RoPE relative positional embeddings
415401
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
416402

403+
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
404+
k = k.transpose(1, 2)
405+
v = v.transpose(1, 2)
406+
417407
if self.use_kv_cache:
418408
assert input_pos is not None
409+
k, v = self.kv_cache.update(input_pos, k, v)
419410
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
420411
return self.wo(output)
421412

422-
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
423-
k = k.transpose(1, 2)
424-
v = v.transpose(1, 2)
425-
426413
# grouped multiquery attention: expand out keys and values
427414
k = k.repeat_interleave(self.n_rep, dim=1)
428415
v = v.repeat_interleave(self.n_rep, dim=1)

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 42 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def __init__(
3737
n_heads,
3838
head_dim,
3939
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
40-
tranposed=False,
41-
enable_dynamic_shape=False,
4240
):
4341
super().__init__()
4442
if cache_type not in (
@@ -52,14 +50,8 @@ def __init__(
5250
# For now supporting int8 only
5351
self.quantized_cache_dtype = torch.int8
5452
self.cache_fp_type = torch.float32
55-
self.is_transposed = tranposed
56-
self.enable_dynamic_shape = enable_dynamic_shape
57-
if self.is_transposed:
58-
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
59-
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
60-
else:
61-
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
62-
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
53+
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
54+
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
6355
self.register_buffer(
6456
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
6557
)
@@ -98,71 +90,37 @@ def _quantize(self, value):
9890
return quantized_value, scales, zero_points
9991

10092
def update(self, input_pos, k_val, v_val):
93+
"""
94+
k_val, v_val: [B, H, S, D]
95+
return: [B, H, S, D]
96+
However the storage is [B, S, H, D] so we incur transpose in, transpose out
97+
This shall be removed by subsequent post-export graph pass
98+
"""
99+
k_val = k_val.transpose(1, 2)
100+
v_val = v_val.transpose(1, 2)
101101
# quantize current k_val and store it in the cache
102102
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
103103

104104
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
105105

106-
if self.is_transposed:
107-
# We cannot use update_cache op at the moment
108-
# if the cache is transposed
109-
# Also note that we shold not need separate paths
110-
# for dynamic shape vs !
111-
# Only reason it is done this way is to accommodate
112-
# for lowering pains of backends that work better
113-
# with index_put op.
114-
if self.enable_dynamic_shape:
115-
start_pos = input_pos[0].item()
116-
torch._check_is_size(start_pos)
117-
dim_to_slice = 2 if self.is_transposed else 1
118-
torch._check(start_pos < self.k_cache.size(dim_to_slice))
119-
seq_length = k_val.size(dim_to_slice)
120-
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
121-
narrowed_k_scales = self.k_cache_scales.narrow(
122-
dim_to_slice, start_pos, seq_length
123-
)
124-
narrowed_k_zp = self.k_cache_zero_points.narrow(
125-
dim_to_slice, start_pos, seq_length
126-
)
127-
narrowed_k.copy_(quantized_k_val)
128-
narrowed_k_scales.copy_(k_scales)
129-
narrowed_k_zp.copy_(k_zero_points)
130-
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
131-
narrowed_v_scales = self.v_cache_scales.narrow(
132-
dim_to_slice, start_pos, seq_length
133-
)
134-
narrowed_v_zp = self.v_cache_zero_points.narrow(
135-
dim_to_slice, start_pos, seq_length
136-
)
137-
narrowed_v.copy_(quantized_v_val)
138-
narrowed_v_scales.copy_(v_scales)
139-
narrowed_v_zp.copy_(v_zero_points)
140-
else:
141-
self.k_cache[:, :, input_pos] = quantized_k_val
142-
self.k_cache_scales[:, :, input_pos] = k_scales
143-
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
144-
self.v_cache[:, :, input_pos] = quantized_v_val
145-
self.v_cache_scales[:, :, input_pos] = v_scales
146-
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
147-
else:
148-
# Right now using custom ops on this path.
149-
# In future we can update custom op to handle transposed cache
150-
# as well.
151-
# Note that we may have to revert this change if other ET
152-
# backends such as QNN want to use quantized cache, with dynamic shape,
153-
# instead of quantizing on their own.
154-
# But until this opting for code simplicity
155-
start_pos = input_pos[0].item()
156-
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
157-
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
158-
_ = torch.ops.llama.update_cache(
159-
k_zero_points, self.k_cache_zero_points, start_pos
160-
)
161-
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
162-
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
163-
_ = torch.ops.llama.update_cache(
164-
v_zero_points, self.v_cache_zero_points, start_pos
165-
)
106+
# Right now using custom ops on this path.
107+
# In future we can update custom op to handle transposed cache
108+
# as well.
109+
# Note that we may have to revert this change if other ET
110+
# backends such as QNN want to use quantized cache, with dynamic shape,
111+
# instead of quantizing on their own.
112+
# But until this opting for code simplicity
113+
start_pos = input_pos[0].item()
114+
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
115+
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
116+
_ = torch.ops.llama.update_cache(
117+
k_zero_points, self.k_cache_zero_points, start_pos
118+
)
119+
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
120+
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
121+
_ = torch.ops.llama.update_cache(
122+
v_zero_points, self.v_cache_zero_points, start_pos
123+
)
166124

167125
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
168126
self.k_cache,
@@ -183,42 +141,24 @@ def update(self, input_pos, k_val, v_val):
183141
self.cache_fp_type,
184142
)
185143

186-
if self.is_transposed:
187-
if self.enable_dynamic_shape:
188-
start_pos = input_pos[0].item()
189-
torch._check_is_size(start_pos)
190-
dim_to_slice = 2 if self.is_transposed else 1
191-
torch._check(start_pos < self.k_cache.size(dim_to_slice))
192-
seq_length = k_val.size(dim_to_slice)
193-
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
194-
narrowed_k.copy_(k_val)
195-
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
196-
narrowed_v.copy_(v_val)
197-
else:
198-
k_out[:, :, input_pos] = k_val
199-
v_out[:, :, input_pos] = v_val
200-
else:
201-
start_pos = input_pos[0].item()
202-
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
203-
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
144+
start_pos = input_pos[0].item()
145+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
146+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
204147

205-
return k_out, v_out
148+
return k_out.transpose(1, 2), v_out.transpose(1, 2)
206149

207150
@classmethod
208151
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
209-
cache_shape = kv_cache.k_cache.shape
210-
if kv_cache.is_transposed:
211-
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
212-
else:
213-
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
152+
max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape
153+
if isinstance(kv_cache, CustomKVCache):
154+
# If replacing custom kv cache, then the shape is [B, S, H, D]
155+
max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape
214156
return cls(
215157
max_batch_size,
216158
max_seq_length,
217159
n_heads,
218160
head_dim,
219161
cache_type,
220-
kv_cache.is_transposed,
221-
kv_cache.enable_dynamic_shape,
222162
)
223163

224164

@@ -254,7 +194,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
254194
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
255195
)
256196
for name, child in module.named_children():
257-
if isinstance(child, KVCache):
197+
if isinstance(child, KVCache) or isinstance(child, CustomKVCache):
258198
setattr(
259199
module,
260200
name,
@@ -291,11 +231,13 @@ def __init__(
291231
def update(
292232
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
293233
) -> Tuple[torch.Tensor, torch.Tensor]:
294-
# input_pos: [S], k_val: [B, S, H, D]
234+
# input_pos: [S], k_val: [B, H, S, D]
235+
k_val = k_val.transpose(1, 2)
236+
v_val = v_val.transpose(1, 2)
295237
start_pos = input_pos[0].item()
296238
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
297239
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
298-
return self.k_cache, self.v_cache
240+
return self.k_cache.transpose(1, 2), self.v_cache.transpose(1, 2)
299241

300242

301243
def replace_kv_cache_with_custom_kv_cache(module):
@@ -313,10 +255,7 @@ def replace_kv_cache_with_custom_kv_cache(module):
313255
if isinstance(child, KVCache):
314256
cache_shape = child.k_cache.shape
315257
cache_dtype = child.k_cache.dtype
316-
assert (
317-
child.is_transposed is False
318-
), "CustomKVCache does not support transposed cache"
319-
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
258+
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
320259
setattr(
321260
module,
322261
name,

0 commit comments

Comments
 (0)