@@ -170,14 +170,24 @@ def __init__(self, params: ModelArgs):
170
170
self .params = params
171
171
self .vocab_size = params .vocab_size
172
172
self .n_layers = params .n_layers
173
+ self .apply_embedding = params .apply_embedding
174
+ self .apply_output = params .apply_output
173
175
174
- self .tok_embeddings = nn .Embedding (params .vocab_size , params .dim )
176
+ self .tok_embeddings = (
177
+ nn .Embedding (params .vocab_size , params .dim )
178
+ if self .apply_embedding
179
+ else None
180
+ )
175
181
self .rope = Rope (params )
176
182
self .layers = torch .nn .ModuleList ()
177
183
for layer_id in range (params .n_layers ):
178
184
self .layers .append (TransformerBlock (layer_id , params , self .rope ))
179
185
self .norm = RMSNorm (params .dim , eps = params .norm_eps )
180
- self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
186
+ self .output = (
187
+ nn .Linear (params .dim , params .vocab_size , bias = False )
188
+ if self .apply_output
189
+ else None
190
+ )
181
191
self .use_kv_cache = params .use_kv_cache
182
192
self .generate_full_logits = params .generate_full_logits
183
193
self .max_seq_len = params .max_seq_len
@@ -195,7 +205,7 @@ def forward(
195
205
raise ValueError (
196
206
"You cannot specify both tokens and h at the same time, and must specify either one"
197
207
)
198
- if tokens is not None and h is None :
208
+ if self . apply_embedding and tokens is not None and h is None :
199
209
h = self .tok_embeddings (tokens )
200
210
201
211
if attn_options is None :
@@ -219,7 +229,8 @@ def forward(
219
229
220
230
h = self .norm (h )
221
231
222
- logits = self .output (h )
232
+ if self .apply_output :
233
+ logits = self .output (h )
223
234
224
235
if self .output_prune_map is not None :
225
236
# expand to original size so that downstream applications can use the logits as-is.
0 commit comments