Skip to content

Kv Cache as mutable buffer #2595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions examples/models/llama2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,6 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager":
logging.info(f"model.to {torch_dtype}")
self.model = self.model.to(dtype=torch_dtype)
self.dtype = dtype_override

# convert kv cache to dtype as well. This should be removed after mutable buffer is supported.
# assuming the kv cache are the last 2 tensors in the example inputs
if self.use_kv_cache:
dtype = torch.float16 if self.dtype == DType.fp16 else torch.float32
example_inputs = list(self.example_inputs[:-2]) + [
cache.to(dtype) for cache in self.example_inputs[-2:]
]
self.example_inputs = tuple(example_inputs)
return self

def source_transform(
Expand All @@ -209,11 +200,15 @@ def source_transform(
return self

def _get_dynamic_shape(self) -> Optional[Dict[str, Any]]:
if self.use_kv_cache:
return None
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
dynamic_shape = {"tokens": {1: dim}}
return dynamic_shape
if self.use_kv_cache:
if self.use_sdpa_with_kv_cache:
return None
else:
# return {"tokens": {1: dim}, "input_pos": {0: dim}} TODO update xnnpack to be able to handle dynamic shape kv cache
return None
else:
return {"tokens": {1: dim}}

def _get_edge_config(self) -> EdgeCompileConfig:
edge_config = EdgeCompileConfig(
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/custom_ops/custom_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
- arg_meta: null
kernel_name: torch::executor::flash_attention_kernel_out

- func: llama::sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor key_cache, Tensor value_cache, int layer_id, int start_pos, int seq_len, Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(a!) out) -> Tensor(a!)
- func: llama::sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, Tensor(b!) value_cache, int start_pos, int seq_len, Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)
kernels:
- arg_meta: null
kernel_name: torch::executor::sdpa_with_kv_cache_out
68 changes: 25 additions & 43 deletions examples/models/llama2/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,41 +541,34 @@ bool validate_flash_attention_args(
bool validate_cache_params(
const Tensor& k_cache,
const Tensor& v_cache,
int64_t layer_id,
int64_t start_pos,
int64_t seq_length) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
k_cache.dim() == 5, "kcache must be a 5D tensor");
k_cache.dim() == 4, "kcache must be a 4D tensor");

ET_LOG_MSG_AND_RETURN_IF_FALSE(
v_cache.dim() == 5, "v_cache must be a 5D tensor");
v_cache.dim() == 4, "v_cache must be a 4D tensor");

ET_LOG_MSG_AND_RETURN_IF_FALSE(
layer_id < k_cache.size(0), "layer_id must be less than kcache dim 0");

ET_LOG_MSG_AND_RETURN_IF_FALSE(
layer_id < v_cache.size(0), "layer_id must be less than vcache dim 0");

ET_LOG_MSG_AND_RETURN_IF_FALSE(
start_pos < k_cache.size(2),
start_pos < k_cache.size(1),
"start_pos must be less than key cache at dim 1");

ET_LOG_MSG_AND_RETURN_IF_FALSE(
start_pos < v_cache.size(2),
start_pos < v_cache.size(1),
"start_pos must be less than value cache at dim 1");

ET_LOG_MSG_AND_RETURN_IF_FALSE(
(start_pos + seq_length) < k_cache.size(2),
(start_pos + seq_length) < k_cache.size(1),
"start_post + seq_length must be less than max seq length supported by key cache."
"start pos: %" PRId64 ", seq_length: %" PRId64
"."
"key cache size: %zd",
start_pos,
seq_length,
k_cache.size(2));
k_cache.size(1));

ET_LOG_MSG_AND_RETURN_IF_FALSE(
(start_pos + seq_length) < v_cache.size(2),
(start_pos + seq_length) < v_cache.size(1),
"start_post + seq_length must be less than max seq length supported by key cache."
"start pos: %" PRId64 ", seq_length: %" PRId64
"."
Expand All @@ -600,14 +593,13 @@ bool validate_cache_params(
void update_cache(
const Tensor& projected_value,
const Tensor& cache,
int64_t layer_id,
int64_t start_pos,
int64_t seq_length) {
ET_CHECK_MSG(seq_length == 1, "seq_length must be 1");
ET_CHECK_MSG(
projected_value.size(0) == 1,
"projected_value must have batch size of 1");
ET_CHECK_MSG(cache.size(1) == 1, "cache must have batch size of 1");
ET_CHECK_MSG(cache.size(0) == 1, "cache must have batch size of 1");
ET_CHECK_MSG(
is_default_dim_order(
projected_value.dim_order().data(), projected_value.dim()),
Expand All @@ -619,10 +611,8 @@ void update_cache(
ET_CHECK_MSG(cache_data, "cache data is null");

auto strides = cache.strides();
exec_aten::StridesType layer_stride = strides[0];
exec_aten::StridesType seq_dim_stride = strides[2];
exec_aten::SizesType pos_offset =
layer_id * layer_stride + start_pos * seq_dim_stride;
exec_aten::StridesType seq_dim_stride = strides[1];
exec_aten::SizesType pos_offset = start_pos * seq_dim_stride;
exec_aten::SizesType pos_offset_bytes =
pos_offset * projected_value.element_size();
exec_aten::SizesType num_bytes =
Expand Down Expand Up @@ -713,19 +703,16 @@ Tensor& flash_attention_kernel_out(
@param[in] key_cache Cache of previous v_projected.
Format [n_layers, batch size, max_seq_len, num heads, head dim]
....
@param[in] layer_id which layer this call belongs to.
Used to updated appropriate entry of kv cache
@param[in] start_pos sequence position
@param[in] seq_len Seq length. e.g. seq_len dim of q_projected.
@param[in] start_pos: sequence position
@param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
*/
Tensor& sdpa_with_kv_cache_out(
RuntimeContext& ctx,
const Tensor& q_projected,
const Tensor& k_projected,
const Tensor& v_projected,
const Tensor& key_cache,
const Tensor& value_cache,
const int64_t layer_id, // THis should be gone with buffer based impl
Tensor& key_cache,
Tensor& value_cache,
const int64_t start_pos,
const int64_t seq_len,
const optional<Tensor>& attn_mask,
Expand All @@ -737,34 +724,31 @@ Tensor& sdpa_with_kv_cache_out(
(void)ctx;
ET_KERNEL_CHECK(
ctx,
validate_cache_params(
key_cache, value_cache, layer_id, start_pos, seq_len),
validate_cache_params(key_cache, value_cache, start_pos, seq_len),
InvalidArgument,
output);

ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor");

update_cache(k_projected, key_cache, layer_id, start_pos, seq_len);
update_cache(v_projected, value_cache, layer_id, start_pos, seq_len);
update_cache(k_projected, key_cache, start_pos, seq_len);
update_cache(v_projected, value_cache, start_pos, seq_len);

auto q_seq_len = q_projected.size(1);

std::array<exec_aten::DimOrderType, util::kKVDim> sliced_key_dim_order{
0, 1, 2, 3};
std::array<exec_aten::SizesType, util::kKVDim> sliced_key_sizes;
sliced_key_sizes[0] = key_cache.size(1);
sliced_key_sizes[0] = key_cache.size(0);
sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2);
sliced_key_sizes[2] = key_cache.size(3);
sliced_key_sizes[3] = key_cache.size(4);
sliced_key_sizes[2] = key_cache.size(2);
sliced_key_sizes[3] = key_cache.size(3);
std::array<exec_aten::StridesType, util::kKVDim> sliced_key_strides;
dim_order_to_stride_nocheck(
sliced_key_sizes.data(),
sliced_key_dim_order.data(),
util::kKVDim,
sliced_key_strides.data());
void* key_cache_data = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(key_cache.mutable_data_ptr()) +
layer_id * key_cache.strides()[0] * key_cache.element_size());
void* key_cache_data = key_cache.mutable_data_ptr();
TensorImpl k_impl = TensorImpl(
key_cache.scalar_type(),
util::kKVDim,
Expand All @@ -778,19 +762,17 @@ Tensor& sdpa_with_kv_cache_out(
std::array<exec_aten::DimOrderType, util::kKVDim> sliced_value_dim_order{
0, 1, 2, 3};
std::array<exec_aten::SizesType, util::kKVDim> sliced_value_sizes;
sliced_value_sizes[0] = value_cache.size(1);
sliced_value_sizes[0] = value_cache.size(0);
sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2);
sliced_value_sizes[2] = value_cache.size(3);
sliced_value_sizes[3] = value_cache.size(4);
sliced_value_sizes[2] = value_cache.size(2);
sliced_value_sizes[3] = value_cache.size(3);
std::array<exec_aten::StridesType, util::kKVDim> sliced_value_strides;
dim_order_to_stride_nocheck(
sliced_value_sizes.data(),
sliced_value_dim_order.data(),
util::kKVDim,
sliced_value_strides.data());
void* value_cache_data = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(value_cache.mutable_data_ptr()) +
layer_id * value_cache.strides()[0] * value_cache.element_size());
void* value_cache_data = value_cache.mutable_data_ptr();
TensorImpl value_impl = TensorImpl(
value_cache.scalar_type(),
util::kKVDim,
Expand Down
Loading