Skip to content

Commit 99be628

Browse files
YIWENX14facebook-github-bot
authored andcommitted
Update coreml transformer with Rope and RMSNorm original definitions (#11583)
Summary: Pull Request resolved: #11583 Keep the module in sync with https://github.com/pytorch/executorch/tree/main/examples/models/llama Differential Revision: D76468703
1 parent 0d244f9 commit 99be628

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
import torch.nn.functional as F
15+
from executorch.examples.models.llama.norm import RMSNorm
1516

1617
from executorch.examples.models.llama.rope import (
1718
hf_apply_rotary_emb,
@@ -64,6 +65,8 @@ class ModelArgs:
6465
use_scaled_rope: bool = True # Use scaled RoPE, introduced in llama3.1.
6566
# Additional Model Metadata needed at runtime
6667
rope_scale_factor: int = 8
68+
high_freq_factor: int = 4
69+
6770
bos_idx: int = 1
6871
eos_idx: int = 3
6972
bos_count: int = -1 # i.e., a single EOS is used as BOS
@@ -74,6 +77,12 @@ class ModelArgs:
7477

7578
use_cache_list: bool = True
7679

80+
use_kv_cache: bool = False
81+
enable_dynamic_shape: bool = False
82+
83+
use_qk_norm: bool = False
84+
qk_norm_before_rope: bool = False
85+
7786
def __post_init__(self):
7887
if self.n_kv_heads is None:
7988
self.n_kv_heads = self.n_heads
@@ -96,7 +105,7 @@ def __post_init__(self):
96105
self.head_dim = self.dim // self.n_heads
97106

98107

99-
class RMSNorm(torch.nn.Module):
108+
class CoreMLRMSNorm(torch.nn.Module):
100109
def __init__(self, dim: int, eps: float = 1e-6):
101110
"""
102111
Initialize the RMSNorm normalization layer.
@@ -160,10 +169,16 @@ def __init__(self, params: ModelArgs):
160169
super().__init__()
161170
self.params = params
162171
if self.params.use_hf_rope:
163-
self.precompute_freqs_cis = hf_precompute_freqs_cis
172+
self.precompute_freqs_cis = partial(
173+
hf_precompute_freqs_cis,
174+
partial_rotary_factor=self.params.partial_rotary_factor,
175+
)
164176
else:
165177
self.precompute_freqs_cis = partial(
166-
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
178+
precompute_freqs_cis,
179+
use_scaled=self.params.use_scaled_rope,
180+
scale_factor=self.params.rope_scale_factor,
181+
high_freq_factor=self.params.high_freq_factor,
167182
)
168183
freqs_cos, freqs_sin = self.precompute_freqs_cis(
169184
self.params.head_dim,
@@ -303,6 +318,14 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
303318

304319
self.rope = rope
305320

321+
self.use_qk_norm = args.use_qk_norm
322+
self.qk_norm_before_rope = args.qk_norm_before_rope
323+
if self.use_qk_norm:
324+
q_norm_dim = self.head_dim
325+
k_norm_dim = self.head_dim
326+
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
327+
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
328+
306329
def forward(
307330
self,
308331
x: torch.Tensor,
@@ -327,6 +350,10 @@ def forward(
327350
k = k.transpose(1, 2)
328351
v = v.transpose(1, 2)
329352

353+
if self.use_qk_norm and not self.qk_norm_before_rope:
354+
q = self.q_norm_fn(q)
355+
k = self.k_norm_fn(k)
356+
330357
new_k = k
331358
new_v = v
332359

0 commit comments

Comments
 (0)