Skip to content

Commit 1fe46a4

Browse files
yushangdifacebook-github-bot
authored andcommitted
Add RemoveAssert pass to remove _assert_tensor_metadata nodes (#7277)
Summary: `_assert_tensor_metadata` nodes is added to the result of exported graphs in D66988295. (More background in T209705957, mostly to relax aten.to constraint). Add a pass to remove this op when calling to_edge. Differential Revision: D67057219
1 parent cb04347 commit 1fe46a4

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

exir/passes/remove_graph_asserts_pass.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2929
torch.ops.aten._assert_scalar.default,
3030
torch.ops.aten.sym_constrain_range_for_size.default,
3131
torch.ops.aten.sym_constrain_range.default,
32+
torch.ops.aten._assert_tensor_metadata.default,
3233
)
3334
):
3435
module.graph.erase_node(node)
@@ -37,3 +38,25 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3738
module.graph.eliminate_dead_code()
3839

3940
return PassResult(graph_module, True)
41+
42+
43+
class RemoveNonCoreAtenOpGraphAssertsPass(PassBase):
44+
"""
45+
Remove assert ops from the graph that're not Aten Canonical.
46+
"""
47+
48+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
49+
for module in graph_module.modules():
50+
if not isinstance(module, torch.fx.GraphModule):
51+
continue
52+
53+
for node in module.graph.nodes:
54+
if node.op == "call_function" and (
55+
node.target in (torch.ops.aten._assert_tensor_metadata.default,)
56+
):
57+
module.graph.erase_node(node)
58+
59+
module.recompile()
60+
module.graph.eliminate_dead_code()
61+
62+
return PassResult(graph_module, True)

exir/program/_program.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@
4040
from executorch.exir.passes.normalize_view_copy_base_pass import (
4141
NormalizeViewCopyBasePass,
4242
)
43-
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
43+
from executorch.exir.passes.remove_graph_asserts_pass import (
44+
RemoveGraphAssertsPass,
45+
RemoveNonCoreAtenOpGraphAssertsPass,
46+
)
4447
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
4548
from executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge
4649
from executorch.exir.passes.replace_view_copy_with_view_pass import (
@@ -722,13 +725,20 @@ def _generate_edge_program(
722725
program: ExportedProgram,
723726
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
724727
) -> ExportedProgram:
728+
729+
# Remove invalid assert ops, such as _assert_tensor_metadata
730+
gm = program.graph_module
731+
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
732+
assert gm_res is not None
733+
gm = gm_res.graph_module
734+
725735
if config._check_ir_validity:
726736
try:
727737
EXIRATenDialectVerifier(
728738
edge_compile_config=config,
729739
class_only=False,
730740
exception_list=ops_set_to_not_decompose,
731-
)(program.graph_module)
741+
)(gm)
732742
except ExportError as e:
733743
logging.info(f"Input program {name} is not in ATen dialect.")
734744
raise e
@@ -745,7 +755,6 @@ def _generate_edge_program(
745755
if not config._skip_dim_order:
746756
passes.append(MemoryFormatOpsPass())
747757

748-
gm = program.graph_module
749758
for p in passes:
750759
gm_res = p(gm)
751760
assert gm_res is not None

0 commit comments

Comments
 (0)