Skip to content

Commit 5f6fba2

Browse files
committed
Update
[ghstack-poisoned]
1 parent 8d5377f commit 5f6fba2

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

generate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,10 @@ def decode_n_tokens(
323323
# Actually better for Inductor to codegen attention here
324324
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
325325

326+
out_token = cur_token.clone()
326327
next_token, next_prob = self.decode_one_token(
327328
model,
328-
cur_token.clone(),
329+
out_token,
329330
input_pos,
330331
need_probs=need_probs,
331332
**sampling_kwargs,
@@ -334,10 +335,10 @@ def decode_n_tokens(
334335
new_tokens.append(next_token.clone())
335336
callback(new_tokens[-1], done_generating=_i == num_new_tokens - 2)
336337
if need_probs or next_prob is None:
337-
yield cur_token.clone(), None
338+
yield out_token, None
338339
else:
339340
new_probs.append(next_prob.clone())
340-
yield cur_token.clone(), next_prob.clone()
341+
yield out_token, next_prob.clone()
341342
cur_token = next_token
342343

343344
# encountered eos

0 commit comments

Comments
 (0)