Skip to content

Update coreml transformer with Rope and RMSNorm original definitions #11583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions examples/apple/coreml/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import torch.nn.functional as F
from executorch.examples.models.llama.norm import RMSNorm

from executorch.examples.models.llama.rope import (
hf_apply_rotary_emb,
Expand Down Expand Up @@ -64,6 +65,8 @@ class ModelArgs:
use_scaled_rope: bool = True # Use scaled RoPE, introduced in llama3.1.
# Additional Model Metadata needed at runtime
rope_scale_factor: int = 8
high_freq_factor: int = 4

bos_idx: int = 1
eos_idx: int = 3
bos_count: int = -1 # i.e., a single EOS is used as BOS
Expand All @@ -74,6 +77,12 @@ class ModelArgs:

use_cache_list: bool = True

use_kv_cache: bool = False
enable_dynamic_shape: bool = False

use_qk_norm: bool = False
qk_norm_before_rope: bool = False

def __post_init__(self):
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
Expand All @@ -96,7 +105,7 @@ def __post_init__(self):
self.head_dim = self.dim // self.n_heads


class RMSNorm(torch.nn.Module):
class CoreMLRMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Expand Down Expand Up @@ -160,10 +169,16 @@ def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
if self.params.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
self.precompute_freqs_cis = partial(
hf_precompute_freqs_cis,
partial_rotary_factor=self.params.partial_rotary_factor,
)
else:
self.precompute_freqs_cis = partial(
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
precompute_freqs_cis,
use_scaled=self.params.use_scaled_rope,
scale_factor=self.params.rope_scale_factor,
high_freq_factor=self.params.high_freq_factor,
)
freqs_cos, freqs_sin = self.precompute_freqs_cis(
self.params.head_dim,
Expand Down Expand Up @@ -303,6 +318,14 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):

self.rope = rope

self.use_qk_norm = args.use_qk_norm
self.qk_norm_before_rope = args.qk_norm_before_rope
if self.use_qk_norm:
q_norm_dim = self.head_dim
k_norm_dim = self.head_dim
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)

def forward(
self,
x: torch.Tensor,
Expand All @@ -327,6 +350,10 @@ def forward(
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if self.use_qk_norm and not self.qk_norm_before_rope:
q = self.q_norm_fn(q)
k = self.k_norm_fn(k)

new_k = k
new_v = v

Expand Down
Loading