Skip to content

Commit a1f668d

Browse files
allow customized head_dim (#7065)
Pull Request resolved: #6872 This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/). Similar change in HF: huggingface/transformers#32502 ghstack-source-id: 255340016 Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/) Co-authored-by: Lunwen He <[email protected]>
1 parent 52fa043 commit a1f668d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class ModelArgs:
8585
n_kv_heads: Optional[int] = None
8686
vocab_size: int = -1 # defined later by tokenizer
8787
hidden_dim: Optional[int] = None
88+
head_dim: Optional[int] = None # Optional customized head_dim
8889
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
8990
ffn_dim_multiplier: Optional[float] = None
9091
norm_eps: float = 1e-5
@@ -142,6 +143,9 @@ def __post_init__(self):
142143
hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
143144
self.hidden_dim = find_multiple(hidden_dim, multiple_of)
144145

146+
if self.head_dim is None:
147+
self.head_dim = self.dim // self.n_heads
148+
145149

146150
class KVCache(nn.Module):
147151
def __init__(
@@ -272,7 +276,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
272276
self.n_local_heads = self.n_heads // model_parallel_size
273277
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
274278
self.n_rep = self.n_local_heads // self.n_local_kv_heads
275-
self.head_dim = args.dim // self.n_heads
279+
self.head_dim = args.head_dim
276280
self.max_batch_size = args.max_batch_size
277281
self.max_seq_len = args.max_seq_len
278282
self.dim = args.dim
@@ -304,7 +308,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
304308
)
305309
self.SDPA = SDPA(
306310
kv_cache=self.kv_cache,
307-
dim=self.dim,
311+
dim=self.n_local_heads * self.head_dim,
308312
head_dim=self.head_dim,
309313
n_rep=self.n_rep,
310314
max_seq_len=self.max_seq_len,
@@ -425,7 +429,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
425429
self.use_kv_cache = args.use_kv_cache
426430
self.n_heads = args.n_heads
427431
self.dim = args.dim
428-
self.head_dim = args.dim // args.n_heads
432+
self.head_dim = args.head_dim
429433
self.attention = Attention(args, layer_id)
430434
if args.moe:
431435
self.block_sparse_moe = MOEFeedForward(args)
@@ -472,7 +476,7 @@ def __init__(self, params: ModelArgs):
472476
precompute_freqs_cis, use_scaled=params.use_scaled_rope
473477
)
474478
freqs_cos, freqs_sin = self.precompute_freqs_cis(
475-
params.dim // params.n_heads,
479+
params.head_dim,
476480
(
477481
params.max_seq_len # Normal llama2.
478482
if params.ffn_dim_multiplier is None

0 commit comments

Comments
 (0)