Skip to content

Commit f91ae96

Browse files
committed
fix eager_eval with kv cache and improve pybind eval speed
1 parent 7315894 commit f91ae96

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,9 @@ def _model_call(self, inps):
5858
# inps: Tensor of shape (1, max_seq_len - 1)
5959
# logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
6060
if self._use_kv_cache:
61-
result_logits = []
62-
for pos in range(self._max_seq_length):
63-
pos_tensor = torch.tensor([pos], dtype=torch.int64)
64-
logits = self._et_model.forward((inps[:, pos : pos + 1], pos_tensor))
65-
result_logits.append(logits[0])
66-
return torch.cat(result_logits, dim=1)
61+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
62+
result = self._et_model.forward((inps[:, : self._max_seq_length], pos_tensor))
63+
return result[0]
6764
else:
6865
result = self._et_model.forward((inps,))
6966
return result[0]

examples/models/llama2/evaluate/eager_eval.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ def tok_decode(self, tokens):
7777

7878
def _model_call(self, inps):
7979
if self._use_kv_cache:
80-
pos_tensor = torch.arange(
81-
self._max_seq_length, dtype=torch.int64, device=self.device
82-
)
83-
80+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
8481
# Batch process the whole sequence.
8582
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
8683
return logits

0 commit comments

Comments
 (0)