Skip to content

Commit 87a6457

Browse files
JacobSzwejbkamalfet
authored andcommitted
patch a couple issues related to token output in chat (#462)
1 parent 7966318 commit 87a6457

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

generate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def generate(
360360
**sampling_kwargs,
361361
)
362362
seq[T] = next_token
363+
callback(next_token.clone().view(-1))
363364

364365
num_tokens_generated = 0
365366
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
@@ -609,6 +610,9 @@ def callback(
609610
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) # I think this results in the first output token being dropped from the display which is wrong.
610611
if x.item() == tokenizer.eos_id():
611612
done_generating = True
613+
if (is_llama3_model and x.item() == tokenizer.special_tokens["<|eot_id|>"]):
614+
done_generating = True
615+
buffer = buffer[:-1] # drop the eot_id from the output buffer
612616
if len(buffer) == 4 or done_generating:
613617
print("".join(buffer), end="", flush=True)
614618
buffer.clear()

0 commit comments

Comments
 (0)