-
Notifications
You must be signed in to change notification settings - Fork 608
Add dtype, fix RMS norm for FP16 #8641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,6 @@ | |
import torch | ||
import torch.nn.functional as F | ||
|
||
from executorch.examples.models.llama.llama_transformer import RMSNorm | ||
|
||
from executorch.examples.models.llama.rope import ( | ||
hf_apply_rotary_emb, | ||
hf_precompute_freqs_cis, | ||
|
@@ -25,29 +23,6 @@ | |
from torch import nn | ||
|
||
|
||
# These are just to prevent to_edge from decomposing SDPA | ||
# A better method is to use the to_edge_transform_and_lower API for CoreML | ||
# and not decompose SDPA | ||
@torch.library.custom_op("coreml::sdpa", mutates_args=()) | ||
def sdpa( | ||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor | ||
) -> torch.Tensor: | ||
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion.""" | ||
return torch.ops.aten.scaled_dot_product_attention.default( | ||
q, k, v, attn_mask=attn_mask | ||
) | ||
|
||
|
||
@torch.library.register_fake("coreml::sdpa") | ||
def _( | ||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor | ||
) -> torch.Tensor: | ||
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing.""" | ||
expected_shape = list(q.shape) | ||
expected_shape[-1] = v.shape[-1] | ||
return q.new_empty(expected_shape) | ||
|
||
|
||
def find_multiple(n: int, k: int) -> int: | ||
if n % k == 0: | ||
return n | ||
|
@@ -121,6 +96,63 @@ def __post_init__(self): | |
self.head_dim = self.dim // self.n_heads | ||
|
||
|
||
class RMSNorm(torch.nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In sync with CoreML team, we might try using https://pytorch.org/docs/stable/generated/torch.nn.functional.rms_norm.html and then write an CoreML op definition for it here: https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py @YifanShenSZ mentioned they have a fused norm op that could be used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there are 2 possibilities
|
||
def __init__(self, dim: int, eps: float = 1e-6): | ||
""" | ||
Initialize the RMSNorm normalization layer. | ||
|
||
Args: | ||
dim (int): The dimension of the input tensor. | ||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. | ||
|
||
Attributes: | ||
eps (float): A small value added to the denominator for numerical stability. | ||
weight (nn.Parameter): Learnable scaling parameter. | ||
|
||
""" | ||
super().__init__() | ||
self.dim = dim | ||
self.eps = eps | ||
self.weight = nn.Parameter(torch.ones(dim)) | ||
|
||
def _norm(self, x): | ||
""" | ||
Apply the RMSNorm normalization to the input tensor. | ||
|
||
Args: | ||
x (torch.Tensor): The input tensor. | ||
|
||
Returns: | ||
torch.Tensor: The normalized tensor. | ||
|
||
""" | ||
# CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable | ||
# We instead use (x * sqrt(n)) / norm(x, dim=-1) | ||
# Using torch.norm and preserving this op in CoreML improves stability | ||
# Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator | ||
# In future, we want to add CoreML support for the functional RMSNorm op | ||
# We have yet to do large scale evaluations on the numeric stability of this solution, but note that | ||
# it appears better than what exists currently (removing FP32 casts and using FP16) | ||
rms_norm_eps0 = ( | ||
x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) | ||
) / torch.linalg.vector_norm(x, dim=-1, keepdim=True) | ||
return rms_norm_eps0 | ||
|
||
def forward(self, x): | ||
""" | ||
Forward pass through the RMSNorm layer. | ||
|
||
Args: | ||
x (torch.Tensor): The input tensor. | ||
|
||
Returns: | ||
torch.Tensor: The output tensor after applying RMSNorm. | ||
|
||
""" | ||
output = self._norm(x) | ||
return output * self.weight | ||
|
||
|
||
class Rope(torch.nn.Module): | ||
def __init__(self, params: ModelArgs): | ||
super().__init__() | ||
|
@@ -304,12 +336,11 @@ def forward( | |
k = k.repeat_interleave(self.n_rep, dim=1) | ||
v = v.repeat_interleave(self.n_rep, dim=1) | ||
|
||
output = torch.ops.coreml.sdpa(q, k, v, attn_mask) | ||
|
||
output = torch.ops.aten.scaled_dot_product_attention.default( | ||
q, k, v, attn_mask=attn_mask | ||
) | ||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) | ||
|
||
output = self.wo(output) | ||
|
||
return output, new_k, new_v | ||
|
||
|
||
|
@@ -413,6 +444,39 @@ def forward( | |
return logits, k_out, v_out | ||
|
||
|
||
def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list): | ||
import json | ||
|
||
with open(params_path, "r") as f: | ||
params = json.loads(f.read()) | ||
|
||
args = ModelArgs( | ||
max_seq_len=max_seq_length, | ||
generate_full_logits=False, | ||
use_cache_list=use_cache_list, | ||
**params, | ||
) | ||
|
||
with torch.device("meta"): | ||
model = Transformer(args) | ||
|
||
checkpoint = torch.load( | ||
checkpoint_path, map_location="cpu", mmap=True, weights_only=True | ||
) | ||
if "model" in checkpoint: | ||
checkpoint = checkpoint["model"] | ||
|
||
missing, unexpected = model.load_state_dict( | ||
checkpoint, | ||
strict=False, | ||
assign=True, | ||
) | ||
print("Missing keys: ", missing) | ||
print("Unexpected keys: ", unexpected) | ||
|
||
return model | ||
|
||
|
||
class InputManager: | ||
def __init__( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does CoreML support RMSNorm op? It will be a lot easier if they do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see it here: https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something existing in Core ML is the translation for torch.norm, which uses Core ML fused reduce_l2_norm kernel
That is to say, we may compute RMS norm by something like
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a slightly different op when eps > 0, although I'm not sure how much it matters in practice.
RMSNorm would actually be something like x / torch.norm([x/sqrt(n), sqrt(eps)])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update to use norm, and then maybe we can work on a longer term solution of support rmsnorm in CoreML @YifanShenSZ?