@@ -114,15 +114,30 @@ def update(
114
114
return all_data , (out_k_cache , out_v_cache )
115
115
116
116
117
- def _apply_rotary_embedding (
118
- x : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor
119
- ) -> torch .Tensor :
120
- x_r , x_i = x [..., ::2 ], x [..., 1 ::2 ]
121
- x_out_r = x_r * freqs_cos - x_i * freqs_sin
122
- x_out_i = x_r * freqs_sin + x_i * freqs_cos
117
+ class _Rope (nn .Module ):
118
+ def __init__ (self , use_hf_rope ):
119
+ super ().__init__ ()
120
+ self .use_hf_rope = use_hf_rope
121
+
122
+ def forward (
123
+ self , x : torch .Tensor , freqs_cos : torch .Tensor , freqs_sin : torch .Tensor
124
+ ) -> torch .Tensor :
125
+ if self .use_hf_rope :
126
+ if len (freqs_cos .shape ) == 2 :
127
+ freqs_cos = freqs_cos .unsqueeze (0 )
128
+ if len (freqs_sin .shape ) == 2 :
129
+ freqs_sin = freqs_sin .unsqueeze (0 )
130
+ x1 = x [..., : x .shape [- 1 ] // 2 ]
131
+ x2 = x [..., x .shape [- 1 ] // 2 :]
132
+ x_rotated = torch .cat ((- x2 , x1 ), dim = - 1 )
133
+ return x * freqs_cos + x_rotated * freqs_sin
134
+ else :
135
+ x_r , x_i = x [..., ::2 ], x [..., 1 ::2 ]
136
+ x_out_r = x_r * freqs_cos - x_i * freqs_sin
137
+ x_out_i = x_r * freqs_sin + x_i * freqs_cos
123
138
124
- x_out = torch .cat ([x_out_r , x_out_i ], dim = - 1 )
125
- return x_out
139
+ x_out = torch .cat ([x_out_r , x_out_i ], dim = - 1 )
140
+ return x_out
126
141
127
142
128
143
@register_attention ("static" )
@@ -172,6 +187,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
172
187
[StaticVCache (layer_id , i ) for i in range (self .n_kv_heads )]
173
188
)
174
189
self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
190
+ self .rope = _Rope (rope .params .use_hf_rope )
175
191
176
192
def forward (
177
193
self ,
@@ -191,8 +207,8 @@ def forward(
191
207
new_qs = [self .wqs [i ](x ) for i in range (self .n_heads )]
192
208
new_ks = [self .wks [i ](x ) for i in range (self .n_kv_heads )]
193
209
new_vs = [self .wvs [i ](x ) for i in range (self .n_kv_heads )]
194
- new_qs = [_apply_rotary_embedding (q , freqs_cos , freqs_sin ) for q in new_qs ]
195
- new_ks = [_apply_rotary_embedding (k , freqs_cos , freqs_sin ) for k in new_ks ]
210
+ new_qs = [self . rope (q , freqs_cos , freqs_sin ) for q in new_qs ]
211
+ new_ks = [self . rope (k , freqs_cos , freqs_sin ) for k in new_ks ]
196
212
197
213
all_ks = []
198
214
all_vs = []
0 commit comments