Skip to content

Commit f42942a

Browse files
cccclaifacebook-github-bot
authored andcommitted
improve the eval perf with kv cache (#3732)
Summary: Pull Request resolved: #3732 The original implementation was too slow, and because of the frequent travel: cpu->gpu->cpu->gpu-> , it's inefficient. Change it to batch process the sequence so the compute remains in gpu When evaluate stories model, before the change: ``` 2024-05-23:23:42:25,115 INFO [evaluator.py:362] Running loglikelihood_rolling requests 100%|██████████| 5/5 [02:37<00:00, 31.50s/it] wikitext: {'word_perplexity,none': 10589.525426446424, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 6.111053701258041, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 2.6114211588515417, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} ``` After the change: ``` 2024-05-23:23:36:50,339 INFO [evaluator.py:362] Running loglikelihood_rolling requests 100%|██████████| 5/5 [00:03<00:00, 1.55it/s] wikitext: {'word_perplexity,none': 10589.52618994558, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 6.111053787314264, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 2.611421179167659, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} ``` Reviewed By: lucylq Differential Revision: D57764318 fbshipit-source-id: 984578cfb52e625e15b624a743b7dfd8340c1755
1 parent d44877b commit f42942a

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,13 @@ def tok_decode(self, tokens):
8888

8989
def _model_call(self, inps):
9090
if self._use_kv_cache:
91-
result_logits = []
92-
for pos in range(self._max_seq_length):
93-
pos_tensor = torch.tensor([pos], dtype=torch.int64)
94-
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
95-
result_logits.append(logits)
96-
return torch.cat(result_logits, dim=1)
91+
pos_tensor = torch.arange(
92+
self._max_seq_length, dtype=torch.int64, device=self.device
93+
)
94+
95+
# Batch process the whole sequence.
96+
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
97+
return logits
9798
else:
9899
return self._model(inps)
99100

0 commit comments

Comments
 (0)