Skip to content

Commit d7a4d24

Browse files
metascroymalfet
authored andcommitted
remove duplicate code in generate + fix generate.py prompt for llama2 (#590)
* remove duplicate code in generate * fix indent * add fixes
1 parent aa00167 commit d7a4d24

File tree

1 file changed

+23
-41
lines changed

1 file changed

+23
-41
lines changed

generate.py

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,24 @@ def get_device_info(name: str) -> str:
463463
return torch.cuda.get_device_name(0)
464464
return ""
465465

466+
def _callback(x, buffer, period_id, done_generating, tokenizer, is_llama3_model):
467+
if done_generating:
468+
return
469+
buffer.append(
470+
tokenizer.decode([period_id] + x.tolist())[1:]
471+
) # I think this results in the first output token being dropped from the display which is wrong.
472+
if x.item() == tokenizer.eos_id():
473+
done_generating = True
474+
if (
475+
is_llama3_model
476+
and x.item() == tokenizer.special_tokens["<|eot_id|>"]
477+
):
478+
done_generating = True
479+
buffer = buffer[:-1] # drop the eot_id from the output buffer
480+
if len(buffer) == 4 or done_generating:
481+
print("".join(buffer), end="", flush=True)
482+
buffer.clear()
483+
# print(, end='', flush=True)
466484

467485
def _main(
468486
builder_args: BuilderArgs,
@@ -612,7 +630,7 @@ def _main(
612630
break
613631
if not is_llama3_model:
614632
if system_prompt:
615-
prompt = f"{B_INST} {B_SYS}\n{system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip} {E_INST}"
633+
prompt = f"{B_INST} {B_SYS}\n{system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip()} {E_INST}"
616634
system_prompt = (
617635
None # can only provide system prompt on first interaction
618636
)
@@ -659,53 +677,17 @@ def _main(
659677
period_id = tokenizer.encode(".")[0]
660678
done_generating = False
661679

662-
def callback(
663-
x, buffer=buffer, period_id=period_id, done_generating=done_generating
664-
):
665-
if done_generating:
666-
return
667-
buffer.append(
668-
tokenizer.decode([period_id] + x.tolist())[1:]
669-
) # I think this results in the first output token being dropped from the display which is wrong.
670-
if x.item() == tokenizer.eos_id():
671-
done_generating = True
672-
if (
673-
is_llama3_model
674-
and x.item() == tokenizer.special_tokens["<|eot_id|>"]
675-
):
676-
done_generating = True
677-
buffer = buffer[:-1] # drop the eot_id from the output buffer
678-
if len(buffer) == 4 or done_generating:
679-
print("".join(buffer), end="", flush=True)
680-
buffer.clear()
681-
# print(, end='', flush=True)
680+
def callback(x):
681+
return _callback(x, buffer=buffer, period_id=period_id, done_generating=done_generating, tokenizer=tokenizer, is_llama3_model=is_llama3_model)
682682

683683
else:
684684
assert not generator_args.chat_mode
685685
buffer = [generator_args.prompt]
686686
period_id = tokenizer.encode(".")[0]
687687
done_generating = False
688688

689-
def callback(
690-
x, buffer=buffer, period_id=period_id, done_generating=done_generating
691-
):
692-
if done_generating:
693-
return
694-
buffer.append(
695-
tokenizer.decode([period_id] + x.tolist())[1:]
696-
) # I think this results in the first output token being dropped from the display which is wrong.
697-
if x.item() == tokenizer.eos_id():
698-
done_generating = True
699-
if (
700-
is_llama3_model
701-
and x.item() == tokenizer.special_tokens["<|eot_id|>"]
702-
):
703-
done_generating = True
704-
buffer = buffer[:-1] # drop the eot_id from the output buffer
705-
if len(buffer) == 4 or done_generating:
706-
print("".join(buffer), end="", flush=True)
707-
buffer.clear()
708-
# print(, end='', flush=True)
689+
def callback(x):
690+
return _callback(x, buffer=buffer, period_id=period_id, done_generating=done_generating, tokenizer=tokenizer, is_llama3_model=is_llama3_model)
709691

710692
t0 = time.perf_counter()
711693
import contextlib

0 commit comments

Comments
 (0)