Skip to content

Commit 645256d

Browse files
lucylqfacebook-github-bot
authored andcommitted
generation.py with kv cache (#3030)
Summary: python e2e generation, using tiktoken tokenizer. using text_completion, haven't tried chat_completion. Pull Request resolved: #3030 Test Plan: Imported from GitHub, without a `Test Plan:` line. Command, with prompt "Hello, I am" and seq_len = 10 ``` python -m examples.models.llama2.runner.generation --pte llama_4ckpts_x.pte --tokenizer tokenizer.model --prompt="Hello I am" --temperature=0 --params ../llama-models/llama3/params_less.json --max_gen_len=10 ``` fp32, xnn, kv fp32, xnn same results: ``` Result: [{'generation': ' a 25 year old woman. I am a'}] ``` fp32, xnn, int4 ``` Result: [{'generation': ' interested in the following products: - 1 x'}] ``` fp32, xnn, kv, sdpa (need investigation) ``` Result: [{'generation': 'ฉopteraenthalenthalenthalenthalenthalenthalenthalenthal'}] ``` Reviewed By: larryliu0820 Differential Revision: D56087430 Pulled By: lucylq fbshipit-source-id: 31c73fe87af8646bf2512e1a6aadc8804a101719
1 parent 7c81155 commit 645256d

File tree

1 file changed

+370
-0
lines changed

1 file changed

+370
-0
lines changed
Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import argparse
9+
10+
import json
11+
from typing import List, Optional, Tuple, TypedDict
12+
13+
import torch
14+
import torch.nn.functional as F
15+
from executorch.examples.models.llama2.llama_transformer import ModelArgs
16+
17+
from executorch.examples.models.llama2.tokenizer.tiktoken import (
18+
Dialog,
19+
Message,
20+
Tokenizer,
21+
)
22+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
23+
24+
25+
class CompletionPrediction(TypedDict, total=False):
26+
generation: str
27+
tokens: List[str] # not required
28+
logprobs: List[float] # not required
29+
30+
31+
class ChatPrediction(TypedDict, total=False):
32+
generation: Message
33+
tokens: List[str] # not required
34+
logprobs: List[float] # not required
35+
36+
37+
def sample_top_p(probs, p):
38+
"""
39+
Perform top-p (nucleus) sampling on a probability distribution.
40+
41+
Args:
42+
probs (torch.Tensor): Probability distribution tensor.
43+
p (float): Probability threshold for top-p sampling.
44+
45+
Returns:
46+
torch.Tensor: Sampled token indices.
47+
48+
Note:
49+
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
50+
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
51+
"""
52+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
53+
probs_sum = torch.cumsum(probs_sort, dim=-1)
54+
mask = probs_sum - probs_sort > p
55+
probs_sort[mask] = 0.0
56+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
57+
next_token = torch.multinomial(probs_sort, num_samples=1)
58+
next_token = torch.gather(probs_idx, -1, next_token)
59+
return next_token
60+
61+
62+
class LlamaRunner:
63+
def __init__(self, model_path: str, tokenizer_path: str, model_args: ModelArgs):
64+
# model is a pte file.
65+
self.model = _load_for_executorch(model_path)
66+
self.params = model_args
67+
self.tokenizer = Tokenizer(tokenizer_path)
68+
assert model_args.vocab_size == self.tokenizer.n_words
69+
70+
def generate( # noqa: C901
71+
self,
72+
prompt_tokens: List[List[int]],
73+
max_gen_len: int,
74+
temperature: float = 0.6,
75+
top_p: float = 0.9,
76+
logprobs: bool = False,
77+
echo: bool = False,
78+
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
79+
bsz = len(prompt_tokens)
80+
params = self.params
81+
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
82+
83+
min_prompt_len = min(len(t) for t in prompt_tokens)
84+
max_prompt_len = max(len(t) for t in prompt_tokens)
85+
86+
assert max_prompt_len <= params.max_seq_len
87+
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
88+
pad_id = self.tokenizer.pad_id
89+
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cpu")
90+
for k, t in enumerate(prompt_tokens):
91+
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cpu")
92+
if logprobs:
93+
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
94+
95+
prev_pos = 0
96+
if self.params.use_kv_cache:
97+
min_prompt_len = 1
98+
99+
eos_reached = torch.tensor([False] * bsz, device="cpu")
100+
input_text_mask = tokens != pad_id
101+
pos = torch.tensor([prev_pos], dtype=torch.int64)
102+
if min_prompt_len == total_len:
103+
if self.params.use_kv_cache:
104+
inputs = (tokens, pos)
105+
else:
106+
inputs = (tokens,)
107+
logits = self.model.forward(inputs) # updated forward call.
108+
logits = logits[0]
109+
token_logprobs = -F.cross_entropy(
110+
input=logits.transpose(1, 2),
111+
target=tokens,
112+
reduction="none",
113+
ignore_index=pad_id,
114+
)
115+
116+
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
117+
118+
for cur_pos in range(min_prompt_len, total_len):
119+
pos = torch.tensor([prev_pos], dtype=torch.int64)
120+
if self.params.use_kv_cache:
121+
inputs = (tokens[:, prev_pos:cur_pos], pos)
122+
else:
123+
inputs = (tokens[:, :cur_pos],)
124+
logits = self.model.forward(inputs) # updated forward call.
125+
logits = logits[0]
126+
if temperature > 0:
127+
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
128+
next_token = sample_top_p(probs, top_p)
129+
else:
130+
next_token = torch.argmax(logits[:, -1], dim=-1)
131+
132+
next_token = next_token.reshape(-1)
133+
134+
# only replace token if prompt has already been generated
135+
if not self.params.use_kv_cache or cur_pos < len(prompt_tokens[0]):
136+
next_token = torch.where(
137+
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
138+
)
139+
140+
tokens[:, cur_pos] = next_token
141+
if logprobs:
142+
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
143+
input=logits.transpose(1, 2),
144+
target=tokens[:, prev_pos + 1 : cur_pos + 1],
145+
reduction="none",
146+
ignore_index=pad_id,
147+
)
148+
eos_reached |= (~input_text_mask[:, cur_pos]) & (
149+
torch.isin(next_token, stop_tokens)
150+
)
151+
prev_pos = cur_pos
152+
if all(eos_reached):
153+
break
154+
155+
if logprobs:
156+
token_logprobs = token_logprobs.tolist()
157+
out_tokens, out_logprobs = [], []
158+
for i, toks in enumerate(tokens.tolist()):
159+
# cut to max gen len
160+
start = 0 if echo else len(prompt_tokens[i])
161+
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
162+
probs = None
163+
if logprobs:
164+
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
165+
# cut to after eos tok if any
166+
for stop_token in self.tokenizer.stop_tokens:
167+
try:
168+
eos_idx = toks.index(stop_token)
169+
toks = toks[:eos_idx]
170+
probs = probs[:eos_idx] if logprobs else None
171+
except ValueError:
172+
pass
173+
out_tokens.append(toks)
174+
out_logprobs.append(probs)
175+
return (out_tokens, out_logprobs if logprobs else None)
176+
177+
def text_completion(
178+
self,
179+
prompts: List[str],
180+
temperature: float = 0.6,
181+
top_p: float = 0.9,
182+
max_gen_len: Optional[int] = None,
183+
logprobs: bool = False,
184+
echo: bool = False,
185+
) -> List[CompletionPrediction]:
186+
"""
187+
Perform text completion for a list of prompts using the language generation model.
188+
189+
Args:
190+
prompts (List[str]): List of text prompts for completion.
191+
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
192+
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
193+
max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
194+
If not provided, it's set to the model's maximum sequence length minus 1.
195+
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
196+
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
197+
198+
Returns:
199+
List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.
200+
201+
Note:
202+
This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
203+
If logprobs is True, token log probabilities are computed for each generated token.
204+
"""
205+
if max_gen_len is None:
206+
max_gen_len = self.model.params.max_seq_len - 1
207+
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
208+
generation_tokens, generation_logprobs = self.generate(
209+
prompt_tokens=prompt_tokens,
210+
max_gen_len=max_gen_len,
211+
temperature=temperature,
212+
top_p=top_p,
213+
logprobs=logprobs,
214+
echo=echo,
215+
)
216+
217+
if logprobs:
218+
return [
219+
{
220+
"generation": self.tokenizer.decode(t),
221+
"tokens": [self.tokenizer.decode([x]) for x in t],
222+
"logprobs": logprobs_i,
223+
}
224+
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
225+
]
226+
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
227+
228+
def chat_completion(
229+
self,
230+
dialogs: List[Dialog],
231+
temperature: float = 0.6,
232+
top_p: float = 0.9,
233+
max_gen_len: Optional[int] = None,
234+
logprobs: bool = False,
235+
) -> List[ChatPrediction]:
236+
"""
237+
Generate assistant responses for a list of conversational dialogs using the language generation model.
238+
239+
Args:
240+
dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
241+
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
242+
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
243+
max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
244+
If not provided, it's set to the model's maximum sequence length minus 1.
245+
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
246+
247+
Returns:
248+
List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
249+
250+
Raises:
251+
AssertionError: If the last message in a dialog is not from the user.
252+
AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
253+
254+
Note:
255+
This method generates assistant responses for the provided conversational dialogs.
256+
It employs nucleus sampling to introduce controlled randomness in text generation.
257+
If logprobs is True, token log probabilities are computed for each generated token.
258+
"""
259+
if max_gen_len is None:
260+
max_gen_len = self.model.params.max_seq_len - 1
261+
262+
prompt_tokens = [
263+
self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
264+
]
265+
generation_tokens, generation_logprobs = self.generate(
266+
prompt_tokens=prompt_tokens,
267+
max_gen_len=max_gen_len,
268+
temperature=temperature,
269+
top_p=top_p,
270+
logprobs=logprobs,
271+
)
272+
if logprobs:
273+
return [
274+
{
275+
"generation": {
276+
"role": "assistant",
277+
"content": self.tokenizer.decode(t),
278+
},
279+
"tokens": [self.tokenizer.decode([x]) for x in t],
280+
"logprobs": logprobs_i,
281+
}
282+
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
283+
]
284+
return [
285+
{
286+
"generation": {
287+
"role": "assistant",
288+
"content": self.tokenizer.decode(t),
289+
},
290+
}
291+
for t in generation_tokens
292+
]
293+
294+
295+
def build_args_parser() -> argparse.ArgumentParser:
296+
parser = argparse.ArgumentParser()
297+
298+
parser.add_argument(
299+
"-f",
300+
"--pte",
301+
type=str,
302+
default=None,
303+
help="path to exported executorch .pte file",
304+
)
305+
306+
parser.add_argument(
307+
"-p", "--params", type=str, default=None, help="model params file"
308+
)
309+
310+
parser.add_argument(
311+
"-t",
312+
"--tokenizer",
313+
type=str,
314+
default=None,
315+
)
316+
317+
parser.add_argument(
318+
"--prompt",
319+
type=str,
320+
default="Hello",
321+
)
322+
323+
parser.add_argument(
324+
"--temperature",
325+
type=float,
326+
default=0.6,
327+
)
328+
329+
parser.add_argument(
330+
"-kv",
331+
"--kv_cache",
332+
default=False,
333+
action="store_true",
334+
)
335+
336+
parser.add_argument(
337+
"--max_gen_len",
338+
type=int,
339+
default=10,
340+
help="Maximum length of the generated response sequence.",
341+
)
342+
343+
return parser
344+
345+
346+
def main() -> None:
347+
parser = build_args_parser()
348+
args = parser.parse_args()
349+
350+
with open(args.params, "r") as f:
351+
params = json.loads(f.read())
352+
model_args: ModelArgs = ModelArgs(
353+
max_seq_len=128,
354+
max_batch_size=1,
355+
use_kv_cache=args.kv_cache,
356+
**params,
357+
)
358+
runner = LlamaRunner(
359+
model_path=args.pte, tokenizer_path=args.tokenizer, model_args=model_args
360+
)
361+
result = runner.text_completion(
362+
prompts=[args.prompt],
363+
max_gen_len=args.max_gen_len,
364+
temperature=args.temperature,
365+
)
366+
print(f"Result: {result}")
367+
368+
369+
if __name__ == "__main__":
370+
main() # pragma: no cover

0 commit comments

Comments
 (0)