Skip to content

Commit 1814f8d

Browse files
authored
Merge pull request #68 from google/fix-xla-downcast
Fix downcasting and upcasting similar to https://github.com/google/ge…
2 parents 1bcd536 + ce4912d commit 1814f8d

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

gemma/model_xla.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,14 @@ def _norm(self, x):
150150
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
151151

152152
def forward(self, x):
153-
x = self._norm(x.float()).type_as(x)
153+
# Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
154+
# See https://github.com/huggingface/transformers/pull/29402
155+
output = self._norm(x.float())
154156
if self.add_unit_offset:
155-
output = x * (1 + self.weight)
157+
output = output * (1 + self.weight.float())
156158
else:
157-
output = x * self.weight
158-
return output
159+
output = output * self.weight.float()
160+
return output.type_as(x)
159161

160162

161163
class GemmaMLP(nn.Module):
@@ -621,7 +623,10 @@ def forward(
621623

622624
hidden_states = self.embedder(input_token_ids)
623625
# Gemma normalizes the embedding by sqrt(hidden_size).
624-
hidden_states = hidden_states * (self.config.hidden_size**0.5)
626+
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
627+
# See https://github.com/huggingface/transformers/pull/29402
628+
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
629+
hidden_states = hidden_states * normalizer
625630
# hidden_states should be [batch_size, input_len, hidden_size]
626631

627632
hidden_states = self.model(

0 commit comments

Comments
 (0)