Skip to content

Commit 8c4a721

Browse files
Sheng Feng Wushewu-quic
authored andcommitted
Use transform to replace rms_norm
1 parent adbb1e7 commit 8c4a721

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
get_quant_embedding_transform,
5151
get_quant_weight_transform,
5252
)
53+
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
5354
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
5455
from .source_transformation.sdpa import (
5556
replace_causal_mask,
@@ -409,6 +410,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
409410
transforms.append(replace_kv_cache_with_simple_kv_cache)
410411
transforms.append(replace_sdpa_with_flex_sdpa)
411412
transforms.append(replace_causal_mask)
413+
transforms.append(replace_rms_norm_with_native_rms_norm)
412414
transforms.append(convert_linear_to_conv2d)
413415

414416
elif args.coreml or args.mps:

examples/models/llama2/llama_transformer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(self, dim: int, eps: float = 1e-6):
3939
4040
"""
4141
super().__init__()
42+
self.dim = dim
4243
self.eps = eps
4344
self.weight = nn.Parameter(torch.ones(dim))
4445

@@ -416,8 +417,8 @@ def __init__(self, layer_id: int, args: ModelArgs):
416417
self.block_sparse_moe = MOEFeedForward(args)
417418
else:
418419
self.feed_forward = FeedForward(args)
419-
self.attention_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps)
420-
self.ffn_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps)
420+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
421+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
421422

422423
def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
423424
h = self.attention.forward(
@@ -443,7 +444,7 @@ def __init__(self, params: ModelArgs):
443444
self.layers = torch.nn.ModuleList()
444445
for layer_id in range(params.n_layers):
445446
self.layers.append(TransformerBlock(layer_id, params))
446-
self.norm = torch.nn.RMSNorm(params.dim, eps=params.norm_eps)
447+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
447448
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
448449
self.use_kv_cache = params.use_kv_cache
449450
self.generate_full_logits = params.generate_full_logits
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.examples.models.llama2.llama_transformer import RMSNorm
9+
10+
11+
def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module):
12+
for name, child in module.named_children():
13+
if isinstance(child, RMSNorm):
14+
rms_norm = torch.nn.RMSNorm(child.dim, eps=child.eps)
15+
rms_norm.weight = child.weight
16+
setattr(
17+
module,
18+
name,
19+
rms_norm,
20+
)
21+
else:
22+
replace_rms_norm_with_native_rms_norm(child)
23+
return module

0 commit comments

Comments
 (0)