Skip to content

Commit c74f161

Browse files
committed
generation.py
1 parent b1edc3d commit c74f161

File tree

1 file changed

+344
-0
lines changed

1 file changed

+344
-0
lines changed
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
import argparse
4+
from typing import List, Optional, Tuple, TypedDict
5+
6+
import json
7+
import torch
8+
import torch.nn.functional as F
9+
from executorch.examples.models.llama2.llama_transformer import ModelArgs
10+
11+
from executorch.examples.models.llama2.tokenizer.tiktoken import (
12+
Dialog,
13+
Message,
14+
Tokenizer,
15+
)
16+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
17+
18+
19+
class CompletionPrediction(TypedDict, total=False):
20+
generation: str
21+
tokens: List[str] # not required
22+
logprobs: List[float] # not required
23+
24+
25+
class ChatPrediction(TypedDict, total=False):
26+
generation: Message
27+
tokens: List[str] # not required
28+
logprobs: List[float] # not required
29+
30+
31+
def sample_top_p(probs, p):
32+
"""
33+
Perform top-p (nucleus) sampling on a probability distribution.
34+
35+
Args:
36+
probs (torch.Tensor): Probability distribution tensor.
37+
p (float): Probability threshold for top-p sampling.
38+
39+
Returns:
40+
torch.Tensor: Sampled token indices.
41+
42+
Note:
43+
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
44+
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
45+
"""
46+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
47+
probs_sum = torch.cumsum(probs_sort, dim=-1)
48+
mask = probs_sum - probs_sort > p
49+
probs_sort[mask] = 0.0
50+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
51+
next_token = torch.multinomial(probs_sort, num_samples=1)
52+
next_token = torch.gather(probs_idx, -1, next_token)
53+
return next_token
54+
55+
56+
class LlamaRunner:
57+
def __init__(self, model_path: str, tokenizer_path: str, model_args: ModelArgs):
58+
# model is a pte file.
59+
self.model = _load_for_executorch(model_path)
60+
self.params = model_args
61+
self.tokenizer = Tokenizer(tokenizer_path)
62+
assert model_args.vocab_size == self.tokenizer.n_words
63+
64+
def generate(
65+
self,
66+
prompt_tokens: List[List[int]],
67+
max_gen_len: int,
68+
temperature: float = 0.6,
69+
top_p: float = 0.9,
70+
logprobs: bool = False,
71+
echo: bool = False,
72+
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
73+
bsz = len(prompt_tokens)
74+
params = self.params
75+
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
76+
77+
min_prompt_len = min(len(t) for t in prompt_tokens)
78+
max_prompt_len = max(len(t) for t in prompt_tokens)
79+
80+
assert max_prompt_len <= params.max_seq_len
81+
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
82+
pad_id = self.tokenizer.pad_id
83+
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cpu")
84+
for k, t in enumerate(prompt_tokens):
85+
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cpu")
86+
if logprobs:
87+
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
88+
89+
prev_pos = 0
90+
91+
eos_reached = torch.tensor([False] * bsz, device="cpu")
92+
input_text_mask = tokens != pad_id
93+
if min_prompt_len == total_len:
94+
inputs = (tokens,)
95+
logits = self.model.forward(inputs) # updated forward call.
96+
logits = logits[0]
97+
token_logprobs = -F.cross_entropy(
98+
input=logits.transpose(1, 2),
99+
target=tokens,
100+
reduction="none",
101+
ignore_index=pad_id,
102+
)
103+
104+
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
105+
106+
for cur_pos in range(min_prompt_len, total_len):
107+
inputs = (tokens[:, :cur_pos],)
108+
logits = self.model.forward(inputs) # updated forward call.
109+
logits = logits[0]
110+
if temperature > 0:
111+
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
112+
next_token = sample_top_p(probs, top_p)
113+
else:
114+
next_token = torch.argmax(logits[:, -1], dim=-1)
115+
116+
next_token = next_token.reshape(-1)
117+
118+
# only replace token if prompt has already been generated
119+
next_token = torch.where(
120+
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
121+
)
122+
123+
tokens[:, cur_pos] = next_token
124+
if logprobs:
125+
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
126+
input=logits.transpose(1, 2),
127+
target=tokens[:, prev_pos + 1 : cur_pos + 1],
128+
reduction="none",
129+
ignore_index=pad_id,
130+
)
131+
eos_reached |= (~input_text_mask[:, cur_pos]) & (
132+
torch.isin(next_token, stop_tokens)
133+
)
134+
prev_pos = cur_pos
135+
if all(eos_reached):
136+
break
137+
138+
if logprobs:
139+
token_logprobs = token_logprobs.tolist()
140+
out_tokens, out_logprobs = [], []
141+
for i, toks in enumerate(tokens.tolist()):
142+
# cut to max gen len
143+
start = 0 if echo else len(prompt_tokens[i])
144+
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
145+
probs = None
146+
if logprobs:
147+
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
148+
# cut to after eos tok if any
149+
for stop_token in self.tokenizer.stop_tokens:
150+
try:
151+
eos_idx = toks.index(stop_token)
152+
toks = toks[:eos_idx]
153+
probs = probs[:eos_idx] if logprobs else None
154+
except ValueError:
155+
pass
156+
out_tokens.append(toks)
157+
out_logprobs.append(probs)
158+
return (out_tokens, out_logprobs if logprobs else None)
159+
160+
def text_completion(
161+
self,
162+
prompts: List[str],
163+
temperature: float = 0.6,
164+
top_p: float = 0.9,
165+
max_gen_len: Optional[int] = None,
166+
logprobs: bool = False,
167+
echo: bool = False,
168+
) -> List[CompletionPrediction]:
169+
"""
170+
Perform text completion for a list of prompts using the language generation model.
171+
172+
Args:
173+
prompts (List[str]): List of text prompts for completion.
174+
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
175+
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
176+
max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
177+
If not provided, it's set to the model's maximum sequence length minus 1.
178+
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
179+
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
180+
181+
Returns:
182+
List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.
183+
184+
Note:
185+
This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
186+
If logprobs is True, token log probabilities are computed for each generated token.
187+
"""
188+
if max_gen_len is None:
189+
max_gen_len = self.model.params.max_seq_len - 1
190+
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
191+
generation_tokens, generation_logprobs = self.generate(
192+
prompt_tokens=prompt_tokens,
193+
max_gen_len=max_gen_len,
194+
temperature=temperature,
195+
top_p=top_p,
196+
logprobs=logprobs,
197+
echo=echo,
198+
)
199+
200+
if logprobs:
201+
return [
202+
{
203+
"generation": self.tokenizer.decode(t),
204+
"tokens": [self.tokenizer.decode([x]) for x in t],
205+
"logprobs": logprobs_i,
206+
}
207+
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
208+
]
209+
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
210+
211+
def chat_completion(
212+
self,
213+
dialogs: List[Dialog],
214+
temperature: float = 0.6,
215+
top_p: float = 0.9,
216+
max_gen_len: Optional[int] = None,
217+
logprobs: bool = False,
218+
) -> List[ChatPrediction]:
219+
"""
220+
Generate assistant responses for a list of conversational dialogs using the language generation model.
221+
222+
Args:
223+
dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
224+
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
225+
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
226+
max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
227+
If not provided, it's set to the model's maximum sequence length minus 1.
228+
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
229+
230+
Returns:
231+
List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
232+
233+
Raises:
234+
AssertionError: If the last message in a dialog is not from the user.
235+
AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
236+
237+
Note:
238+
This method generates assistant responses for the provided conversational dialogs.
239+
It employs nucleus sampling to introduce controlled randomness in text generation.
240+
If logprobs is True, token log probabilities are computed for each generated token.
241+
"""
242+
if max_gen_len is None:
243+
max_gen_len = self.model.params.max_seq_len - 1
244+
245+
prompt_tokens = [
246+
self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
247+
]
248+
generation_tokens, generation_logprobs = self.generate(
249+
prompt_tokens=prompt_tokens,
250+
max_gen_len=max_gen_len,
251+
temperature=temperature,
252+
top_p=top_p,
253+
logprobs=logprobs,
254+
)
255+
if logprobs:
256+
return [
257+
{
258+
"generation": {
259+
"role": "assistant",
260+
"content": self.tokenizer.decode(t),
261+
},
262+
"tokens": [self.tokenizer.decode([x]) for x in t],
263+
"logprobs": logprobs_i,
264+
}
265+
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
266+
]
267+
return [
268+
{
269+
"generation": {
270+
"role": "assistant",
271+
"content": self.tokenizer.decode(t),
272+
},
273+
}
274+
for t in generation_tokens
275+
]
276+
277+
278+
def build_args_parser() -> argparse.ArgumentParser:
279+
parser = argparse.ArgumentParser()
280+
281+
parser.add_argument(
282+
"-f",
283+
"--pte",
284+
type=str,
285+
default=None,
286+
help="path to exported executorch .pte file",
287+
)
288+
289+
parser.add_argument(
290+
"-p", "--params", type=str, default=None, help="model params file"
291+
)
292+
293+
parser.add_argument(
294+
"-t",
295+
"--tokenizer",
296+
type=str,
297+
default=None,
298+
)
299+
300+
parser.add_argument(
301+
"--prompt",
302+
type=str,
303+
default="Hello",
304+
)
305+
306+
parser.add_argument(
307+
"--temperature",
308+
type=float,
309+
default=0.6,
310+
)
311+
312+
parser.add_argument(
313+
"-kv",
314+
"--kv_cache",
315+
default=False,
316+
action="store_true",
317+
)
318+
319+
return parser
320+
321+
322+
def main() -> None:
323+
parser = build_args_parser()
324+
args = parser.parse_args()
325+
326+
with open(args.params, "r") as f:
327+
params = json.loads(f.read())
328+
model_args: ModelArgs = ModelArgs(
329+
max_seq_len=128,
330+
max_batch_size=1,
331+
use_kv_cache=args.kv_cache,
332+
**params,
333+
)
334+
runner = LlamaRunner(
335+
model_path=args.pte, tokenizer_path=args.tokenizer, model_args=model_args
336+
)
337+
result = runner.text_completion(
338+
prompts=[args.prompt], max_gen_len=10, temperature=args.temperature
339+
)
340+
print(f"Result: {result}")
341+
342+
343+
if __name__ == "__main__":
344+
main() # pragma: no cover

0 commit comments

Comments
 (0)