Skip to content

Commit 91f989b

Browse files
Varun Purifacebook-github-bot
authored andcommitted
Add seq_len to llama runner for early stopping
Summary: By default, the llama runner will continue generating until max_seq_len. This is a property embedded in the model metadata. We want a way to limit the number of tokens generated. Reviewed By: larryliu0820 Differential Revision: D53873431
1 parent ca6995b commit 91f989b

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

examples/models/llama2/main.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ DEFINE_double(
2424
0.8f,
2525
"Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
2626

27+
DEFINE_int32(
28+
seq_len,
29+
128,
30+
"Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
31+
2732
int32_t main(int32_t argc, char** argv) {
2833
gflags::ParseCommandLineFlags(&argc, &argv, true);
2934

@@ -38,11 +43,13 @@ int32_t main(int32_t argc, char** argv) {
3843

3944
double temperature = FLAGS_temperature;
4045

46+
int32_t seq_len = FLAGS_seq_len;
47+
4148
// create llama runner
4249
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);
4350

4451
// generate
45-
runner.generate(prompt);
52+
runner.generate(prompt, seq_len);
4653

4754
return 0;
4855
}

examples/models/llama2/runner/runner.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
119119
return res;
120120
}
121121

122-
std::vector<exec_aten::SizesType> Runner::getKVCacheShape() {
122+
std::vector<exec_aten::SizesType> Runner::getKVCacheShape(int32_t seq_len) {
123123
// shape: (n_layers, args.max_batch_size, args.max_seq_len, self.n_kv_heads,
124124
// self.head_dim)
125125
std::vector<std::string> methods = {
@@ -134,6 +134,8 @@ std::vector<exec_aten::SizesType> Runner::getKVCacheShape() {
134134
// convert from int64_t to int32_t
135135
result.push_back(getMetadataHelper<int64_t>(methods[i], default_values[i]));
136136
}
137+
// update seq_len if one is provided between 1 and max_seq_len
138+
result[2] = (seq_len > 0 && seq_len <= result[2]) ? seq_len : result[2];
137139
return result;
138140
}
139141

@@ -155,6 +157,7 @@ int32_t Runner::logitsToToken(
155157

156158
Error Runner::generate(
157159
const std::string& prompt,
160+
int32_t seq_len,
158161
std::function<void(const std::string&)> callback) {
159162
// Prepare the inputs.
160163
// Use ones-initialized inputs.
@@ -168,6 +171,9 @@ Error Runner::generate(
168171
// max # of prompt tokens: len(prompt) + '\0', ?BOS, ?EOS
169172
int* prompt_tokens = new int[prompt.size() + 1 + n_bos_ + n_eos_];
170173

174+
// Set the sequence length to the max seq length if not provided
175+
seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
176+
171177
tokenizer_->encode(
172178
prompt.c_str(),
173179
n_bos_,
@@ -182,6 +188,10 @@ Error Runner::generate(
182188
num_prompt_tokens < max_seq_len_,
183189
"Max seq length exceeded - please increase max seq len value in .../llama2/model.py");
184190

191+
ET_CHECK_MSG(
192+
num_prompt_tokens < seq_len,
193+
"Sequence length exceeded - please increase the seq_len value passed to generate()");
194+
185195
// start the main loop
186196
long start =
187197
0; // used to time our code, only initialized after first iteration
@@ -190,7 +200,7 @@ Error Runner::generate(
190200
int token = prompt_tokens[pos]; // prefill starts from 0 to num_prompt_tokens
191201
int eos_counter = 0; // counter to capture EOS
192202
int logits_index = 0; // index of the logits tensor in the output
193-
std::vector<exec_aten::SizesType> kv_cache_shape = getKVCacheShape();
203+
std::vector<exec_aten::SizesType> kv_cache_shape = getKVCacheShape(seq_len);
194204
std::vector<exec_aten::SizesType> input_shape = {1, 1};
195205
std::vector<exec_aten::SizesType> pos_shape = {};
196206
std::vector<uint8_t> k_data;
@@ -215,7 +225,7 @@ Error Runner::generate(
215225
token_data.resize(1);
216226
} else {
217227
// reserve data for tokens, notice the size is still 0.
218-
token_data.resize(max_seq_len_);
228+
token_data.resize(seq_len);
219229
}
220230

221231
// initialize tensor wrappers
@@ -235,7 +245,7 @@ Error Runner::generate(
235245
}
236246
}
237247
// create a 1xN int tensor with next as value
238-
while (pos < max_seq_len_) {
248+
while (pos < seq_len) {
239249
// ET_LOG(Info, "Generating step %d...", pos);
240250
// set the current token in the tensor
241251
std::vector<EValue> inputs;
@@ -348,8 +358,8 @@ Error Runner::generate(
348358
}
349359
printf("\n");
350360

351-
if (pos == max_seq_len_) {
352-
ET_LOG(Info, "Maximum sequence length reached!");
361+
if (pos == seq_len) {
362+
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
353363
}
354364
// report achieved tok/s (pos-1 because the timer starts after first
355365
// iteration)

examples/models/llama2/runner/runner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class Runner {
3434
Error load();
3535
Error generate(
3636
const std::string& prompt,
37+
int32_t seq_len = 128,
3738
std::function<void(const std::string&)> callback = {});
3839
void stop();
3940

@@ -44,7 +45,7 @@ class Runner {
4445
template <typename T>
4546
int32_t
4647
logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _);
47-
std::vector<exec_aten::SizesType> getKVCacheShape();
48+
std::vector<exec_aten::SizesType> getKVCacheShape(int32_t seq_len);
4849
// metadata
4950
int32_t vocab_size_;
5051
int32_t bos_id_;

0 commit comments

Comments
 (0)