Skip to content

Commit 3992b27

Browse files
committed
up
1 parent e87347d commit 3992b27

File tree

2 files changed

+26
-31
lines changed

2 files changed

+26
-31
lines changed

examples/apple/coreml/llama/export.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from executorch.exir.passes import MemoryPlanningPass
2020
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
2121
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
22-
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
22+
from executorch.exir.program._program import to_edge_with_preserved_ops
23+
from executorch.extension.export_util.utils import save_pte_program
2324

2425
sys.path.insert(0, ".")
2526
from llama_transformer import InputManager, load_model
@@ -229,10 +230,20 @@ def main() -> None:
229230
)
230231
example_inputs = input_manager.get_inputs(tokens=[0])
231232

232-
edge_manager = export_to_edge(
233+
ep = torch.export.export(
233234
model,
234235
example_inputs,
235-
edge_compile_config=EdgeCompileConfig(
236+
)
237+
print("Exported program")
238+
print(ep)
239+
240+
edge_manager = to_edge_with_preserved_ops(
241+
ep,
242+
preserve_ops=[
243+
torch.ops.aten.scaled_dot_product_attention.default,
244+
torch.ops.aten.linalg_vector_norm.default,
245+
],
246+
compile_config=EdgeCompileConfig(
236247
_check_ir_validity=False,
237248
_skip_type_promotion=(float_dtype == torch.float16),
238249
_skip_dim_order=True,

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,6 @@
2323
from torch import nn
2424

2525

26-
# These are just to prevent to_edge from decomposing SDPA
27-
# A better method is to use the to_edge_transform_and_lower API for CoreML
28-
# and not decompose SDPA
29-
@torch.library.custom_op("coreml::sdpa", mutates_args=())
30-
def sdpa(
31-
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
32-
) -> torch.Tensor:
33-
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
34-
return torch.ops.aten.scaled_dot_product_attention.default(
35-
q, k, v, attn_mask=attn_mask
36-
)
37-
38-
39-
@torch.library.register_fake("coreml::sdpa")
40-
def _(
41-
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
42-
) -> torch.Tensor:
43-
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
44-
expected_shape = list(q.shape)
45-
expected_shape[-1] = v.shape[-1]
46-
return q.new_empty(expected_shape)
47-
48-
4926
def find_multiple(n: int, k: int) -> int:
5027
if n % k == 0:
5128
return n
@@ -149,10 +126,15 @@ def _norm(self, x):
149126
torch.Tensor: The normalized tensor.
150127
151128
"""
152-
x_max, _ = torch.abs(x).max(-1, keepdim=True)
153-
x = x / x_max # This makes the op more stable in FP16
154-
eps = self.eps / (x_max * x_max)
155-
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + eps)
129+
# CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable
130+
# We instead use (x * sqrt(n)) / norm(x, dim=-1)
131+
# Using torch.norm and preserving this op in CoreML improves stability
132+
# Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator
133+
# In future, we want to add CoreML support for the functional RMSNorm op
134+
rms_norm_eps0 = (
135+
x * torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
136+
) / torch.linalg.vector_norm(x, dim=-1, keepdim=True)
137+
return rms_norm_eps0
156138

157139
def forward(self, x):
158140
"""
@@ -352,7 +334,9 @@ def forward(
352334
k = k.repeat_interleave(self.n_rep, dim=1)
353335
v = v.repeat_interleave(self.n_rep, dim=1)
354336

355-
output = torch.ops.coreml.sdpa(q, k, v, attn_mask)
337+
output = torch.ops.aten.scaled_dot_product_attention.default(
338+
q, k, v, attn_mask=attn_mask
339+
)
356340
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
357341
output = self.wo(output)
358342
return output, new_k, new_v

0 commit comments

Comments
 (0)