@@ -147,6 +147,81 @@ def __post_init__(self):
147
147
self .head_dim = self .dim // self .n_heads
148
148
149
149
150
+ class Rope (torch .nn .Module ):
151
+ def __init__ (self , params : ModelArgs ):
152
+ super ().__init__ ()
153
+ self .params = params
154
+ if self .params .use_hf_rope :
155
+ self .precompute_freqs_cis = hf_precompute_freqs_cis
156
+ else :
157
+ self .precompute_freqs_cis = partial (
158
+ precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
159
+ )
160
+ freqs_cos , freqs_sin = self .precompute_freqs_cis (
161
+ self .params .head_dim ,
162
+ (
163
+ self .params .max_seq_len # Normal llama2.
164
+ if self .params .ffn_dim_multiplier is None
165
+ else self .params .max_seq_len * 2 # Sharded checkpoint.
166
+ ),
167
+ self .params .rope_freq_base ,
168
+ )
169
+ self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
170
+ self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
171
+ if self .params .use_hf_rope :
172
+ self .apply_rotary_emb = hf_apply_rotary_emb
173
+ else :
174
+ self .apply_rotary_emb = RotaryEmbedding ()
175
+
176
+ def forward (
177
+ self ,
178
+ q : torch .Tensor ,
179
+ k : torch .Tensor ,
180
+ freqs_cos : torch .Tensor ,
181
+ freqs_sin : torch .Tensor ,
182
+ ):
183
+ return self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
184
+
185
+ def get_freqs (self , input_pos : Optional [torch .Tensor ], seq_len : int ):
186
+ """
187
+ Get the precomputed frequencies for the given input position and sequence length.
188
+
189
+ Args:
190
+ input_pos (torch.Tensor): The input position tensor.
191
+ seq_len (int): The sequence length.
192
+
193
+ Returns:
194
+ Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length.
195
+ """
196
+ if self .params .use_kv_cache :
197
+ assert (
198
+ input_pos is not None
199
+ ), "input_pos must be provided when use_kv_cache is True"
200
+
201
+ if self .params .enable_dynamic_shape :
202
+ # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
203
+ input_pos_item = input_pos [- 1 ].item ()
204
+ torch ._check_is_size (input_pos_item )
205
+ torch ._check (input_pos_item < self .params .max_seq_len )
206
+ # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
207
+ freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seq_len )
208
+ # pyre-ignore: Incompatible parameter type [6]
209
+ freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seq_len )
210
+ else :
211
+ # When not using dynamic shape, use of the .item results in
212
+ # symints, due to querying the data from tensor.
213
+ # this path avoids that for mps backend, although probably mps backend
214
+ # can support dynamic shape?
215
+ freqs_cos = self .freqs_cos [input_pos ]
216
+ freqs_sin = self .freqs_sin [input_pos ]
217
+
218
+ else :
219
+ assert input_pos is None , "input_pos is unused when use_kv_cache is False"
220
+ freqs_cos = self .freqs_cos [:seq_len ]
221
+ freqs_sin = self .freqs_sin [:seq_len ]
222
+ return freqs_cos , freqs_sin
223
+
224
+
150
225
class KVCache (nn .Module ):
151
226
def __init__ (
152
227
self ,
@@ -266,7 +341,7 @@ def forward(
266
341
267
342
268
343
class Attention (nn .Module ):
269
- def __init__ (self , args : ModelArgs , layer_id : int ):
344
+ def __init__ (self , args : ModelArgs , layer_id : int , rope : Rope ):
270
345
super ().__init__ ()
271
346
self .use_kv_cache = args .use_kv_cache
272
347
self .n_heads = args .n_heads
@@ -287,6 +362,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
287
362
288
363
self .layer_id = layer_id
289
364
365
+ self .rope = rope
366
+
290
367
causal_mask = torch .tril (
291
368
torch .ones (
292
369
self .max_seq_len ,
@@ -303,7 +380,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
303
380
args .max_seq_len ,
304
381
self .n_kv_heads ,
305
382
self .head_dim ,
306
- not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
383
+ not args .use_sdpa_with_kv_cache_op , # if we are using the custom op don't transpose the cache. Expect untransposed q k v
307
384
args .enable_dynamic_shape ,
308
385
)
309
386
self .SDPA = SDPA (
@@ -314,10 +391,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
314
391
max_seq_len = self .max_seq_len ,
315
392
enable_dynamic_shape = args .enable_dynamic_shape ,
316
393
)
317
- if args .use_hf_rope :
318
- self .apply_rotary_emb = hf_apply_rotary_emb
319
- else :
320
- self .apply_rotary_emb = RotaryEmbedding ()
321
394
322
395
def forward (
323
396
self ,
@@ -336,7 +409,7 @@ def forward(
336
409
v = v .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
337
410
338
411
# RoPE relative positional embeddings
339
- q , k = self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
412
+ q , k = self .rope . forward (q , k , freqs_cos , freqs_sin )
340
413
341
414
if self .use_kv_cache :
342
415
assert input_pos is not None
@@ -424,13 +497,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
424
497
425
498
426
499
class TransformerBlock (nn .Module ):
427
- def __init__ (self , layer_id : int , args : ModelArgs ):
500
+ def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
428
501
super ().__init__ ()
429
502
self .use_kv_cache = args .use_kv_cache
430
503
self .n_heads = args .n_heads
431
504
self .dim = args .dim
432
505
self .head_dim = args .head_dim
433
- self .attention = Attention (args , layer_id )
506
+ self .attention = Attention (args , layer_id , rope )
434
507
if args .moe :
435
508
self .block_sparse_moe = MOEFeedForward (args )
436
509
else :
@@ -459,33 +532,17 @@ def __init__(self, params: ModelArgs):
459
532
self .n_layers = params .n_layers
460
533
461
534
self .tok_embeddings = nn .Embedding (params .vocab_size , params .dim )
535
+ self .rope = Rope (params )
462
536
self .layers = torch .nn .ModuleList ()
463
537
for layer_id in range (params .n_layers ):
464
- self .layers .append (TransformerBlock (layer_id , params ))
538
+ self .layers .append (TransformerBlock (layer_id , params , self . rope ))
465
539
self .norm = RMSNorm (params .dim , eps = params .norm_eps )
466
540
self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
467
541
self .use_kv_cache = params .use_kv_cache
468
542
self .generate_full_logits = params .generate_full_logits
469
543
self .max_seq_len = params .max_seq_len
470
544
self .input_prune_map = params .input_prune_map
471
545
self .output_prune_map = params .output_prune_map
472
- if params .use_hf_rope :
473
- self .precompute_freqs_cis = hf_precompute_freqs_cis
474
- else :
475
- self .precompute_freqs_cis = partial (
476
- precompute_freqs_cis , use_scaled = params .use_scaled_rope
477
- )
478
- freqs_cos , freqs_sin = self .precompute_freqs_cis (
479
- params .head_dim ,
480
- (
481
- params .max_seq_len # Normal llama2.
482
- if params .ffn_dim_multiplier is None
483
- else params .max_seq_len * 2 # Sharded checkpoint.
484
- ),
485
- params .rope_freq_base ,
486
- )
487
- self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
488
- self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
489
546
490
547
def forward (
491
548
self ,
@@ -502,33 +559,7 @@ def forward(
502
559
if tokens is not None and h is None :
503
560
h = self .tok_embeddings (tokens )
504
561
seqlen = h .shape [1 ]
505
-
506
- if self .use_kv_cache :
507
- assert (
508
- input_pos is not None
509
- ), "input_pos must be provided when use_kv_cache is True"
510
-
511
- if self .params .enable_dynamic_shape :
512
- # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
513
- input_pos_item = input_pos [- 1 ].item ()
514
- torch ._check_is_size (input_pos_item )
515
- torch ._check (input_pos_item < self .params .max_seq_len )
516
- # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
517
- freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seqlen )
518
- # pyre-ignore: Incompatible parameter type [6]
519
- freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seqlen )
520
- else :
521
- # When not using dynamic shape, use of the .item results in
522
- # symints, due to querying the data from tensor.
523
- # this path avoids that for mps backend, although probably mps backend
524
- # can support dynamic shape?
525
- freqs_cos = self .freqs_cos [input_pos ]
526
- freqs_sin = self .freqs_sin [input_pos ]
527
-
528
- else :
529
- assert input_pos is None , "input_pos is unused when use_kv_cache is False"
530
- freqs_cos = self .freqs_cos [:seqlen ]
531
- freqs_sin = self .freqs_sin [:seqlen ]
562
+ freqs_cos , freqs_sin = self .rope .get_freqs (input_pos , seqlen )
532
563
533
564
for layer in self .layers :
534
565
h = layer (
0 commit comments