Skip to content

Commit 1d934f0

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Kv Cache as mutable buffer (#2595)
Summary: Pull Request resolved: #2595 Big mish mash of changes to support kv cache as a mutable buffer see this stack D55219107 for the individual reviews bypass-github-export-checks Reviewed By: mergennachin Differential Revision: D55229223 fbshipit-source-id: caac8a94d69126bf86349556020814a08e4b43c0
1 parent d52ebdc commit 1d934f0

File tree

11 files changed

+290
-395
lines changed

11 files changed

+290
-395
lines changed

examples/models/llama2/builder.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,6 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager":
179179
logging.info(f"model.to {torch_dtype}")
180180
self.model = self.model.to(dtype=torch_dtype)
181181
self.dtype = dtype_override
182-
183-
# convert kv cache to dtype as well. This should be removed after mutable buffer is supported.
184-
# assuming the kv cache are the last 2 tensors in the example inputs
185-
if self.use_kv_cache:
186-
dtype = torch.float16 if self.dtype == DType.fp16 else torch.float32
187-
example_inputs = list(self.example_inputs[:-2]) + [
188-
cache.to(dtype) for cache in self.example_inputs[-2:]
189-
]
190-
self.example_inputs = tuple(example_inputs)
191182
return self
192183

193184
def source_transform(
@@ -209,11 +200,15 @@ def source_transform(
209200
return self
210201

211202
def _get_dynamic_shape(self) -> Optional[Dict[str, Any]]:
212-
if self.use_kv_cache:
213-
return None
214203
dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1)
215-
dynamic_shape = {"tokens": {1: dim}}
216-
return dynamic_shape
204+
if self.use_kv_cache:
205+
if self.use_sdpa_with_kv_cache:
206+
return None
207+
else:
208+
# return {"tokens": {1: dim}, "input_pos": {0: dim}} TODO update xnnpack to be able to handle dynamic shape kv cache
209+
return None
210+
else:
211+
return {"tokens": {1: dim}}
217212

218213
def _get_edge_config(self) -> EdgeCompileConfig:
219214
edge_config = EdgeCompileConfig(

examples/models/llama2/custom_ops/custom_ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
- arg_meta: null
99
kernel_name: torch::executor::flash_attention_kernel_out
1010

11-
- func: llama::sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor key_cache, Tensor value_cache, int layer_id, int start_pos, int seq_len, Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(a!) out) -> Tensor(a!)
11+
- func: llama::sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, Tensor(b!) value_cache, int start_pos, int seq_len, Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)
1212
kernels:
1313
- arg_meta: null
1414
kernel_name: torch::executor::sdpa_with_kv_cache_out

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -541,41 +541,34 @@ bool validate_flash_attention_args(
541541
bool validate_cache_params(
542542
const Tensor& k_cache,
543543
const Tensor& v_cache,
544-
int64_t layer_id,
545544
int64_t start_pos,
546545
int64_t seq_length) {
547546
ET_LOG_MSG_AND_RETURN_IF_FALSE(
548-
k_cache.dim() == 5, "kcache must be a 5D tensor");
547+
k_cache.dim() == 4, "kcache must be a 4D tensor");
549548

550549
ET_LOG_MSG_AND_RETURN_IF_FALSE(
551-
v_cache.dim() == 5, "v_cache must be a 5D tensor");
550+
v_cache.dim() == 4, "v_cache must be a 4D tensor");
552551

553552
ET_LOG_MSG_AND_RETURN_IF_FALSE(
554-
layer_id < k_cache.size(0), "layer_id must be less than kcache dim 0");
555-
556-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
557-
layer_id < v_cache.size(0), "layer_id must be less than vcache dim 0");
558-
559-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
560-
start_pos < k_cache.size(2),
553+
start_pos < k_cache.size(1),
561554
"start_pos must be less than key cache at dim 1");
562555

563556
ET_LOG_MSG_AND_RETURN_IF_FALSE(
564-
start_pos < v_cache.size(2),
557+
start_pos < v_cache.size(1),
565558
"start_pos must be less than value cache at dim 1");
566559

567560
ET_LOG_MSG_AND_RETURN_IF_FALSE(
568-
(start_pos + seq_length) < k_cache.size(2),
561+
(start_pos + seq_length) < k_cache.size(1),
569562
"start_post + seq_length must be less than max seq length supported by key cache."
570563
"start pos: %" PRId64 ", seq_length: %" PRId64
571564
"."
572565
"key cache size: %zd",
573566
start_pos,
574567
seq_length,
575-
k_cache.size(2));
568+
k_cache.size(1));
576569

577570
ET_LOG_MSG_AND_RETURN_IF_FALSE(
578-
(start_pos + seq_length) < v_cache.size(2),
571+
(start_pos + seq_length) < v_cache.size(1),
579572
"start_post + seq_length must be less than max seq length supported by key cache."
580573
"start pos: %" PRId64 ", seq_length: %" PRId64
581574
"."
@@ -600,14 +593,13 @@ bool validate_cache_params(
600593
void update_cache(
601594
const Tensor& projected_value,
602595
const Tensor& cache,
603-
int64_t layer_id,
604596
int64_t start_pos,
605597
int64_t seq_length) {
606598
ET_CHECK_MSG(seq_length == 1, "seq_length must be 1");
607599
ET_CHECK_MSG(
608600
projected_value.size(0) == 1,
609601
"projected_value must have batch size of 1");
610-
ET_CHECK_MSG(cache.size(1) == 1, "cache must have batch size of 1");
602+
ET_CHECK_MSG(cache.size(0) == 1, "cache must have batch size of 1");
611603
ET_CHECK_MSG(
612604
is_default_dim_order(
613605
projected_value.dim_order().data(), projected_value.dim()),
@@ -619,10 +611,8 @@ void update_cache(
619611
ET_CHECK_MSG(cache_data, "cache data is null");
620612

621613
auto strides = cache.strides();
622-
exec_aten::StridesType layer_stride = strides[0];
623-
exec_aten::StridesType seq_dim_stride = strides[2];
624-
exec_aten::SizesType pos_offset =
625-
layer_id * layer_stride + start_pos * seq_dim_stride;
614+
exec_aten::StridesType seq_dim_stride = strides[1];
615+
exec_aten::SizesType pos_offset = start_pos * seq_dim_stride;
626616
exec_aten::SizesType pos_offset_bytes =
627617
pos_offset * projected_value.element_size();
628618
exec_aten::SizesType num_bytes =
@@ -713,19 +703,16 @@ Tensor& flash_attention_kernel_out(
713703
@param[in] key_cache Cache of previous v_projected.
714704
Format [n_layers, batch size, max_seq_len, num heads, head dim]
715705
....
716-
@param[in] layer_id which layer this call belongs to.
717-
Used to updated appropriate entry of kv cache
718-
@param[in] start_pos sequence position
719-
@param[in] seq_len Seq length. e.g. seq_len dim of q_projected.
706+
@param[in] start_pos: sequence position
707+
@param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
720708
*/
721709
Tensor& sdpa_with_kv_cache_out(
722710
RuntimeContext& ctx,
723711
const Tensor& q_projected,
724712
const Tensor& k_projected,
725713
const Tensor& v_projected,
726-
const Tensor& key_cache,
727-
const Tensor& value_cache,
728-
const int64_t layer_id, // THis should be gone with buffer based impl
714+
Tensor& key_cache,
715+
Tensor& value_cache,
729716
const int64_t start_pos,
730717
const int64_t seq_len,
731718
const optional<Tensor>& attn_mask,
@@ -737,34 +724,31 @@ Tensor& sdpa_with_kv_cache_out(
737724
(void)ctx;
738725
ET_KERNEL_CHECK(
739726
ctx,
740-
validate_cache_params(
741-
key_cache, value_cache, layer_id, start_pos, seq_len),
727+
validate_cache_params(key_cache, value_cache, start_pos, seq_len),
742728
InvalidArgument,
743729
output);
744730

745731
ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor");
746732

747-
update_cache(k_projected, key_cache, layer_id, start_pos, seq_len);
748-
update_cache(v_projected, value_cache, layer_id, start_pos, seq_len);
733+
update_cache(k_projected, key_cache, start_pos, seq_len);
734+
update_cache(v_projected, value_cache, start_pos, seq_len);
749735

750736
auto q_seq_len = q_projected.size(1);
751737

752738
std::array<exec_aten::DimOrderType, util::kKVDim> sliced_key_dim_order{
753739
0, 1, 2, 3};
754740
std::array<exec_aten::SizesType, util::kKVDim> sliced_key_sizes;
755-
sliced_key_sizes[0] = key_cache.size(1);
741+
sliced_key_sizes[0] = key_cache.size(0);
756742
sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2);
757-
sliced_key_sizes[2] = key_cache.size(3);
758-
sliced_key_sizes[3] = key_cache.size(4);
743+
sliced_key_sizes[2] = key_cache.size(2);
744+
sliced_key_sizes[3] = key_cache.size(3);
759745
std::array<exec_aten::StridesType, util::kKVDim> sliced_key_strides;
760746
dim_order_to_stride_nocheck(
761747
sliced_key_sizes.data(),
762748
sliced_key_dim_order.data(),
763749
util::kKVDim,
764750
sliced_key_strides.data());
765-
void* key_cache_data = reinterpret_cast<void*>(
766-
reinterpret_cast<ptrdiff_t>(key_cache.mutable_data_ptr()) +
767-
layer_id * key_cache.strides()[0] * key_cache.element_size());
751+
void* key_cache_data = key_cache.mutable_data_ptr();
768752
TensorImpl k_impl = TensorImpl(
769753
key_cache.scalar_type(),
770754
util::kKVDim,
@@ -778,19 +762,17 @@ Tensor& sdpa_with_kv_cache_out(
778762
std::array<exec_aten::DimOrderType, util::kKVDim> sliced_value_dim_order{
779763
0, 1, 2, 3};
780764
std::array<exec_aten::SizesType, util::kKVDim> sliced_value_sizes;
781-
sliced_value_sizes[0] = value_cache.size(1);
765+
sliced_value_sizes[0] = value_cache.size(0);
782766
sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2);
783-
sliced_value_sizes[2] = value_cache.size(3);
784-
sliced_value_sizes[3] = value_cache.size(4);
767+
sliced_value_sizes[2] = value_cache.size(2);
768+
sliced_value_sizes[3] = value_cache.size(3);
785769
std::array<exec_aten::StridesType, util::kKVDim> sliced_value_strides;
786770
dim_order_to_stride_nocheck(
787771
sliced_value_sizes.data(),
788772
sliced_value_dim_order.data(),
789773
util::kKVDim,
790774
sliced_value_strides.data());
791-
void* value_cache_data = reinterpret_cast<void*>(
792-
reinterpret_cast<ptrdiff_t>(value_cache.mutable_data_ptr()) +
793-
layer_id * value_cache.strides()[0] * value_cache.element_size());
775+
void* value_cache_data = value_cache.mutable_data_ptr();
794776
TensorImpl value_impl = TensorImpl(
795777
value_cache.scalar_type(),
796778
util::kKVDim,

0 commit comments

Comments
 (0)