File tree Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -281,6 +281,8 @@ class TransformerArgs:
281
281
# Optional biases
282
282
attention_bias : bool = False
283
283
feed_forward_bias : bool = False
284
+ # Whether or not to tie the input word embeddings to the output
285
+ tie_word_embeddings : bool = False
284
286
285
287
def __post_init__ (self ):
286
288
if self .n_local_heads == - 1 :
@@ -632,12 +634,20 @@ def __init__(self, config: TransformerArgs) -> None:
632
634
if config .stage_idx == config .n_stages - 1 :
633
635
self .norm = RMSNorm (config .dim , eps = config .norm_eps )
634
636
self .output = nn .Linear (config .dim , config .vocab_size , bias = False )
637
+ if config .tie_word_embeddings :
638
+ self .output .weight = self .tok_embeddings .weight
635
639
else :
636
640
self .norm = None
637
641
self .output = None
638
642
639
643
self .max_batch_size = - 1
640
644
self .max_seq_length = - 1
645
+ self ._register_load_state_dict_pre_hook (self .load_hook )
646
+
647
+ def load_hook (self , state_dict , prefix , * args ):
648
+ """Handle tied embeddings at load time"""
649
+ if self .config .tie_word_embeddings :
650
+ state_dict .setdefault ("model.output.weight" , state_dict ["model.tok_embeddings.weight" ])
641
651
642
652
def setup_caches (self , max_batch_size , max_seq_length , cache_lanes : int = 1 ):
643
653
if (
You can’t perform that action at this time.
0 commit comments