Skip to content

Commit 758168e

Browse files
Varun Purifacebook-github-bot
authored andcommitted
Revert changes to getKVCacheSize()
Summary: KV Cache does not support dynamic shapes. Do not change the size of the KV cache based on the sequence length. Reviewed By: kimishpatel, larryliu0820 Differential Revision: D54218307 fbshipit-source-id: 5a40093fd44db082a1de57126eab970bfc022b4b
1 parent ce99c21 commit 758168e

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
119119
return res;
120120
}
121121

122-
std::vector<exec_aten::SizesType> Runner::getKVCacheShape(int32_t seq_len) {
122+
std::vector<exec_aten::SizesType> Runner::getKVCacheShape() {
123123
// shape: (n_layers, args.max_batch_size, args.max_seq_len, self.n_kv_heads,
124124
// self.head_dim)
125125
std::vector<std::string> methods = {
@@ -134,9 +134,6 @@ std::vector<exec_aten::SizesType> Runner::getKVCacheShape(int32_t seq_len) {
134134
// convert from int64_t to int32_t
135135
result.push_back(getMetadataHelper<int64_t>(methods[i], default_values[i]));
136136
}
137-
// update seq_len if one is provided between 1 and max_seq_len
138-
ET_CHECK_MSG(result.size() == 5, "KV cache shape must have 5 elements");
139-
result[2] = (seq_len > 0 && seq_len <= result[2]) ? seq_len : result[2];
140137
return result;
141138
}
142139

@@ -201,7 +198,7 @@ Error Runner::generate(
201198
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
202199
int eos_counter = 0; // counter to capture EOS
203200
int logits_index = 0; // index of the logits tensor in the output
204-
std::vector<exec_aten::SizesType> kv_cache_shape = getKVCacheShape(seq_len);
201+
std::vector<exec_aten::SizesType> kv_cache_shape = getKVCacheShape();
205202
std::vector<exec_aten::SizesType> input_shape = {1, 1};
206203
std::vector<exec_aten::SizesType> pos_shape = {};
207204
std::vector<uint8_t> k_data;

examples/models/llama2/runner/runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class Runner {
4545
template <typename T>
4646
int32_t
4747
logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _);
48-
std::vector<exec_aten::SizesType> getKVCacheShape(int32_t seq_len);
48+
std::vector<exec_aten::SizesType> getKVCacheShape();
4949
// metadata
5050
int32_t vocab_size_;
5151
int32_t bos_id_;

0 commit comments

Comments
 (0)