Skip to content

Add dtype, fix RMS norm for FP16 #8641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 38 additions & 43 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# pyre-strict

import argparse
import json

import sys

Expand All @@ -20,10 +19,11 @@
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from executorch.exir.program._program import to_edge_with_preserved_ops
from executorch.extension.export_util.utils import save_pte_program

sys.path.insert(0, ".")
from llama_transformer import InputManager, ModelArgs, Transformer
from llama_transformer import InputManager, load_model


class SplitLinearModule(torch.nn.Module):
Expand Down Expand Up @@ -141,42 +141,23 @@ def main() -> None:
default=8,
help="Maximum number of splits to divide linear layers",
)
parser.add_argument(
"--dtype",
type=str,
default="fp16",
)

export_args = parser.parse_args()
params_path = export_args.params
checkpoint_path = export_args.checkpoint

# Load model args
with open(params_path, "r") as f:
params = json.loads(f.read())

args = ModelArgs(
max_seq_len=export_args.max_seq_length,
generate_full_logits=False,
model = load_model(
export_args.checkpoint,
export_args.params,
max_seq_length=export_args.max_seq_length,
use_cache_list=export_args.use_cache_list,
**params,
)

with torch.device("meta"):
model = Transformer(args)

checkpoint = torch.load(
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
)
if "model" in checkpoint:
checkpoint = checkpoint["model"]

missing, unexpected = model.load_state_dict(
checkpoint,
strict=False,
assign=True,
)
print("Missing keys: ", missing)
print("Unexpected keys: ", unexpected)

float_dtype = torch.float16 # dtype for model/inputs
model.eval()
model.to(float_dtype)
float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[
export_args.dtype
] # dtype for model/inputs

if export_args.embedding_quantize:
bitwidth, group_size = export_args.embedding_quantize.split(",")
Expand All @@ -197,7 +178,8 @@ def main() -> None:
model, export_args.target_split_size, export_args.max_splits
)

model = model.to(float_dtype)
model.eval()
model.to(float_dtype)

op_linear_quantizer_config = None
if export_args.coreml_quantize == "b4w":
Expand All @@ -217,7 +199,10 @@ def main() -> None:

compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
minimum_deployment_target=ct.target.iOS18,
compute_precision=ct.precision(ct.precision.FLOAT16.value),
compute_precision={
torch.float16: ct.precision.FLOAT16,
torch.float32: ct.precision.FLOAT32,
}[float_dtype],
compute_unit=ct.ComputeUnit.CPU_AND_NE,
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
op_linear_quantizer_config=op_linear_quantizer_config,
Expand All @@ -232,11 +217,11 @@ def main() -> None:
)

input_manager = InputManager(
n_layers=args.n_layers,
max_batch_size=args.max_batch_size,
n_kv_heads=args.n_kv_heads,
max_seq_length=args.max_seq_len,
head_dim=args.head_dim,
n_layers=model.params.n_layers,
max_batch_size=model.params.max_batch_size,
n_kv_heads=model.params.n_kv_heads,
max_seq_length=model.params.max_seq_len,
head_dim=model.params.head_dim,
use_cache_list=export_args.use_cache_list,
seq_length=export_args.seq_length,
dtype=float_dtype,
Expand All @@ -245,10 +230,20 @@ def main() -> None:
)
example_inputs = input_manager.get_inputs(tokens=[0])

edge_manager = export_to_edge(
ep = torch.export.export(
model,
example_inputs,
edge_compile_config=EdgeCompileConfig(
)
print("Exported program")
print(ep)

edge_manager = to_edge_with_preserved_ops(
ep,
preserve_ops=[
torch.ops.aten.scaled_dot_product_attention.default,
torch.ops.aten.linalg_vector_norm.default,
],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_type_promotion=(float_dtype == torch.float16),
_skip_dim_order=True,
Expand Down
122 changes: 93 additions & 29 deletions examples/apple/coreml/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import torch
import torch.nn.functional as F

from executorch.examples.models.llama.llama_transformer import RMSNorm

from executorch.examples.models.llama.rope import (
hf_apply_rotary_emb,
hf_precompute_freqs_cis,
Expand All @@ -25,29 +23,6 @@
from torch import nn


# These are just to prevent to_edge from decomposing SDPA
# A better method is to use the to_edge_transform_and_lower API for CoreML
# and not decompose SDPA
@torch.library.custom_op("coreml::sdpa", mutates_args=())
def sdpa(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
return torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=attn_mask
)


@torch.library.register_fake("coreml::sdpa")
def _(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
expected_shape = list(q.shape)
expected_shape[-1] = v.shape[-1]
return q.new_empty(expected_shape)


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
Expand Down Expand Up @@ -121,6 +96,63 @@ def __post_init__(self):
self.head_dim = self.dim // self.n_heads


class RMSNorm(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does CoreML support RMSNorm op? It will be a lot easier if they do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@YifanShenSZ YifanShenSZ Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something existing in Core ML is the translation for torch.norm, which uses Core ML fused reduce_l2_norm kernel

That is to say, we may compute RMS norm by something like

x / torch.norm(x)

Copy link
Contributor Author

@metascroy metascroy Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a slightly different op when eps > 0, although I'm not sure how much it matters in practice.

RMSNorm would actually be something like x / torch.norm([x/sqrt(n), sqrt(eps)])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update to use norm, and then maybe we can work on a longer term solution of support rmsnorm in CoreML @YifanShenSZ?

Copy link
Contributor Author

@metascroy metascroy Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In sync with CoreML team, we might try using https://pytorch.org/docs/stable/generated/torch.nn.functional.rms_norm.html and then write an CoreML op definition for it here: https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py

@YifanShenSZ mentioned they have a fused norm op that could be used.

Copy link
Collaborator

@YifanShenSZ YifanShenSZ Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are 2 possibilities

  1. (Simpler) As mentioned above, we already have torch.norm translation using Core ML fused reduced_l2_norm, so we may have a RMS norm torch source like x / torch.norm(x)
  2. (Better) I think we can further fuse to Core ML l2_norm, which may directly correspond to RMS norm? (with some restrictions, though, e.g. x must have rank >= 3) We will need to add the translation function in CoreMLTools

def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""
# CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable
# We instead use (x * sqrt(n)) / norm(x, dim=-1)
# Using torch.norm and preserving this op in CoreML improves stability
# Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator
# In future, we want to add CoreML support for the functional RMSNorm op
# We have yet to do large scale evaluations on the numeric stability of this solution, but note that
# it appears better than what exists currently (removing FP32 casts and using FP16)
rms_norm_eps0 = (
x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
) / torch.linalg.vector_norm(x, dim=-1, keepdim=True)
return rms_norm_eps0

def forward(self, x):
"""
Forward pass through the RMSNorm layer.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.

"""
output = self._norm(x)
return output * self.weight


class Rope(torch.nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
Expand Down Expand Up @@ -304,12 +336,11 @@ def forward(
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)

output = torch.ops.coreml.sdpa(q, k, v, attn_mask)

output = torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=attn_mask
)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

output = self.wo(output)

return output, new_k, new_v


Expand Down Expand Up @@ -413,6 +444,39 @@ def forward(
return logits, k_out, v_out


def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
import json

with open(params_path, "r") as f:
params = json.loads(f.read())

args = ModelArgs(
max_seq_len=max_seq_length,
generate_full_logits=False,
use_cache_list=use_cache_list,
**params,
)

with torch.device("meta"):
model = Transformer(args)

checkpoint = torch.load(
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
)
if "model" in checkpoint:
checkpoint = checkpoint["model"]

missing, unexpected = model.load_state_dict(
checkpoint,
strict=False,
assign=True,
)
print("Missing keys: ", missing)
print("Unexpected keys: ", unexpected)

return model


class InputManager:
def __init__(
self,
Expand Down
8 changes: 7 additions & 1 deletion examples/apple/coreml/llama/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This directory contains ANE-friendly Llama models.

Export model with:
```
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w --dtype fp16
```

(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)
Expand All @@ -17,6 +17,12 @@ Run model with:
python run.py -m /path/to/model.pte -t /path/to/tokenizer.model --prompt "Once upon a time,"
```

The runner can also be used to run an eager model model to compare with CoreML numerics (--use_eager). In this case, you must specify:
* --checkpoint
* --dtype
* --max_seq_length
* --seq_length

(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)


Expand Down
Loading
Loading