Skip to content

Commit 5611dbe

Browse files
committed
add kv cache to eval
1 parent fa433cb commit 5611dbe

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ def __init__(
4141
model: nn.Module,
4242
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
4343
max_seq_length: Optional[int] = None,
44+
use_kv_cache: bool = False
4445
):
4546
device = "cuda" if torch.cuda.is_available() else "cpu"
4647
super().__init__(device=device)
4748
self._model = model
4849
self._tokenizer = tokenizer
4950
self._device = torch.device(device)
5051
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
52+
self._use_kv_cache = use_kv_cache
5153

5254
@property
5355
def eot_token_id(self):
@@ -83,7 +85,15 @@ def tok_decode(self, tokens):
8385
return decoded
8486

8587
def _model_call(self, inps):
86-
return self._model(inps)
88+
if self._use_kv_cache:
89+
result_logits = []
90+
for pos in range(self._max_seq_length):
91+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
92+
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
93+
result_logits.append(logits)
94+
return torch.cat(result_logits, dim=1)
95+
else:
96+
return self._model(inps)
8797

8898
def _model_generate(self, context, max_length, eos_token_id):
8999
raise Exception("unimplemented")
@@ -99,21 +109,30 @@ def __init__(
99109
self,
100110
model: str,
101111
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
102-
max_seq_length: Optional[int] = None,
112+
max_seq_length: Optional[int] = None
103113
):
104114
super().__init__(None, tokenizer, max_seq_length)
105115
self._model = model # Expects model to be path to a .pte file
106116

107117
from executorch.extension.pybindings.portable_lib import _load_for_executorch
108118

109119
self._et_model = _load_for_executorch(self._model)
120+
self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0]
110121

111122
def _model_call(self, inps):
112123
# Given inps (tokens), return the logits from a single forward call
113124
# inps: Tensor of shape (1, max_seq_len - 1)
114-
# logits: Tensor of shape (1, max_seq_len - 1, 32000)
115-
result = self._et_model.forward((inps,))
116-
return result[0]
125+
# logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
126+
if self._use_kv_cache:
127+
result_logits = []
128+
for pos in range(self._max_seq_length):
129+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
130+
logits = self._et_model.forward((inps[:, pos : pos + 1], pos_tensor))
131+
result_logits.append(logits[0])
132+
return torch.cat(result_logits, dim=1)
133+
else:
134+
result = self._et_model.forward((inps,))
135+
return result[0]
117136

118137

119138
class ETRunnerEvalWrapper(GPTFastEvalWrapper):
@@ -139,7 +158,7 @@ def _model_call(self, inps):
139158

140159
# Example:
141160
# inps: Tensor of shape (1, N)
142-
# logits: Tensor of shape (1, N, 32000)
161+
# logits: Tensor of shape (1, N, vocab_size)
143162
pass
144163

145164

@@ -212,6 +231,7 @@ def gen_eval_wrapper(
212231
# Exported model takes at most (max_seq_length - 1) tokens.
213232
# Note that the eager model takes at most max_seq_length tokens.
214233
max_seq_length=args.max_seq_length - 1,
234+
use_kv_cache=args.use_kv_cache,
215235
)
216236

217237
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
@@ -225,6 +245,7 @@ def gen_eval_wrapper(
225245
model=model,
226246
tokenizer=tokenizer,
227247
max_seq_length=args.max_seq_length,
248+
use_kv_cache=args.use_kv_cache,
228249
)
229250

230251

0 commit comments

Comments
 (0)