|
13 | 13 | import torch
|
14 | 14 | import torch.nn.functional as F
|
15 | 15 |
|
16 |
| -from executorch.examples.models.llama.llama_transformer import RMSNorm |
17 |
| - |
18 | 16 | from executorch.examples.models.llama.rope import (
|
19 | 17 | hf_apply_rotary_emb,
|
20 | 18 | hf_precompute_freqs_cis,
|
|
25 | 23 | from torch import nn
|
26 | 24 |
|
27 | 25 |
|
28 |
| -# These are just to prevent to_edge from decomposing SDPA |
29 |
| -# A better method is to use the to_edge_transform_and_lower API for CoreML |
30 |
| -# and not decompose SDPA |
31 |
| -@torch.library.custom_op("coreml::sdpa", mutates_args=()) |
32 |
| -def sdpa( |
33 |
| - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor |
34 |
| -) -> torch.Tensor: |
35 |
| - """Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion.""" |
36 |
| - return torch.ops.aten.scaled_dot_product_attention.default( |
37 |
| - q, k, v, attn_mask=attn_mask |
38 |
| - ) |
39 |
| - |
40 |
| - |
41 |
| -@torch.library.register_fake("coreml::sdpa") |
42 |
| -def _( |
43 |
| - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor |
44 |
| -) -> torch.Tensor: |
45 |
| - """Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing.""" |
46 |
| - expected_shape = list(q.shape) |
47 |
| - expected_shape[-1] = v.shape[-1] |
48 |
| - return q.new_empty(expected_shape) |
49 |
| - |
50 |
| - |
51 | 26 | def find_multiple(n: int, k: int) -> int:
|
52 | 27 | if n % k == 0:
|
53 | 28 | return n
|
@@ -121,6 +96,63 @@ def __post_init__(self):
|
121 | 96 | self.head_dim = self.dim // self.n_heads
|
122 | 97 |
|
123 | 98 |
|
| 99 | +class RMSNorm(torch.nn.Module): |
| 100 | + def __init__(self, dim: int, eps: float = 1e-6): |
| 101 | + """ |
| 102 | + Initialize the RMSNorm normalization layer. |
| 103 | +
|
| 104 | + Args: |
| 105 | + dim (int): The dimension of the input tensor. |
| 106 | + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. |
| 107 | +
|
| 108 | + Attributes: |
| 109 | + eps (float): A small value added to the denominator for numerical stability. |
| 110 | + weight (nn.Parameter): Learnable scaling parameter. |
| 111 | +
|
| 112 | + """ |
| 113 | + super().__init__() |
| 114 | + self.dim = dim |
| 115 | + self.eps = eps |
| 116 | + self.weight = nn.Parameter(torch.ones(dim)) |
| 117 | + |
| 118 | + def _norm(self, x): |
| 119 | + """ |
| 120 | + Apply the RMSNorm normalization to the input tensor. |
| 121 | +
|
| 122 | + Args: |
| 123 | + x (torch.Tensor): The input tensor. |
| 124 | +
|
| 125 | + Returns: |
| 126 | + torch.Tensor: The normalized tensor. |
| 127 | +
|
| 128 | + """ |
| 129 | + # CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable |
| 130 | + # We instead use (x * sqrt(n)) / norm(x, dim=-1) |
| 131 | + # Using torch.norm and preserving this op in CoreML improves stability |
| 132 | + # Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator |
| 133 | + # In future, we want to add CoreML support for the functional RMSNorm op |
| 134 | + # We have yet to do large scale evaluations on the numeric stability of this solution, but note that |
| 135 | + # it appears better than what exists currently (removing FP32 casts and using FP16) |
| 136 | + rms_norm_eps0 = ( |
| 137 | + x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) |
| 138 | + ) / torch.linalg.vector_norm(x, dim=-1, keepdim=True) |
| 139 | + return rms_norm_eps0 |
| 140 | + |
| 141 | + def forward(self, x): |
| 142 | + """ |
| 143 | + Forward pass through the RMSNorm layer. |
| 144 | +
|
| 145 | + Args: |
| 146 | + x (torch.Tensor): The input tensor. |
| 147 | +
|
| 148 | + Returns: |
| 149 | + torch.Tensor: The output tensor after applying RMSNorm. |
| 150 | +
|
| 151 | + """ |
| 152 | + output = self._norm(x) |
| 153 | + return output * self.weight |
| 154 | + |
| 155 | + |
124 | 156 | class Rope(torch.nn.Module):
|
125 | 157 | def __init__(self, params: ModelArgs):
|
126 | 158 | super().__init__()
|
@@ -304,12 +336,11 @@ def forward(
|
304 | 336 | k = k.repeat_interleave(self.n_rep, dim=1)
|
305 | 337 | v = v.repeat_interleave(self.n_rep, dim=1)
|
306 | 338 |
|
307 |
| - output = torch.ops.coreml.sdpa(q, k, v, attn_mask) |
308 |
| - |
| 339 | + output = torch.ops.aten.scaled_dot_product_attention.default( |
| 340 | + q, k, v, attn_mask=attn_mask |
| 341 | + ) |
309 | 342 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
310 |
| - |
311 | 343 | output = self.wo(output)
|
312 |
| - |
313 | 344 | return output, new_k, new_v
|
314 | 345 |
|
315 | 346 |
|
@@ -413,6 +444,39 @@ def forward(
|
413 | 444 | return logits, k_out, v_out
|
414 | 445 |
|
415 | 446 |
|
| 447 | +def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list): |
| 448 | + import json |
| 449 | + |
| 450 | + with open(params_path, "r") as f: |
| 451 | + params = json.loads(f.read()) |
| 452 | + |
| 453 | + args = ModelArgs( |
| 454 | + max_seq_len=max_seq_length, |
| 455 | + generate_full_logits=False, |
| 456 | + use_cache_list=use_cache_list, |
| 457 | + **params, |
| 458 | + ) |
| 459 | + |
| 460 | + with torch.device("meta"): |
| 461 | + model = Transformer(args) |
| 462 | + |
| 463 | + checkpoint = torch.load( |
| 464 | + checkpoint_path, map_location="cpu", mmap=True, weights_only=True |
| 465 | + ) |
| 466 | + if "model" in checkpoint: |
| 467 | + checkpoint = checkpoint["model"] |
| 468 | + |
| 469 | + missing, unexpected = model.load_state_dict( |
| 470 | + checkpoint, |
| 471 | + strict=False, |
| 472 | + assign=True, |
| 473 | + ) |
| 474 | + print("Missing keys: ", missing) |
| 475 | + print("Unexpected keys: ", unexpected) |
| 476 | + |
| 477 | + return model |
| 478 | + |
| 479 | + |
416 | 480 | class InputManager:
|
417 | 481 | def __init__(
|
418 | 482 | self,
|
|
0 commit comments