@@ -85,6 +85,7 @@ class ModelArgs:
85
85
n_kv_heads : Optional [int ] = None
86
86
vocab_size : int = - 1 # defined later by tokenizer
87
87
hidden_dim : Optional [int ] = None
88
+ head_dim : Optional [int ] = None # Optional customized head_dim
88
89
multiple_of : int = 256 # make SwiGLU hidden layer size multiple of large power of 2
89
90
ffn_dim_multiplier : Optional [float ] = None
90
91
norm_eps : float = 1e-5
@@ -142,6 +143,9 @@ def __post_init__(self):
142
143
hidden_dim = int (self .ffn_dim_multiplier * hidden_dim )
143
144
self .hidden_dim = find_multiple (hidden_dim , multiple_of )
144
145
146
+ if self .head_dim is None :
147
+ self .head_dim = self .dim // self .n_heads
148
+
145
149
146
150
class KVCache (nn .Module ):
147
151
def __init__ (
@@ -272,7 +276,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
272
276
self .n_local_heads = self .n_heads // model_parallel_size
273
277
self .n_local_kv_heads = self .n_kv_heads // model_parallel_size
274
278
self .n_rep = self .n_local_heads // self .n_local_kv_heads
275
- self .head_dim = args .dim // self . n_heads
279
+ self .head_dim = args .head_dim
276
280
self .max_batch_size = args .max_batch_size
277
281
self .max_seq_len = args .max_seq_len
278
282
self .dim = args .dim
@@ -304,7 +308,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
304
308
)
305
309
self .SDPA = SDPA (
306
310
kv_cache = self .kv_cache ,
307
- dim = self .dim ,
311
+ dim = self .n_local_heads * self . head_dim ,
308
312
head_dim = self .head_dim ,
309
313
n_rep = self .n_rep ,
310
314
max_seq_len = self .max_seq_len ,
@@ -425,7 +429,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
425
429
self .use_kv_cache = args .use_kv_cache
426
430
self .n_heads = args .n_heads
427
431
self .dim = args .dim
428
- self .head_dim = args .dim // args . n_heads
432
+ self .head_dim = args .head_dim
429
433
self .attention = Attention (args , layer_id )
430
434
if args .moe :
431
435
self .block_sparse_moe = MOEFeedForward (args )
@@ -472,7 +476,7 @@ def __init__(self, params: ModelArgs):
472
476
precompute_freqs_cis , use_scaled = params .use_scaled_rope
473
477
)
474
478
freqs_cos , freqs_sin = self .precompute_freqs_cis (
475
- params .dim // params . n_heads ,
479
+ params .head_dim ,
476
480
(
477
481
params .max_seq_len # Normal llama2.
478
482
if params .ffn_dim_multiplier is None
0 commit comments