Skip to content

Commit 5a594a7

Browse files
authored
Add dtype, fix RMS norm for FP16 (#8641)
* Add dtype, fix RMS norm for FP16 * up * up * Update llama_transformer.py
1 parent 2be4e94 commit 5a594a7

File tree

4 files changed

+223
-92
lines changed

4 files changed

+223
-92
lines changed

examples/apple/coreml/llama/export.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# pyre-strict
44

55
import argparse
6-
import json
76

87
import sys
98

@@ -20,10 +19,11 @@
2019
from executorch.exir.passes import MemoryPlanningPass
2120
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
2221
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
23-
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
22+
from executorch.exir.program._program import to_edge_with_preserved_ops
23+
from executorch.extension.export_util.utils import save_pte_program
2424

2525
sys.path.insert(0, ".")
26-
from llama_transformer import InputManager, ModelArgs, Transformer
26+
from llama_transformer import InputManager, load_model
2727

2828

2929
class SplitLinearModule(torch.nn.Module):
@@ -141,42 +141,23 @@ def main() -> None:
141141
default=8,
142142
help="Maximum number of splits to divide linear layers",
143143
)
144+
parser.add_argument(
145+
"--dtype",
146+
type=str,
147+
default="fp16",
148+
)
144149

145150
export_args = parser.parse_args()
146-
params_path = export_args.params
147-
checkpoint_path = export_args.checkpoint
148-
149-
# Load model args
150-
with open(params_path, "r") as f:
151-
params = json.loads(f.read())
152-
153-
args = ModelArgs(
154-
max_seq_len=export_args.max_seq_length,
155-
generate_full_logits=False,
151+
model = load_model(
152+
export_args.checkpoint,
153+
export_args.params,
154+
max_seq_length=export_args.max_seq_length,
156155
use_cache_list=export_args.use_cache_list,
157-
**params,
158-
)
159-
160-
with torch.device("meta"):
161-
model = Transformer(args)
162-
163-
checkpoint = torch.load(
164-
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
165156
)
166-
if "model" in checkpoint:
167-
checkpoint = checkpoint["model"]
168157

169-
missing, unexpected = model.load_state_dict(
170-
checkpoint,
171-
strict=False,
172-
assign=True,
173-
)
174-
print("Missing keys: ", missing)
175-
print("Unexpected keys: ", unexpected)
176-
177-
float_dtype = torch.float16 # dtype for model/inputs
178-
model.eval()
179-
model.to(float_dtype)
158+
float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[
159+
export_args.dtype
160+
] # dtype for model/inputs
180161

181162
if export_args.embedding_quantize:
182163
bitwidth, group_size = export_args.embedding_quantize.split(",")
@@ -197,7 +178,8 @@ def main() -> None:
197178
model, export_args.target_split_size, export_args.max_splits
198179
)
199180

200-
model = model.to(float_dtype)
181+
model.eval()
182+
model.to(float_dtype)
201183

202184
op_linear_quantizer_config = None
203185
if export_args.coreml_quantize == "b4w":
@@ -217,7 +199,10 @@ def main() -> None:
217199

218200
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
219201
minimum_deployment_target=ct.target.iOS18,
220-
compute_precision=ct.precision(ct.precision.FLOAT16.value),
202+
compute_precision={
203+
torch.float16: ct.precision.FLOAT16,
204+
torch.float32: ct.precision.FLOAT32,
205+
}[float_dtype],
221206
compute_unit=ct.ComputeUnit.CPU_AND_NE,
222207
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
223208
op_linear_quantizer_config=op_linear_quantizer_config,
@@ -232,11 +217,11 @@ def main() -> None:
232217
)
233218

234219
input_manager = InputManager(
235-
n_layers=args.n_layers,
236-
max_batch_size=args.max_batch_size,
237-
n_kv_heads=args.n_kv_heads,
238-
max_seq_length=args.max_seq_len,
239-
head_dim=args.head_dim,
220+
n_layers=model.params.n_layers,
221+
max_batch_size=model.params.max_batch_size,
222+
n_kv_heads=model.params.n_kv_heads,
223+
max_seq_length=model.params.max_seq_len,
224+
head_dim=model.params.head_dim,
240225
use_cache_list=export_args.use_cache_list,
241226
seq_length=export_args.seq_length,
242227
dtype=float_dtype,
@@ -245,10 +230,20 @@ def main() -> None:
245230
)
246231
example_inputs = input_manager.get_inputs(tokens=[0])
247232

248-
edge_manager = export_to_edge(
233+
ep = torch.export.export(
249234
model,
250235
example_inputs,
251-
edge_compile_config=EdgeCompileConfig(
236+
)
237+
print("Exported program")
238+
print(ep)
239+
240+
edge_manager = to_edge_with_preserved_ops(
241+
ep,
242+
preserve_ops=[
243+
torch.ops.aten.scaled_dot_product_attention.default,
244+
torch.ops.aten.linalg_vector_norm.default,
245+
],
246+
compile_config=EdgeCompileConfig(
252247
_check_ir_validity=False,
253248
_skip_type_promotion=(float_dtype == torch.float16),
254249
_skip_dim_order=True,

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 93 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
import torch
1414
import torch.nn.functional as F
1515

16-
from executorch.examples.models.llama.llama_transformer import RMSNorm
17-
1816
from executorch.examples.models.llama.rope import (
1917
hf_apply_rotary_emb,
2018
hf_precompute_freqs_cis,
@@ -25,29 +23,6 @@
2523
from torch import nn
2624

2725

28-
# These are just to prevent to_edge from decomposing SDPA
29-
# A better method is to use the to_edge_transform_and_lower API for CoreML
30-
# and not decompose SDPA
31-
@torch.library.custom_op("coreml::sdpa", mutates_args=())
32-
def sdpa(
33-
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
34-
) -> torch.Tensor:
35-
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
36-
return torch.ops.aten.scaled_dot_product_attention.default(
37-
q, k, v, attn_mask=attn_mask
38-
)
39-
40-
41-
@torch.library.register_fake("coreml::sdpa")
42-
def _(
43-
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
44-
) -> torch.Tensor:
45-
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
46-
expected_shape = list(q.shape)
47-
expected_shape[-1] = v.shape[-1]
48-
return q.new_empty(expected_shape)
49-
50-
5126
def find_multiple(n: int, k: int) -> int:
5227
if n % k == 0:
5328
return n
@@ -121,6 +96,63 @@ def __post_init__(self):
12196
self.head_dim = self.dim // self.n_heads
12297

12398

99+
class RMSNorm(torch.nn.Module):
100+
def __init__(self, dim: int, eps: float = 1e-6):
101+
"""
102+
Initialize the RMSNorm normalization layer.
103+
104+
Args:
105+
dim (int): The dimension of the input tensor.
106+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
107+
108+
Attributes:
109+
eps (float): A small value added to the denominator for numerical stability.
110+
weight (nn.Parameter): Learnable scaling parameter.
111+
112+
"""
113+
super().__init__()
114+
self.dim = dim
115+
self.eps = eps
116+
self.weight = nn.Parameter(torch.ones(dim))
117+
118+
def _norm(self, x):
119+
"""
120+
Apply the RMSNorm normalization to the input tensor.
121+
122+
Args:
123+
x (torch.Tensor): The input tensor.
124+
125+
Returns:
126+
torch.Tensor: The normalized tensor.
127+
128+
"""
129+
# CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable
130+
# We instead use (x * sqrt(n)) / norm(x, dim=-1)
131+
# Using torch.norm and preserving this op in CoreML improves stability
132+
# Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator
133+
# In future, we want to add CoreML support for the functional RMSNorm op
134+
# We have yet to do large scale evaluations on the numeric stability of this solution, but note that
135+
# it appears better than what exists currently (removing FP32 casts and using FP16)
136+
rms_norm_eps0 = (
137+
x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
138+
) / torch.linalg.vector_norm(x, dim=-1, keepdim=True)
139+
return rms_norm_eps0
140+
141+
def forward(self, x):
142+
"""
143+
Forward pass through the RMSNorm layer.
144+
145+
Args:
146+
x (torch.Tensor): The input tensor.
147+
148+
Returns:
149+
torch.Tensor: The output tensor after applying RMSNorm.
150+
151+
"""
152+
output = self._norm(x)
153+
return output * self.weight
154+
155+
124156
class Rope(torch.nn.Module):
125157
def __init__(self, params: ModelArgs):
126158
super().__init__()
@@ -304,12 +336,11 @@ def forward(
304336
k = k.repeat_interleave(self.n_rep, dim=1)
305337
v = v.repeat_interleave(self.n_rep, dim=1)
306338

307-
output = torch.ops.coreml.sdpa(q, k, v, attn_mask)
308-
339+
output = torch.ops.aten.scaled_dot_product_attention.default(
340+
q, k, v, attn_mask=attn_mask
341+
)
309342
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
310-
311343
output = self.wo(output)
312-
313344
return output, new_k, new_v
314345

315346

@@ -413,6 +444,39 @@ def forward(
413444
return logits, k_out, v_out
414445

415446

447+
def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
448+
import json
449+
450+
with open(params_path, "r") as f:
451+
params = json.loads(f.read())
452+
453+
args = ModelArgs(
454+
max_seq_len=max_seq_length,
455+
generate_full_logits=False,
456+
use_cache_list=use_cache_list,
457+
**params,
458+
)
459+
460+
with torch.device("meta"):
461+
model = Transformer(args)
462+
463+
checkpoint = torch.load(
464+
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
465+
)
466+
if "model" in checkpoint:
467+
checkpoint = checkpoint["model"]
468+
469+
missing, unexpected = model.load_state_dict(
470+
checkpoint,
471+
strict=False,
472+
assign=True,
473+
)
474+
print("Missing keys: ", missing)
475+
print("Unexpected keys: ", unexpected)
476+
477+
return model
478+
479+
416480
class InputManager:
417481
def __init__(
418482
self,

examples/apple/coreml/llama/readme.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This directory contains ANE-friendly Llama models.
44

55
Export model with:
66
```
7-
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w
7+
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w --dtype fp16
88
```
99

1010
(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)
@@ -17,6 +17,12 @@ Run model with:
1717
python run.py -m /path/to/model.pte -t /path/to/tokenizer.model --prompt "Once upon a time,"
1818
```
1919

20+
The runner can also be used to run an eager model model to compare with CoreML numerics (--use_eager). In this case, you must specify:
21+
* --checkpoint
22+
* --dtype
23+
* --max_seq_length
24+
* --seq_length
25+
2026
(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)
2127

2228

0 commit comments

Comments
 (0)