Skip to content

Commit 8ca715c

Browse files
author
Martin Yuan
committed
etLLM: add options to apply embedding or output.
1 parent 9c51e58 commit 8ca715c

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,24 @@ def __init__(self, params: ModelArgs):
170170
self.params = params
171171
self.vocab_size = params.vocab_size
172172
self.n_layers = params.n_layers
173+
self.apply_embedding = params.apply_embedding
174+
self.apply_output = params.apply_output
173175

174-
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
176+
self.tok_embeddings = (
177+
nn.Embedding(params.vocab_size, params.dim)
178+
if self.apply_embedding
179+
else None
180+
)
175181
self.rope = Rope(params)
176182
self.layers = torch.nn.ModuleList()
177183
for layer_id in range(params.n_layers):
178184
self.layers.append(TransformerBlock(layer_id, params, self.rope))
179185
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
180-
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
186+
self.output = (
187+
nn.Linear(params.dim, params.vocab_size, bias=False)
188+
if self.apply_output
189+
else None
190+
)
181191
self.use_kv_cache = params.use_kv_cache
182192
self.generate_full_logits = params.generate_full_logits
183193
self.max_seq_len = params.max_seq_len
@@ -195,7 +205,7 @@ def forward(
195205
raise ValueError(
196206
"You cannot specify both tokens and h at the same time, and must specify either one"
197207
)
198-
if tokens is not None and h is None:
208+
if self.apply_embedding and tokens is not None and h is None:
199209
h = self.tok_embeddings(tokens)
200210

201211
if attn_options is None:
@@ -219,7 +229,8 @@ def forward(
219229

220230
h = self.norm(h)
221231

222-
logits = self.output(h)
232+
if self.apply_output:
233+
logits = self.output(h)
223234

224235
if self.output_prune_map is not None:
225236
# expand to original size so that downstream applications can use the logits as-is.

examples/models/llama/model_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class ModelArgs:
3434
input_prune_map: Optional[Dict[int, int]] = None
3535
# A dictionary mapping from pruned token-id to original token-id
3636
output_prune_map: Optional[Dict[int, int]] = None
37+
apply_embedding: bool = True # Use embedding inside the transformer
38+
apply_output: bool = True # Use output layer (unembedding) inside the transformer
3739
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
3840
rope_theta: Optional[float] = (
3941
None # The official name to override self.rope_freq_base.

0 commit comments

Comments
 (0)