Skip to content

Commit 60db2c8

Browse files
yushangdifacebook-github-bot
authored andcommitted
Add RemoveAssert pass to remove _assert_tensor_metadata nodes (#7277)
Summary: `_assert_tensor_metadata` nodes is added to the result of exported graphs in D66988295. (More background in T209705957, mostly to relax aten.to constraint). Add a pass to remove this op when calling to_edge. Differential Revision: D67057219
1 parent 61b9e1b commit 60db2c8

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

exir/passes/remove_graph_asserts_pass.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,31 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2929
torch.ops.aten._assert_scalar.default,
3030
torch.ops.aten.sym_constrain_range_for_size.default,
3131
torch.ops.aten.sym_constrain_range.default,
32+
torch.ops.aten._assert_tensor_metadata.default,
33+
)
34+
):
35+
module.graph.erase_node(node)
36+
37+
module.recompile()
38+
module.graph.eliminate_dead_code()
39+
40+
return PassResult(graph_module, True)
41+
42+
class RemoveNonCoreAtenOpPass(PassBase):
43+
"""
44+
Remove ops from the graph that're not Aten Canonical.
45+
"""
46+
47+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
48+
for module in graph_module.modules():
49+
if not isinstance(module, torch.fx.GraphModule):
50+
continue
51+
52+
for node in module.graph.nodes:
53+
if node.op == "call_function" and (
54+
node.target
55+
in (
56+
torch.ops.aten._assert_tensor_metadata.default,
3257
)
3358
):
3459
module.graph.erase_node(node)

exir/program/_program.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from executorch.exir.passes.normalize_view_copy_base_pass import (
4141
NormalizeViewCopyBasePass,
4242
)
43-
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
43+
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass, RemoveNonCoreAtenOpPass
4444
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
4545
from executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge
4646
from executorch.exir.passes.replace_view_copy_with_view_pass import (
@@ -722,13 +722,20 @@ def _generate_edge_program(
722722
program: ExportedProgram,
723723
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
724724
) -> ExportedProgram:
725+
726+
# Remove invalid assert ops, such as _assert_tensor_metadata
727+
gm = program.graph_module
728+
gm_res = RemoveNonCoreAtenOpPass()(gm)
729+
assert gm_res is not None
730+
gm = gm_res.graph_module
731+
725732
if config._check_ir_validity:
726733
try:
727734
EXIRATenDialectVerifier(
728735
edge_compile_config=config,
729736
class_only=False,
730737
exception_list=ops_set_to_not_decompose,
731-
)(program.graph_module)
738+
)(gm)
732739
except ExportError as e:
733740
logging.info(f"Input program {name} is not in ATen dialect.")
734741
raise e
@@ -745,7 +752,6 @@ def _generate_edge_program(
745752
if not config._skip_dim_order:
746753
passes.append(MemoryFormatOpsPass())
747754

748-
gm = program.graph_module
749755
for p in passes:
750756
gm_res = p(gm)
751757
assert gm_res is not None

0 commit comments

Comments
 (0)