Skip to content

Commit 7ec2374

Browse files
committed
update llama runner to decode single token
Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response. This PR updates it to decode each new token immediately after it is generated. Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/) [ghstack-poisoned]
1 parent 1c0c17c commit 7ec2374

File tree

5 files changed

+36
-27
lines changed

5 files changed

+36
-27
lines changed

examples/models/llama/runner/eager.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def build_args_parser() -> argparse.ArgumentParser:
6363
default=0,
6464
)
6565

66+
parser.add_argument(
67+
"--show_tokens",
68+
action="store_true",
69+
default=False,
70+
help="Show the tokens that were generated",
71+
)
72+
6673
return parser
6774

6875

@@ -71,15 +78,12 @@ def main() -> None:
7178
args = parser.parse_args()
7279

7380
runner = EagerLlamaRunner(args)
74-
result = runner.text_completion(
81+
generated_tokens = runner.text_completion(
7582
prompt=args.prompt,
7683
temperature=args.temperature,
7784
)
78-
print(
79-
"Response: \n{response}\n Tokens:\n {tokens}".format(
80-
response=result["generation"], tokens=result["tokens"]
81-
)
82-
)
85+
if args.show_tokens:
86+
print(f"Tokens: {generated_tokens}")
8387

8488

8589
if __name__ == "__main__":

examples/models/llama/runner/generation.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8-
from typing import List, Optional, TypedDict
8+
from typing import List, Optional
99

1010
import torch
1111

1212
from executorch.examples.models.llama.llama_transformer import ModelArgs
1313
from executorch.extension.llm.tokenizer.utils import get_tokenizer
1414

1515

16-
class CompletionPrediction(TypedDict, total=False):
17-
generation: str
18-
tokens: List[int] # not required
19-
20-
2116
def sample_top_p(probs, p):
2217
"""
2318
Perform top-p (nucleus) sampling on a probability distribution.
@@ -84,6 +79,7 @@ def generate( # noqa: C901
8479
)
8580

8681
current_token = next_token(logits, temperature, top_p)
82+
print(f"{self.tokenizer.decode(current_token)}", end="", flush=True)
8783
tokens = prompt_tokens + [current_token]
8884

8985
while len(tokens) < self.params.max_seq_len:
@@ -101,12 +97,14 @@ def generate( # noqa: C901
10197
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
10298
)
10399
current_token = next_token(logits, temperature, top_p)
100+
tokens.append(current_token)
104101
if current_token == self.tokenizer.eos_id or (
105102
hasattr(self.tokenizer, "stop_tokens")
106103
and current_token in self.tokenizer.stop_tokens
107104
):
108105
break
109-
tokens.append(current_token)
106+
print(f"{self.tokenizer.decode(current_token)}", end="", flush=True)
107+
print("\n")
110108

111109
return tokens if echo else tokens[len(prompt_tokens) :]
112110

@@ -116,7 +114,7 @@ def text_completion(
116114
temperature: float = 0.6,
117115
top_p: float = 0.9,
118116
echo: bool = False,
119-
) -> CompletionPrediction:
117+
) -> List[int]:
120118
"""
121119
Perform text completion for a prompt using the language model.
122120
@@ -132,14 +130,9 @@ def text_completion(
132130
Note:
133131
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
134132
"""
135-
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
136-
generation_tokens = self.generate(
137-
prompt_tokens=prompt_tokens,
133+
return self.generate(
134+
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
138135
temperature=temperature,
139136
top_p=top_p,
140137
echo=echo,
141138
)
142-
return {
143-
"generation": self.tokenizer.decode(generation_tokens),
144-
"tokens": generation_tokens,
145-
}

examples/models/llama/runner/native.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,11 @@ def main() -> None:
107107
parser = build_args_parser()
108108
args = parser.parse_args()
109109
runner = NativeLlamaRunner(args)
110-
result = runner.text_completion(
110+
generated_tokens = runner.text_completion(
111111
prompt=args.prompt,
112112
temperature=args.temperature,
113113
)
114-
print(
115-
"Response: \n{response}\n Tokens:\n {tokens}".format(
116-
response=result["generation"], tokens=result["tokens"]
117-
)
118-
)
114+
print(f"Response: {generated_tokens}")
119115

120116

121117
if __name__ == "__main__":

examples/models/llama/tokenizer/tiktoken.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,18 @@ def decode(self, t: Sequence[int]) -> str:
185185
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
186186
return self.model.decode(cast(List[int], t))
187187

188+
def decode(self, t: int) -> str:
189+
"""
190+
Decodes a single token ID into a string.
191+
192+
Args:
193+
t (int): The token ID to be decoded.
194+
195+
Returns:
196+
str: The decoded string.
197+
"""
198+
return self.model.decode_single_token_bytes(t).decode("utf-8")
199+
188200
@staticmethod
189201
def _split_whitespaces_or_nonwhitespaces(
190202
s: str, max_consecutive_slice_len: int

extension/llm/tokenizer/tokenizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def decode(self, t: List[int]) -> str:
5050
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
5151
return self.sp_model.decode(t)
5252

53+
def decode(self, t: int) -> str:
54+
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
55+
return self.sp_model.decode(t)
56+
5357
def export(self, output_path: str, *, prepend_padding: bool = False) -> None:
5458
"""
5559
Export tokenizer.model to another serialization format. Here we did some lightweight

0 commit comments

Comments
 (0)