Skip to content

Commit ba9aba1

Browse files
Replace RMSNorm by nn.RMSNorm
1 parent f4ae60f commit ba9aba1

File tree

1 file changed

+3
-17
lines changed

1 file changed

+3
-17
lines changed

torchchat/model.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def __init__(self, config: TransformerArgs) -> None:
657657
self.layers[str(layer_id)] = TransformerBlock(config)
658658

659659
if config.stage_idx == config.n_stages - 1:
660-
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
660+
self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
661661
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
662662
if config.tie_word_embeddings:
663663
self.output.weight = self.tok_embeddings.weight
@@ -751,8 +751,8 @@ def __init__(self, config: TransformerArgs) -> None:
751751
super().__init__()
752752
self.attention = Attention(config)
753753
self.feed_forward = FeedForward(config)
754-
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
755-
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
754+
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
755+
self.attention_norm = nn.RMSNorm(config.dim, config.norm_eps)
756756
# None for llama architecture, set for granite architectures
757757
self.residual_multiplier = (
758758
config.residual_multiplier
@@ -928,20 +928,6 @@ def forward(self, x: Tensor) -> Tensor:
928928
return self.w2(F.silu(self.w1(x)) * self.w3(x))
929929

930930

931-
class RMSNorm(nn.Module):
932-
def __init__(self, dim: int, eps: float = 1e-5):
933-
super().__init__()
934-
self.eps = eps
935-
self.weight = nn.Parameter(torch.ones(dim))
936-
937-
def _norm(self, x):
938-
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
939-
940-
def forward(self, x: Tensor) -> Tensor:
941-
output = self._norm(x.float()).type_as(x)
942-
return output * self.weight
943-
944-
945931
def apply_scaling(freqs: torch.Tensor, rope_scaling: Dict[str, Any]):
946932
# Check for the presence of the required keys
947933
required_keys = {

0 commit comments

Comments
 (0)