Skip to content

Commit 7fa41a2

Browse files
committed
[Executorch][llm] Make custom update cache op operate on indices
Pull Request resolved: #10610 This allows us to use ring buffer kv cache ghstack-source-id: 282013419 @exported-using-ghexport Differential Revision: [D73891424](https://our.internmc.facebook.com/intern/diff/D73891424/)
1 parent 19ac5b0 commit 7fa41a2

File tree

6 files changed

+432
-61
lines changed

6 files changed

+432
-61
lines changed

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import logging
88
from enum import Enum
9-
from typing import Tuple
9+
from typing import Optional, Tuple
1010

1111
import torch
1212
import torch.nn as nn
@@ -93,7 +93,7 @@ def _quantize(self, value):
9393
)
9494
return quantized_value, scales, zero_points
9595

96-
def _quantize_and_update(self, input_pos, k_val, v_val):
96+
def _quantize_and_update(self, input_pos, k_val, v_val, indices=None):
9797
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
9898
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
9999

@@ -104,26 +104,37 @@ def _quantize_and_update(self, input_pos, k_val, v_val):
104104

105105
if self.use_custom_update_cache_op:
106106
start_pos = input_pos[0].item()
107-
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
108-
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
109107
_ = torch.ops.llama.update_cache(
110-
k_zero_points, self.k_cache_zero_points, start_pos
108+
quantized_k_val, self.k_cache, start_pos, indices
111109
)
112-
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
113-
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
114110
_ = torch.ops.llama.update_cache(
115-
v_zero_points, self.v_cache_zero_points, start_pos
111+
k_scales, self.k_cache_scales, start_pos, indices
112+
)
113+
_ = torch.ops.llama.update_cache(
114+
k_zero_points, self.k_cache_zero_points, start_pos, indices
115+
)
116+
_ = torch.ops.llama.update_cache(
117+
quantized_v_val, self.v_cache, start_pos, indices
118+
)
119+
_ = torch.ops.llama.update_cache(
120+
v_scales, self.v_cache_scales, start_pos, indices
121+
)
122+
_ = torch.ops.llama.update_cache(
123+
v_zero_points, self.v_cache_zero_points, start_pos, indices
116124
)
117125
else:
126+
assert indices is None, "Indices not supported for this path"
127+
# Following is also broken because in prefill input_pos = [0]
128+
# but we need to update some slice of cache
118129
self.k_cache[:, input_pos] = quantized_k_val
119130
self.k_cache_scales[:, input_pos] = k_scales
120131
self.k_cache_zero_points[:, input_pos] = k_zero_points
121132
self.v_cache[:, input_pos] = quantized_v_val
122133
self.v_cache_scales[:, input_pos] = v_scales
123134
self.v_cache_zero_points[:, input_pos] = v_zero_points
124135

125-
def _update_and_return_float_values(self, input_pos, k_val, v_val):
126-
self._quantize_and_update(input_pos, k_val, v_val)
136+
def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None):
137+
self._quantize_and_update(input_pos, k_val, v_val, indices)
127138

128139
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
129140
self.k_cache,
@@ -144,24 +155,26 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val):
144155
self.cache_fp_type,
145156
)
146157

147-
# When returning float values we jsut use the last value
158+
# When returning float values we just use the last value
148159
# instead of dequantized value.
149160
start_pos = input_pos[0].item()
150161
if self.use_custom_update_cache_op:
151-
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
152-
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
162+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos, indices)
163+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos, indices)
153164
else:
154165
k_out[:, input_pos] = k_val
155166
v_out[:, input_pos] = v_val
156167

157168
return k_out, v_out
158169

159-
def _update_and_return_quantized_values(self, input_pos, k_val, v_val):
160-
self._quantize_and_update(input_pos, k_val, v_val)
170+
def _update_and_return_quantized_values(
171+
self, input_pos, k_val, v_val, indices=None
172+
):
173+
self._quantize_and_update(input_pos, k_val, v_val, indices)
161174

162175
return self.k_cache, self.v_cache
163176

164-
def update(self, input_pos, k_val, v_val):
177+
def update(self, input_pos, k_val, v_val, indices=None):
165178
"""
166179
k_val, v_val: [B, H, S, D]
167180
return: [B, H, S, D]
@@ -172,10 +185,12 @@ def update(self, input_pos, k_val, v_val):
172185
v_val = v_val.transpose(1, 2)
173186

174187
if self.return_float_values:
175-
k_out, v_out = self._update_and_return_float_values(input_pos, k_val, v_val)
188+
k_out, v_out = self._update_and_return_float_values(
189+
input_pos, k_val, v_val, indices
190+
)
176191
else:
177192
k_out, v_out = self._update_and_return_quantized_values(
178-
input_pos, k_val, v_val
193+
input_pos, k_val, v_val, indices
179194
)
180195
return k_out.transpose(1, 2), v_out.transpose(1, 2)
181196

@@ -277,14 +292,20 @@ def __init__(
277292
)
278293

279294
def update(
280-
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
295+
self,
296+
input_pos: torch.Tensor,
297+
k_val: torch.Tensor,
298+
v_val: torch.Tensor,
299+
indices: Optional[torch.Tensor] = None,
281300
) -> Tuple[torch.Tensor, torch.Tensor]:
282301
# input_pos: [S], k_val: [B, H, S, D]
283302
k_val = k_val.transpose(1, 2)
284303
v_val = v_val.transpose(1, 2)
285304
start_pos = input_pos[0].item()
286-
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
287-
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
305+
306+
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, indices)
307+
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, indices)
308+
288309
return (
289310
self.k_cache.transpose(1, 2),
290311
self.v_cache.transpose(1, 2),

extension/llm/custom_ops/custom_ops.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def _validate_update_cache_params(
184184
value,
185185
cache,
186186
start_pos,
187+
indices=None,
187188
):
188189
seq_len = value.size(1)
189190
assert (
@@ -200,29 +201,44 @@ def _validate_update_cache_params(
200201
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"
201202

202203
torch._check_is_size(start_pos)
203-
# Setting to arbitrary limit of 256 for now since there is no way
204-
# to plumb this information from model config
205-
torch._check(start_pos < cache.size(1))
206-
assert start_pos < cache.size(
207-
1
208-
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
209-
210-
torch._check((start_pos + seq_len) < cache.size(1))
211-
assert (start_pos + seq_len) < cache.size(
212-
1
213-
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
204+
if indices is None:
205+
torch._check(start_pos < cache.size(1))
206+
assert start_pos < cache.size(
207+
1
208+
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
209+
210+
torch._check((start_pos + seq_len) < cache.size(1))
211+
assert (start_pos + seq_len) < cache.size(
212+
1
213+
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
214+
215+
if indices is not None:
216+
assert (
217+
indices.dim() == 2
218+
), f"Expected indices to be 2 dimensional but got {indices.dim()} dimensions."
219+
assert (
220+
indices.dtype == torch.int64
221+
), f"Expected indices to be int64 but got {indices.dtype}"
222+
assert indices.size(0) == value.size(
223+
0
224+
), f"Expected indices batch dimension to match value batch dimension but got {indices.size(0)} and {value.size(0)}"
225+
assert indices.size(1) == value.size(
226+
1
227+
), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(1)}"
214228

215229

216230
@impl(custom_ops_lib, "update_cache", "Meta")
217231
def update_cache_meta(
218232
value,
219233
cache,
220234
start_pos,
235+
indices=None,
221236
):
222237
_validate_update_cache_params(
223238
value,
224239
cache,
225240
start_pos,
241+
indices,
226242
)
227243

228244
# Update cache doesnt really return anything but I dont know a better

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,14 @@ Tensor& update_cache_out_no_context(
122122
const Tensor& value,
123123
Tensor& cache,
124124
const int64_t start_pos,
125+
const std::optional<Tensor> indices,
125126
Tensor& output);
126127

127128
at::Tensor update_cache_aten(
128129
const at::Tensor& value,
129130
at::Tensor& cache,
130-
const int64_t start_pos);
131+
const int64_t start_pos,
132+
const std::optional<at::Tensor>& indices);
131133

132134
Tensor& sdpa_with_kv_cache_out_no_context(
133135
const Tensor& q_projected,
@@ -324,19 +326,21 @@ Tensor& update_cache_out_no_context(
324326
const Tensor& value,
325327
Tensor& cache,
326328
const int64_t start_pos,
329+
const std::optional<Tensor> indices,
327330
Tensor& output) {
328331
executorch::aten::RuntimeContext context{};
329332
return torch::executor::native::update_cache_out(
330-
context, value, cache, start_pos, output);
333+
context, value, cache, start_pos, indices, output);
331334
}
332335

333336
at::Tensor update_cache_aten(
334337
const at::Tensor& value,
335338
at::Tensor& cache,
336-
const int64_t start_pos) {
339+
const int64_t start_pos,
340+
const std::optional<at::Tensor>& indices) {
337341
auto output = at::empty({1});
338-
WRAP_TO_ATEN(update_cache_out_no_context, 3)
339-
(value, cache, start_pos, output);
342+
WRAP_TO_ATEN(update_cache_out_no_context, 4)
343+
(value, cache, start_pos, indices, output);
340344
return output;
341345
}
342346

@@ -363,10 +367,10 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
363367
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
364368
m.def(
365369
"update_cache(Tensor value, Tensor(a!) cache, "
366-
"SymInt start_pos) -> Tensor");
370+
"SymInt start_pos, Tensor? indices=None) -> Tensor");
367371
m.def(
368372
"update_cache.out(Tensor value, Tensor(a!) cache, "
369-
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
373+
"SymInt start_pos, Tensor? indices=None, *, Tensor(b!) out) -> Tensor(b!)");
370374
m.def(
371375
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
372376
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
@@ -396,7 +400,7 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
396400
m.impl("update_cache", torch::executor::native::update_cache_aten);
397401
m.impl(
398402
"update_cache.out",
399-
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
403+
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 4));
400404
m.impl(
401405
"custom_quantized_sdpa",
402406
torch::executor::native::custom_quantized_sdpa_aten);

0 commit comments

Comments
 (0)