Skip to content

Commit a5bcbd2

Browse files
yushangdifacebook-github-bot
authored andcommitted
Add RemoveAssert pass to remove _assert_tensor_metadata nodes (#7277)
Summary: `_assert_tensor_metadata` nodes is added to the result of exported graphs in D66988295. (More background in T209705957, mostly to relax aten.to constraint). Add a pass to remove this op when calling to_edge. Differential Revision: D67057219
1 parent 61b9e1b commit a5bcbd2

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

exir/passes/remove_graph_asserts_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2929
torch.ops.aten._assert_scalar.default,
3030
torch.ops.aten.sym_constrain_range_for_size.default,
3131
torch.ops.aten.sym_constrain_range.default,
32+
torch.ops.aten._assert_tensor_metadata.default,
3233
)
3334
):
3435
module.graph.erase_node(node)

exir/program/_program.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,13 +722,20 @@ def _generate_edge_program(
722722
program: ExportedProgram,
723723
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
724724
) -> ExportedProgram:
725+
726+
# Remove invalid assert ops, such as _assert_tensor_metadata
727+
gm = program.graph_module
728+
gm_res = RemoveGraphAssertsPass()(gm)
729+
assert gm_res is not None
730+
gm = gm_res.graph_module
731+
725732
if config._check_ir_validity:
726733
try:
727734
EXIRATenDialectVerifier(
728735
edge_compile_config=config,
729736
class_only=False,
730737
exception_list=ops_set_to_not_decompose,
731-
)(program.graph_module)
738+
)(gm)
732739
except ExportError as e:
733740
logging.info(f"Input program {name} is not in ATen dialect.")
734741
raise e
@@ -745,7 +752,6 @@ def _generate_edge_program(
745752
if not config._skip_dim_order:
746753
passes.append(MemoryFormatOpsPass())
747754

748-
gm = program.graph_module
749755
for p in passes:
750756
gm_res = p(gm)
751757
assert gm_res is not None

0 commit comments

Comments
 (0)