@@ -170,14 +170,16 @@ def __init__(self, params: ModelArgs):
170
170
self .params = params
171
171
self .vocab_size = params .vocab_size
172
172
self .n_layers = params .n_layers
173
+ self .apply_embedding = params .apply_embedding
174
+ self .apply_output = params .apply_output
173
175
174
- self .tok_embeddings = nn .Embedding (params .vocab_size , params .dim )
176
+ self .tok_embeddings = nn .Embedding (params .vocab_size , params .dim ) if self . apply_embedding else None
175
177
self .rope = Rope (params )
176
178
self .layers = torch .nn .ModuleList ()
177
179
for layer_id in range (params .n_layers ):
178
180
self .layers .append (TransformerBlock (layer_id , params , self .rope ))
179
181
self .norm = RMSNorm (params .dim , eps = params .norm_eps )
180
- self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
182
+ self .output = nn .Linear (params .dim , params .vocab_size , bias = False ) if self . apply_output else None
181
183
self .use_kv_cache = params .use_kv_cache
182
184
self .generate_full_logits = params .generate_full_logits
183
185
self .max_seq_len = params .max_seq_len
@@ -195,7 +197,7 @@ def forward(
195
197
raise ValueError (
196
198
"You cannot specify both tokens and h at the same time, and must specify either one"
197
199
)
198
- if tokens is not None and h is None :
200
+ if self . apply_embedding and tokens is not None and h is None :
199
201
h = self .tok_embeddings (tokens )
200
202
201
203
if attn_options is None :
@@ -219,7 +221,8 @@ def forward(
219
221
220
222
h = self .norm (h )
221
223
222
- logits = self .output (h )
224
+ if self .apply_output :
225
+ logits = self .output (h )
223
226
224
227
if self .output_prune_map is not None :
225
228
# expand to original size so that downstream applications can use the logits as-is.
0 commit comments