19
19
precompute_freqs_cis ,
20
20
RotaryEmbedding ,
21
21
)
22
+ from executorch .examples .models .llama .norm import RMSNorm
22
23
23
24
from torch import nn
24
25
@@ -64,6 +65,8 @@ class ModelArgs:
64
65
use_scaled_rope : bool = True # Use scaled RoPE, introduced in llama3.1.
65
66
# Additional Model Metadata needed at runtime
66
67
rope_scale_factor : int = 8
68
+ high_freq_factor : int = 4
69
+
67
70
bos_idx : int = 1
68
71
eos_idx : int = 3
69
72
bos_count : int = - 1 # i.e., a single EOS is used as BOS
@@ -74,6 +77,12 @@ class ModelArgs:
74
77
75
78
use_cache_list : bool = True
76
79
80
+ use_kv_cache : bool = False
81
+ enable_dynamic_shape : bool = False
82
+
83
+ use_qk_norm : bool = False
84
+ qk_norm_before_rope : bool = False
85
+
77
86
def __post_init__ (self ):
78
87
if self .n_kv_heads is None :
79
88
self .n_kv_heads = self .n_heads
@@ -96,7 +105,7 @@ def __post_init__(self):
96
105
self .head_dim = self .dim // self .n_heads
97
106
98
107
99
- class RMSNorm (torch .nn .Module ):
108
+ class CoreMLRMSNorm (torch .nn .Module ):
100
109
def __init__ (self , dim : int , eps : float = 1e-6 ):
101
110
"""
102
111
Initialize the RMSNorm normalization layer.
@@ -160,10 +169,16 @@ def __init__(self, params: ModelArgs):
160
169
super ().__init__ ()
161
170
self .params = params
162
171
if self .params .use_hf_rope :
163
- self .precompute_freqs_cis = hf_precompute_freqs_cis
172
+ self .precompute_freqs_cis = partial (
173
+ hf_precompute_freqs_cis ,
174
+ partial_rotary_factor = self .params .partial_rotary_factor ,
175
+ )
164
176
else :
165
177
self .precompute_freqs_cis = partial (
166
- precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
178
+ precompute_freqs_cis ,
179
+ use_scaled = self .params .use_scaled_rope ,
180
+ scale_factor = self .params .rope_scale_factor ,
181
+ high_freq_factor = self .params .high_freq_factor ,
167
182
)
168
183
freqs_cos , freqs_sin = self .precompute_freqs_cis (
169
184
self .params .head_dim ,
@@ -303,6 +318,15 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
303
318
304
319
self .rope = rope
305
320
321
+ self .use_qk_norm = args .use_qk_norm
322
+ self .qk_norm_before_rope = args .qk_norm_before_rope
323
+ if self .use_qk_norm :
324
+ q_norm_dim = self .head_dim
325
+ k_norm_dim = self .head_dim
326
+ self .q_norm_fn = RMSNorm (q_norm_dim , eps = args .norm_eps )
327
+ self .k_norm_fn = RMSNorm (k_norm_dim , eps = args .norm_eps )
328
+
329
+
306
330
def forward (
307
331
self ,
308
332
x : torch .Tensor ,
@@ -327,6 +351,10 @@ def forward(
327
351
k = k .transpose (1 , 2 )
328
352
v = v .transpose (1 , 2 )
329
353
354
+ if self .use_qk_norm and not self .qk_norm_before_rope :
355
+ q = self .q_norm_fn (q )
356
+ k = self .k_norm_fn (k )
357
+
330
358
new_k = k
331
359
new_v = v
332
360
0 commit comments