@@ -143,6 +143,69 @@ def __post_init__(self):
143
143
self .hidden_dim = find_multiple (hidden_dim , multiple_of )
144
144
145
145
146
+ class Rope (torch .nn .Module ):
147
+ def __init__ (self , params : ModelArgs ):
148
+ super ().__init__ ()
149
+ self .params = params
150
+ if self .params .use_hf_rope :
151
+ self .precompute_freqs_cis = hf_precompute_freqs_cis
152
+ else :
153
+ self .precompute_freqs_cis = partial (
154
+ precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
155
+ )
156
+ freqs_cos , freqs_sin = self .precompute_freqs_cis (
157
+ self .params .dim // self .params .n_heads ,
158
+ (
159
+ self .params .max_seq_len # Normal llama2.
160
+ if self .params .ffn_dim_multiplier is None
161
+ else self .params .max_seq_len * 2 # Sharded checkpoint.
162
+ ),
163
+ self .params .rope_freq_base ,
164
+ )
165
+ self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
166
+ self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
167
+ if self .params .use_hf_rope :
168
+ self .apply_rotary_emb = hf_apply_rotary_emb
169
+ else :
170
+ self .apply_rotary_emb = RotaryEmbedding ()
171
+
172
+ def forward (
173
+ self ,
174
+ q : torch .Tensor ,
175
+ k : torch .Tensor ,
176
+ seq_len : int ,
177
+ input_pos : Optional [torch .Tensor ] = None ,
178
+ ):
179
+ if self .params .use_kv_cache :
180
+ assert (
181
+ input_pos is not None
182
+ ), "input_pos must be provided when use_kv_cache is True"
183
+
184
+ if self .params .enable_dynamic_shape :
185
+ # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
186
+ input_pos_item = input_pos [- 1 ].item ()
187
+ torch ._check_is_size (input_pos_item )
188
+ torch ._check (input_pos_item < self .params .max_seq_len )
189
+ # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
190
+ freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seq_len )
191
+ # pyre-ignore: Incompatible parameter type [6]
192
+ freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seq_len )
193
+ else :
194
+ # When not using dynamic shape, use of the .item results in
195
+ # symints, due to querying the data from tensor.
196
+ # this path avoids that for mps backend, although probably mps backend
197
+ # can support dynamic shape?
198
+ freqs_cos = self .freqs_cos [input_pos ]
199
+ freqs_sin = self .freqs_sin [input_pos ]
200
+
201
+ else :
202
+ assert input_pos is None , "input_pos is unused when use_kv_cache is False"
203
+ freqs_cos = self .freqs_cos [:seq_len ]
204
+ freqs_sin = self .freqs_sin [:seq_len ]
205
+ q , k = self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
206
+ return q , k
207
+
208
+
146
209
class KVCache (nn .Module ):
147
210
def __init__ (
148
211
self ,
@@ -262,7 +325,7 @@ def forward(
262
325
263
326
264
327
class Attention (nn .Module ):
265
- def __init__ (self , args : ModelArgs , layer_id : int ):
328
+ def __init__ (self , args : ModelArgs , layer_id : int , rope : Rope ):
266
329
super ().__init__ ()
267
330
self .use_kv_cache = args .use_kv_cache
268
331
self .n_heads = args .n_heads
@@ -284,6 +347,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
284
347
285
348
self .layer_id = layer_id
286
349
350
+ self .rope = rope
351
+
287
352
causal_mask = torch .tril (
288
353
torch .ones (
289
354
self .max_seq_len ,
@@ -300,7 +365,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
300
365
args .max_seq_len ,
301
366
self .n_kv_heads ,
302
367
self .head_dim ,
303
- not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
368
+ not args .use_sdpa_with_kv_cache_op , # if we are using the custom op don't transpose the cache. Expect untransposed q k v
304
369
args .enable_dynamic_shape ,
305
370
)
306
371
self .SDPA = SDPA (
@@ -311,16 +376,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
311
376
max_seq_len = self .max_seq_len ,
312
377
enable_dynamic_shape = args .enable_dynamic_shape ,
313
378
)
314
- if args .use_hf_rope :
315
- self .apply_rotary_emb = hf_apply_rotary_emb
316
- else :
317
- self .apply_rotary_emb = RotaryEmbedding ()
318
379
319
380
def forward (
320
381
self ,
321
382
x : torch .Tensor ,
322
- freqs_cos : torch .Tensor ,
323
- freqs_sin : torch .Tensor ,
324
383
input_pos : Optional [torch .Tensor ] = None ,
325
384
):
326
385
bsz , seqlen , _ = x .shape
@@ -333,7 +392,7 @@ def forward(
333
392
v = v .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
334
393
335
394
# RoPE relative positional embeddings
336
- q , k = self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
395
+ q , k = self .rope . forward (q , k , seqlen , input_pos )
337
396
338
397
if self .use_kv_cache :
339
398
assert input_pos is not None
@@ -421,24 +480,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
421
480
422
481
423
482
class TransformerBlock (nn .Module ):
424
- def __init__ (self , layer_id : int , args : ModelArgs ):
483
+ def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
425
484
super ().__init__ ()
426
485
self .use_kv_cache = args .use_kv_cache
427
486
self .n_heads = args .n_heads
428
487
self .dim = args .dim
429
488
self .head_dim = args .dim // args .n_heads
430
- self .attention = Attention (args , layer_id )
489
+ self .attention = Attention (args , layer_id , rope )
431
490
if args .moe :
432
491
self .block_sparse_moe = MOEFeedForward (args )
433
492
else :
434
493
self .feed_forward = FeedForward (args )
435
494
self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
436
495
self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
437
496
438
- def forward (self , x , freqs_cos , freqs_sin , input_pos = None ): # x: 1xN
439
- h = self .attention .forward (
440
- self .attention_norm (x ), freqs_cos , freqs_sin , input_pos
441
- )
497
+ def forward (self , x , input_pos = None ): # x: 1xN
498
+ h = self .attention .forward (self .attention_norm (x ), input_pos )
442
499
443
500
h = x + h
444
501
if hasattr (self , "block_sparse_moe" ):
@@ -456,33 +513,17 @@ def __init__(self, params: ModelArgs):
456
513
self .n_layers = params .n_layers
457
514
458
515
self .tok_embeddings = nn .Embedding (params .vocab_size , params .dim )
516
+ self .rope = Rope (params )
459
517
self .layers = torch .nn .ModuleList ()
460
518
for layer_id in range (params .n_layers ):
461
- self .layers .append (TransformerBlock (layer_id , params ))
519
+ self .layers .append (TransformerBlock (layer_id , params , self . rope ))
462
520
self .norm = RMSNorm (params .dim , eps = params .norm_eps )
463
521
self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
464
522
self .use_kv_cache = params .use_kv_cache
465
523
self .generate_full_logits = params .generate_full_logits
466
524
self .max_seq_len = params .max_seq_len
467
525
self .input_prune_map = params .input_prune_map
468
526
self .output_prune_map = params .output_prune_map
469
- if params .use_hf_rope :
470
- self .precompute_freqs_cis = hf_precompute_freqs_cis
471
- else :
472
- self .precompute_freqs_cis = partial (
473
- precompute_freqs_cis , use_scaled = params .use_scaled_rope
474
- )
475
- freqs_cos , freqs_sin = self .precompute_freqs_cis (
476
- params .dim // params .n_heads ,
477
- (
478
- params .max_seq_len # Normal llama2.
479
- if params .ffn_dim_multiplier is None
480
- else params .max_seq_len * 2 # Sharded checkpoint.
481
- ),
482
- params .rope_freq_base ,
483
- )
484
- self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
485
- self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
486
527
487
528
def forward (
488
529
self ,
@@ -498,42 +539,9 @@ def forward(
498
539
)
499
540
if tokens is not None and h is None :
500
541
h = self .tok_embeddings (tokens )
501
- seqlen = h .shape [1 ]
502
-
503
- if self .use_kv_cache :
504
- assert (
505
- input_pos is not None
506
- ), "input_pos must be provided when use_kv_cache is True"
507
-
508
- if self .params .enable_dynamic_shape :
509
- # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
510
- input_pos_item = input_pos [- 1 ].item ()
511
- torch ._check_is_size (input_pos_item )
512
- torch ._check (input_pos_item < self .params .max_seq_len )
513
- # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
514
- freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seqlen )
515
- # pyre-ignore: Incompatible parameter type [6]
516
- freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seqlen )
517
- else :
518
- # When not using dynamic shape, use of the .item results in
519
- # symints, due to querying the data from tensor.
520
- # this path avoids that for mps backend, although probably mps backend
521
- # can support dynamic shape?
522
- freqs_cos = self .freqs_cos [input_pos ]
523
- freqs_sin = self .freqs_sin [input_pos ]
524
-
525
- else :
526
- assert input_pos is None , "input_pos is unused when use_kv_cache is False"
527
- freqs_cos = self .freqs_cos [:seqlen ]
528
- freqs_sin = self .freqs_sin [:seqlen ]
529
542
530
543
for layer in self .layers :
531
- h = layer (
532
- h ,
533
- freqs_cos ,
534
- freqs_sin ,
535
- input_pos ,
536
- )
544
+ h = layer (h , input_pos )
537
545
538
546
if not self .generate_full_logits :
539
547
# Only the last logit is used for the new generated token
0 commit comments