12
12
13
13
import torch
14
14
import torch .nn .functional as F
15
+ from executorch .examples .models .llama .norm import RMSNorm
15
16
16
17
from executorch .examples .models .llama .rope import (
17
18
hf_apply_rotary_emb ,
@@ -64,6 +65,8 @@ class ModelArgs:
64
65
use_scaled_rope : bool = True # Use scaled RoPE, introduced in llama3.1.
65
66
# Additional Model Metadata needed at runtime
66
67
rope_scale_factor : int = 8
68
+ high_freq_factor : int = 4
69
+
67
70
bos_idx : int = 1
68
71
eos_idx : int = 3
69
72
bos_count : int = - 1 # i.e., a single EOS is used as BOS
@@ -74,6 +77,12 @@ class ModelArgs:
74
77
75
78
use_cache_list : bool = True
76
79
80
+ use_kv_cache : bool = False
81
+ enable_dynamic_shape : bool = False
82
+
83
+ use_qk_norm : bool = False
84
+ qk_norm_before_rope : bool = False
85
+
77
86
def __post_init__ (self ):
78
87
if self .n_kv_heads is None :
79
88
self .n_kv_heads = self .n_heads
@@ -96,7 +105,7 @@ def __post_init__(self):
96
105
self .head_dim = self .dim // self .n_heads
97
106
98
107
99
- class RMSNorm (torch .nn .Module ):
108
+ class CoreMLRMSNorm (torch .nn .Module ):
100
109
def __init__ (self , dim : int , eps : float = 1e-6 ):
101
110
"""
102
111
Initialize the RMSNorm normalization layer.
@@ -160,10 +169,16 @@ def __init__(self, params: ModelArgs):
160
169
super ().__init__ ()
161
170
self .params = params
162
171
if self .params .use_hf_rope :
163
- self .precompute_freqs_cis = hf_precompute_freqs_cis
172
+ self .precompute_freqs_cis = partial (
173
+ hf_precompute_freqs_cis ,
174
+ partial_rotary_factor = self .params .partial_rotary_factor ,
175
+ )
164
176
else :
165
177
self .precompute_freqs_cis = partial (
166
- precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
178
+ precompute_freqs_cis ,
179
+ use_scaled = self .params .use_scaled_rope ,
180
+ scale_factor = self .params .rope_scale_factor ,
181
+ high_freq_factor = self .params .high_freq_factor ,
167
182
)
168
183
freqs_cos , freqs_sin = self .precompute_freqs_cis (
169
184
self .params .head_dim ,
@@ -303,6 +318,14 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
303
318
304
319
self .rope = rope
305
320
321
+ self .use_qk_norm = args .use_qk_norm
322
+ self .qk_norm_before_rope = args .qk_norm_before_rope
323
+ if self .use_qk_norm :
324
+ q_norm_dim = self .head_dim
325
+ k_norm_dim = self .head_dim
326
+ self .q_norm_fn = RMSNorm (q_norm_dim , eps = args .norm_eps )
327
+ self .k_norm_fn = RMSNorm (k_norm_dim , eps = args .norm_eps )
328
+
306
329
def forward (
307
330
self ,
308
331
x : torch .Tensor ,
@@ -327,6 +350,10 @@ def forward(
327
350
k = k .transpose (1 , 2 )
328
351
v = v .transpose (1 , 2 )
329
352
353
+ if self .use_qk_norm and not self .qk_norm_before_rope :
354
+ q = self .q_norm_fn (q )
355
+ k = self .k_norm_fn (k )
356
+
330
357
new_k = k
331
358
new_v = v
332
359
0 commit comments