Skip to content

Commit 671f9c5

Browse files
update llama runner to decode single token (#6768)
Pull Request resolved: #6703 Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response. This PR updates it to decode each new token immediately after it is generated. ghstack-source-id: 252924039 Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/) Co-authored-by: Lunwen He <[email protected]>
1 parent 6887ae9 commit 671f9c5

File tree

6 files changed

+39
-29
lines changed

6 files changed

+39
-29
lines changed

.ci/scripts/test_llama_runner_eager.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ run_and_verify() {
4242
-d fp32 \
4343
--max_seq_length 32 \
4444
--temperature 0 \
45+
--show_tokens \
4546
--prompt "Once upon a time," > result.txt
4647

4748
# Verify result.txt
4849
RESULT=$(cat result.txt)
49-
EXPECTED_RESULT="there was a little girl"
50+
EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263"
5051
if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then
5152
echo "Actual result: ${RESULT}"
5253
echo "Success"

examples/models/llama/runner/eager.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def build_args_parser() -> argparse.ArgumentParser:
6363
default=0,
6464
)
6565

66+
parser.add_argument(
67+
"--show_tokens",
68+
action="store_true",
69+
default=False,
70+
help="Show the tokens that were generated",
71+
)
72+
6673
return parser
6774

6875

@@ -71,15 +78,12 @@ def main() -> None:
7178
args = parser.parse_args()
7279

7380
runner = EagerLlamaRunner(args)
74-
result = runner.text_completion(
81+
generated_tokens = runner.text_completion(
7582
prompt=args.prompt,
7683
temperature=args.temperature,
7784
)
78-
print(
79-
"Response: \n{response}\n Tokens:\n {tokens}".format(
80-
response=result["generation"], tokens=result["tokens"]
81-
)
82-
)
85+
if args.show_tokens:
86+
print(f"Tokens: {generated_tokens}")
8387

8488

8589
if __name__ == "__main__":

examples/models/llama/runner/generation.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8-
from typing import List, Optional, TypedDict
8+
from typing import List, Optional
99

1010
import torch
1111

1212
from executorch.examples.models.llama.llama_transformer import ModelArgs
1313
from executorch.extension.llm.tokenizer.utils import get_tokenizer
1414

1515

16-
class CompletionPrediction(TypedDict, total=False):
17-
generation: str
18-
tokens: List[int] # not required
19-
20-
2116
def sample_top_p(probs, p):
2217
"""
2318
Perform top-p (nucleus) sampling on a probability distribution.
@@ -84,6 +79,7 @@ def generate( # noqa: C901
8479
)
8580

8681
current_token = next_token(logits, temperature, top_p)
82+
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
8783
tokens = prompt_tokens + [current_token]
8884

8985
while len(tokens) < self.params.max_seq_len:
@@ -101,12 +97,14 @@ def generate( # noqa: C901
10197
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
10298
)
10399
current_token = next_token(logits, temperature, top_p)
100+
tokens.append(current_token)
104101
if current_token == self.tokenizer.eos_id or (
105102
hasattr(self.tokenizer, "stop_tokens")
106103
and current_token in self.tokenizer.stop_tokens
107104
):
108105
break
109-
tokens.append(current_token)
106+
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
107+
print("\n")
110108

111109
return tokens if echo else tokens[len(prompt_tokens) :]
112110

@@ -116,7 +114,7 @@ def text_completion(
116114
temperature: float = 0.6,
117115
top_p: float = 0.9,
118116
echo: bool = False,
119-
) -> CompletionPrediction:
117+
) -> List[int]:
120118
"""
121119
Perform text completion for a prompt using the language model.
122120
@@ -127,19 +125,14 @@ def text_completion(
127125
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
128126
129127
Returns:
130-
CompletionPrediction: Completion prediction, which contains the generated text completion.
128+
Generated list of tokens.
131129
132130
Note:
133131
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
134132
"""
135-
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
136-
generation_tokens = self.generate(
137-
prompt_tokens=prompt_tokens,
133+
return self.generate(
134+
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
138135
temperature=temperature,
139136
top_p=top_p,
140137
echo=echo,
141138
)
142-
return {
143-
"generation": self.tokenizer.decode(generation_tokens),
144-
"tokens": generation_tokens,
145-
}

examples/models/llama/runner/native.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,11 @@ def main() -> None:
107107
parser = build_args_parser()
108108
args = parser.parse_args()
109109
runner = NativeLlamaRunner(args)
110-
result = runner.text_completion(
110+
generated_tokens = runner.text_completion(
111111
prompt=args.prompt,
112112
temperature=args.temperature,
113113
)
114-
print(
115-
"Response: \n{response}\n Tokens:\n {tokens}".format(
116-
response=result["generation"], tokens=result["tokens"]
117-
)
118-
)
114+
print(f"Response: {generated_tokens}")
119115

120116

121117
if __name__ == "__main__":

examples/models/llama/tokenizer/tiktoken.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,18 @@ def decode(self, t: Sequence[int]) -> str:
185185
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
186186
return self.model.decode(cast(List[int], t))
187187

188+
def decode_token(self, t: int) -> str:
189+
"""
190+
Decodes a single token ID into a string.
191+
192+
Args:
193+
t (int): The token ID to be decoded.
194+
195+
Returns:
196+
str: The decoded string.
197+
"""
198+
return self.model.decode_single_token_bytes(t).decode("utf-8")
199+
188200
@staticmethod
189201
def _split_whitespaces_or_nonwhitespaces(
190202
s: str, max_consecutive_slice_len: int

extension/llm/tokenizer/tokenizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def decode(self, t: List[int]) -> str:
5050
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
5151
return self.sp_model.decode(t)
5252

53+
def decode_token(self, t: int) -> str:
54+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
55+
return self.sp_model.decode(t)
56+
5357
def export(self, output_path: str, *, prepend_padding: bool = False) -> None:
5458
"""
5559
Export tokenizer.model to another serialization format. Here we did some lightweight

0 commit comments

Comments
 (0)