Skip to content

Commit c8b43d2

Browse files
lucylqfacebook-github-bot
authored andcommitted
add kv cache to eval (#3162)
Summary: Pull Request resolved: #3162 Reviewed By: kirklandsign Differential Revision: D56365716 Pulled By: lucylq fbshipit-source-id: 707c5b869df128cc7e669fc0d78ca185f1c68f31
1 parent 7469a28 commit c8b43d2

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ def __init__(
4141
model: nn.Module,
4242
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
4343
max_seq_length: Optional[int] = None,
44+
use_kv_cache: bool = False,
4445
):
4546
device = "cuda" if torch.cuda.is_available() else "cpu"
4647
super().__init__(device=device)
4748
self._model = model
4849
self._tokenizer = tokenizer
4950
self._device = torch.device(device)
5051
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
52+
self._use_kv_cache = use_kv_cache
5153

5254
@property
5355
def eot_token_id(self):
@@ -83,7 +85,15 @@ def tok_decode(self, tokens):
8385
return decoded
8486

8587
def _model_call(self, inps):
86-
return self._model(inps)
88+
if self._use_kv_cache:
89+
result_logits = []
90+
for pos in range(self._max_seq_length):
91+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
92+
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
93+
result_logits.append(logits)
94+
return torch.cat(result_logits, dim=1)
95+
else:
96+
return self._model(inps)
8797

8898
def _model_generate(self, context, max_length, eos_token_id):
8999
raise Exception("unimplemented")
@@ -107,13 +117,22 @@ def __init__(
107117
from executorch.extension.pybindings.portable_lib import _load_for_executorch
108118

109119
self._et_model = _load_for_executorch(self._model)
120+
self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0]
110121

111122
def _model_call(self, inps):
112123
# Given inps (tokens), return the logits from a single forward call
113124
# inps: Tensor of shape (1, max_seq_len - 1)
114-
# logits: Tensor of shape (1, max_seq_len - 1, 32000)
115-
result = self._et_model.forward((inps,))
116-
return result[0]
125+
# logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
126+
if self._use_kv_cache:
127+
result_logits = []
128+
for pos in range(self._max_seq_length):
129+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
130+
logits = self._et_model.forward((inps[:, pos : pos + 1], pos_tensor))
131+
result_logits.append(logits[0])
132+
return torch.cat(result_logits, dim=1)
133+
else:
134+
result = self._et_model.forward((inps,))
135+
return result[0]
117136

118137

119138
class ETRunnerEvalWrapper(GPTFastEvalWrapper):
@@ -139,7 +158,7 @@ def _model_call(self, inps):
139158

140159
# Example:
141160
# inps: Tensor of shape (1, N)
142-
# logits: Tensor of shape (1, N, 32000)
161+
# logits: Tensor of shape (1, N, vocab_size)
143162
pass
144163

145164

@@ -225,6 +244,7 @@ def gen_eval_wrapper(
225244
model=model,
226245
tokenizer=tokenizer,
227246
max_seq_length=args.max_seq_length,
247+
use_kv_cache=args.use_kv_cache,
228248
)
229249

230250

0 commit comments

Comments
 (0)