Skip to content

Commit ed62e07

Browse files
committed
generation.py
1 parent b1edc3d commit ed62e07

File tree

1 file changed

+345
-0
lines changed

1 file changed

+345
-0
lines changed
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
import argparse
4+
5+
import json
6+
from typing import List, Optional, Tuple, TypedDict
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from executorch.examples.models.llama2.llama_transformer import ModelArgs
11+
12+
from executorch.examples.models.llama2.tokenizer.tiktoken import (
13+
Dialog,
14+
Message,
15+
Tokenizer,
16+
)
17+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
18+
19+
20+
class CompletionPrediction(TypedDict, total=False):
21+
generation: str
22+
tokens: List[str] # not required
23+
logprobs: List[float] # not required
24+
25+
26+
class ChatPrediction(TypedDict, total=False):
27+
generation: Message
28+
tokens: List[str] # not required
29+
logprobs: List[float] # not required
30+
31+
32+
def sample_top_p(probs, p):
33+
"""
34+
Perform top-p (nucleus) sampling on a probability distribution.
35+
36+
Args:
37+
probs (torch.Tensor): Probability distribution tensor.
38+
p (float): Probability threshold for top-p sampling.
39+
40+
Returns:
41+
torch.Tensor: Sampled token indices.
42+
43+
Note:
44+
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
45+
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
46+
"""
47+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
48+
probs_sum = torch.cumsum(probs_sort, dim=-1)
49+
mask = probs_sum - probs_sort > p
50+
probs_sort[mask] = 0.0
51+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
52+
next_token = torch.multinomial(probs_sort, num_samples=1)
53+
next_token = torch.gather(probs_idx, -1, next_token)
54+
return next_token
55+
56+
57+
class LlamaRunner:
58+
def __init__(self, model_path: str, tokenizer_path: str, model_args: ModelArgs):
59+
# model is a pte file.
60+
self.model = _load_for_executorch(model_path)
61+
self.params = model_args
62+
self.tokenizer = Tokenizer(tokenizer_path)
63+
assert model_args.vocab_size == self.tokenizer.n_words
64+
65+
def generate(
66+
self,
67+
prompt_tokens: List[List[int]],
68+
max_gen_len: int,
69+
temperature: float = 0.6,
70+
top_p: float = 0.9,
71+
logprobs: bool = False,
72+
echo: bool = False,
73+
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
74+
bsz = len(prompt_tokens)
75+
params = self.params
76+
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
77+
78+
min_prompt_len = min(len(t) for t in prompt_tokens)
79+
max_prompt_len = max(len(t) for t in prompt_tokens)
80+
81+
assert max_prompt_len <= params.max_seq_len
82+
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
83+
pad_id = self.tokenizer.pad_id
84+
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cpu")
85+
for k, t in enumerate(prompt_tokens):
86+
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cpu")
87+
if logprobs:
88+
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
89+
90+
prev_pos = 0
91+
92+
eos_reached = torch.tensor([False] * bsz, device="cpu")
93+
input_text_mask = tokens != pad_id
94+
if min_prompt_len == total_len:
95+
inputs = (tokens,)
96+
logits = self.model.forward(inputs) # updated forward call.
97+
logits = logits[0]
98+
token_logprobs = -F.cross_entropy(
99+
input=logits.transpose(1, 2),
100+
target=tokens,
101+
reduction="none",
102+
ignore_index=pad_id,
103+
)
104+
105+
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
106+
107+
for cur_pos in range(min_prompt_len, total_len):
108+
inputs = (tokens[:, :cur_pos],)
109+
logits = self.model.forward(inputs) # updated forward call.
110+
logits = logits[0]
111+
if temperature > 0:
112+
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
113+
next_token = sample_top_p(probs, top_p)
114+
else:
115+
next_token = torch.argmax(logits[:, -1], dim=-1)
116+
117+
next_token = next_token.reshape(-1)
118+
119+
# only replace token if prompt has already been generated
120+
next_token = torch.where(
121+
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
122+
)
123+
124+
tokens[:, cur_pos] = next_token
125+
if logprobs:
126+
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
127+
input=logits.transpose(1, 2),
128+
target=tokens[:, prev_pos + 1 : cur_pos + 1],
129+
reduction="none",
130+
ignore_index=pad_id,
131+
)
132+
eos_reached |= (~input_text_mask[:, cur_pos]) & (
133+
torch.isin(next_token, stop_tokens)
134+
)
135+
prev_pos = cur_pos
136+
if all(eos_reached):
137+
break
138+
139+
if logprobs:
140+
token_logprobs = token_logprobs.tolist()
141+
out_tokens, out_logprobs = [], []
142+
for i, toks in enumerate(tokens.tolist()):
143+
# cut to max gen len
144+
start = 0 if echo else len(prompt_tokens[i])
145+
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
146+
probs = None
147+
if logprobs:
148+
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
149+
# cut to after eos tok if any
150+
for stop_token in self.tokenizer.stop_tokens:
151+
try:
152+
eos_idx = toks.index(stop_token)
153+
toks = toks[:eos_idx]
154+
probs = probs[:eos_idx] if logprobs else None
155+
except ValueError:
156+
pass
157+
out_tokens.append(toks)
158+
out_logprobs.append(probs)
159+
return (out_tokens, out_logprobs if logprobs else None)
160+
161+
def text_completion(
162+
self,
163+
prompts: List[str],
164+
temperature: float = 0.6,
165+
top_p: float = 0.9,
166+
max_gen_len: Optional[int] = None,
167+
logprobs: bool = False,
168+
echo: bool = False,
169+
) -> List[CompletionPrediction]:
170+
"""
171+
Perform text completion for a list of prompts using the language generation model.
172+
173+
Args:
174+
prompts (List[str]): List of text prompts for completion.
175+
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
176+
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
177+
max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
178+
If not provided, it's set to the model's maximum sequence length minus 1.
179+
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
180+
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
181+
182+
Returns:
183+
List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.
184+
185+
Note:
186+
This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
187+
If logprobs is True, token log probabilities are computed for each generated token.
188+
"""
189+
if max_gen_len is None:
190+
max_gen_len = self.model.params.max_seq_len - 1
191+
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
192+
generation_tokens, generation_logprobs = self.generate(
193+
prompt_tokens=prompt_tokens,
194+
max_gen_len=max_gen_len,
195+
temperature=temperature,
196+
top_p=top_p,
197+
logprobs=logprobs,
198+
echo=echo,
199+
)
200+
201+
if logprobs:
202+
return [
203+
{
204+
"generation": self.tokenizer.decode(t),
205+
"tokens": [self.tokenizer.decode([x]) for x in t],
206+
"logprobs": logprobs_i,
207+
}
208+
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
209+
]
210+
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
211+
212+
def chat_completion(
213+
self,
214+
dialogs: List[Dialog],
215+
temperature: float = 0.6,
216+
top_p: float = 0.9,
217+
max_gen_len: Optional[int] = None,
218+
logprobs: bool = False,
219+
) -> List[ChatPrediction]:
220+
"""
221+
Generate assistant responses for a list of conversational dialogs using the language generation model.
222+
223+
Args:
224+
dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
225+
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
226+
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
227+
max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
228+
If not provided, it's set to the model's maximum sequence length minus 1.
229+
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
230+
231+
Returns:
232+
List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
233+
234+
Raises:
235+
AssertionError: If the last message in a dialog is not from the user.
236+
AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
237+
238+
Note:
239+
This method generates assistant responses for the provided conversational dialogs.
240+
It employs nucleus sampling to introduce controlled randomness in text generation.
241+
If logprobs is True, token log probabilities are computed for each generated token.
242+
"""
243+
if max_gen_len is None:
244+
max_gen_len = self.model.params.max_seq_len - 1
245+
246+
prompt_tokens = [
247+
self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
248+
]
249+
generation_tokens, generation_logprobs = self.generate(
250+
prompt_tokens=prompt_tokens,
251+
max_gen_len=max_gen_len,
252+
temperature=temperature,
253+
top_p=top_p,
254+
logprobs=logprobs,
255+
)
256+
if logprobs:
257+
return [
258+
{
259+
"generation": {
260+
"role": "assistant",
261+
"content": self.tokenizer.decode(t),
262+
},
263+
"tokens": [self.tokenizer.decode([x]) for x in t],
264+
"logprobs": logprobs_i,
265+
}
266+
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
267+
]
268+
return [
269+
{
270+
"generation": {
271+
"role": "assistant",
272+
"content": self.tokenizer.decode(t),
273+
},
274+
}
275+
for t in generation_tokens
276+
]
277+
278+
279+
def build_args_parser() -> argparse.ArgumentParser:
280+
parser = argparse.ArgumentParser()
281+
282+
parser.add_argument(
283+
"-f",
284+
"--pte",
285+
type=str,
286+
default=None,
287+
help="path to exported executorch .pte file",
288+
)
289+
290+
parser.add_argument(
291+
"-p", "--params", type=str, default=None, help="model params file"
292+
)
293+
294+
parser.add_argument(
295+
"-t",
296+
"--tokenizer",
297+
type=str,
298+
default=None,
299+
)
300+
301+
parser.add_argument(
302+
"--prompt",
303+
type=str,
304+
default="Hello",
305+
)
306+
307+
parser.add_argument(
308+
"--temperature",
309+
type=float,
310+
default=0.6,
311+
)
312+
313+
parser.add_argument(
314+
"-kv",
315+
"--kv_cache",
316+
default=False,
317+
action="store_true",
318+
)
319+
320+
return parser
321+
322+
323+
def main() -> None:
324+
parser = build_args_parser()
325+
args = parser.parse_args()
326+
327+
with open(args.params, "r") as f:
328+
params = json.loads(f.read())
329+
model_args: ModelArgs = ModelArgs(
330+
max_seq_len=128,
331+
max_batch_size=1,
332+
use_kv_cache=args.kv_cache,
333+
**params,
334+
)
335+
runner = LlamaRunner(
336+
model_path=args.pte, tokenizer_path=args.tokenizer, model_args=model_args
337+
)
338+
result = runner.text_completion(
339+
prompts=[args.prompt], max_gen_len=10, temperature=args.temperature
340+
)
341+
print(f"Result: {result}")
342+
343+
344+
if __name__ == "__main__":
345+
main() # pragma: no cover

0 commit comments

Comments
 (0)