Skip to content

Commit 0cff7c9

Browse files
author
Joey Tsai
committed
Rebase and minor fix
- Fix rebase conflict - Change input dtype of calibration function
1 parent b7061d7 commit 0cff7c9

File tree

3 files changed

+3
-5
lines changed

3 files changed

+3
-5
lines changed

examples/qualcomm/oss_scripts/llama2/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
203203
return probs_indices.gather(dim=-1, index=next_token)
204204

205205
with torch.no_grad():
206-
while token_list[-1] != sp_model.eos_id() and pos < max_seq_len:
206+
while token_list[-1] != sp_model.eos_id() and pos < max_seq_len - 1:
207207
logits, new_k_caches, new_v_caches = module(
208208
torch.full((1, 1), token_list[pos]),
209209
atten_mask,
@@ -248,7 +248,7 @@ def _bert_calibrate(
248248
token_list = torch.cat(
249249
[
250250
token_list,
251-
torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int64),
251+
torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32),
252252
],
253253
dim=1,
254254
)

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _bert_calibrate(
122122
token_list = torch.cat(
123123
[
124124
token_list,
125-
torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int64),
125+
torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32),
126126
],
127127
dim=1,
128128
)

examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ Runner::Runner(
4545
const int eval_mode)
4646
: n_bos_(1),
4747
n_eos_(1),
48-
vocab_size_(QNN_LLAMA3_2_LOGITS),
49-
max_seq_len_(QNN_LLAMA3_2_SEQLEN),
5048
tokenizer_path_(tokenizer_path),
5149
temperature_(temperature),
5250
eval_mode_(eval_mode),

0 commit comments

Comments
 (0)