Skip to content

Commit 5ed2284

Browse files
committed
Update on "[Executorch][llm] Enable leveraging ring kv cache via module swap"
This allows us to make some of the attention modules to use sliding window kv cache. Will help enable models like gemma3. Differential Revision: [D73891426](https://our.internmc.facebook.com/intern/diff/D73891426/) [ghstack-poisoned]
2 parents 3375f85 + cb6b2bf commit 5ed2284

File tree

6 files changed

+225
-72
lines changed

6 files changed

+225
-72
lines changed

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -110,24 +110,44 @@ def _quantize_and_update(self, input_pos, k_val, v_val, indices=None):
110110

111111
if self.use_custom_update_cache_op:
112112
start_pos = input_pos[0].item()
113-
_ = torch.ops.llama.update_cache(
114-
quantized_k_val, self.k_cache, start_pos, indices
115-
)
116-
_ = torch.ops.llama.update_cache(
117-
k_scales, self.k_cache_scales, start_pos, indices
118-
)
119-
_ = torch.ops.llama.update_cache(
120-
k_zero_points, self.k_cache_zero_points, start_pos, indices
121-
)
122-
_ = torch.ops.llama.update_cache(
123-
quantized_v_val, self.v_cache, start_pos, indices
124-
)
125-
_ = torch.ops.llama.update_cache(
126-
v_scales, self.v_cache_scales, start_pos, indices
127-
)
128-
_ = torch.ops.llama.update_cache(
129-
v_zero_points, self.v_cache_zero_points, start_pos, indices
130-
)
113+
if indices is not None:
114+
_ = torch.ops.llama.update_cache_with_indices(
115+
quantized_k_val, self.k_cache, start_pos, indices
116+
)
117+
_ = torch.ops.llama.update_cache_with_indices(
118+
k_scales, self.k_cache_scales, start_pos, indices
119+
)
120+
_ = torch.ops.llama.update_cache_with_indices(
121+
k_zero_points, self.k_cache_zero_points, start_pos, indices
122+
)
123+
_ = torch.ops.llama.update_cache_with_indices(
124+
quantized_v_val, self.v_cache, start_pos, indices
125+
)
126+
_ = torch.ops.llama.update_cache_with_indices(
127+
v_scales, self.v_cache_scales, start_pos, indices
128+
)
129+
_ = torch.ops.llama.update_cache_with_indices(
130+
v_zero_points, self.v_cache_zero_points, start_pos, indices
131+
)
132+
else:
133+
_ = torch.ops.llama.update_cache(
134+
quantized_k_val, self.k_cache, start_pos
135+
)
136+
_ = torch.ops.llama.update_cache(
137+
k_scales, self.k_cache_scales, start_pos
138+
)
139+
_ = torch.ops.llama.update_cache(
140+
k_zero_points, self.k_cache_zero_points, start_pos
141+
)
142+
_ = torch.ops.llama.update_cache(
143+
quantized_v_val, self.v_cache, start_pos
144+
)
145+
_ = torch.ops.llama.update_cache(
146+
v_scales, self.v_cache_scales, start_pos
147+
)
148+
_ = torch.ops.llama.update_cache(
149+
v_zero_points, self.v_cache_zero_points, start_pos
150+
)
131151
else:
132152
assert indices is None, "Indices not supported for this path"
133153
# Following is also broken because in prefill input_pos = [0]
@@ -165,8 +185,16 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None)
165185
# instead of dequantized value.
166186
start_pos = input_pos[0].item()
167187
if self.use_custom_update_cache_op:
168-
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos, indices)
169-
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos, indices)
188+
if indices is not None:
189+
_ = torch.ops.llama.update_cache_with_indices(
190+
k_val, k_out, start_pos, indices
191+
)
192+
_ = torch.ops.llama.update_cache_with_indices(
193+
v_val, v_out, start_pos, indices
194+
)
195+
else:
196+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
197+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
170198
else:
171199
k_out[:, input_pos] = k_val
172200
v_out[:, input_pos] = v_val
@@ -310,8 +338,16 @@ def update(
310338
v_val = v_val.transpose(1, 2)
311339
start_pos = input_pos[0].item()
312340

313-
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, indices)
314-
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, indices)
341+
if indices is not None:
342+
_ = torch.ops.llama.update_cache_with_indices(
343+
k_val, self.k_cache, start_pos, indices
344+
)
345+
_ = torch.ops.llama.update_cache_with_indices(
346+
v_val, self.v_cache, start_pos, indices
347+
)
348+
else:
349+
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
350+
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
315351

316352
return (
317353
self.k_cache.transpose(1, 2),

extension/llm/custom_ops/custom_ops.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,25 @@ def update_cache_meta(
232232
value,
233233
cache,
234234
start_pos,
235-
indices=None,
235+
):
236+
_validate_update_cache_params(
237+
value,
238+
cache,
239+
start_pos,
240+
)
241+
242+
# Update cache doesnt really return anything but I dont know a better
243+
# workaround. Should we just return cache instead? But I am afraid that
244+
# will result in extra memory allocation
245+
return torch.empty((1,), dtype=value.dtype, device="meta")
246+
247+
248+
@impl(custom_ops_lib, "update_cache_with_indices", "Meta")
249+
def update_cache_with_indices_meta(
250+
value,
251+
cache,
252+
start_pos,
253+
indices,
236254
):
237255
_validate_update_cache_params(
238256
value,

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,26 @@ Tensor& update_cache_out_no_context(
122122
const Tensor& value,
123123
Tensor& cache,
124124
const int64_t start_pos,
125-
const std::optional<Tensor> indices,
126125
Tensor& output);
127126

128127
at::Tensor update_cache_aten(
128+
const at::Tensor& value,
129+
at::Tensor& cache,
130+
const int64_t start_pos);
131+
132+
// New functions for update_cache_with_indices
133+
Tensor& update_cache_with_indices_out_no_context(
134+
const Tensor& value,
135+
Tensor& cache,
136+
const int64_t start_pos,
137+
const Tensor& indices,
138+
Tensor& output);
139+
140+
at::Tensor update_cache_with_indices_aten(
129141
const at::Tensor& value,
130142
at::Tensor& cache,
131143
const int64_t start_pos,
132-
const std::optional<at::Tensor>& indices);
144+
const at::Tensor& indices);
133145

134146
Tensor& sdpa_with_kv_cache_out_no_context(
135147
const Tensor& q_projected,
@@ -326,20 +338,41 @@ Tensor& update_cache_out_no_context(
326338
const Tensor& value,
327339
Tensor& cache,
328340
const int64_t start_pos,
329-
const std::optional<Tensor> indices,
330341
Tensor& output) {
331342
executorch::aten::RuntimeContext context{};
332343
return torch::executor::native::update_cache_out(
333-
context, value, cache, start_pos, indices, output);
344+
context, value, cache, start_pos, output);
334345
}
335346

336347
at::Tensor update_cache_aten(
348+
const at::Tensor& value,
349+
at::Tensor& cache,
350+
const int64_t start_pos) {
351+
auto output = at::empty({1});
352+
WRAP_TO_ATEN(update_cache_out_no_context, 3)
353+
(value, cache, start_pos, output);
354+
return output;
355+
}
356+
357+
// Implementations for update_cache_with_indices
358+
Tensor& update_cache_with_indices_out_no_context(
359+
const Tensor& value,
360+
Tensor& cache,
361+
const int64_t start_pos,
362+
const Tensor& indices,
363+
Tensor& output) {
364+
executorch::aten::RuntimeContext context{};
365+
return torch::executor::native::update_cache_with_indices_out(
366+
context, value, cache, start_pos, indices, output);
367+
}
368+
369+
at::Tensor update_cache_with_indices_aten(
337370
const at::Tensor& value,
338371
at::Tensor& cache,
339372
const int64_t start_pos,
340-
const std::optional<at::Tensor>& indices) {
373+
const at::Tensor& indices) {
341374
auto output = at::empty({1});
342-
WRAP_TO_ATEN(update_cache_out_no_context, 4)
375+
WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4)
343376
(value, cache, start_pos, indices, output);
344377
return output;
345378
}
@@ -367,10 +400,16 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
367400
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
368401
m.def(
369402
"update_cache(Tensor value, Tensor(a!) cache, "
370-
"SymInt start_pos, Tensor? indices=None) -> Tensor");
403+
"SymInt start_pos) -> Tensor");
371404
m.def(
372405
"update_cache.out(Tensor value, Tensor(a!) cache, "
373-
"SymInt start_pos, Tensor? indices=None, *, Tensor(b!) out) -> Tensor(b!)");
406+
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
407+
m.def(
408+
"update_cache_with_indices(Tensor value, Tensor(a!) cache, "
409+
"SymInt start_pos, Tensor indices) -> Tensor");
410+
m.def(
411+
"update_cache_with_indices.out(Tensor value, Tensor(a!) cache, "
412+
"SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)");
374413
m.def(
375414
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
376415
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
@@ -400,7 +439,15 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
400439
m.impl("update_cache", torch::executor::native::update_cache_aten);
401440
m.impl(
402441
"update_cache.out",
403-
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 4));
442+
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
443+
m.impl(
444+
"update_cache_with_indices",
445+
torch::executor::native::update_cache_with_indices_aten);
446+
m.impl(
447+
"update_cache_with_indices.out",
448+
WRAP_TO_ATEN(
449+
torch::executor::native::update_cache_with_indices_out_no_context,
450+
4));
404451
m.impl(
405452
"custom_quantized_sdpa",
406453
torch::executor::native::custom_quantized_sdpa_aten);

extension/llm/custom_ops/op_update_cache.cpp

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace executor {
2020
namespace native {
2121

2222
namespace {
23+
// Helper function to validate cache parameters
2324
bool validate_cache_params(
2425
const Tensor& quantized_value,
2526
const Tensor& quantized_cache,
@@ -32,26 +33,8 @@ bool validate_cache_params(
3233
ET_CHECK_OR_RETURN_FALSE(
3334
quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
3435

35-
ET_CHECK_OR_RETURN_FALSE(
36-
indices.has_value() || start_pos < quantized_cache.size(1),
37-
"start_pos: %" PRId64 " must be less than cache size at dim 1: %zd",
38-
start_pos,
39-
quantized_cache.size(1));
40-
41-
ET_CHECK_OR_RETURN_FALSE(
42-
indices.has_value() ||
43-
(start_pos + seq_length) <= quantized_cache.size(1),
44-
"start_post + seq_length must be less than max seq length supported by cache."
45-
"start pos: %" PRId64 ", seq_length: %" PRId64
46-
"."
47-
"cache size: %zd",
48-
start_pos,
49-
seq_length,
50-
quantized_cache.size(1));
51-
52-
// Validate indices tensor if provided
5336
if (indices.has_value()) {
54-
const Tensor& indices_tensor = indices.value();
37+
const auto& indices_tensor = indices.value();
5538
ET_CHECK_OR_RETURN_FALSE(
5639
indices_tensor.dim() == 2,
5740
"indices must be a 2D tensor [batch_size, seq_len]");
@@ -72,6 +55,22 @@ bool validate_cache_params(
7255
is_contiguous_dim_order(
7356
indices_tensor.dim_order().data(), indices_tensor.dim()),
7457
"indices must be in contiguous dim order");
58+
} else {
59+
ET_CHECK_OR_RETURN_FALSE(
60+
start_pos < quantized_cache.size(1),
61+
"start_pos: %" PRId64 " must be less than cache size at dim 1: %zd",
62+
start_pos,
63+
quantized_cache.size(1));
64+
65+
ET_CHECK_OR_RETURN_FALSE(
66+
(start_pos + seq_length) <= quantized_cache.size(1),
67+
"start_post + seq_length must be less than max seq length supported by cache."
68+
"start pos: %" PRId64 ", seq_length: %" PRId64
69+
"."
70+
"cache size: %zd",
71+
start_pos,
72+
seq_length,
73+
quantized_cache.size(1));
7574
}
7675

7776
// Make sure they are in contiguous dim order
@@ -87,22 +86,16 @@ bool validate_cache_params(
8786

8887
return true;
8988
}
90-
} // anonymous namespace
9189

92-
Tensor& update_cache_out(
90+
// Helper function for the actual update operation
91+
Tensor& update_cache_impl(
9392
RuntimeContext& ctx,
9493
const Tensor& value,
9594
Tensor& cache,
9695
const int64_t start_pos,
97-
const optional<Tensor>& indices,
98-
Tensor& output) {
96+
Tensor& output,
97+
const optional<Tensor>& indices = nullopt) {
9998
(void)ctx;
100-
int64_t seq_len = value.size(1);
101-
ET_KERNEL_CHECK(
102-
ctx,
103-
validate_cache_params(value, cache, start_pos, seq_len, indices),
104-
InvalidArgument,
105-
output);
10699

107100
ET_CHECK_MSG(
108101
value.size(0) == cache.size(0),
@@ -151,7 +144,8 @@ Tensor& update_cache_out(
151144
if (indices.has_value()) {
152145
// Use the provided indices tensor for each batch and sequence position
153146
const Tensor& indices_tensor = indices.value();
154-
const int64_t* indices_data = indices_tensor.const_data_ptr<int64_t>();
147+
const int64_t* indices_data =
148+
static_cast<const int64_t*>(indices_tensor.const_data_ptr());
155149
auto indices_strides = indices_tensor.strides();
156150
executorch::aten::StridesType indices_batch_stride = indices_strides[0];
157151
executorch::aten::StridesType indices_seq_stride = indices_strides[1];
@@ -211,6 +205,43 @@ Tensor& update_cache_out(
211205
// Noone uses output. Just a placeholder.
212206
return output;
213207
}
208+
} // anonymous namespace
209+
210+
// Original update_cache_out function without indices parameter
211+
Tensor& update_cache_out(
212+
RuntimeContext& ctx,
213+
const Tensor& value,
214+
Tensor& cache,
215+
const int64_t start_pos,
216+
Tensor& output) {
217+
int64_t seq_len = value.size(1);
218+
ET_KERNEL_CHECK(
219+
ctx,
220+
validate_cache_params(value, cache, start_pos, seq_len),
221+
InvalidArgument,
222+
output);
223+
224+
return update_cache_impl(ctx, value, cache, start_pos, output);
225+
}
226+
227+
// New function that explicitly takes indices
228+
Tensor& update_cache_with_indices_out(
229+
RuntimeContext& ctx,
230+
const Tensor& value,
231+
Tensor& cache,
232+
const int64_t start_pos,
233+
const Tensor& indices,
234+
Tensor& output) {
235+
int64_t seq_len = value.size(1);
236+
ET_KERNEL_CHECK(
237+
ctx,
238+
validate_cache_params(value, cache, start_pos, seq_len, indices),
239+
InvalidArgument,
240+
output);
241+
242+
return update_cache_impl(ctx, value, cache, start_pos, output, indices);
243+
}
244+
214245
} // namespace native
215246
} // namespace executor
216247
} // namespace torch
@@ -225,3 +256,9 @@ EXECUTORCH_LIBRARY(
225256
llama,
226257
"update_cache.out",
227258
torch::executor::native::update_cache_out);
259+
260+
// Register the new update_cache_with_indices.out op
261+
EXECUTORCH_LIBRARY(
262+
llama,
263+
"update_cache_with_indices.out",
264+
torch::executor::native::update_cache_with_indices_out);

0 commit comments

Comments
 (0)