Skip to content

Fix triu/tril CoreML lowering error in to_edge_transform_and_lower #11107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions backends/apple/coreml/partition/coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,33 @@ def ops_to_not_decompose(
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
do_not_decompose = []
op_support = OperatorsSupportedForCoreMLBackend()
_logged_warnings = set()

# CoreML prevents certain ops (like triu) from lowering to CoreML when put in the ExecuTorch op namespace
# TODO: upstream fixes, but pending ET consuming a new published version of coremltools with the
# desired changes, we need to manually block them here
do_not_decompose_blocklist = [
# https://github.com/apple/coremltools/blob/release/8.3/coremltools/converters/mil/frontend/torch/ops.py#L6965-L6966
torch.ops.aten.triu.default,
# https://github.com/apple/coremltools/blob/release/8.3/coremltools/converters/mil/frontend/torch/ops.py#L6997-L6998
torch.ops.aten.tril.default,
]
for node in ep.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, torch._ops.OpOverload
):
try:
if op_support.is_node_supported(None, node):
if (
op_support.is_node_supported(None, node)
and node.target not in do_not_decompose_blocklist
):
do_not_decompose.append(node.target)
except Exception as e:
# CoreML's op_support.is_node_supported will sometimes throw
# for unsupported ops, rather than returning False
logger.warning(
f"Encountered exception when checking node support: {e}"
)
warn_str = f"Encountered exception when checking node support: {e}"
if warn_str not in _logged_warnings:
logger.warning(warn_str)
_logged_warnings.add(warn_str)

return do_not_decompose, None
16 changes: 15 additions & 1 deletion backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def forward(self, q, k, v, mask):
q, k, v, attn_mask=mask
)

# triu/tril should be ignored by do_not_decompose
# because otherwise they fail during CoreML lowering
offset1 = torch.triu(mask, diagonal=1)
offset2 = torch.tril(mask)
offset = offset1 + offset2
offset = torch.sum(offset)

# Add non-functional and alias ops
# These will be removed by ExecuTorch in non-decomposition
# table because they cannot be functionalized
Expand All @@ -102,7 +109,7 @@ def forward(self, q, k, v, mask):
out = out.sub_(4.0)
out = torch.ops.aten.view_copy.default(out, (-1,))
out = out.select(0, 0)
return out
return out + offset

model = Model()
model.eval()
Expand All @@ -118,6 +125,13 @@ def forward(self, q, k, v, mask):
mask = torch.randn(seq_len, max_seq_length)
example_inputs = (q, k, v, mask)
ep = torch.export.export(model, example_inputs, strict=True)
self.assertTrue(
"torch.ops.aten.triu.default" in ep.graph_module.code,
)
self.assertTrue(
"torch.ops.aten.tril.default" in ep.graph_module.code,
)

coreml_partitioner = CoreMLPartitioner()

# Using to_edge_transform_and_lower, we expect SDPA will be preserved and show up in delegated graph
Expand Down
15 changes: 11 additions & 4 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,13 @@ def _sanity_check_graph_for_non_decomp_ops(
def _remove_invalid_ops_for_not_decompose(
ops_to_not_decompose: List[torch._ops.OpOverload],
) -> List[torch._ops.OpOverload]:
_logged_warnings = set()

def log_warning(warn_str):
if warn_str not in _logged_warnings:
logging.warn(warn_str)
_logged_warnings.add(warn_str)

# To address https://github.com/pytorch/executorch/issues/8781
def keep(op):
# Explicit allow list
Expand All @@ -1034,18 +1041,18 @@ def keep(op):
schema = op._schema
native_schema = _pybind_schema_to_native_schema(schema)
if native_schema is None:
logging.warn(
log_warning(
f"Torchgen is not able to parse the schema of {op._schema}. This is not fatal."
)
else:
if native_schema.is_mutable:
logging.warn(
log_warning(
f"Op {op} was requested for preservation by partitioner. This request is ignored because it is mutable."
)
return False

if native_schema.aliased_return_names() != [None]:
logging.warn(
log_warning(
f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output."
)
return False
Expand All @@ -1067,7 +1074,7 @@ def keep(op):
torch.ops.aten.unbind.int,
torch.ops.aten.split_with_sizes.default,
]:
logging.warn(
log_warning(
f"Op {op} was requested for preservation by partitioner. This request is ignored because it is in a blocklist."
)
return False
Expand Down
Loading