@@ -87,11 +87,17 @@ def generate(
87
87
token_logprobs = torch .zeros_like (tokens , dtype = torch .float )
88
88
89
89
prev_pos = 0
90
+ if self .params .use_kv_cache :
91
+ min_prompt_len = 1
90
92
91
93
eos_reached = torch .tensor ([False ] * bsz , device = "cpu" )
92
94
input_text_mask = tokens != pad_id
95
+ pos = torch .tensor ([prev_pos ], dtype = torch .int64 )
93
96
if min_prompt_len == total_len :
94
- inputs = (tokens ,)
97
+ if self .params .use_kv_cache :
98
+ inputs = (tokens , pos )
99
+ else :
100
+ inputs = (tokens ,)
95
101
logits = self .model .forward (inputs ) # updated forward call.
96
102
logits = logits [0 ]
97
103
token_logprobs = - F .cross_entropy (
@@ -104,7 +110,11 @@ def generate(
104
110
stop_tokens = torch .tensor (list (self .tokenizer .stop_tokens ))
105
111
106
112
for cur_pos in range (min_prompt_len , total_len ):
107
- inputs = (tokens [:, :cur_pos ],)
113
+ pos = torch .tensor ([prev_pos ], dtype = torch .int64 )
114
+ if self .params .use_kv_cache :
115
+ inputs = (tokens [:, prev_pos :cur_pos ], pos )
116
+ else :
117
+ inputs = (tokens [:, :cur_pos ],)
108
118
logits = self .model .forward (inputs ) # updated forward call.
109
119
logits = logits [0 ]
110
120
if temperature > 0 :
@@ -116,9 +126,10 @@ def generate(
116
126
next_token = next_token .reshape (- 1 )
117
127
118
128
# only replace token if prompt has already been generated
119
- next_token = torch .where (
120
- input_text_mask [:, cur_pos ], tokens [:, cur_pos ], next_token
121
- )
129
+ if not self .params .use_kv_cache or cur_pos < len (prompt_tokens [0 ]):
130
+ next_token = torch .where (
131
+ input_text_mask [:, cur_pos ], tokens [:, cur_pos ], next_token
132
+ )
122
133
123
134
tokens [:, cur_pos ] = next_token
124
135
if logprobs :
@@ -316,6 +327,13 @@ def build_args_parser() -> argparse.ArgumentParser:
316
327
action = "store_true" ,
317
328
)
318
329
330
+ parser .add_argument (
331
+ "--max_gen_len" ,
332
+ type = int ,
333
+ default = 10 ,
334
+ help = "Maximum length of the generated response sequence." ,
335
+ )
336
+
319
337
return parser
320
338
321
339
@@ -335,7 +353,7 @@ def main() -> None:
335
353
model_path = args .pte , tokenizer_path = args .tokenizer , model_args = model_args
336
354
)
337
355
result = runner .text_completion (
338
- prompts = [args .prompt ], max_gen_len = 10 , temperature = args .temperature
356
+ prompts = [args .prompt ], max_gen_len = args . max_gen_len , temperature = args .temperature
339
357
)
340
358
print (f"Result: { result } " )
341
359
0 commit comments