@@ -150,12 +150,14 @@ def _norm(self, x):
150
150
return x * torch .rsqrt (x .pow (2 ).mean (- 1 , keepdim = True ) + self .eps )
151
151
152
152
def forward (self , x ):
153
- x = self ._norm (x .float ()).type_as (x )
153
+ # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
154
+ # See https://github.com/huggingface/transformers/pull/29402
155
+ output = self ._norm (x .float ())
154
156
if self .add_unit_offset :
155
- output = x * (1 + self .weight )
157
+ output = output * (1 + self .weight . float () )
156
158
else :
157
- output = x * self .weight
158
- return output
159
+ output = output * self .weight . float ()
160
+ return output . type_as ( x )
159
161
160
162
161
163
class GemmaMLP (nn .Module ):
@@ -621,7 +623,10 @@ def forward(
621
623
622
624
hidden_states = self .embedder (input_token_ids )
623
625
# Gemma normalizes the embedding by sqrt(hidden_size).
624
- hidden_states = hidden_states * (self .config .hidden_size ** 0.5 )
626
+ # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
627
+ # See https://github.com/huggingface/transformers/pull/29402
628
+ normalizer = torch .tensor (self .config .hidden_size ** 0.5 , dtype = hidden_states .dtype )
629
+ hidden_states = hidden_states * normalizer
625
630
# hidden_states should be [batch_size, input_len, hidden_size]
626
631
627
632
hidden_states = self .model (
0 commit comments