Skip to content

Commit 49756f6

Browse files
Print the number of tokens generated (#6773)
Pull Request resolved: #6771 This is useful for verifying the correctness of AttentionSink. ghstack-source-id: 252993225 @exported-using-ghexport Differential Revision: [D65784095](https://our.internmc.facebook.com/intern/diff/D65784095/) Co-authored-by: Lunwen He <[email protected]>
1 parent 99ba779 commit 49756f6

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

examples/models/llama/runner/eager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,11 @@ def main() -> None:
9191
else runner.text_completion(
9292
prompt=args.prompt,
9393
temperature=args.temperature,
94+
echo=True,
9495
)
9596
)
9697
if args.show_tokens:
97-
print(f"Tokens: {generated_tokens}")
98+
print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
9899

99100

100101
if __name__ == "__main__":

examples/models/llama/runner/generation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def forward(
6464
def generate( # noqa: C901
6565
self,
6666
prompt_tokens: List[int],
67+
max_seq_len: int,
6768
temperature: float = 0.8,
6869
top_p: float = 0.9,
6970
echo: bool = False,
@@ -83,7 +84,7 @@ def generate( # noqa: C901
8384
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
8485
tokens = prompt_tokens + [current_token]
8586

86-
while len(tokens) < self.params.max_seq_len:
87+
while len(tokens) < max_seq_len:
8788
if self.params.use_kv_cache:
8889
logits = self.forward(
8990
tokens=torch.tensor(
@@ -135,6 +136,7 @@ def text_completion(
135136
"""
136137
return self.generate(
137138
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
139+
max_seq_len=self.params.max_seq_len,
138140
temperature=temperature,
139141
top_p=top_p,
140142
echo=echo,
@@ -169,6 +171,7 @@ def chat_completion(
169171
prompt_tokens=self.tokenizer.encode(
170172
self._format_prompt(prompt), bos=True, eos=False
171173
),
174+
max_seq_len=self.params.max_seq_len,
172175
temperature=temperature,
173176
top_p=top_p,
174177
echo=True,

0 commit comments

Comments
 (0)