Skip to content

Commit 80526dd

Browse files
yushangdifacebook-github-bot
authored andcommitted
Add RemoveAssert pass to remove _assert_tensor_metadata nodes
Summary: `_assert_tensor_metadata` nodes is added in D66988295. Differential Revision: D67057219
1 parent 343aa0c commit 80526dd

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

exir/passes/executorch_prim_ops_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def trunc(a: _SymScalar) -> _SymScalar:
141141
torch.ops.aten._local_scalar_dense.default,
142142
torch.ops.aten.sym_constrain_range_for_size.default,
143143
torch.ops.aten.sym_constrain_range.default,
144+
torch.ops.aten._assert_tensor_metadata.default,
144145
}
145146
)
146147

exir/program/_program.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,13 +722,20 @@ def _generate_edge_program(
722722
program: ExportedProgram,
723723
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
724724
) -> ExportedProgram:
725+
726+
# Remove invalid assert ops, such as _assert_tensor_metadata
727+
gm = program.graph_module
728+
gm_res = RemoveGraphAssertsPass()(gm)
729+
assert gm_res is not None
730+
gm = gm_res.graph_module
731+
725732
if config._check_ir_validity:
726733
try:
727734
EXIRATenDialectVerifier(
728735
edge_compile_config=config,
729736
class_only=False,
730737
exception_list=ops_set_to_not_decompose,
731-
)(program.graph_module)
738+
)(gm)
732739
except ExportError as e:
733740
logging.info(f"Input program {name} is not in ATen dialect.")
734741
raise e
@@ -745,7 +752,6 @@ def _generate_edge_program(
745752
if not config._skip_dim_order:
746753
passes.append(MemoryFormatOpsPass())
747754

748-
gm = program.graph_module
749755
for p in passes:
750756
gm_res = p(gm)
751757
assert gm_res is not None

0 commit comments

Comments
 (0)