Skip to content

Commit 00c49c8

Browse files
committed
fix eager_eval with kv cache and improve pybind eval speed
1 parent 1cb97e0 commit 00c49c8

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,9 @@ def _model_call(self, inps):
5454
# inps: Tensor of shape (1, max_seq_len - 1)
5555
# logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
5656
if self._use_kv_cache:
57-
result_logits = []
58-
for pos in range(self._max_seq_length):
59-
pos_tensor = torch.tensor([pos], dtype=torch.int64)
60-
logits = self._et_model.forward((inps[:, pos : pos + 1], pos_tensor))
61-
result_logits.append(logits[0])
62-
return torch.cat(result_logits, dim=1)
57+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
58+
result = self._et_model.forward((inps[:, : self._max_seq_length], pos_tensor))
59+
return result[0]
6360
else:
6461
result = self._et_model.forward((inps,))
6562
return result[0]

examples/models/llama2/evaluate/eager_eval.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ def tok_decode(self, tokens):
7777

7878
def _model_call(self, inps):
7979
if self._use_kv_cache:
80-
pos_tensor = torch.arange(
81-
self._max_seq_length, dtype=torch.int64, device=self.device
82-
)
83-
80+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
8481
# Batch process the whole sequence.
8582
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
8683
return logits

0 commit comments

Comments
 (0)