File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -323,9 +323,10 @@ def decode_n_tokens(
323
323
# Actually better for Inductor to codegen attention here
324
324
with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
325
325
326
+ out_token = cur_token .clone ()
326
327
next_token , next_prob = self .decode_one_token (
327
328
model ,
328
- cur_token . clone () ,
329
+ out_token ,
329
330
input_pos ,
330
331
need_probs = need_probs ,
331
332
** sampling_kwargs ,
@@ -334,10 +335,10 @@ def decode_n_tokens(
334
335
new_tokens .append (next_token .clone ())
335
336
callback (new_tokens [- 1 ], done_generating = _i == num_new_tokens - 2 )
336
337
if need_probs or next_prob is None :
337
- yield cur_token . clone () , None
338
+ yield out_token , None
338
339
else :
339
340
new_probs .append (next_prob .clone ())
340
- yield cur_token . clone () , next_prob .clone ()
341
+ yield out_token , next_prob .clone ()
341
342
cur_token = next_token
342
343
343
344
# encountered eos
You can’t perform that action at this time.
0 commit comments