Skip to content

Commit 891b780

Browse files
authored
update llama runner to decode single token
Differential Revision: D65578306 Pull Request resolved: #6703
1 parent df7be71 commit 891b780

File tree

6 files changed

+39
-29
lines changed

6 files changed

+39
-29
lines changed

.ci/scripts/test_llama_runner_eager.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ run_and_verify() {
4242
-d fp32 \
4343
--max_seq_length 32 \
4444
--temperature 0 \
45+
--show_tokens \
4546
--prompt "Once upon a time," > result.txt
4647

4748
# Verify result.txt
4849
RESULT=$(cat result.txt)
49-
EXPECTED_RESULT="there was a little girl"
50+
EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263"
5051
if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then
5152
echo "Actual result: ${RESULT}"
5253
echo "Success"

examples/models/llama/runner/eager.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def build_args_parser() -> argparse.ArgumentParser:
6363
default=0,
6464
)
6565

66+
parser.add_argument(
67+
"--show_tokens",
68+
action="store_true",
69+
default=False,
70+
help="Show the tokens that were generated",
71+
)
72+
6673
return parser
6774

6875

@@ -71,15 +78,12 @@ def main() -> None:
7178
args = parser.parse_args()
7279

7380
runner = EagerLlamaRunner(args)
74-
result = runner.text_completion(
81+
generated_tokens = runner.text_completion(
7582
prompt=args.prompt,
7683
temperature=args.temperature,
7784
)
78-
print(
79-
"Response: \n{response}\n Tokens:\n {tokens}".format(
80-
response=result["generation"], tokens=result["tokens"]
81-
)
82-
)
85+
if args.show_tokens:
86+
print(f"Tokens: {generated_tokens}")
8387

8488

8589
if __name__ == "__main__":

examples/models/llama/runner/generation.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8-
from typing import List, Optional, TypedDict
8+
from typing import List, Optional
99

1010
import torch
1111

1212
from executorch.examples.models.llama.llama_transformer import ModelArgs
1313
from executorch.extension.llm.tokenizer.utils import get_tokenizer
1414

1515

16-
class CompletionPrediction(TypedDict, total=False):
17-
generation: str
18-
tokens: List[int] # not required
19-
20-
2116
def sample_top_p(probs, p):
2217
"""
2318
Perform top-p (nucleus) sampling on a probability distribution.
@@ -84,6 +79,7 @@ def generate( # noqa: C901
8479
)
8580

8681
current_token = next_token(logits, temperature, top_p)
82+
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
8783
tokens = prompt_tokens + [current_token]
8884

8985
while len(tokens) < self.params.max_seq_len:
@@ -101,12 +97,14 @@ def generate( # noqa: C901
10197
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
10298
)
10399
current_token = next_token(logits, temperature, top_p)
100+
tokens.append(current_token)
104101
if current_token == self.tokenizer.eos_id or (
105102
hasattr(self.tokenizer, "stop_tokens")
106103
and current_token in self.tokenizer.stop_tokens
107104
):
108105
break
109-
tokens.append(current_token)
106+
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
107+
print("\n")
110108

111109
return tokens if echo else tokens[len(prompt_tokens) :]
112110

@@ -116,7 +114,7 @@ def text_completion(
116114
temperature: float = 0.6,
117115
top_p: float = 0.9,
118116
echo: bool = False,
119-
) -> CompletionPrediction:
117+
) -> List[int]:
120118
"""
121119
Perform text completion for a prompt using the language model.
122120
@@ -127,19 +125,14 @@ def text_completion(
127125
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
128126
129127
Returns:
130-
CompletionPrediction: Completion prediction, which contains the generated text completion.
128+
Generated list of tokens.
131129
132130
Note:
133131
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
134132
"""
135-
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
136-
generation_tokens = self.generate(
137-
prompt_tokens=prompt_tokens,
133+
return self.generate(
134+
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
138135
temperature=temperature,
139136
top_p=top_p,
140137
echo=echo,
141138
)
142-
return {
143-
"generation": self.tokenizer.decode(generation_tokens),
144-
"tokens": generation_tokens,
145-
}

examples/models/llama/runner/native.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,11 @@ def main() -> None:
107107
parser = build_args_parser()
108108
args = parser.parse_args()
109109
runner = NativeLlamaRunner(args)
110-
result = runner.text_completion(
110+
generated_tokens = runner.text_completion(
111111
prompt=args.prompt,
112112
temperature=args.temperature,
113113
)
114-
print(
115-
"Response: \n{response}\n Tokens:\n {tokens}".format(
116-
response=result["generation"], tokens=result["tokens"]
117-
)
118-
)
114+
print(f"Response: {generated_tokens}")
119115

120116

121117
if __name__ == "__main__":

examples/models/llama/tokenizer/tiktoken.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,18 @@ def decode(self, t: Sequence[int]) -> str:
185185
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
186186
return self.model.decode(cast(List[int], t))
187187

188+
def decode_token(self, t: int) -> str:
189+
"""
190+
Decodes a single token ID into a string.
191+
192+
Args:
193+
t (int): The token ID to be decoded.
194+
195+
Returns:
196+
str: The decoded string.
197+
"""
198+
return self.model.decode_single_token_bytes(t).decode("utf-8")
199+
188200
@staticmethod
189201
def _split_whitespaces_or_nonwhitespaces(
190202
s: str, max_consecutive_slice_len: int

extension/llm/tokenizer/tokenizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def decode(self, t: List[int]) -> str:
5050
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
5151
return self.sp_model.decode(t)
5252

53+
def decode_token(self, t: int) -> str:
54+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
55+
return self.sp_model.decode(t)
56+
5357
def export(self, output_path: str, *, prepend_padding: bool = False) -> None:
5458
"""
5559
Export tokenizer.model to another serialization format. Here we did some lightweight

0 commit comments

Comments
 (0)