Skip to content

Commit cb6b2bf

Browse files
committed
Update base for Update on "[Executorch][llm] Enable leveraging ring kv cache via module swap"
This allows us to make some of the attention modules to use sliding window kv cache. Will help enable models like gemma3. Differential Revision: [D73891426](https://our.internmc.facebook.com/intern/diff/D73891426/) [ghstack-poisoned]
1 parent eb677e5 commit cb6b2bf

File tree

6 files changed

+225
-72
lines changed

6 files changed

+225
-72
lines changed

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -104,24 +104,44 @@ def _quantize_and_update(self, input_pos, k_val, v_val, indices=None):
104104

105105
if self.use_custom_update_cache_op:
106106
start_pos = input_pos[0].item()
107-
_ = torch.ops.llama.update_cache(
108-
quantized_k_val, self.k_cache, start_pos, indices
109-
)
110-
_ = torch.ops.llama.update_cache(
111-
k_scales, self.k_cache_scales, start_pos, indices
112-
)
113-
_ = torch.ops.llama.update_cache(
114-
k_zero_points, self.k_cache_zero_points, start_pos, indices
115-
)
116-
_ = torch.ops.llama.update_cache(
117-
quantized_v_val, self.v_cache, start_pos, indices
118-
)
119-
_ = torch.ops.llama.update_cache(
120-
v_scales, self.v_cache_scales, start_pos, indices
121-
)
122-
_ = torch.ops.llama.update_cache(
123-
v_zero_points, self.v_cache_zero_points, start_pos, indices
124-
)
107+
if indices is not None:
108+
_ = torch.ops.llama.update_cache_with_indices(
109+
quantized_k_val, self.k_cache, start_pos, indices
110+
)
111+
_ = torch.ops.llama.update_cache_with_indices(
112+
k_scales, self.k_cache_scales, start_pos, indices
113+
)
114+
_ = torch.ops.llama.update_cache_with_indices(
115+
k_zero_points, self.k_cache_zero_points, start_pos, indices
116+
)
117+
_ = torch.ops.llama.update_cache_with_indices(
118+
quantized_v_val, self.v_cache, start_pos, indices
119+
)
120+
_ = torch.ops.llama.update_cache_with_indices(
121+
v_scales, self.v_cache_scales, start_pos, indices
122+
)
123+
_ = torch.ops.llama.update_cache_with_indices(
124+
v_zero_points, self.v_cache_zero_points, start_pos, indices
125+
)
126+
else:
127+
_ = torch.ops.llama.update_cache(
128+
quantized_k_val, self.k_cache, start_pos
129+
)
130+
_ = torch.ops.llama.update_cache(
131+
k_scales, self.k_cache_scales, start_pos
132+
)
133+
_ = torch.ops.llama.update_cache(
134+
k_zero_points, self.k_cache_zero_points, start_pos
135+
)
136+
_ = torch.ops.llama.update_cache(
137+
quantized_v_val, self.v_cache, start_pos
138+
)
139+
_ = torch.ops.llama.update_cache(
140+
v_scales, self.v_cache_scales, start_pos
141+
)
142+
_ = torch.ops.llama.update_cache(
143+
v_zero_points, self.v_cache_zero_points, start_pos
144+
)
125145
else:
126146
assert indices is None, "Indices not supported for this path"
127147
# Following is also broken because in prefill input_pos = [0]
@@ -159,8 +179,16 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None)
159179
# instead of dequantized value.
160180
start_pos = input_pos[0].item()
161181
if self.use_custom_update_cache_op:
162-
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos, indices)
163-
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos, indices)
182+
if indices is not None:
183+
_ = torch.ops.llama.update_cache_with_indices(
184+
k_val, k_out, start_pos, indices
185+
)
186+
_ = torch.ops.llama.update_cache_with_indices(
187+
v_val, v_out, start_pos, indices
188+
)
189+
else:
190+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
191+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
164192
else:
165193
k_out[:, input_pos] = k_val
166194
v_out[:, input_pos] = v_val
@@ -303,8 +331,16 @@ def update(
303331
v_val = v_val.transpose(1, 2)
304332
start_pos = input_pos[0].item()
305333

306-
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, indices)
307-
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, indices)
334+
if indices is not None:
335+
_ = torch.ops.llama.update_cache_with_indices(
336+
k_val, self.k_cache, start_pos, indices
337+
)
338+
_ = torch.ops.llama.update_cache_with_indices(
339+
v_val, self.v_cache, start_pos, indices
340+
)
341+
else:
342+
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
343+
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
308344

309345
return (
310346
self.k_cache.transpose(1, 2),

extension/llm/custom_ops/custom_ops.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,25 @@ def update_cache_meta(
232232
value,
233233
cache,
234234
start_pos,
235-
indices=None,
235+
):
236+
_validate_update_cache_params(
237+
value,
238+
cache,
239+
start_pos,
240+
)
241+
242+
# Update cache doesnt really return anything but I dont know a better
243+
# workaround. Should we just return cache instead? But I am afraid that
244+
# will result in extra memory allocation
245+
return torch.empty((1,), dtype=value.dtype, device="meta")
246+
247+
248+
@impl(custom_ops_lib, "update_cache_with_indices", "Meta")
249+
def update_cache_with_indices_meta(
250+
value,
251+
cache,
252+
start_pos,
253+
indices,
236254
):
237255
_validate_update_cache_params(
238256
value,

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,26 @@ Tensor& update_cache_out_no_context(
122122
const Tensor& value,
123123
Tensor& cache,
124124
const int64_t start_pos,
125-
const std::optional<Tensor> indices,
126125
Tensor& output);
127126

128127
at::Tensor update_cache_aten(
128+
const at::Tensor& value,
129+
at::Tensor& cache,
130+
const int64_t start_pos);
131+
132+
// New functions for update_cache_with_indices
133+
Tensor& update_cache_with_indices_out_no_context(
134+
const Tensor& value,
135+
Tensor& cache,
136+
const int64_t start_pos,
137+
const Tensor& indices,
138+
Tensor& output);
139+
140+
at::Tensor update_cache_with_indices_aten(
129141
const at::Tensor& value,
130142
at::Tensor& cache,
131143
const int64_t start_pos,
132-
const std::optional<at::Tensor>& indices);
144+
const at::Tensor& indices);
133145

134146
Tensor& sdpa_with_kv_cache_out_no_context(
135147
const Tensor& q_projected,
@@ -326,20 +338,41 @@ Tensor& update_cache_out_no_context(
326338
const Tensor& value,
327339
Tensor& cache,
328340
const int64_t start_pos,
329-
const std::optional<Tensor> indices,
330341
Tensor& output) {
331342
executorch::aten::RuntimeContext context{};
332343
return torch::executor::native::update_cache_out(
333-
context, value, cache, start_pos, indices, output);
344+
context, value, cache, start_pos, output);
334345
}
335346

336347
at::Tensor update_cache_aten(
348+
const at::Tensor& value,
349+
at::Tensor& cache,
350+
const int64_t start_pos) {
351+
auto output = at::empty({1});
352+
WRAP_TO_ATEN(update_cache_out_no_context, 3)
353+
(value, cache, start_pos, output);
354+
return output;
355+
}
356+
357+
// Implementations for update_cache_with_indices
358+
Tensor& update_cache_with_indices_out_no_context(
359+
const Tensor& value,
360+
Tensor& cache,
361+
const int64_t start_pos,
362+
const Tensor& indices,
363+
Tensor& output) {
364+
executorch::aten::RuntimeContext context{};
365+
return torch::executor::native::update_cache_with_indices_out(
366+
context, value, cache, start_pos, indices, output);
367+
}
368+
369+
at::Tensor update_cache_with_indices_aten(
337370
const at::Tensor& value,
338371
at::Tensor& cache,
339372
const int64_t start_pos,
340-
const std::optional<at::Tensor>& indices) {
373+
const at::Tensor& indices) {
341374
auto output = at::empty({1});
342-
WRAP_TO_ATEN(update_cache_out_no_context, 4)
375+
WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4)
343376
(value, cache, start_pos, indices, output);
344377
return output;
345378
}
@@ -367,10 +400,16 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
367400
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
368401
m.def(
369402
"update_cache(Tensor value, Tensor(a!) cache, "
370-
"SymInt start_pos, Tensor? indices=None) -> Tensor");
403+
"SymInt start_pos) -> Tensor");
371404
m.def(
372405
"update_cache.out(Tensor value, Tensor(a!) cache, "
373-
"SymInt start_pos, Tensor? indices=None, *, Tensor(b!) out) -> Tensor(b!)");
406+
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
407+
m.def(
408+
"update_cache_with_indices(Tensor value, Tensor(a!) cache, "
409+
"SymInt start_pos, Tensor indices) -> Tensor");
410+
m.def(
411+
"update_cache_with_indices.out(Tensor value, Tensor(a!) cache, "
412+
"SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)");
374413
m.def(
375414
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
376415
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
@@ -400,7 +439,15 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
400439
m.impl("update_cache", torch::executor::native::update_cache_aten);
401440
m.impl(
402441
"update_cache.out",
403-
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 4));
442+
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
443+
m.impl(
444+
"update_cache_with_indices",
445+
torch::executor::native::update_cache_with_indices_aten);
446+
m.impl(
447+
"update_cache_with_indices.out",
448+
WRAP_TO_ATEN(
449+
torch::executor::native::update_cache_with_indices_out_no_context,
450+
4));
404451
m.impl(
405452
"custom_quantized_sdpa",
406453
torch::executor::native::custom_quantized_sdpa_aten);

extension/llm/custom_ops/op_update_cache.cpp

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace executor {
2020
namespace native {
2121

2222
namespace {
23+
// Helper function to validate cache parameters
2324
bool validate_cache_params(
2425
const Tensor& quantized_value,
2526
const Tensor& quantized_cache,
@@ -32,26 +33,8 @@ bool validate_cache_params(
3233
ET_CHECK_OR_RETURN_FALSE(
3334
quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
3435

35-
ET_CHECK_OR_RETURN_FALSE(
36-
indices.has_value() || start_pos < quantized_cache.size(1),
37-
"start_pos: %" PRId64 " must be less than cache size at dim 1: %zd",
38-
start_pos,
39-
quantized_cache.size(1));
40-
41-
ET_CHECK_OR_RETURN_FALSE(
42-
indices.has_value() ||
43-
(start_pos + seq_length) <= quantized_cache.size(1),
44-
"start_post + seq_length must be less than max seq length supported by cache."
45-
"start pos: %" PRId64 ", seq_length: %" PRId64
46-
"."
47-
"cache size: %zd",
48-
start_pos,
49-
seq_length,
50-
quantized_cache.size(1));
51-
52-
// Validate indices tensor if provided
5336
if (indices.has_value()) {
54-
const Tensor& indices_tensor = indices.value();
37+
const auto& indices_tensor = indices.value();
5538
ET_CHECK_OR_RETURN_FALSE(
5639
indices_tensor.dim() == 2,
5740
"indices must be a 2D tensor [batch_size, seq_len]");
@@ -72,6 +55,22 @@ bool validate_cache_params(
7255
is_contiguous_dim_order(
7356
indices_tensor.dim_order().data(), indices_tensor.dim()),
7457
"indices must be in contiguous dim order");
58+
} else {
59+
ET_CHECK_OR_RETURN_FALSE(
60+
start_pos < quantized_cache.size(1),
61+
"start_pos: %" PRId64 " must be less than cache size at dim 1: %zd",
62+
start_pos,
63+
quantized_cache.size(1));
64+
65+
ET_CHECK_OR_RETURN_FALSE(
66+
(start_pos + seq_length) <= quantized_cache.size(1),
67+
"start_post + seq_length must be less than max seq length supported by cache."
68+
"start pos: %" PRId64 ", seq_length: %" PRId64
69+
"."
70+
"cache size: %zd",
71+
start_pos,
72+
seq_length,
73+
quantized_cache.size(1));
7574
}
7675

7776
// Make sure they are in contiguous dim order
@@ -87,22 +86,16 @@ bool validate_cache_params(
8786

8887
return true;
8988
}
90-
} // anonymous namespace
9189

92-
Tensor& update_cache_out(
90+
// Helper function for the actual update operation
91+
Tensor& update_cache_impl(
9392
RuntimeContext& ctx,
9493
const Tensor& value,
9594
Tensor& cache,
9695
const int64_t start_pos,
97-
const optional<Tensor>& indices,
98-
Tensor& output) {
96+
Tensor& output,
97+
const optional<Tensor>& indices = nullopt) {
9998
(void)ctx;
100-
int64_t seq_len = value.size(1);
101-
ET_KERNEL_CHECK(
102-
ctx,
103-
validate_cache_params(value, cache, start_pos, seq_len, indices),
104-
InvalidArgument,
105-
output);
10699

107100
ET_CHECK_MSG(
108101
value.size(0) == cache.size(0),
@@ -151,7 +144,8 @@ Tensor& update_cache_out(
151144
if (indices.has_value()) {
152145
// Use the provided indices tensor for each batch and sequence position
153146
const Tensor& indices_tensor = indices.value();
154-
const int64_t* indices_data = indices_tensor.const_data_ptr<int64_t>();
147+
const int64_t* indices_data =
148+
static_cast<const int64_t*>(indices_tensor.const_data_ptr());
155149
auto indices_strides = indices_tensor.strides();
156150
executorch::aten::StridesType indices_batch_stride = indices_strides[0];
157151
executorch::aten::StridesType indices_seq_stride = indices_strides[1];
@@ -211,6 +205,43 @@ Tensor& update_cache_out(
211205
// Noone uses output. Just a placeholder.
212206
return output;
213207
}
208+
} // anonymous namespace
209+
210+
// Original update_cache_out function without indices parameter
211+
Tensor& update_cache_out(
212+
RuntimeContext& ctx,
213+
const Tensor& value,
214+
Tensor& cache,
215+
const int64_t start_pos,
216+
Tensor& output) {
217+
int64_t seq_len = value.size(1);
218+
ET_KERNEL_CHECK(
219+
ctx,
220+
validate_cache_params(value, cache, start_pos, seq_len),
221+
InvalidArgument,
222+
output);
223+
224+
return update_cache_impl(ctx, value, cache, start_pos, output);
225+
}
226+
227+
// New function that explicitly takes indices
228+
Tensor& update_cache_with_indices_out(
229+
RuntimeContext& ctx,
230+
const Tensor& value,
231+
Tensor& cache,
232+
const int64_t start_pos,
233+
const Tensor& indices,
234+
Tensor& output) {
235+
int64_t seq_len = value.size(1);
236+
ET_KERNEL_CHECK(
237+
ctx,
238+
validate_cache_params(value, cache, start_pos, seq_len, indices),
239+
InvalidArgument,
240+
output);
241+
242+
return update_cache_impl(ctx, value, cache, start_pos, output, indices);
243+
}
244+
214245
} // namespace native
215246
} // namespace executor
216247
} // namespace torch
@@ -225,3 +256,9 @@ EXECUTORCH_LIBRARY(
225256
llama,
226257
"update_cache.out",
227258
torch::executor::native::update_cache_out);
259+
260+
// Register the new update_cache_with_indices.out op
261+
EXECUTORCH_LIBRARY(
262+
llama,
263+
"update_cache_with_indices.out",
264+
torch::executor::native::update_cache_with_indices_out);

0 commit comments

Comments
 (0)