Skip to content

Commit f369ada

Browse files
committed
Make kv cache pos buffer name more specific
1 parent 2a4256b commit f369ada

File tree

5 files changed

+19
-16
lines changed

5 files changed

+19
-16
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
778778

779779
additional_passes = []
780780
if args.model in TORCHTUNE_DEFINED_MODELS:
781-
additional_passes = [InitializedMutableBufferPass(["cache_pos"])]
781+
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
782782
if args.generate_etrecord:
783783
if not builder_exported_to_edge.edge_manager:
784784
raise ValueError("Unable to generate etrecord due to missing edge manager.")

extension/llm/modules/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,13 +284,13 @@ def calculate_kv(y):
284284

285285
def true_fn(y):
286286
kv_cache = self.kv_cache.clone()
287-
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
287+
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.kv_cache_pos
288288

289289
def false_fn(y):
290290
k, v = calculate_kv(y)
291291
kv_cache = self.kv_cache.clone()
292292
kv_cache.update(k, v)
293-
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
293+
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.kv_cache_pos
294294

295295
# If kv cache is None, we expect y to be provided
296296
if self.kv_cache is None:
@@ -308,7 +308,7 @@ def false_fn(y):
308308
# Update key-value cache
309309
self.kv_cache.k_cache.copy_(k)
310310
self.kv_cache.v_cache.copy_(v)
311-
self.kv_cache.cache_pos.copy_(cache_pos)
311+
self.kv_cache.kv_cache_pos.copy_(cache_pos)
312312

313313
output = self._sdpa(q, k, v, b, s_x, mask=mask)
314314
return self.output_proj(output)

extension/llm/modules/kv_cache.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,11 @@ def __init__(
5555
self.register_buffer(
5656
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
5757
)
58+
# We use "kv_cache_pos" here instead of "cache_pos" since the latter is too generic, and we have
59+
# a InitMutableBuferPass that needs to single out this buffer to initialize (and not others)
60+
# since it takes up space in the pte file.
5861
self.register_buffer(
59-
"cache_pos", torch.arange(0, self.max_seq_len), persistent=False
62+
"kv_cache_pos", torch.arange(0, self.max_seq_len), persistent=False
6063
)
6164
self.batch_size = batch_size
6265

@@ -105,17 +108,17 @@ def update(
105108
f", but found new key tensors with batch size {k_val.shape[0]}!"
106109
)
107110

108-
assert (self.cache_pos[0] + seq_len) <= self.max_seq_len
111+
assert (self.kv_cache_pos[0] + seq_len) <= self.max_seq_len
109112

110113
k_out = self.k_cache
111114
v_out = self.v_cache
112115

113116
if self.transpose_cache:
114-
k_out[:, :, self.cache_pos[:seq_len]] = k_val
115-
v_out[:, :, self.cache_pos[:seq_len]] = v_val
117+
k_out[:, :, self.kv_cache_pos[:seq_len]] = k_val
118+
v_out[:, :, self.kv_cache_pos[:seq_len]] = v_val
116119
else:
117-
k_out[:, self.cache_pos[:seq_len]] = k_val
118-
v_out[:, self.cache_pos[:seq_len]] = v_val
120+
k_out[:, self.kv_cache_pos[:seq_len]] = k_val
121+
v_out[:, self.kv_cache_pos[:seq_len]] = v_val
119122

120123
# forward cache_pos seq_len positions along
121124
# cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
@@ -124,7 +127,7 @@ def update(
124127
# this allows us to track the current position in the cache
125128
# after the last update in a compile-friendly way without any dynamism
126129
# e.g. relying on an int size tracker, or re-creating cache_pos every time
127-
self.cache_pos.add_(seq_len)
130+
self.kv_cache_pos.add_(seq_len)
128131

129132
return k_out, v_out
130133

@@ -144,5 +147,5 @@ def clone(self) -> "KVCache":
144147
)
145148
clone.k_cache.copy_(self.k_cache)
146149
clone.v_cache.copy_(self.v_cache)
147-
clone.cache_pos.copy_(self.cache_pos)
150+
clone.kv_cache_pos.copy_(self.kv_cache_pos)
148151
return clone

extension/llm/modules/test/test_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_attention_executorch(self):
219219
),
220220
).to_executorch(
221221
config=ExecutorchBackendConfig(
222-
passes=[InitializedMutableBufferPass(["cache_pos"])],
222+
passes=[InitializedMutableBufferPass(["kv_cache_pos"])],
223223
)
224224
)
225225

@@ -330,7 +330,7 @@ def test_attention_torch_cond_executorch(self):
330330
),
331331
).to_executorch(
332332
config=ExecutorchBackendConfig(
333-
passes=[InitializedMutableBufferPass(["cache_pos"])],
333+
passes=[InitializedMutableBufferPass(["kv_cache_pos"])],
334334
)
335335
)
336336

extension/llm/modules/test/test_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def test_kv_cache_executorch(self):
174174
),
175175
)
176176
et_config = ExecutorchBackendConfig(
177-
passes=[InitializedMutableBufferPass(["cache_pos"])],
177+
passes=[InitializedMutableBufferPass(["kv_cache_pos"])],
178178
)
179179
et_program = edge_program.to_executorch(config=et_config)
180180

@@ -198,7 +198,7 @@ def test_kv_cache_executorch_from_file(self):
198198
),
199199
)
200200
et_config = ExecutorchBackendConfig(
201-
passes=[InitializedMutableBufferPass(["cache_pos"])],
201+
passes=[InitializedMutableBufferPass(["kv_cache_pos"])],
202202
)
203203
et_program = edge_program.to_executorch(config=et_config)
204204

0 commit comments

Comments
 (0)