|
23 | 23 | from torch import nn
|
24 | 24 |
|
25 | 25 |
|
26 |
| -# These are just to prevent to_edge from decomposing SDPA |
27 |
| -# A better method is to use the to_edge_transform_and_lower API for CoreML |
28 |
| -# and not decompose SDPA |
29 |
| -@torch.library.custom_op("coreml::sdpa", mutates_args=()) |
30 |
| -def sdpa( |
31 |
| - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor |
32 |
| -) -> torch.Tensor: |
33 |
| - """Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion.""" |
34 |
| - return torch.ops.aten.scaled_dot_product_attention.default( |
35 |
| - q, k, v, attn_mask=attn_mask |
36 |
| - ) |
37 |
| - |
38 |
| - |
39 |
| -@torch.library.register_fake("coreml::sdpa") |
40 |
| -def _( |
41 |
| - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor |
42 |
| -) -> torch.Tensor: |
43 |
| - """Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing.""" |
44 |
| - expected_shape = list(q.shape) |
45 |
| - expected_shape[-1] = v.shape[-1] |
46 |
| - return q.new_empty(expected_shape) |
47 |
| - |
48 |
| - |
49 | 26 | def find_multiple(n: int, k: int) -> int:
|
50 | 27 | if n % k == 0:
|
51 | 28 | return n
|
@@ -149,10 +126,15 @@ def _norm(self, x):
|
149 | 126 | torch.Tensor: The normalized tensor.
|
150 | 127 |
|
151 | 128 | """
|
152 |
| - x_max, _ = torch.abs(x).max(-1, keepdim=True) |
153 |
| - x = x / x_max # This makes the op more stable in FP16 |
154 |
| - eps = self.eps / (x_max * x_max) |
155 |
| - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + eps) |
| 129 | + # CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable |
| 130 | + # We instead use (x * sqrt(n)) / norm(x, dim=-1) |
| 131 | + # Using torch.norm and preserving this op in CoreML improves stability |
| 132 | + # Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator |
| 133 | + # In future, we want to add CoreML support for the functional RMSNorm op |
| 134 | + rms_norm_eps0 = ( |
| 135 | + x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype)) |
| 136 | + ) / torch.linalg.vector_norm(x, dim=-1, keepdim=True) |
| 137 | + return rms_norm_eps0 |
156 | 138 |
|
157 | 139 | def forward(self, x):
|
158 | 140 | """
|
@@ -352,7 +334,9 @@ def forward(
|
352 | 334 | k = k.repeat_interleave(self.n_rep, dim=1)
|
353 | 335 | v = v.repeat_interleave(self.n_rep, dim=1)
|
354 | 336 |
|
355 |
| - output = torch.ops.coreml.sdpa(q, k, v, attn_mask) |
| 337 | + output = torch.ops.aten.scaled_dot_product_attention.default( |
| 338 | + q, k, v, attn_mask=attn_mask |
| 339 | + ) |
356 | 340 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
357 | 341 | output = self.wo(output)
|
358 | 342 | return output, new_k, new_v
|
|
0 commit comments