Skip to content

Commit 1ddc7e0

Browse files
author
Martin Yuan
committed
etLLM: add options to apply embedding or output.
1 parent 9c51e58 commit 1ddc7e0

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,16 @@ def __init__(self, params: ModelArgs):
170170
self.params = params
171171
self.vocab_size = params.vocab_size
172172
self.n_layers = params.n_layers
173+
self.apply_embedding = params.apply_embedding
174+
self.apply_output = params.apply_output
173175

174-
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
176+
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) if self.apply_embedding else None
175177
self.rope = Rope(params)
176178
self.layers = torch.nn.ModuleList()
177179
for layer_id in range(params.n_layers):
178180
self.layers.append(TransformerBlock(layer_id, params, self.rope))
179181
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
180-
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
182+
self.output = nn.Linear(params.dim, params.vocab_size, bias=False) if self.apply_output else None
181183
self.use_kv_cache = params.use_kv_cache
182184
self.generate_full_logits = params.generate_full_logits
183185
self.max_seq_len = params.max_seq_len
@@ -195,7 +197,7 @@ def forward(
195197
raise ValueError(
196198
"You cannot specify both tokens and h at the same time, and must specify either one"
197199
)
198-
if tokens is not None and h is None:
200+
if self.apply_embedding and tokens is not None and h is None:
199201
h = self.tok_embeddings(tokens)
200202

201203
if attn_options is None:
@@ -219,7 +221,8 @@ def forward(
219221

220222
h = self.norm(h)
221223

222-
logits = self.output(h)
224+
if self.apply_output:
225+
logits = self.output(h)
223226

224227
if self.output_prune_map is not None:
225228
# expand to original size so that downstream applications can use the logits as-is.

examples/models/llama/model_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class ModelArgs:
3434
input_prune_map: Optional[Dict[int, int]] = None
3535
# A dictionary mapping from pruned token-id to original token-id
3636
output_prune_map: Optional[Dict[int, int]] = None
37+
apply_embedding: bool = True # Use embedding inside the transformer
38+
apply_output: bool = True # Use output layer (unembedding) inside the transformer
3739
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
3840
rope_theta: Optional[float] = (
3941
None # The official name to override self.rope_freq_base.

0 commit comments

Comments
 (0)