1
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
2
3
3
import argparse
4
- from typing import List , Optional , Tuple , TypedDict
5
4
6
5
import json
6
+ from typing import List , Optional , Tuple , TypedDict
7
+
7
8
import torch
8
9
import torch .nn .functional as F
9
10
from executorch .examples .models .llama2 .llama_transformer import ModelArgs
@@ -87,11 +88,17 @@ def generate(
87
88
token_logprobs = torch .zeros_like (tokens , dtype = torch .float )
88
89
89
90
prev_pos = 0
91
+ if self .params .use_kv_cache :
92
+ min_prompt_len = 1
90
93
91
94
eos_reached = torch .tensor ([False ] * bsz , device = "cpu" )
92
95
input_text_mask = tokens != pad_id
96
+ pos = torch .tensor ([prev_pos ], dtype = torch .int64 )
93
97
if min_prompt_len == total_len :
94
- inputs = (tokens ,)
98
+ if self .params .use_kv_cache :
99
+ inputs = (tokens , pos )
100
+ else :
101
+ inputs = (tokens ,)
95
102
logits = self .model .forward (inputs ) # updated forward call.
96
103
logits = logits [0 ]
97
104
token_logprobs = - F .cross_entropy (
@@ -104,7 +111,11 @@ def generate(
104
111
stop_tokens = torch .tensor (list (self .tokenizer .stop_tokens ))
105
112
106
113
for cur_pos in range (min_prompt_len , total_len ):
107
- inputs = (tokens [:, :cur_pos ],)
114
+ pos = torch .tensor ([prev_pos ], dtype = torch .int64 )
115
+ if self .params .use_kv_cache :
116
+ inputs = (tokens [:, prev_pos :cur_pos ], pos )
117
+ else :
118
+ inputs = (tokens [:, :cur_pos ],)
108
119
logits = self .model .forward (inputs ) # updated forward call.
109
120
logits = logits [0 ]
110
121
if temperature > 0 :
@@ -116,9 +127,10 @@ def generate(
116
127
next_token = next_token .reshape (- 1 )
117
128
118
129
# only replace token if prompt has already been generated
119
- next_token = torch .where (
120
- input_text_mask [:, cur_pos ], tokens [:, cur_pos ], next_token
121
- )
130
+ if not self .params .use_kv_cache or cur_pos < len (prompt_tokens [0 ]):
131
+ next_token = torch .where (
132
+ input_text_mask [:, cur_pos ], tokens [:, cur_pos ], next_token
133
+ )
122
134
123
135
tokens [:, cur_pos ] = next_token
124
136
if logprobs :
@@ -316,6 +328,13 @@ def build_args_parser() -> argparse.ArgumentParser:
316
328
action = "store_true" ,
317
329
)
318
330
331
+ parser .add_argument (
332
+ "--max_gen_len" ,
333
+ type = int ,
334
+ default = 10 ,
335
+ help = "Maximum length of the generated response sequence." ,
336
+ )
337
+
319
338
return parser
320
339
321
340
@@ -335,7 +354,9 @@ def main() -> None:
335
354
model_path = args .pte , tokenizer_path = args .tokenizer , model_args = model_args
336
355
)
337
356
result = runner .text_completion (
338
- prompts = [args .prompt ], max_gen_len = 10 , temperature = args .temperature
357
+ prompts = [args .prompt ],
358
+ max_gen_len = args .max_gen_len ,
359
+ temperature = args .temperature ,
339
360
)
340
361
print (f"Result: { result } " )
341
362
0 commit comments