Skip to content

Commit 5c98ea6

Browse files
committed
kv cache
1 parent b1f5f96 commit 5c98ea6

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

examples/models/llama2/runner/generation.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,17 @@ def generate(
8787
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
8888

8989
prev_pos = 0
90+
if self.params.use_kv_cache:
91+
min_prompt_len = 1
9092

9193
eos_reached = torch.tensor([False] * bsz, device="cpu")
9294
input_text_mask = tokens != pad_id
95+
pos = torch.tensor([prev_pos], dtype=torch.int64)
9396
if min_prompt_len == total_len:
94-
inputs = (tokens,)
97+
if self.params.use_kv_cache:
98+
inputs = (tokens, pos)
99+
else:
100+
inputs = (tokens,)
95101
logits = self.model.forward(inputs) # updated forward call.
96102
logits = logits[0]
97103
token_logprobs = -F.cross_entropy(
@@ -104,7 +110,11 @@ def generate(
104110
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
105111

106112
for cur_pos in range(min_prompt_len, total_len):
107-
inputs = (tokens[:, :cur_pos],)
113+
pos = torch.tensor([prev_pos], dtype=torch.int64)
114+
if self.params.use_kv_cache:
115+
inputs = (tokens[:, prev_pos:cur_pos], pos)
116+
else:
117+
inputs = (tokens[:, :cur_pos],)
108118
logits = self.model.forward(inputs) # updated forward call.
109119
logits = logits[0]
110120
if temperature > 0:
@@ -116,9 +126,10 @@ def generate(
116126
next_token = next_token.reshape(-1)
117127

118128
# only replace token if prompt has already been generated
119-
next_token = torch.where(
120-
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
121-
)
129+
if not self.params.use_kv_cache or cur_pos < len(prompt_tokens[0]):
130+
next_token = torch.where(
131+
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
132+
)
122133

123134
tokens[:, cur_pos] = next_token
124135
if logprobs:
@@ -316,6 +327,13 @@ def build_args_parser() -> argparse.ArgumentParser:
316327
action="store_true",
317328
)
318329

330+
parser.add_argument(
331+
"--max_gen_len",
332+
type=int,
333+
default=10,
334+
help="Maximum length of the generated response sequence.",
335+
)
336+
319337
return parser
320338

321339

@@ -335,7 +353,7 @@ def main() -> None:
335353
model_path=args.pte, tokenizer_path=args.tokenizer, model_args=model_args
336354
)
337355
result = runner.text_completion(
338-
prompts=[args.prompt], max_gen_len=10, temperature=args.temperature
356+
prompts=[args.prompt], max_gen_len=args.max_gen_len, temperature=args.temperature
339357
)
340358
print(f"Result: {result}")
341359

0 commit comments

Comments
 (0)