Skip to content

Commit a54d62c

Browse files
authored
fix eager_eval with kv cache and improve pybind eval speed
Differential Revision: D61302251 Pull Request resolved: #4720
1 parent 5c9a00a commit a54d62c

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,11 @@ def _model_call(self, inps):
5454
# inps: Tensor of shape (1, max_seq_len - 1)
5555
# logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
5656
if self._use_kv_cache:
57-
result_logits = []
58-
for pos in range(self._max_seq_length):
59-
pos_tensor = torch.tensor([pos], dtype=torch.int64)
60-
logits = self._et_model.forward((inps[:, pos : pos + 1], pos_tensor))
61-
result_logits.append(logits[0])
62-
return torch.cat(result_logits, dim=1)
57+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
58+
result = self._et_model.forward(
59+
(inps[:, : self._max_seq_length], pos_tensor)
60+
)
61+
return result[0]
6362
else:
6463
result = self._et_model.forward((inps,))
6564
return result[0]

examples/models/llama2/evaluate/eager_eval.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,7 @@ def tok_decode(self, tokens):
7777

7878
def _model_call(self, inps):
7979
if self._use_kv_cache:
80-
pos_tensor = torch.arange(
81-
self._max_seq_length, dtype=torch.int64, device=self.device
82-
)
83-
80+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
8481
# Batch process the whole sequence.
8582
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
8683
return logits

0 commit comments

Comments
 (0)