|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | + |
| 8 | +import argparse |
| 9 | + |
| 10 | +import json |
| 11 | +from typing import List, Optional, Tuple, TypedDict |
| 12 | + |
| 13 | +import torch |
| 14 | +import torch.nn.functional as F |
| 15 | +from executorch.examples.models.llama2.llama_transformer import ModelArgs |
| 16 | + |
| 17 | +from executorch.examples.models.llama2.tokenizer.tiktoken import ( |
| 18 | + Dialog, |
| 19 | + Message, |
| 20 | + Tokenizer, |
| 21 | +) |
| 22 | +from executorch.extension.pybindings.portable_lib import _load_for_executorch |
| 23 | + |
| 24 | + |
| 25 | +class CompletionPrediction(TypedDict, total=False): |
| 26 | + generation: str |
| 27 | + tokens: List[str] # not required |
| 28 | + logprobs: List[float] # not required |
| 29 | + |
| 30 | + |
| 31 | +class ChatPrediction(TypedDict, total=False): |
| 32 | + generation: Message |
| 33 | + tokens: List[str] # not required |
| 34 | + logprobs: List[float] # not required |
| 35 | + |
| 36 | + |
| 37 | +def sample_top_p(probs, p): |
| 38 | + """ |
| 39 | + Perform top-p (nucleus) sampling on a probability distribution. |
| 40 | +
|
| 41 | + Args: |
| 42 | + probs (torch.Tensor): Probability distribution tensor. |
| 43 | + p (float): Probability threshold for top-p sampling. |
| 44 | +
|
| 45 | + Returns: |
| 46 | + torch.Tensor: Sampled token indices. |
| 47 | +
|
| 48 | + Note: |
| 49 | + Top-p sampling selects the smallest set of tokens whose cumulative probability mass |
| 50 | + exceeds the threshold p. The distribution is renormalized based on the selected tokens. |
| 51 | + """ |
| 52 | + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
| 53 | + probs_sum = torch.cumsum(probs_sort, dim=-1) |
| 54 | + mask = probs_sum - probs_sort > p |
| 55 | + probs_sort[mask] = 0.0 |
| 56 | + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
| 57 | + next_token = torch.multinomial(probs_sort, num_samples=1) |
| 58 | + next_token = torch.gather(probs_idx, -1, next_token) |
| 59 | + return next_token |
| 60 | + |
| 61 | + |
| 62 | +class LlamaRunner: |
| 63 | + def __init__(self, model_path: str, tokenizer_path: str, model_args: ModelArgs): |
| 64 | + # model is a pte file. |
| 65 | + self.model = _load_for_executorch(model_path) |
| 66 | + self.params = model_args |
| 67 | + self.tokenizer = Tokenizer(tokenizer_path) |
| 68 | + assert model_args.vocab_size == self.tokenizer.n_words |
| 69 | + |
| 70 | + def generate( # noqa: C901 |
| 71 | + self, |
| 72 | + prompt_tokens: List[List[int]], |
| 73 | + max_gen_len: int, |
| 74 | + temperature: float = 0.6, |
| 75 | + top_p: float = 0.9, |
| 76 | + logprobs: bool = False, |
| 77 | + echo: bool = False, |
| 78 | + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: |
| 79 | + bsz = len(prompt_tokens) |
| 80 | + params = self.params |
| 81 | + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) |
| 82 | + |
| 83 | + min_prompt_len = min(len(t) for t in prompt_tokens) |
| 84 | + max_prompt_len = max(len(t) for t in prompt_tokens) |
| 85 | + |
| 86 | + assert max_prompt_len <= params.max_seq_len |
| 87 | + total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) |
| 88 | + pad_id = self.tokenizer.pad_id |
| 89 | + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cpu") |
| 90 | + for k, t in enumerate(prompt_tokens): |
| 91 | + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cpu") |
| 92 | + if logprobs: |
| 93 | + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) |
| 94 | + |
| 95 | + prev_pos = 0 |
| 96 | + if self.params.use_kv_cache: |
| 97 | + min_prompt_len = 1 |
| 98 | + |
| 99 | + eos_reached = torch.tensor([False] * bsz, device="cpu") |
| 100 | + input_text_mask = tokens != pad_id |
| 101 | + pos = torch.tensor([prev_pos], dtype=torch.int64) |
| 102 | + if min_prompt_len == total_len: |
| 103 | + if self.params.use_kv_cache: |
| 104 | + inputs = (tokens, pos) |
| 105 | + else: |
| 106 | + inputs = (tokens,) |
| 107 | + logits = self.model.forward(inputs) # updated forward call. |
| 108 | + logits = logits[0] |
| 109 | + token_logprobs = -F.cross_entropy( |
| 110 | + input=logits.transpose(1, 2), |
| 111 | + target=tokens, |
| 112 | + reduction="none", |
| 113 | + ignore_index=pad_id, |
| 114 | + ) |
| 115 | + |
| 116 | + stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)) |
| 117 | + |
| 118 | + for cur_pos in range(min_prompt_len, total_len): |
| 119 | + pos = torch.tensor([prev_pos], dtype=torch.int64) |
| 120 | + if self.params.use_kv_cache: |
| 121 | + inputs = (tokens[:, prev_pos:cur_pos], pos) |
| 122 | + else: |
| 123 | + inputs = (tokens[:, :cur_pos],) |
| 124 | + logits = self.model.forward(inputs) # updated forward call. |
| 125 | + logits = logits[0] |
| 126 | + if temperature > 0: |
| 127 | + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) |
| 128 | + next_token = sample_top_p(probs, top_p) |
| 129 | + else: |
| 130 | + next_token = torch.argmax(logits[:, -1], dim=-1) |
| 131 | + |
| 132 | + next_token = next_token.reshape(-1) |
| 133 | + |
| 134 | + # only replace token if prompt has already been generated |
| 135 | + if not self.params.use_kv_cache or cur_pos < len(prompt_tokens[0]): |
| 136 | + next_token = torch.where( |
| 137 | + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token |
| 138 | + ) |
| 139 | + |
| 140 | + tokens[:, cur_pos] = next_token |
| 141 | + if logprobs: |
| 142 | + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( |
| 143 | + input=logits.transpose(1, 2), |
| 144 | + target=tokens[:, prev_pos + 1 : cur_pos + 1], |
| 145 | + reduction="none", |
| 146 | + ignore_index=pad_id, |
| 147 | + ) |
| 148 | + eos_reached |= (~input_text_mask[:, cur_pos]) & ( |
| 149 | + torch.isin(next_token, stop_tokens) |
| 150 | + ) |
| 151 | + prev_pos = cur_pos |
| 152 | + if all(eos_reached): |
| 153 | + break |
| 154 | + |
| 155 | + if logprobs: |
| 156 | + token_logprobs = token_logprobs.tolist() |
| 157 | + out_tokens, out_logprobs = [], [] |
| 158 | + for i, toks in enumerate(tokens.tolist()): |
| 159 | + # cut to max gen len |
| 160 | + start = 0 if echo else len(prompt_tokens[i]) |
| 161 | + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] |
| 162 | + probs = None |
| 163 | + if logprobs: |
| 164 | + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] |
| 165 | + # cut to after eos tok if any |
| 166 | + for stop_token in self.tokenizer.stop_tokens: |
| 167 | + try: |
| 168 | + eos_idx = toks.index(stop_token) |
| 169 | + toks = toks[:eos_idx] |
| 170 | + probs = probs[:eos_idx] if logprobs else None |
| 171 | + except ValueError: |
| 172 | + pass |
| 173 | + out_tokens.append(toks) |
| 174 | + out_logprobs.append(probs) |
| 175 | + return (out_tokens, out_logprobs if logprobs else None) |
| 176 | + |
| 177 | + def text_completion( |
| 178 | + self, |
| 179 | + prompts: List[str], |
| 180 | + temperature: float = 0.6, |
| 181 | + top_p: float = 0.9, |
| 182 | + max_gen_len: Optional[int] = None, |
| 183 | + logprobs: bool = False, |
| 184 | + echo: bool = False, |
| 185 | + ) -> List[CompletionPrediction]: |
| 186 | + """ |
| 187 | + Perform text completion for a list of prompts using the language generation model. |
| 188 | +
|
| 189 | + Args: |
| 190 | + prompts (List[str]): List of text prompts for completion. |
| 191 | + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. |
| 192 | + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. |
| 193 | + max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. |
| 194 | + If not provided, it's set to the model's maximum sequence length minus 1. |
| 195 | + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. |
| 196 | + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. |
| 197 | +
|
| 198 | + Returns: |
| 199 | + List[CompletionPrediction]: List of completion predictions, each containing the generated text completion. |
| 200 | +
|
| 201 | + Note: |
| 202 | + This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness. |
| 203 | + If logprobs is True, token log probabilities are computed for each generated token. |
| 204 | + """ |
| 205 | + if max_gen_len is None: |
| 206 | + max_gen_len = self.model.params.max_seq_len - 1 |
| 207 | + prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] |
| 208 | + generation_tokens, generation_logprobs = self.generate( |
| 209 | + prompt_tokens=prompt_tokens, |
| 210 | + max_gen_len=max_gen_len, |
| 211 | + temperature=temperature, |
| 212 | + top_p=top_p, |
| 213 | + logprobs=logprobs, |
| 214 | + echo=echo, |
| 215 | + ) |
| 216 | + |
| 217 | + if logprobs: |
| 218 | + return [ |
| 219 | + { |
| 220 | + "generation": self.tokenizer.decode(t), |
| 221 | + "tokens": [self.tokenizer.decode([x]) for x in t], |
| 222 | + "logprobs": logprobs_i, |
| 223 | + } |
| 224 | + for t, logprobs_i in zip(generation_tokens, generation_logprobs) |
| 225 | + ] |
| 226 | + return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] |
| 227 | + |
| 228 | + def chat_completion( |
| 229 | + self, |
| 230 | + dialogs: List[Dialog], |
| 231 | + temperature: float = 0.6, |
| 232 | + top_p: float = 0.9, |
| 233 | + max_gen_len: Optional[int] = None, |
| 234 | + logprobs: bool = False, |
| 235 | + ) -> List[ChatPrediction]: |
| 236 | + """ |
| 237 | + Generate assistant responses for a list of conversational dialogs using the language generation model. |
| 238 | +
|
| 239 | + Args: |
| 240 | + dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. |
| 241 | + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. |
| 242 | + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. |
| 243 | + max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. |
| 244 | + If not provided, it's set to the model's maximum sequence length minus 1. |
| 245 | + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. |
| 246 | +
|
| 247 | + Returns: |
| 248 | + List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. |
| 249 | +
|
| 250 | + Raises: |
| 251 | + AssertionError: If the last message in a dialog is not from the user. |
| 252 | + AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order. |
| 253 | +
|
| 254 | + Note: |
| 255 | + This method generates assistant responses for the provided conversational dialogs. |
| 256 | + It employs nucleus sampling to introduce controlled randomness in text generation. |
| 257 | + If logprobs is True, token log probabilities are computed for each generated token. |
| 258 | + """ |
| 259 | + if max_gen_len is None: |
| 260 | + max_gen_len = self.model.params.max_seq_len - 1 |
| 261 | + |
| 262 | + prompt_tokens = [ |
| 263 | + self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs |
| 264 | + ] |
| 265 | + generation_tokens, generation_logprobs = self.generate( |
| 266 | + prompt_tokens=prompt_tokens, |
| 267 | + max_gen_len=max_gen_len, |
| 268 | + temperature=temperature, |
| 269 | + top_p=top_p, |
| 270 | + logprobs=logprobs, |
| 271 | + ) |
| 272 | + if logprobs: |
| 273 | + return [ |
| 274 | + { |
| 275 | + "generation": { |
| 276 | + "role": "assistant", |
| 277 | + "content": self.tokenizer.decode(t), |
| 278 | + }, |
| 279 | + "tokens": [self.tokenizer.decode([x]) for x in t], |
| 280 | + "logprobs": logprobs_i, |
| 281 | + } |
| 282 | + for t, logprobs_i in zip(generation_tokens, generation_logprobs) |
| 283 | + ] |
| 284 | + return [ |
| 285 | + { |
| 286 | + "generation": { |
| 287 | + "role": "assistant", |
| 288 | + "content": self.tokenizer.decode(t), |
| 289 | + }, |
| 290 | + } |
| 291 | + for t in generation_tokens |
| 292 | + ] |
| 293 | + |
| 294 | + |
| 295 | +def build_args_parser() -> argparse.ArgumentParser: |
| 296 | + parser = argparse.ArgumentParser() |
| 297 | + |
| 298 | + parser.add_argument( |
| 299 | + "-f", |
| 300 | + "--pte", |
| 301 | + type=str, |
| 302 | + default=None, |
| 303 | + help="path to exported executorch .pte file", |
| 304 | + ) |
| 305 | + |
| 306 | + parser.add_argument( |
| 307 | + "-p", "--params", type=str, default=None, help="model params file" |
| 308 | + ) |
| 309 | + |
| 310 | + parser.add_argument( |
| 311 | + "-t", |
| 312 | + "--tokenizer", |
| 313 | + type=str, |
| 314 | + default=None, |
| 315 | + ) |
| 316 | + |
| 317 | + parser.add_argument( |
| 318 | + "--prompt", |
| 319 | + type=str, |
| 320 | + default="Hello", |
| 321 | + ) |
| 322 | + |
| 323 | + parser.add_argument( |
| 324 | + "--temperature", |
| 325 | + type=float, |
| 326 | + default=0.6, |
| 327 | + ) |
| 328 | + |
| 329 | + parser.add_argument( |
| 330 | + "-kv", |
| 331 | + "--kv_cache", |
| 332 | + default=False, |
| 333 | + action="store_true", |
| 334 | + ) |
| 335 | + |
| 336 | + parser.add_argument( |
| 337 | + "--max_gen_len", |
| 338 | + type=int, |
| 339 | + default=10, |
| 340 | + help="Maximum length of the generated response sequence.", |
| 341 | + ) |
| 342 | + |
| 343 | + return parser |
| 344 | + |
| 345 | + |
| 346 | +def main() -> None: |
| 347 | + parser = build_args_parser() |
| 348 | + args = parser.parse_args() |
| 349 | + |
| 350 | + with open(args.params, "r") as f: |
| 351 | + params = json.loads(f.read()) |
| 352 | + model_args: ModelArgs = ModelArgs( |
| 353 | + max_seq_len=128, |
| 354 | + max_batch_size=1, |
| 355 | + use_kv_cache=args.kv_cache, |
| 356 | + **params, |
| 357 | + ) |
| 358 | + runner = LlamaRunner( |
| 359 | + model_path=args.pte, tokenizer_path=args.tokenizer, model_args=model_args |
| 360 | + ) |
| 361 | + result = runner.text_completion( |
| 362 | + prompts=[args.prompt], |
| 363 | + max_gen_len=args.max_gen_len, |
| 364 | + temperature=args.temperature, |
| 365 | + ) |
| 366 | + print(f"Result: {result}") |
| 367 | + |
| 368 | + |
| 369 | +if __name__ == "__main__": |
| 370 | + main() # pragma: no cover |
0 commit comments