Skip to content

feat: support force downcast after FastRMSNorm multiply for Gemma #1658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 21, 2024

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Mar 20, 2024

This PR adds force_downcast_after to FastRMSNorm.forward which is used in the Gemma model. References huggingface/transformers#29402 and huggingface/transformers#29729

Setting force_downcast_after=True will perform the hidden_states * weight multiplication in f32 and then downcast to half. This differs slightly from the current implementation which first casts the hidden_states to a half and then multiples.

@@ -687,7 +687,7 @@ def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight")
return cls(weight, eps)

def forward(self, hidden_states, residual=None):
def forward(self, hidden_states, residual=None, force_downcast_after=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd personally use a different method forward_downcast_after.

Also this only triggers when hidden_size > 8192 meaning it won't trigger for gemma.
Having a dedicated method seems simpler.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that makes alot of sense thanks for the comments. I've removed the branching logic from FastRMSNorm and added forward_downcast_after into the gemma modeling code

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better !

@Narsil Narsil merged commit 6f15ac6 into main Mar 21, 2024
@Narsil Narsil deleted the fix-gemma-bugs branch March 21, 2024 09:25
kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this pull request Apr 29, 2024
…ggingface#1658)

This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is
used in the Gemma model. References
huggingface/transformers#29402 and
huggingface/transformers#29729

Setting `force_downcast_after=True` will perform the `hidden_states *
weight` multiplication in f32 and then downcast to half. This differs
slightly from the current implementation which first casts the
`hidden_states` to a half and then multiples.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants