Skip to content

[Executorch][llm] Make custom update cache op operate on indices #10610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 13, 2025
Merged
63 changes: 42 additions & 21 deletions examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
from enum import Enum
from typing import Tuple
from typing import Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -93,7 +93,7 @@ def _quantize(self, value):
)
return quantized_value, scales, zero_points

def _quantize_and_update(self, input_pos, k_val, v_val):
def _quantize_and_update(self, input_pos, k_val, v_val, indices=None):
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)

Expand All @@ -104,26 +104,37 @@ def _quantize_and_update(self, input_pos, k_val, v_val):

if self.use_custom_update_cache_op:
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
quantized_k_val, self.k_cache, start_pos, indices
)
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
k_scales, self.k_cache_scales, start_pos, indices
)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos, indices
)
_ = torch.ops.llama.update_cache(
quantized_v_val, self.v_cache, start_pos, indices
)
_ = torch.ops.llama.update_cache(
v_scales, self.v_cache_scales, start_pos, indices
)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos, indices
)
else:
assert indices is None, "Indices not supported for this path"
# Following is also broken because in prefill input_pos = [0]
# but we need to update some slice of cache
self.k_cache[:, input_pos] = quantized_k_val
self.k_cache_scales[:, input_pos] = k_scales
self.k_cache_zero_points[:, input_pos] = k_zero_points
self.v_cache[:, input_pos] = quantized_v_val
self.v_cache_scales[:, input_pos] = v_scales
self.v_cache_zero_points[:, input_pos] = v_zero_points

def _update_and_return_float_values(self, input_pos, k_val, v_val):
self._quantize_and_update(input_pos, k_val, v_val)
def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None):
self._quantize_and_update(input_pos, k_val, v_val, indices)

k_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.k_cache,
Expand All @@ -144,24 +155,26 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val):
self.cache_fp_type,
)

# When returning float values we jsut use the last value
# When returning float values we just use the last value
# instead of dequantized value.
start_pos = input_pos[0].item()
if self.use_custom_update_cache_op:
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos, indices)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos, indices)
else:
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val

return k_out, v_out

def _update_and_return_quantized_values(self, input_pos, k_val, v_val):
self._quantize_and_update(input_pos, k_val, v_val)
def _update_and_return_quantized_values(
self, input_pos, k_val, v_val, indices=None
):
self._quantize_and_update(input_pos, k_val, v_val, indices)

return self.k_cache, self.v_cache

def update(self, input_pos, k_val, v_val):
def update(self, input_pos, k_val, v_val, indices=None):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
Expand All @@ -172,10 +185,12 @@ def update(self, input_pos, k_val, v_val):
v_val = v_val.transpose(1, 2)

if self.return_float_values:
k_out, v_out = self._update_and_return_float_values(input_pos, k_val, v_val)
k_out, v_out = self._update_and_return_float_values(
input_pos, k_val, v_val, indices
)
else:
k_out, v_out = self._update_and_return_quantized_values(
input_pos, k_val, v_val
input_pos, k_val, v_val, indices
)
return k_out.transpose(1, 2), v_out.transpose(1, 2)

Expand Down Expand Up @@ -277,14 +292,20 @@ def __init__(
)

def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
self,
input_pos: torch.Tensor,
k_val: torch.Tensor,
v_val: torch.Tensor,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D]
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)

_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, indices)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, indices)

return (
self.k_cache.transpose(1, 2),
self.v_cache.transpose(1, 2),
Expand Down
38 changes: 27 additions & 11 deletions extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _validate_update_cache_params(
value,
cache,
start_pos,
indices=None,
):
seq_len = value.size(1)
assert (
Expand All @@ -200,29 +201,44 @@ def _validate_update_cache_params(
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"

torch._check_is_size(start_pos)
# Setting to arbitrary limit of 256 for now since there is no way
# to plumb this information from model config
torch._check(start_pos < cache.size(1))
assert start_pos < cache.size(
1
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"

torch._check((start_pos + seq_len) < cache.size(1))
assert (start_pos + seq_len) < cache.size(
1
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
if indices is None:
torch._check(start_pos < cache.size(1))
assert start_pos < cache.size(
1
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"

torch._check((start_pos + seq_len) < cache.size(1))
assert (start_pos + seq_len) < cache.size(
1
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"

if indices is not None:
assert (
indices.dim() == 2
), f"Expected indices to be 2 dimensional but got {indices.dim()} dimensions."
assert (
indices.dtype == torch.int64
), f"Expected indices to be int64 but got {indices.dtype}"
assert indices.size(0) == value.size(
0
), f"Expected indices batch dimension to match value batch dimension but got {indices.size(0)} and {value.size(0)}"
assert indices.size(1) == value.size(
1
), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(1)}"


@impl(custom_ops_lib, "update_cache", "Meta")
def update_cache_meta(
value,
cache,
start_pos,
indices=None,
):
_validate_update_cache_params(
value,
cache,
start_pos,
indices,
)

# Update cache doesnt really return anything but I dont know a better
Expand Down
20 changes: 12 additions & 8 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,14 @@ Tensor& update_cache_out_no_context(
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
const std::optional<Tensor> indices,
Tensor& output);

at::Tensor update_cache_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos);
const int64_t start_pos,
const std::optional<at::Tensor>& indices);

Tensor& sdpa_with_kv_cache_out_no_context(
const Tensor& q_projected,
Expand Down Expand Up @@ -324,19 +326,21 @@ Tensor& update_cache_out_no_context(
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
const std::optional<Tensor> indices,
Tensor& output) {
executorch::aten::RuntimeContext context{};
return torch::executor::native::update_cache_out(
context, value, cache, start_pos, output);
context, value, cache, start_pos, indices, output);
}

at::Tensor update_cache_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos) {
const int64_t start_pos,
const std::optional<at::Tensor>& indices) {
auto output = at::empty({1});
WRAP_TO_ATEN(update_cache_out_no_context, 3)
(value, cache, start_pos, output);
WRAP_TO_ATEN(update_cache_out_no_context, 4)
(value, cache, start_pos, indices, output);
return output;
}

Expand All @@ -363,10 +367,10 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
"float? scale=None, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"update_cache(Tensor value, Tensor(a!) cache, "
"SymInt start_pos) -> Tensor");
"SymInt start_pos, Tensor? indices=None) -> Tensor");
m.def(
"update_cache.out(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
"SymInt start_pos, Tensor? indices=None, *, Tensor(b!) out) -> Tensor(b!)");
m.def(
"custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
"Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
Expand Down Expand Up @@ -396,7 +400,7 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl("update_cache", torch::executor::native::update_cache_aten);
m.impl(
"update_cache.out",
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3));
WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 4));
m.impl(
"custom_quantized_sdpa",
torch::executor::native::custom_quantized_sdpa_aten);
Expand Down
Loading
Loading