Skip to content

Add dtype, fix RMS norm for FP16 #8641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 26, 2025
Merged

Add dtype, fix RMS norm for FP16 #8641

merged 4 commits into from
Feb 26, 2025

Conversation

metascroy
Copy link
Contributor

@metascroy metascroy commented Feb 23, 2025

Llama1B quality in CoreML is bad due to FP16 arithmetic. Here is a sample of generated text:

"Once upon a time,we had our way,as we navigated,through the vast expanse of our understanding,as we journeyed,through the treacherous terrain of our experiences,as we faced our fears,as we stood tall,as we confronted our demons,as we emerged victorious,as we transcend our limits,as we ascend,as we unite,as we become,as we lose ourselves,as we search,as we remember,as we come back,as we return,as we find myself,as I am,as I am not,as I stand tall,as I hear my voice,as I remember,as I forgive,as I am, as I am, as I become, as I lose myself, as I find myself, as I remember, as I come back, as I return, as I find myself, as I am, as I am not, as I stand tall, as I Hear My Voice, as I Remember, as I Find Myself, as I am, as I Become, as I Lose Myself, as I Find Myself, asных, as I Become, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose Myself, as I Lose My"

The corresponding FP16 eager mode model has much better generated text:

"Once upon a time, in a small village nestled between two great mountains, there lived a young girl named Akira. She was a curious and adventurous soul, with a heart full of wonder and a mind full of questions. Akira lived with her grandmother, a wise and kind woman named Kana, who taught her the ways of the world and the secrets of the universe.

One day, while exploring the village, Akira stumbled upon a mysterious shop tucked away in a quiet alley. The sign above the door read "Moonlit Curios and Antiques," and the windows were filled with a dazzling array of strange and exotic objects. Akira felt an inexplicable pull towards the shop, as if the very fabric of the universe was calling to her."

The discrepancy is that the eager mode model actually computes the RMSNorm in FP32 due to a cast operation (which CoreML appears to ignore):

self._norm(x.float()).type_as(x)

Moreover, the norm computation appears unstable in FP16 and gives bad results. We can improve the numeric quality of the norm in FP16 by first dividing x by its maximum absolute value. Here is the generated text from CoreML in FP16 after this change:

"Once upon a time, in a small village nestled in the rolling hills of Provence, there lived a young girl named Colette. Colette was a curious and adventurous soul, with a heart full of wonder and a mind full of questions. She spent most of her days exploring the village, visiting the local market, and listening to the tales of the old villagers.

One day, while wandering through the village, Colette stumbled upon a small, mysterious shop tucked away on a quiet street. The sign above the door read "Curios and Wonders," and the windows were filled with a dazzling array of strange and exotic objects. Colette's curiosity was piqued, and she pushed open the door to reveal a dimly lit interior filled with the scent of old books and dust."

Note, for 4-bit channelwise quantization, the results do not look good even after this change. The ideal solution is to do QAT for llama1B with 4-bit channelwise quantization + FP16 arithmetic.

Copy link

pytorch-bot bot commented Feb 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8641

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Pending

As of commit 82356a5 with merge base 366d87e (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 23, 2025
Copy link

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@metascroy
Copy link
Contributor Author

@YifanShenSZ It is probably a bug for CoreML to ignore the cast. It was presumably added because FP16 arithmetic was not sufficient enough.


"""
x_max, _ = torch.abs(x).max(-1, keepdim=True)
x = x / x_max # This makes the op more stable in FP16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll leave review for somebody who is better at math, but I'll just note that it is not at all obvious to me that this does not change the result of the operation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't because both numerator/denominator are divided by same thing (x_max). Because denominator is under square root, we divide by x_max**2 there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the function produces equal results before and after. But would it be a concern that if we get very small values of x_max and the result of eps = self.eps / (x_max * x_max) could overflow? Should we dynamically use torch.finfo(x.dtype).eps for different dtypes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont fully follow why this rewrite should hold better for fp16? If you are normalizing by max value then I am presuming that rsqrt is suffering from precision loss of fp16? It is not at all clear

@@ -121,6 +119,56 @@ def __post_init__(self):
self.head_dim = self.dim // self.n_heads


class RMSNorm(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does CoreML support RMSNorm op? It will be a lot easier if they do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@YifanShenSZ YifanShenSZ Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something existing in Core ML is the translation for torch.norm, which uses Core ML fused reduce_l2_norm kernel

That is to say, we may compute RMS norm by something like

x / torch.norm(x)

Copy link
Contributor Author

@metascroy metascroy Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a slightly different op when eps > 0, although I'm not sure how much it matters in practice.

RMSNorm would actually be something like x / torch.norm([x/sqrt(n), sqrt(eps)])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update to use norm, and then maybe we can work on a longer term solution of support rmsnorm in CoreML @YifanShenSZ?

@@ -121,6 +119,56 @@ def __post_init__(self):
self.head_dim = self.dim // self.n_heads


class RMSNorm(torch.nn.Module):
Copy link
Contributor Author

@metascroy metascroy Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In sync with CoreML team, we might try using https://pytorch.org/docs/stable/generated/torch.nn.functional.rms_norm.html and then write an CoreML op definition for it here: https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py

@YifanShenSZ mentioned they have a fused norm op that could be used.

Copy link
Collaborator

@YifanShenSZ YifanShenSZ Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are 2 possibilities

  1. (Simpler) As mentioned above, we already have torch.norm translation using Core ML fused reduced_l2_norm, so we may have a RMS norm torch source like x / torch.norm(x)
  2. (Better) I think we can further fuse to Core ML l2_norm, which may directly correspond to RMS norm? (with some restrictions, though, e.g. x must have rank >= 3) We will need to add the translation function in CoreMLTools

@metascroy
Copy link
Contributor Author

I rewrote the RMS norm using norm as @YifanShenSZ suggested. Here is output of generated text in FP16:

"Once upon a time, in a small village nestled in the rolling hills of Provence, there lived a young girl named Sophie. Sophie was a curious and adventurous child, with a mop of curly brown hair and a smile that could light up the darkest of rooms. She loved nothing more than exploring the countryside, discovering hidden streams and secret meadows, and chasing after butterflies.

As she wandered through the village, Sophie would often stop at the local bakery, where the owner, Monsieur LeFleur, would greet her with a warm smile and a warm baguette. Sophie loved the sweet scent of freshly baked bread and the taste of warm pastries, and she would often sneak into the bakery to sample the latest creations.

One day, while Sophie was exploring the village, she stumbled upon a small, mysterious"

Copy link
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good - maybe let's add CI to prevent regression

@metascroy
Copy link
Contributor Author

Looks good - maybe let's add CI to prevent regression

There is currently no CI for this model. Let me create a task for that.

@metascroy metascroy added the partner: apple For backend delegation, kernels, demo, etc. from the 3rd-party partner, Apple label Feb 26, 2025
@metascroy metascroy merged commit 5a594a7 into main Feb 26, 2025
47 of 49 checks passed
@metascroy metascroy deleted the apple-llama-dtype branch February 26, 2025 01:20
@YifanShenSZ
Copy link
Collaborator

Looks good! Maybe also test by setting temperature to 0?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. partner: apple For backend delegation, kernels, demo, etc. from the 3rd-party partner, Apple
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants