@@ -86,6 +86,7 @@ class ModelArgs:
86
86
n_heads : int = 32
87
87
n_kv_heads : Optional [int ] = None
88
88
vocab_size : int = - 1 # defined later by tokenizer
89
+ hidden_dim : Optional [int ] = None
89
90
multiple_of : int = 256 # make SwiGLU hidden layer size multiple of large power of 2
90
91
ffn_dim_multiplier : Optional [float ] = None
91
92
norm_eps : float = 1e-5
@@ -281,10 +282,18 @@ def forward(
281
282
282
283
283
284
class FeedForward (nn .Module ):
284
- def __init__ (self , dim : int , hidden_dim : int , multiple_of : int ):
285
+ def __init__ (self , args : ModelArgs ):
285
286
super ().__init__ ()
286
- hidden_dim = int (2 * hidden_dim / 3 )
287
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1 ) // multiple_of )
287
+ dim = args .dim
288
+ hidden_dim = args .hidden_dim
289
+ if hidden_dim is None :
290
+ # If hidden_dim is not explicitly set in the ModelArgs,
291
+ # then calculate implicitly based on dim and also multiple of `args.multiple_of`
292
+ multiple_of = args .multiple_of
293
+ hidden_dim = 4 * dim
294
+ hidden_dim = int (2 * hidden_dim / 3 )
295
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1 ) // multiple_of )
296
+
288
297
self .w1 = nn .Linear (dim , hidden_dim , bias = False )
289
298
self .w2 = nn .Linear (hidden_dim , dim , bias = False )
290
299
self .w3 = nn .Linear (dim , hidden_dim , bias = False )
@@ -294,18 +303,22 @@ def forward(self, x):
294
303
295
304
296
305
class ConditionalFeedForward (nn .Module ):
297
- def __init__ (self , config ):
306
+ def __init__ (self , args : ModelArgs ):
298
307
super ().__init__ ()
299
- hidden_dim = 4 * config .dim
300
- hidden_dim = int (2 * hidden_dim / 3 )
301
- hidden_dim = config .multiple_of * (
302
- (hidden_dim + config .multiple_of - 1 ) // config .multiple_of
303
- )
304
- self .w1 = nn .Parameter (torch .randn (config .num_experts , hidden_dim , config .dim ))
305
- self .w2 = nn .Parameter (torch .randn (config .num_experts , hidden_dim , config .dim ))
306
- self .w3 = nn .Parameter (torch .randn (config .num_experts , hidden_dim , config .dim ))
307
- self .num_experts = config .num_experts
308
- self .dim = config .dim
308
+ self .dim = args .dim
309
+ hidden_dim = args .hidden_dim
310
+ if hidden_dim is None :
311
+ # If hidden_dim is not explicitly set in the ModelArgs,
312
+ # then calculate implicitly based on dim and also multiple of `args.multiple_of`
313
+ multiple_of = args .multiple_of
314
+ hidden_dim = 4 * self .dim
315
+ hidden_dim = int (2 * hidden_dim / 3 )
316
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1 ) // multiple_of )
317
+
318
+ self .w1 = nn .Parameter (torch .randn (args .num_experts , hidden_dim , self .dim ))
319
+ self .w2 = nn .Parameter (torch .randn (args .num_experts , hidden_dim , self .dim ))
320
+ self .w3 = nn .Parameter (torch .randn (args .num_experts , hidden_dim , self .dim ))
321
+ self .num_experts = args .num_experts
309
322
310
323
def forward (self , x : torch .Tensor , expert_indices : torch .Tensor ) -> torch .Tensor :
311
324
w1_weights = self .w1 [expert_indices ].transpose (- 1 , - 2 ) # [T, A, D, D]
@@ -346,11 +359,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
346
359
if args .moe :
347
360
self .block_sparse_moe = MOEFeedForward (args )
348
361
else :
349
- self .feed_forward = FeedForward (
350
- dim = args .dim ,
351
- hidden_dim = 4 * args .dim ,
352
- multiple_of = args .multiple_of ,
353
- )
362
+ self .feed_forward = FeedForward (args )
354
363
self .layer_id = layer_id
355
364
self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
356
365
self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
0 commit comments