Skip to content

Commit 51057ac

Browse files
committed
add kv cache to eval
1 parent 74dba6e commit 51057ac

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

examples/models/llama2/eval_llama_lib.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ def __init__(
4141
model: nn.Module,
4242
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
4343
max_seq_length: Optional[int] = None,
44+
use_kv_cache: bool = False,
4445
):
4546
device = "cuda" if torch.cuda.is_available() else "cpu"
4647
super().__init__(device=device)
4748
self._model = model
4849
self._tokenizer = tokenizer
4950
self._device = torch.device(device)
5051
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
52+
self._use_kv_cache = use_kv_cache
5153

5254
@property
5355
def eot_token_id(self):
@@ -83,7 +85,15 @@ def tok_decode(self, tokens):
8385
return decoded
8486

8587
def _model_call(self, inps):
86-
return self._model(inps)
88+
if self._use_kv_cache:
89+
result_logits = []
90+
for pos in range(self._max_seq_length):
91+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
92+
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
93+
result_logits.append(logits)
94+
return torch.cat(result_logits, dim=1)
95+
else:
96+
return self._model(inps)
8797

8898
def _model_generate(self, context, max_length, eos_token_id):
8999
raise Exception("unimplemented")
@@ -100,6 +110,7 @@ def __init__(
100110
model: str,
101111
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
102112
max_seq_length: Optional[int] = None,
113+
use_kv_cache: bool = False,
103114
):
104115
super().__init__(None, tokenizer, max_seq_length)
105116
self._model = model # Expects model to be path to a .pte file
@@ -111,9 +122,17 @@ def __init__(
111122
def _model_call(self, inps):
112123
# Given inps (tokens), return the logits from a single forward call
113124
# inps: Tensor of shape (1, max_seq_len - 1)
114-
# logits: Tensor of shape (1, max_seq_len - 1, 32000)
115-
result = self._et_model.forward((inps,))
116-
return result[0]
125+
# logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
126+
if self._use_kv_cache:
127+
result_logits = []
128+
for pos in range(self._max_seq_length):
129+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
130+
logits = self._et_model.forward((inps[:, pos : pos + 1], pos_tensor))
131+
result_logits.append(logits[0])
132+
return torch.cat(result_logits, dim=1)
133+
else:
134+
result = self._et_model.forward((inps,))
135+
return result[0]
117136

118137

119138
class ETRunnerEvalWrapper(GPTFastEvalWrapper):
@@ -139,7 +158,7 @@ def _model_call(self, inps):
139158

140159
# Example:
141160
# inps: Tensor of shape (1, N)
142-
# logits: Tensor of shape (1, N, 32000)
161+
# logits: Tensor of shape (1, N, vocab_size)
143162
pass
144163

145164

@@ -203,6 +222,7 @@ def gen_eval_wrapper(
203222
tokenizer=tokenizer,
204223
tokenizer_bin=tokenizer_bin,
205224
max_seq_length=args.max_seq_length,
225+
use_kv_cache=args.use_kv_cache,
206226
)
207227

208228
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated eagerly
@@ -225,6 +245,7 @@ def gen_eval_wrapper(
225245
model=model,
226246
tokenizer=tokenizer,
227247
max_seq_length=args.max_seq_length,
248+
use_kv_cache=args.use_kv_cache,
228249
)
229250

230251

0 commit comments

Comments
 (0)