Skip to content

Commit 254fc51

Browse files
committed
feat: Add the option to tie_word_embeddings
Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent bbea338 commit 254fc51

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

torchchat/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@ class TransformerArgs:
281281
# Optional biases
282282
attention_bias: bool = False
283283
feed_forward_bias: bool = False
284+
# Whether or not to tie the input word embeddings to the output
285+
tie_word_embeddings: bool = False
284286

285287
def __post_init__(self):
286288
if self.n_local_heads == -1:
@@ -632,12 +634,20 @@ def __init__(self, config: TransformerArgs) -> None:
632634
if config.stage_idx == config.n_stages - 1:
633635
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
634636
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
637+
if config.tie_word_embeddings:
638+
self.output.weight = self.tok_embeddings.weight
635639
else:
636640
self.norm = None
637641
self.output = None
638642

639643
self.max_batch_size = -1
640644
self.max_seq_length = -1
645+
self._register_load_state_dict_pre_hook(self.load_hook)
646+
647+
def load_hook(self, state_dict, prefix, *args):
648+
"""Handle tied embeddings at load time"""
649+
if self.config.tie_word_embeddings:
650+
state_dict.setdefault("model.output.weight", state_dict["model.tok_embeddings.weight"])
641651

642652
def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
643653
if (

0 commit comments

Comments
 (0)