Skip to content

Commit b48bfd1

Browse files
committed
kv cache
1 parent b1f5f96 commit b48bfd1

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

examples/models/llama2/runner/generation.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22

33
import argparse
4-
from typing import List, Optional, Tuple, TypedDict
54

65
import json
6+
from typing import List, Optional, Tuple, TypedDict
7+
78
import torch
89
import torch.nn.functional as F
910
from executorch.examples.models.llama2.llama_transformer import ModelArgs
@@ -87,11 +88,17 @@ def generate(
8788
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
8889

8990
prev_pos = 0
91+
if self.params.use_kv_cache:
92+
min_prompt_len = 1
9093

9194
eos_reached = torch.tensor([False] * bsz, device="cpu")
9295
input_text_mask = tokens != pad_id
96+
pos = torch.tensor([prev_pos], dtype=torch.int64)
9397
if min_prompt_len == total_len:
94-
inputs = (tokens,)
98+
if self.params.use_kv_cache:
99+
inputs = (tokens, pos)
100+
else:
101+
inputs = (tokens,)
95102
logits = self.model.forward(inputs) # updated forward call.
96103
logits = logits[0]
97104
token_logprobs = -F.cross_entropy(
@@ -104,7 +111,11 @@ def generate(
104111
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
105112

106113
for cur_pos in range(min_prompt_len, total_len):
107-
inputs = (tokens[:, :cur_pos],)
114+
pos = torch.tensor([prev_pos], dtype=torch.int64)
115+
if self.params.use_kv_cache:
116+
inputs = (tokens[:, prev_pos:cur_pos], pos)
117+
else:
118+
inputs = (tokens[:, :cur_pos],)
108119
logits = self.model.forward(inputs) # updated forward call.
109120
logits = logits[0]
110121
if temperature > 0:
@@ -116,9 +127,10 @@ def generate(
116127
next_token = next_token.reshape(-1)
117128

118129
# only replace token if prompt has already been generated
119-
next_token = torch.where(
120-
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
121-
)
130+
if not self.params.use_kv_cache or cur_pos < len(prompt_tokens[0]):
131+
next_token = torch.where(
132+
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
133+
)
122134

123135
tokens[:, cur_pos] = next_token
124136
if logprobs:
@@ -316,6 +328,13 @@ def build_args_parser() -> argparse.ArgumentParser:
316328
action="store_true",
317329
)
318330

331+
parser.add_argument(
332+
"--max_gen_len",
333+
type=int,
334+
default=10,
335+
help="Maximum length of the generated response sequence.",
336+
)
337+
319338
return parser
320339

321340

@@ -335,7 +354,9 @@ def main() -> None:
335354
model_path=args.pte, tokenizer_path=args.tokenizer, model_args=model_args
336355
)
337356
result = runner.text_completion(
338-
prompts=[args.prompt], max_gen_len=10, temperature=args.temperature
357+
prompts=[args.prompt],
358+
max_gen_len=args.max_gen_len,
359+
temperature=args.temperature,
339360
)
340361
print(f"Result: {result}")
341362

0 commit comments

Comments
 (0)