-
Notifications
You must be signed in to change notification settings - Fork 608
Add dtype, fix RMS norm for FP16 #8641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8641
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 PendingAs of commit 82356a5 with merge base 366d87e ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
@YifanShenSZ It is probably a bug for CoreML to ignore the cast. It was presumably added because FP16 arithmetic was not sufficient enough. |
|
||
""" | ||
x_max, _ = torch.abs(x).max(-1, keepdim=True) | ||
x = x / x_max # This makes the op more stable in FP16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'll leave review for somebody who is better at math, but I'll just note that it is not at all obvious to me that this does not change the result of the operation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It shouldn't because both numerator/denominator are divided by same thing (x_max). Because denominator is under square root, we divide by x_max**2 there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the function produces equal results before and after. But would it be a concern that if we get very small values of x_max
and the result of eps = self.eps / (x_max * x_max)
could overflow? Should we dynamically use torch.finfo(x.dtype).eps
for different dtypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont fully follow why this rewrite should hold better for fp16? If you are normalizing by max value then I am presuming that rsqrt is suffering from precision loss of fp16? It is not at all clear
@@ -121,6 +119,56 @@ def __post_init__(self): | |||
self.head_dim = self.dim // self.n_heads | |||
|
|||
|
|||
class RMSNorm(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does CoreML support RMSNorm op? It will be a lot easier if they do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something existing in Core ML is the translation for torch.norm, which uses Core ML fused reduce_l2_norm kernel
That is to say, we may compute RMS norm by something like
x / torch.norm(x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a slightly different op when eps > 0, although I'm not sure how much it matters in practice.
RMSNorm would actually be something like x / torch.norm([x/sqrt(n), sqrt(eps)])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update to use norm, and then maybe we can work on a longer term solution of support rmsnorm in CoreML @YifanShenSZ?
@@ -121,6 +119,56 @@ def __post_init__(self): | |||
self.head_dim = self.dim // self.n_heads | |||
|
|||
|
|||
class RMSNorm(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In sync with CoreML team, we might try using https://pytorch.org/docs/stable/generated/torch.nn.functional.rms_norm.html and then write an CoreML op definition for it here: https://github.com/apple/coremltools/blob/main/coremltools/converters/mil/frontend/torch/ops.py
@YifanShenSZ mentioned they have a fused norm op that could be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there are 2 possibilities
- (Simpler) As mentioned above, we already have torch.norm translation using Core ML fused reduced_l2_norm, so we may have a RMS norm torch source like
x / torch.norm(x)
- (Better) I think we can further fuse to Core ML l2_norm, which may directly correspond to RMS norm? (with some restrictions, though, e.g.
x
must haverank >= 3
) We will need to add the translation function in CoreMLTools
I rewrote the RMS norm using norm as @YifanShenSZ suggested. Here is output of generated text in FP16: "Once upon a time, in a small village nestled in the rolling hills of Provence, there lived a young girl named Sophie. Sophie was a curious and adventurous child, with a mop of curly brown hair and a smile that could light up the darkest of rooms. She loved nothing more than exploring the countryside, discovering hidden streams and secret meadows, and chasing after butterflies. As she wandered through the village, Sophie would often stop at the local bakery, where the owner, Monsieur LeFleur, would greet her with a warm smile and a warm baguette. Sophie loved the sweet scent of freshly baked bread and the taste of warm pastries, and she would often sneak into the bakery to sample the latest creations. One day, while Sophie was exploring the village, she stumbled upon a small, mysterious" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good - maybe let's add CI to prevent regression
There is currently no CI for this model. Let me create a task for that. |
Looks good! Maybe also test by setting temperature to 0? |
Llama1B quality in CoreML is bad due to FP16 arithmetic. Here is a sample of generated text:
The corresponding FP16 eager mode model has much better generated text:
The discrepancy is that the eager mode model actually computes the RMSNorm in FP32 due to a cast operation (which CoreML appears to ignore):
Moreover, the norm computation appears unstable in FP16 and gives bad results. We can improve the numeric quality of the norm in FP16 by first dividing x by its maximum absolute value. Here is the generated text from CoreML in FP16 after this change:
Note, for 4-bit channelwise quantization, the results do not look good even after this change. The ideal solution is to do QAT for llama1B with 4-bit channelwise quantization + FP16 arithmetic.