40
40
from executorch .exir .passes .normalize_view_copy_base_pass import (
41
41
NormalizeViewCopyBasePass ,
42
42
)
43
- from executorch .exir .passes .remove_graph_asserts_pass import RemoveGraphAssertsPass
43
+ from executorch .exir .passes .remove_graph_asserts_pass import (
44
+ RemoveGraphAssertsPass ,
45
+ RemoveNonCoreAtenOpGraphAssertsPass ,
46
+ )
44
47
from executorch .exir .passes .remove_mixed_type_operators import RemoveMixedTypeOperators
45
48
from executorch .exir .passes .replace_aten_with_edge_pass import aten_to_edge
46
49
from executorch .exir .passes .replace_view_copy_with_view_pass import (
@@ -722,13 +725,20 @@ def _generate_edge_program(
722
725
program : ExportedProgram ,
723
726
ops_set_to_not_decompose : Optional [List [torch ._ops .OpOverload ]] = None ,
724
727
) -> ExportedProgram :
728
+
729
+ # Remove invalid assert ops, such as _assert_tensor_metadata
730
+ gm = program .graph_module
731
+ gm_res = RemoveNonCoreAtenOpGraphAssertsPass ()(gm )
732
+ assert gm_res is not None
733
+ gm = gm_res .graph_module
734
+
725
735
if config ._check_ir_validity :
726
736
try :
727
737
EXIRATenDialectVerifier (
728
738
edge_compile_config = config ,
729
739
class_only = False ,
730
740
exception_list = ops_set_to_not_decompose ,
731
- )(program . graph_module )
741
+ )(gm )
732
742
except ExportError as e :
733
743
logging .info (f"Input program { name } is not in ATen dialect." )
734
744
raise e
@@ -745,7 +755,6 @@ def _generate_edge_program(
745
755
if not config ._skip_dim_order :
746
756
passes .append (MemoryFormatOpsPass ())
747
757
748
- gm = program .graph_module
749
758
for p in passes :
750
759
gm_res = p (gm )
751
760
assert gm_res is not None
0 commit comments