Skip to content

Commit 1652a15

Browse files
committed
Add dtype, fix RMS norm for FP16
1 parent 366d87e commit 1652a15

File tree

4 files changed

+198
-65
lines changed

4 files changed

+198
-65
lines changed

examples/apple/coreml/llama/export.py

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# pyre-strict
44

55
import argparse
6-
import json
76

87
import sys
98

@@ -23,7 +22,7 @@
2322
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
2423

2524
sys.path.insert(0, ".")
26-
from llama_transformer import InputManager, ModelArgs, Transformer
25+
from llama_transformer import InputManager, load_model
2726

2827

2928
class SplitLinearModule(torch.nn.Module):
@@ -141,42 +140,23 @@ def main() -> None:
141140
default=8,
142141
help="Maximum number of splits to divide linear layers",
143142
)
143+
parser.add_argument(
144+
"--dtype",
145+
type=str,
146+
default="fp16",
147+
)
144148

145149
export_args = parser.parse_args()
146-
params_path = export_args.params
147-
checkpoint_path = export_args.checkpoint
148-
149-
# Load model args
150-
with open(params_path, "r") as f:
151-
params = json.loads(f.read())
152-
153-
args = ModelArgs(
154-
max_seq_len=export_args.max_seq_length,
155-
generate_full_logits=False,
150+
model = load_model(
151+
export_args.checkpoint,
152+
export_args.params,
153+
max_seq_length=export_args.max_seq_length,
156154
use_cache_list=export_args.use_cache_list,
157-
**params,
158-
)
159-
160-
with torch.device("meta"):
161-
model = Transformer(args)
162-
163-
checkpoint = torch.load(
164-
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
165155
)
166-
if "model" in checkpoint:
167-
checkpoint = checkpoint["model"]
168156

169-
missing, unexpected = model.load_state_dict(
170-
checkpoint,
171-
strict=False,
172-
assign=True,
173-
)
174-
print("Missing keys: ", missing)
175-
print("Unexpected keys: ", unexpected)
176-
177-
float_dtype = torch.float16 # dtype for model/inputs
178-
model.eval()
179-
model.to(float_dtype)
157+
float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[
158+
export_args.dtype
159+
] # dtype for model/inputs
180160

181161
if export_args.embedding_quantize:
182162
bitwidth, group_size = export_args.embedding_quantize.split(",")
@@ -197,7 +177,8 @@ def main() -> None:
197177
model, export_args.target_split_size, export_args.max_splits
198178
)
199179

200-
model = model.to(float_dtype)
180+
model.eval()
181+
model.to(float_dtype)
201182

202183
op_linear_quantizer_config = None
203184
if export_args.coreml_quantize == "b4w":
@@ -217,7 +198,10 @@ def main() -> None:
217198

218199
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
219200
minimum_deployment_target=ct.target.iOS18,
220-
compute_precision=ct.precision(ct.precision.FLOAT16.value),
201+
compute_precision={
202+
torch.float16: ct.precision.FLOAT16,
203+
torch.float32: ct.precision.FLOAT32,
204+
}[float_dtype],
221205
compute_unit=ct.ComputeUnit.CPU_AND_NE,
222206
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
223207
op_linear_quantizer_config=op_linear_quantizer_config,
@@ -232,11 +216,11 @@ def main() -> None:
232216
)
233217

234218
input_manager = InputManager(
235-
n_layers=args.n_layers,
236-
max_batch_size=args.max_batch_size,
237-
n_kv_heads=args.n_kv_heads,
238-
max_seq_length=args.max_seq_len,
239-
head_dim=args.head_dim,
219+
n_layers=model.params.n_layers,
220+
max_batch_size=model.params.max_batch_size,
221+
n_kv_heads=model.params.n_kv_heads,
222+
max_seq_length=model.params.max_seq_len,
223+
head_dim=model.params.head_dim,
240224
use_cache_list=export_args.use_cache_list,
241225
seq_length=export_args.seq_length,
242226
dtype=float_dtype,

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
import torch
1414
import torch.nn.functional as F
1515

16-
from executorch.examples.models.llama.llama_transformer import RMSNorm
17-
1816
from executorch.examples.models.llama.rope import (
1917
hf_apply_rotary_emb,
2018
hf_precompute_freqs_cis,
@@ -121,6 +119,55 @@ def __post_init__(self):
121119
self.head_dim = self.dim // self.n_heads
122120

123121

122+
class RMSNorm(torch.nn.Module):
123+
def __init__(self, dim: int, eps: float = 1e-6):
124+
"""
125+
Initialize the RMSNorm normalization layer.
126+
127+
Args:
128+
dim (int): The dimension of the input tensor.
129+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
130+
131+
Attributes:
132+
eps (float): A small value added to the denominator for numerical stability.
133+
weight (nn.Parameter): Learnable scaling parameter.
134+
135+
"""
136+
super().__init__()
137+
self.dim = dim
138+
self.eps = eps
139+
self.weight = nn.Parameter(torch.ones(dim))
140+
141+
def _norm(self, x):
142+
"""
143+
Apply the RMSNorm normalization to the input tensor.
144+
145+
Args:
146+
x (torch.Tensor): The input tensor.
147+
148+
Returns:
149+
torch.Tensor: The normalized tensor.
150+
151+
"""
152+
x_max, _ = torch.abs(x).max(-1, keepdim=True)
153+
x = x / x_max # This makes the op more stable in FP16
154+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
155+
156+
def forward(self, x):
157+
"""
158+
Forward pass through the RMSNorm layer.
159+
160+
Args:
161+
x (torch.Tensor): The input tensor.
162+
163+
Returns:
164+
torch.Tensor: The output tensor after applying RMSNorm.
165+
166+
"""
167+
output = self._norm(x)
168+
return output * self.weight
169+
170+
124171
class Rope(torch.nn.Module):
125172
def __init__(self, params: ModelArgs):
126173
super().__init__()
@@ -305,11 +352,8 @@ def forward(
305352
v = v.repeat_interleave(self.n_rep, dim=1)
306353

307354
output = torch.ops.coreml.sdpa(q, k, v, attn_mask)
308-
309355
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
310-
311356
output = self.wo(output)
312-
313357
return output, new_k, new_v
314358

315359

@@ -413,6 +457,39 @@ def forward(
413457
return logits, k_out, v_out
414458

415459

460+
def load_model(checkpoint_path, params_path, max_seq_length, use_cache_list):
461+
import json
462+
463+
with open(params_path, "r") as f:
464+
params = json.loads(f.read())
465+
466+
args = ModelArgs(
467+
max_seq_len=max_seq_length,
468+
generate_full_logits=False,
469+
use_cache_list=use_cache_list,
470+
**params,
471+
)
472+
473+
with torch.device("meta"):
474+
model = Transformer(args)
475+
476+
checkpoint = torch.load(
477+
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
478+
)
479+
if "model" in checkpoint:
480+
checkpoint = checkpoint["model"]
481+
482+
missing, unexpected = model.load_state_dict(
483+
checkpoint,
484+
strict=False,
485+
assign=True,
486+
)
487+
print("Missing keys: ", missing)
488+
print("Unexpected keys: ", unexpected)
489+
490+
return model
491+
492+
416493
class InputManager:
417494
def __init__(
418495
self,

examples/apple/coreml/llama/readme.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This directory contains ANE-friendly Llama models.
44

55
Export model with:
66
```
7-
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w
7+
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w --dtype fp16
88
```
99

1010
(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)
@@ -17,6 +17,12 @@ Run model with:
1717
python run.py -m /path/to/model.pte -t /path/to/tokenizer.model --prompt "Once upon a time,"
1818
```
1919

20+
The runner can also be used to run an eager model model to compare with CoreML numerics (--use_eager). In this case, you must specify:
21+
* --checkpoint
22+
* --dtype
23+
* --max_seq_length
24+
* --seq_length
25+
2026
(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)
2127

2228

examples/apple/coreml/llama/run.py

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
sys.path.insert(0, ".")
1212
from executorch.examples.models.llama.runner.generation import next_token
1313
from executorch.examples.models.llama.tokenizer import tiktoken
14-
from llama_transformer import InputManager
14+
from llama_transformer import InputManager, load_model
1515

1616

1717
class Tokenizer:
@@ -71,28 +71,90 @@ def main() -> None:
7171
type=float,
7272
default=0.9,
7373
)
74+
parser.add_argument(
75+
"--use_eager",
76+
action="store_true",
77+
)
78+
parser.add_argument(
79+
"-p",
80+
"--params",
81+
type=str,
82+
default=None,
83+
)
84+
parser.add_argument(
85+
"-c",
86+
"--checkpoint",
87+
type=str,
88+
default=None,
89+
)
90+
parser.add_argument("--dtype", type=str, choices=["fp16", "fp32"], default=None)
91+
parser.add_argument(
92+
"--seq_length",
93+
type=int,
94+
default=None,
95+
)
96+
parser.add_argument(
97+
"--max_seq_length",
98+
type=int,
99+
default=None,
100+
)
101+
parser.add_argument(
102+
"--cache_size",
103+
type=int,
104+
default=None,
105+
)
74106

75107
args = parser.parse_args()
76108

77109
tokenizer = Tokenizer(args.tokenizer)
78110

79111
runtime = Runtime.get()
80-
program = runtime.load_program(args.model)
81-
method = program.load_method("forward")
82-
83-
metadata = method.metadata
84-
print("Method metadata: ", metadata, "\n\n")
85-
86-
assert (
87-
metadata.num_inputs() == 6
88-
), "Do not export with --use_cache_list for use in pybindings"
89-
# k_cache input
90-
n_layers, max_batch_size, n_kv_heads, cache_size, head_dim = (
91-
metadata.input_tensor_meta(3).sizes()
92-
)
93-
94-
# mask input
95-
seq_length, max_seq_length = metadata.input_tensor_meta(5).sizes()
112+
if args.use_eager:
113+
assert args.params is not None
114+
assert args.checkpoint is not None
115+
assert args.dtype is not None
116+
assert args.max_seq_length is not None
117+
assert args.seq_length is not None
118+
119+
max_seq_length = args.max_seq_length
120+
seq_length = args.seq_length
121+
model = load_model(
122+
args.checkpoint,
123+
args.params,
124+
max_seq_length=max_seq_length,
125+
use_cache_list=False,
126+
)
127+
n_layers = model.params.n_layers
128+
max_batch_size = model.params.max_batch_size
129+
n_kv_heads = model.params.n_kv_heads
130+
head_dim = model.params.head_dim
131+
cache_size = args.cache_size
132+
133+
float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[
134+
args.dtype
135+
] # dtype for model/inputs
136+
model.eval()
137+
model.to(float_dtype)
138+
else:
139+
program = runtime.load_program(args.model)
140+
method = program.load_method("forward")
141+
142+
metadata = method.metadata
143+
print("Method metadata: ", metadata, "\n\n")
144+
145+
assert (
146+
metadata.num_inputs() == 6
147+
), "Do not export with --use_cache_list for use in pybindings"
148+
# k_cache input
149+
n_layers, max_batch_size, n_kv_heads, cache_size, head_dim = (
150+
metadata.input_tensor_meta(3).sizes()
151+
)
152+
float_dtype = {5: torch.float16, 6: torch.float32}[
153+
metadata.input_tensor_meta(3).dtype()
154+
]
155+
156+
# mask input
157+
seq_length, max_seq_length = metadata.input_tensor_meta(5).sizes()
96158

97159
input_manager = InputManager(
98160
n_layers=n_layers,
@@ -102,7 +164,7 @@ def main() -> None:
102164
head_dim=head_dim,
103165
use_cache_list=False,
104166
seq_length=seq_length,
105-
dtype=torch.float16,
167+
dtype=float_dtype,
106168
minus_infinity=-30000.0,
107169
cache_size=cache_size,
108170
)
@@ -117,7 +179,11 @@ def main() -> None:
117179
tokens
118180
)
119181
processed_tokens = len(tokens) - len(remaining_tokens)
120-
logits, k, v = method.execute(inputs)
182+
if args.use_eager:
183+
logits, k, v = model(*inputs)
184+
else:
185+
logits, k, v = method.execute(inputs)
186+
121187
input_manager.update(
122188
input_length=processed_tokens, new_k_caches=k, new_v_caches=v
123189
)

0 commit comments

Comments
 (0)