Skip to content

Commit 053c645

Browse files
committed
add kv cache to eval
1 parent fa433cb commit 053c645

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ def __init__(
4141
model: nn.Module,
4242
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
4343
max_seq_length: Optional[int] = None,
44+
use_kv_cache: bool = False,
4445
):
4546
device = "cuda" if torch.cuda.is_available() else "cpu"
4647
super().__init__(device=device)
4748
self._model = model
4849
self._tokenizer = tokenizer
4950
self._device = torch.device(device)
5051
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
52+
self._use_kv_cache = use_kv_cache
5153

5254
@property
5355
def eot_token_id(self):
@@ -83,7 +85,15 @@ def tok_decode(self, tokens):
8385
return decoded
8486

8587
def _model_call(self, inps):
86-
return self._model(inps)
88+
if self._use_kv_cache:
89+
result_logits = []
90+
for pos in range(self._max_seq_length):
91+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
92+
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
93+
result_logits.append(logits)
94+
return torch.cat(result_logits, dim=1)
95+
else:
96+
return self._model(inps)
8797

8898
def _model_generate(self, context, max_length, eos_token_id):
8999
raise Exception("unimplemented")
@@ -100,9 +110,11 @@ def __init__(
100110
model: str,
101111
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
102112
max_seq_length: Optional[int] = None,
113+
use_kv_cache: bool = False,
103114
):
104115
super().__init__(None, tokenizer, max_seq_length)
105116
self._model = model # Expects model to be path to a .pte file
117+
self._use_kv_cache = use_kv_cache
106118

107119
from executorch.extension.pybindings.portable_lib import _load_for_executorch
108120

@@ -111,9 +123,17 @@ def __init__(
111123
def _model_call(self, inps):
112124
# Given inps (tokens), return the logits from a single forward call
113125
# inps: Tensor of shape (1, max_seq_len - 1)
114-
# logits: Tensor of shape (1, max_seq_len - 1, 32000)
115-
result = self._et_model.forward((inps,))
116-
return result[0]
126+
# logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
127+
if self._use_kv_cache:
128+
result_logits = []
129+
for pos in range(self._max_seq_length):
130+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
131+
logits = self._et_model.forward((inps[:, pos : pos + 1], pos_tensor))
132+
result_logits.append(logits[0])
133+
return torch.cat(result_logits, dim=1)
134+
else:
135+
result = self._et_model.forward((inps,))
136+
return result[0]
117137

118138

119139
class ETRunnerEvalWrapper(GPTFastEvalWrapper):
@@ -139,7 +159,7 @@ def _model_call(self, inps):
139159

140160
# Example:
141161
# inps: Tensor of shape (1, N)
142-
# logits: Tensor of shape (1, N, 32000)
162+
# logits: Tensor of shape (1, N, vocab_size)
143163
pass
144164

145165

@@ -212,6 +232,7 @@ def gen_eval_wrapper(
212232
# Exported model takes at most (max_seq_length - 1) tokens.
213233
# Note that the eager model takes at most max_seq_length tokens.
214234
max_seq_length=args.max_seq_length - 1,
235+
use_kv_cache=args.use_kv_cache,
215236
)
216237

217238
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
@@ -225,6 +246,7 @@ def gen_eval_wrapper(
225246
model=model,
226247
tokenizer=tokenizer,
227248
max_seq_length=args.max_seq_length,
249+
use_kv_cache=args.use_kv_cache,
228250
)
229251

230252

0 commit comments

Comments
 (0)