Skip to content

Commit 4c4d88b

Browse files
Sheng Feng Wushewu-quic
authored andcommitted
Use transform to replace rms_norm
1 parent adbb1e7 commit 4c4d88b

File tree

5 files changed

+30
-5
lines changed

5 files changed

+30
-5
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ runtime.python_library(
7171
"export_llama_lib.py",
7272
"model.py",
7373
"source_transformation/quantize.py",
74+
"source_transformation/rms_norm.py",
7475
"source_transformation/rope.py",
7576
"source_transformation/sdpa.py",
7677
],

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@
4343
get_qnn_quantizer,
4444
)
4545
from executorch.util.activation_memory_profiler import generate_memory_trace
46-
from torch._export import capture_pre_autograd_graph
4746

4847
from ..model_factory import EagerModelFactory
4948
from .source_transformation.quantize import (
5049
get_quant_embedding_transform,
5150
get_quant_weight_transform,
5251
)
52+
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
5353
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
5454
from .source_transformation.sdpa import (
5555
replace_causal_mask,
@@ -409,6 +409,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
409409
transforms.append(replace_kv_cache_with_simple_kv_cache)
410410
transforms.append(replace_sdpa_with_flex_sdpa)
411411
transforms.append(replace_causal_mask)
412+
transforms.append(replace_rms_norm_with_native_rms_norm)
412413
transforms.append(convert_linear_to_conv2d)
413414

414415
elif args.coreml or args.mps:

examples/models/llama2/llama_transformer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(self, dim: int, eps: float = 1e-6):
3939
4040
"""
4141
super().__init__()
42+
self.dim = dim
4243
self.eps = eps
4344
self.weight = nn.Parameter(torch.ones(dim))
4445

@@ -416,8 +417,8 @@ def __init__(self, layer_id: int, args: ModelArgs):
416417
self.block_sparse_moe = MOEFeedForward(args)
417418
else:
418419
self.feed_forward = FeedForward(args)
419-
self.attention_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps)
420-
self.ffn_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps)
420+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
421+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
421422

422423
def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
423424
h = self.attention.forward(
@@ -443,7 +444,7 @@ def __init__(self, params: ModelArgs):
443444
self.layers = torch.nn.ModuleList()
444445
for layer_id in range(params.n_layers):
445446
self.layers.append(TransformerBlock(layer_id, params))
446-
self.norm = torch.nn.RMSNorm(params.dim, eps=params.norm_eps)
447+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
447448
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
448449
self.use_kv_cache = params.use_kv_cache
449450
self.generate_full_logits = params.generate_full_logits
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.examples.models.llama2.llama_transformer import RMSNorm
9+
10+
11+
def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module):
12+
for name, child in module.named_children():
13+
if isinstance(child, RMSNorm):
14+
rms_norm = torch.nn.RMSNorm(child.dim, eps=child.eps)
15+
rms_norm.weight = child.weight
16+
setattr(
17+
module,
18+
name,
19+
rms_norm,
20+
)
21+
else:
22+
replace_rms_norm_with_native_rms_norm(child)
23+
return module

extension/llm/export/builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def __init__(
7171
verbose: bool = False,
7272
metadata: Optional[dict] = None,
7373
dynamic_shapes: Optional[Any] = None,
74-
export_fn=capture_pre_autograd_graph,
7574
):
7675
self.model = model
7776
# graph module returned from capture_pre_autograd_graph

0 commit comments

Comments
 (0)