|
10 | 10 |
|
11 | 11 | import json
|
12 | 12 | import math
|
| 13 | +from dataclasses import dataclass |
13 | 14 | from pathlib import Path
|
14 | 15 | from typing import Optional, Tuple
|
15 | 16 |
|
16 | 17 | import torch
|
17 | 18 | import torch.nn.functional as F
|
18 | 19 |
|
19 |
| -from examples.models.model_base import EagerModelBase |
20 |
| - |
21 |
| -from llama.model import ModelArgs, repeat_kv, RMSNorm |
22 | 20 | from torch import nn
|
23 | 21 |
|
| 22 | +from ..model_base import EagerModelBase |
| 23 | + |
| 24 | + |
| 25 | +class RMSNorm(torch.nn.Module): |
| 26 | + def __init__(self, dim: int, eps: float = 1e-6): |
| 27 | + """ |
| 28 | + Initialize the RMSNorm normalization layer. |
| 29 | +
|
| 30 | + Args: |
| 31 | + dim (int): The dimension of the input tensor. |
| 32 | + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. |
| 33 | +
|
| 34 | + Attributes: |
| 35 | + eps (float): A small value added to the denominator for numerical stability. |
| 36 | + weight (nn.Parameter): Learnable scaling parameter. |
| 37 | +
|
| 38 | + """ |
| 39 | + super().__init__() |
| 40 | + self.eps = eps |
| 41 | + self.weight = nn.Parameter(torch.ones(dim)) |
| 42 | + |
| 43 | + def _norm(self, x): |
| 44 | + """ |
| 45 | + Apply the RMSNorm normalization to the input tensor. |
| 46 | +
|
| 47 | + Args: |
| 48 | + x (torch.Tensor): The input tensor. |
| 49 | +
|
| 50 | + Returns: |
| 51 | + torch.Tensor: The normalized tensor. |
| 52 | +
|
| 53 | + """ |
| 54 | + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| 55 | + |
| 56 | + def forward(self, x): |
| 57 | + """ |
| 58 | + Forward pass through the RMSNorm layer. |
| 59 | +
|
| 60 | + Args: |
| 61 | + x (torch.Tensor): The input tensor. |
| 62 | +
|
| 63 | + Returns: |
| 64 | + torch.Tensor: The output tensor after applying RMSNorm. |
| 65 | +
|
| 66 | + """ |
| 67 | + output = self._norm(x.float()).type_as(x) |
| 68 | + return output * self.weight |
| 69 | + |
| 70 | + |
| 71 | +@dataclass |
| 72 | +class ModelArgs: |
| 73 | + dim: int = 4096 |
| 74 | + n_layers: int = 32 |
| 75 | + n_heads: int = 32 |
| 76 | + n_kv_heads: Optional[int] = None |
| 77 | + vocab_size: int = -1 # defined later by tokenizer |
| 78 | + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 |
| 79 | + ffn_dim_multiplier: Optional[float] = None |
| 80 | + norm_eps: float = 1e-5 |
| 81 | + |
| 82 | + max_batch_size: int = 32 |
| 83 | + max_seq_len: int = 2048 |
| 84 | + |
| 85 | + |
| 86 | +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
| 87 | + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" |
| 88 | + bs, slen, n_kv_heads, head_dim = x.shape |
| 89 | + if n_rep == 1: |
| 90 | + return x |
| 91 | + return ( |
| 92 | + x[:, :, :, None, :] |
| 93 | + .expand(bs, slen, n_kv_heads, n_rep, head_dim) |
| 94 | + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) |
| 95 | + ) |
| 96 | + |
24 | 97 |
|
25 | 98 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
26 | 99 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
27 |
| - t = torch.arange(end, device=freqs.device) |
28 |
| - freqs = torch.outer(t, freqs).float() |
| 100 | + t = torch.arange(end, device=freqs.device) # pyre-ignore |
| 101 | + freqs = torch.outer(t, freqs).float() # pyre-ignore |
29 | 102 | freqs_cos = torch.cos(freqs)
|
30 | 103 | freqs_sin = torch.sin(freqs)
|
31 | 104 | return freqs_cos, freqs_sin
|
@@ -155,8 +228,6 @@ def forward(self, x, freqs_cos, freqs_sin):
|
155 | 228 |
|
156 | 229 |
|
157 | 230 | class Transformer(nn.Module):
|
158 |
| - last_loss: Optional[torch.Tensor] |
159 |
| - |
160 | 231 | def __init__(self, params: ModelArgs):
|
161 | 232 | super().__init__()
|
162 | 233 | self.params = params
|
@@ -194,14 +265,31 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
194 | 265 |
|
195 | 266 | class Llama2Model(EagerModelBase):
|
196 | 267 | def __init__(self, **kwargs):
|
| 268 | + import pkg_resources |
| 269 | + |
197 | 270 | ckpt_dir = Path(__file__).absolute().parent
|
| 271 | + |
| 272 | + # Get the path to the resource file |
| 273 | + params_path = ( |
| 274 | + Path(ckpt_dir) / kwargs["checkpoint"] |
| 275 | + if "checkpoint" in kwargs |
| 276 | + else pkg_resources.resource_filename( |
| 277 | + "executorch.examples.portable.scripts", "demo_config.json" |
| 278 | + ) |
| 279 | + ) |
| 280 | + checkpoint_path = ( |
| 281 | + Path(ckpt_dir) / kwargs["params"] |
| 282 | + if "params" in kwargs |
| 283 | + else pkg_resources.resource_filename( |
| 284 | + "executorch.examples.portable.scripts", "demo_rand_params.pth" |
| 285 | + ) |
| 286 | + ) |
| 287 | + |
198 | 288 | # The example is using a dummy small model with random weights for demo purpose only.
|
199 | 289 | # Follow the instruction in https://github.com/facebookresearch/llama to download the model
|
200 | 290 | device = "cpu"
|
201 |
| - checkpoint = torch.load( |
202 |
| - Path(ckpt_dir) / kwargs["checkpoint"], map_location=device |
203 |
| - ) |
204 |
| - with open(Path(ckpt_dir) / kwargs["params"], "r") as f: |
| 291 | + checkpoint = torch.load(checkpoint_path, map_location=device) |
| 292 | + with open(params_path, "r") as f: |
205 | 293 | params = json.loads(f.read())
|
206 | 294 | max_seq_len = 128
|
207 | 295 | max_batch_size = 1
|
|
0 commit comments