Skip to content

Commit 722af90

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Be able to set hidden_dim explicitly
Summary: Currently it is using a magic formula. However, there are cases where we need to set explicitly Currently, it should be a no-op, existing models and CI wouldn't break. Reviewed By: kimishpatel, iseeyuan, Jack-Khuu, shreydesai Differential Revision: D54121131 fbshipit-source-id: 996534456c9e83e10e7fdaa6973c754d37e734d4
1 parent 68a16df commit 722af90

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

examples/models/llama2/model.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class ModelArgs:
8686
n_heads: int = 32
8787
n_kv_heads: Optional[int] = None
8888
vocab_size: int = -1 # defined later by tokenizer
89+
hidden_dim: Optional[int] = None
8990
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
9091
ffn_dim_multiplier: Optional[float] = None
9192
norm_eps: float = 1e-5
@@ -281,10 +282,18 @@ def forward(
281282

282283

283284
class FeedForward(nn.Module):
284-
def __init__(self, dim: int, hidden_dim: int, multiple_of: int):
285+
def __init__(self, args: ModelArgs):
285286
super().__init__()
286-
hidden_dim = int(2 * hidden_dim / 3)
287-
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
287+
dim = args.dim
288+
hidden_dim = args.hidden_dim
289+
if hidden_dim is None:
290+
# If hidden_dim is not explicitly set in the ModelArgs,
291+
# then calculate implicitly based on dim and also multiple of `args.multiple_of`
292+
multiple_of = args.multiple_of
293+
hidden_dim = 4 * dim
294+
hidden_dim = int(2 * hidden_dim / 3)
295+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
296+
288297
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
289298
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
290299
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
@@ -294,18 +303,22 @@ def forward(self, x):
294303

295304

296305
class ConditionalFeedForward(nn.Module):
297-
def __init__(self, config):
306+
def __init__(self, args: ModelArgs):
298307
super().__init__()
299-
hidden_dim = 4 * config.dim
300-
hidden_dim = int(2 * hidden_dim / 3)
301-
hidden_dim = config.multiple_of * (
302-
(hidden_dim + config.multiple_of - 1) // config.multiple_of
303-
)
304-
self.w1 = nn.Parameter(torch.randn(config.num_experts, hidden_dim, config.dim))
305-
self.w2 = nn.Parameter(torch.randn(config.num_experts, hidden_dim, config.dim))
306-
self.w3 = nn.Parameter(torch.randn(config.num_experts, hidden_dim, config.dim))
307-
self.num_experts = config.num_experts
308-
self.dim = config.dim
308+
self.dim = args.dim
309+
hidden_dim = args.hidden_dim
310+
if hidden_dim is None:
311+
# If hidden_dim is not explicitly set in the ModelArgs,
312+
# then calculate implicitly based on dim and also multiple of `args.multiple_of`
313+
multiple_of = args.multiple_of
314+
hidden_dim = 4 * self.dim
315+
hidden_dim = int(2 * hidden_dim / 3)
316+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
317+
318+
self.w1 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim))
319+
self.w2 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim))
320+
self.w3 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim))
321+
self.num_experts = args.num_experts
309322

310323
def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor:
311324
w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D]
@@ -346,11 +359,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
346359
if args.moe:
347360
self.block_sparse_moe = MOEFeedForward(args)
348361
else:
349-
self.feed_forward = FeedForward(
350-
dim=args.dim,
351-
hidden_dim=4 * args.dim,
352-
multiple_of=args.multiple_of,
353-
)
362+
self.feed_forward = FeedForward(args)
354363
self.layer_id = layer_id
355364
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
356365
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

0 commit comments

Comments
 (0)