Skip to content

Commit 0ebbf5c

Browse files
committed
Update on "[Executorch][llm] Make custom update cache op operate on indices"
This allows us to use ring buffer kv cache Differential Revision: [D73891424](https://our.internmc.facebook.com/intern/diff/D73891424/) [ghstack-poisoned]
2 parents a56c31f + 50a09ea commit 0ebbf5c

File tree

6 files changed

+225
-72
lines changed

6 files changed

+225
-72
lines changed

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -104,24 +104,44 @@ def _quantize_and_update(self, input_pos, k_val, v_val, indices=None):
104104

105105
if self.use_custom_update_cache_op:
106106
start_pos = input_pos[0].item()
107-
_ = torch.ops.llama.update_cache(
108-
quantized_k_val, self.k_cache, start_pos, indices
109-
)
110-
_ = torch.ops.llama.update_cache(
111-
k_scales, self.k_cache_scales, start_pos, indices
112-
)
113-
_ = torch.ops.llama.update_cache(
114-
k_zero_points, self.k_cache_zero_points, start_pos, indices
115-
)
116-
_ = torch.ops.llama.update_cache(
117-
quantized_v_val, self.v_cache, start_pos, indices
118-
)
119-
_ = torch.ops.llama.update_cache(
120-
v_scales, self.v_cache_scales, start_pos, indices
121-
)
122-
_ = torch.ops.llama.update_cache(
123-
v_zero_points, self.v_cache_zero_points, start_pos, indices
124-
)
107+
if indices is not None:
108+
_ = torch.ops.llama.update_cache_with_indices(
109+
quantized_k_val, self.k_cache, start_pos, indices
110+
)
111+
_ = torch.ops.llama.update_cache_with_indices(
112+
k_scales, self.k_cache_scales, start_pos, indices
113+
)
114+
_ = torch.ops.llama.update_cache_with_indices(
115+
k_zero_points, self.k_cache_zero_points, start_pos, indices
116+
)
117+
_ = torch.ops.llama.update_cache_with_indices(
118+
quantized_v_val, self.v_cache, start_pos, indices
119+
)
120+
_ = torch.ops.llama.update_cache_with_indices(
121+
v_scales, self.v_cache_scales, start_pos, indices
122+
)
123+
_ = torch.ops.llama.update_cache_with_indices(
124+
v_zero_points, self.v_cache_zero_points, start_pos, indices
125+
)
126+
else:
127+
_ = torch.ops.llama.update_cache(
128+
quantized_k_val, self.k_cache, start_pos
129+
)
130+
_ = torch.ops.llama.update_cache(
131+
k_scales, self.k_cache_scales, start_pos
132+
)
133+
_ = torch.ops.llama.update_cache(
134+
k_zero_points, self.k_cache_zero_points, start_pos
135+
)
136+
_ = torch.ops.llama.update_cache(
137+
quantized_v_val, self.v_cache, start_pos
138+
)
139+
_ = torch.ops.llama.update_cache(
140+
v_scales, self.v_cache_scales, start_pos
141+
)
142+
_ = torch.ops.llama.update_cache(
143+
v_zero_points, self.v_cache_zero_points, start_pos
144+
)
125145
else:
126146
assert indices is None, "Indices not supported for this path"
127147
# Following is also broken because in prefill input_pos = [0]
@@ -159,8 +179,16 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None)
159179
# instead of dequantized value.
160180
start_pos = input_pos[0].item()
161181
if self.use_custom_update_cache_op:
162-
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos, indices)
163-
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos, indices)
182+
if indices is not None:
183+
_ = torch.ops.llama.update_cache_with_indices(
184+
k_val, k_out, start_pos, indices
185+
)
186+
_ = torch.ops.llama.update_cache_with_indices(
187+
v_val, v_out, start_pos, indices
188+
)
189+
else:
190+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
191+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
164192
else:
165193
k_out[:, input_pos] = k_val
166194
v_out[:, input_pos] = v_val
@@ -303,8 +331,16 @@ def update(
303331
v_val = v_val.transpose(1, 2)
304332
start_pos = input_pos[0].item()
305333

306-
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, indices)
307-
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, indices)
334+
if indices is not None:
335+
_ = torch.ops.llama.update_cache_with_indices(
336+
k_val, self.k_cache, start_pos, indices
337+
)
338+
_ = torch.ops.llama.update_cache_with_indices(
339+
v_val, self.v_cache, start_pos, indices
340+
)
341+
else:
342+
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
343+
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
308344

309345
return (
310346
self.k_cache.transpose(1, 2),

extension/llm/custom_ops/custom_ops.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,25 @@ def update_cache_meta(
232232
value,
233233
cache,
234234
start_pos,
235-
indices=None,
235+
):
236+
_validate_update_cache_params(
237+
value,
238+
cache,
239+
start_pos,
240+
)
241+
242+
# Update cache doesnt really return anything but I dont know a better
243+
# workaround. Should we just return cache instead? But I am afraid that
244+
# will result in extra memory allocation
245+
return torch.empty((1,), dtype=value.dtype, device="meta")
246+
247+
248+
@impl(custom_ops_lib, "update_cache_with_indices", "Meta")
249+
def update_cache_with_indices_meta(
250+
value,
251+
cache,
252+
start_pos,
253+
indices,
236254
):
237255
_validate_update_cache_params(
238256
value,

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,26 @@ Tensor& update_cache_out_no_context(
122122
const Tensor& value,
123123
Tensor& cache,
124124
const int64_t start_pos,
125-
const std::optional<Tensor> indices,
126125
Tensor& output);
127126

128127
at::Tensor update_cache_aten(
128+
const at::Tensor& value,
129+
at::Tensor& cache,
130+
const int64_t start_pos);
131+
132+
// New functions for update_cache_with_indices
133+
Tensor& update_cache_with_indices_out_no_context(
134+
const Tensor& value,
135+
Tensor& cache,
136+
const int64_t start_pos,
137+
const Tensor& indices,
138+
Tensor& output);
139+
140+
at::Tensor update_cache_with_indices_aten(
129141
const at::Tensor& value,
130142
at::Tensor& cache,
131143
const int64_t start_pos,
132-
const std::optional<at::Tensor>& indices);
144+
const at::Tensor& indices);
133145

134146
Tensor& sdpa_with_kv_cache_out_no_context(
135147
const Tensor& q_projected,
@@ -326,20 +338,41 @@ Tensor& update_cache_out_no_context(
326338
const Tensor& value,
327339
Tensor& cache,
328340
const int64_t start_pos,
329-
const std::optional<Tensor> indices,
330341
Tensor& output) {
331342
executorch::aten::RuntimeContext context{};
332343
return torch::executor::native::update_cache_out(
333-
context, value, cache, start_pos, indices, output);
344+
context, value, cache, start_pos, output);
334345
}
335346

336347
at::Tensor update_cache_aten(
348+
const at::Tensor& value,
349+
at::Tensor& cache,
350+
const int64_t start_pos) {
351+
auto output = at::empty({1});
352+
WRAP_TO_ATEN(update_cache_out_no_context, 3)
353+
(value, cache, start_pos, output);
354+
return output;
355+
}
356+
357+
// Implementations for update_cache_with_indices
358+
Tensor& update_cache_with_indices_out_no_context(
359+
const Tensor& value,
360+
Tensor& cache,
361+
const int64_t start_pos,
362+
const Tensor& indices,
363+
Tensor& output) {
364+
executorch::aten::RuntimeContext context{};
365+
return torch::executor::native::update_cache_with_indices_out(
366+
context, value, cache, start_pos, indices, output);
367+
}
368+
369+
at::Tensor update_cache_with_indices_aten(
337370
const at::Tensor& value,
338371
at::Tensor& cache,
339372
const int64_t start_pos,
340-
const std::optional<at::Tensor>& indices) {
373+
const at::Tensor& indices) {
341374
auto output = at::empty({1});
342-
WRAP_TO_ATEN(update_cache_out_no_context, 4)
375+
WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4)
343376
(value, cache, start_pos, indices, output);
344377
return output;
345378
}
@@ -367,10 +400,16 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
367400
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
368401
m.def(
369402
"update_cache(Tensor value, Tensor(a!) cache, "
370-
"SymInt start_pos, Tensor? indices=None) -> Tensor");
403+
"SymInt start_pos) -> Tensor");
371404
m.def(
372405
"update_cache.out(Tensor value, Tensor(a!) cache, "
373-
"SymInt start_pos, Tensor? indices=None, *, Tensor(b!) out) -> Tensor(b!)");
406+
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
407+
m.def(
408+
"update_cache_with_indices(Tensor value, Tensor(a!) cache, "
409+
"SymInt start_pos, Tensor indices) -> Tensor");
410+
m.def(
411+
"update_cache_with_indices.out(Tensor value, Tensor(a!) cache, "
412+
"SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)");
374413
m.def(
375414
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
376415
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
@@ -400,7 +439,15 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
400439
m.impl("update_cache", torch::executor::native::update_cache_aten);
401440
m.impl(
402441
"update_cache.out",
403-
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 4));
442+
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
443+
m.impl(
444+
"update_cache_with_indices",
445+
torch::executor::native::update_cache_with_indices_aten);
446+
m.impl(
447+
"update_cache_with_indices.out",
448+
WRAP_TO_ATEN(
449+
torch::executor::native::update_cache_with_indices_out_no_context,
450+
4));
404451
m.impl(
405452
"custom_quantized_sdpa",
406453
torch::executor::native::custom_quantized_sdpa_aten);

extension/llm/custom_ops/op_update_cache.cpp

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace executor {
2020
namespace native {
2121

2222
namespace {
23+
// Helper function to validate cache parameters
2324
bool validate_cache_params(
2425
const Tensor& quantized_value,
2526
const Tensor& quantized_cache,
@@ -32,26 +33,8 @@ bool validate_cache_params(
3233
ET_CHECK_OR_RETURN_FALSE(
3334
quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
3435

35-
ET_CHECK_OR_RETURN_FALSE(
36-
indices.has_value() || start_pos < quantized_cache.size(1),
37-
"start_pos: %" PRId64 " must be less than cache size at dim 1: %zd",
38-
start_pos,
39-
quantized_cache.size(1));
40-
41-
ET_CHECK_OR_RETURN_FALSE(
42-
indices.has_value() ||
43-
(start_pos + seq_length) <= quantized_cache.size(1),
44-
"start_post + seq_length must be less than max seq length supported by cache."
45-
"start pos: %" PRId64 ", seq_length: %" PRId64
46-
"."
47-
"cache size: %zd",
48-
start_pos,
49-
seq_length,
50-
quantized_cache.size(1));
51-
52-
// Validate indices tensor if provided
5336
if (indices.has_value()) {
54-
const Tensor& indices_tensor = indices.value();
37+
const auto& indices_tensor = indices.value();
5538
ET_CHECK_OR_RETURN_FALSE(
5639
indices_tensor.dim() == 2,
5740
"indices must be a 2D tensor [batch_size, seq_len]");
@@ -72,6 +55,22 @@ bool validate_cache_params(
7255
is_contiguous_dim_order(
7356
indices_tensor.dim_order().data(), indices_tensor.dim()),
7457
"indices must be in contiguous dim order");
58+
} else {
59+
ET_CHECK_OR_RETURN_FALSE(
60+
start_pos < quantized_cache.size(1),
61+
"start_pos: %" PRId64 " must be less than cache size at dim 1: %zd",
62+
start_pos,
63+
quantized_cache.size(1));
64+
65+
ET_CHECK_OR_RETURN_FALSE(
66+
(start_pos + seq_length) <= quantized_cache.size(1),
67+
"start_post + seq_length must be less than max seq length supported by cache."
68+
"start pos: %" PRId64 ", seq_length: %" PRId64
69+
"."
70+
"cache size: %zd",
71+
start_pos,
72+
seq_length,
73+
quantized_cache.size(1));
7574
}
7675

7776
// Make sure they are in contiguous dim order
@@ -87,22 +86,16 @@ bool validate_cache_params(
8786

8887
return true;
8988
}
90-
} // anonymous namespace
9189

92-
Tensor& update_cache_out(
90+
// Helper function for the actual update operation
91+
Tensor& update_cache_impl(
9392
RuntimeContext& ctx,
9493
const Tensor& value,
9594
Tensor& cache,
9695
const int64_t start_pos,
97-
const optional<Tensor>& indices,
98-
Tensor& output) {
96+
Tensor& output,
97+
const optional<Tensor>& indices = nullopt) {
9998
(void)ctx;
100-
int64_t seq_len = value.size(1);
101-
ET_KERNEL_CHECK(
102-
ctx,
103-
validate_cache_params(value, cache, start_pos, seq_len, indices),
104-
InvalidArgument,
105-
output);
10699

107100
ET_CHECK_MSG(
108101
value.size(0) == cache.size(0),
@@ -151,7 +144,8 @@ Tensor& update_cache_out(
151144
if (indices.has_value()) {
152145
// Use the provided indices tensor for each batch and sequence position
153146
const Tensor& indices_tensor = indices.value();
154-
const int64_t* indices_data = indices_tensor.const_data_ptr<int64_t>();
147+
const int64_t* indices_data =
148+
static_cast<const int64_t*>(indices_tensor.const_data_ptr());
155149
auto indices_strides = indices_tensor.strides();
156150
executorch::aten::StridesType indices_batch_stride = indices_strides[0];
157151
executorch::aten::StridesType indices_seq_stride = indices_strides[1];
@@ -211,6 +205,43 @@ Tensor& update_cache_out(
211205
// Noone uses output. Just a placeholder.
212206
return output;
213207
}
208+
} // anonymous namespace
209+
210+
// Original update_cache_out function without indices parameter
211+
Tensor& update_cache_out(
212+
RuntimeContext& ctx,
213+
const Tensor& value,
214+
Tensor& cache,
215+
const int64_t start_pos,
216+
Tensor& output) {
217+
int64_t seq_len = value.size(1);
218+
ET_KERNEL_CHECK(
219+
ctx,
220+
validate_cache_params(value, cache, start_pos, seq_len),
221+
InvalidArgument,
222+
output);
223+
224+
return update_cache_impl(ctx, value, cache, start_pos, output);
225+
}
226+
227+
// New function that explicitly takes indices
228+
Tensor& update_cache_with_indices_out(
229+
RuntimeContext& ctx,
230+
const Tensor& value,
231+
Tensor& cache,
232+
const int64_t start_pos,
233+
const Tensor& indices,
234+
Tensor& output) {
235+
int64_t seq_len = value.size(1);
236+
ET_KERNEL_CHECK(
237+
ctx,
238+
validate_cache_params(value, cache, start_pos, seq_len, indices),
239+
InvalidArgument,
240+
output);
241+
242+
return update_cache_impl(ctx, value, cache, start_pos, output, indices);
243+
}
244+
214245
} // namespace native
215246
} // namespace executor
216247
} // namespace torch
@@ -225,3 +256,9 @@ EXECUTORCH_LIBRARY(
225256
llama,
226257
"update_cache.out",
227258
torch::executor::native::update_cache_out);
259+
260+
// Register the new update_cache_with_indices.out op
261+
EXECUTORCH_LIBRARY(
262+
llama,
263+
"update_cache_with_indices.out",
264+
torch::executor::native::update_cache_with_indices_out);

0 commit comments

Comments
 (0)