Skip to content

Commit 82356a5

Browse files
authored
Update llama_transformer.py
1 parent 3992b27 commit 82356a5

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def _norm(self, x):
131131
# Using torch.norm and preserving this op in CoreML improves stability
132132
# Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator
133133
# In future, we want to add CoreML support for the functional RMSNorm op
134+
# We have yet to do large scale evaluations on the numeric stability of this solution, but note that
135+
# it appears better than what exists currently (removing FP32 casts and using FP16)
134136
rms_norm_eps0 = (
135137
x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
136138
) / torch.linalg.vector_norm(x, dim=-1, keepdim=True)

0 commit comments

Comments
 (0)