@@ -657,7 +657,7 @@ def __init__(self, config: TransformerArgs) -> None:
657
657
self .layers [str (layer_id )] = TransformerBlock (config )
658
658
659
659
if config .stage_idx == config .n_stages - 1 :
660
- self .norm = RMSNorm (config .dim , eps = config .norm_eps )
660
+ self .norm = nn . RMSNorm (config .dim , eps = config .norm_eps )
661
661
self .output = nn .Linear (config .dim , config .vocab_size , bias = False )
662
662
if config .tie_word_embeddings :
663
663
self .output .weight = self .tok_embeddings .weight
@@ -751,8 +751,8 @@ def __init__(self, config: TransformerArgs) -> None:
751
751
super ().__init__ ()
752
752
self .attention = Attention (config )
753
753
self .feed_forward = FeedForward (config )
754
- self .ffn_norm = RMSNorm (config .dim , config .norm_eps )
755
- self .attention_norm = RMSNorm (config .dim , config .norm_eps )
754
+ self .ffn_norm = nn . RMSNorm (config .dim , config .norm_eps )
755
+ self .attention_norm = nn . RMSNorm (config .dim , config .norm_eps )
756
756
# None for llama architecture, set for granite architectures
757
757
self .residual_multiplier = (
758
758
config .residual_multiplier
@@ -928,20 +928,6 @@ def forward(self, x: Tensor) -> Tensor:
928
928
return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
929
929
930
930
931
- class RMSNorm (nn .Module ):
932
- def __init__ (self , dim : int , eps : float = 1e-5 ):
933
- super ().__init__ ()
934
- self .eps = eps
935
- self .weight = nn .Parameter (torch .ones (dim ))
936
-
937
- def _norm (self , x ):
938
- return x * torch .rsqrt (torch .mean (x * x , dim = - 1 , keepdim = True ) + self .eps )
939
-
940
- def forward (self , x : Tensor ) -> Tensor :
941
- output = self ._norm (x .float ()).type_as (x )
942
- return output * self .weight
943
-
944
-
945
931
def apply_scaling (freqs : torch .Tensor , rope_scaling : Dict [str , Any ]):
946
932
# Check for the presence of the required keys
947
933
required_keys = {
0 commit comments