@@ -654,6 +654,55 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]
654
654
return passes
655
655
656
656
657
+ def _generate_edge_program (
658
+ name : str , config : EdgeCompileConfig , program : ExportedProgram
659
+ ) -> ExportedProgram :
660
+
661
+ if config ._check_ir_validity :
662
+ try :
663
+ EXIRATenDialectVerifier ()(program .graph_module )
664
+ except ExportError as e :
665
+ logging .info (f"Input program { name } is not in ATen dialect." )
666
+ raise e
667
+
668
+ pre_op_replace_passes , post_op_replace_passes = _get_aten_to_edge_passes (config )
669
+
670
+ passes = []
671
+ passes .append (
672
+ ReplaceViewOpsWithViewCopyOpsPass ()
673
+ ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
674
+ passes .extend (pre_op_replace_passes )
675
+ if config ._use_edge_ops :
676
+ passes .append (OpReplacePass ())
677
+
678
+ gm = program .graph_module
679
+ for p in passes :
680
+ gm_res = p (gm )
681
+ assert gm_res is not None
682
+ gm = gm_res .graph_module
683
+
684
+ edge_program = ExportedProgram (
685
+ root = gm ,
686
+ graph = gm .graph ,
687
+ graph_signature = _get_updated_graph_signature (program .graph_signature , gm ),
688
+ state_dict = program .state_dict ,
689
+ range_constraints = program .range_constraints ,
690
+ module_call_graph = program .module_call_graph ,
691
+ example_inputs = program .example_inputs ,
692
+ verifier = EXIREdgeDialectVerifier (
693
+ check_edge_ops = config ._use_edge_ops ,
694
+ enable = config ._check_ir_validity ,
695
+ class_only = True ,
696
+ ),
697
+ constants = program .constants ,
698
+ )
699
+ # Lift the tensor constants created in ScalarToTensorPass
700
+ edge_program = lift_constant_tensor_pass (edge_program )
701
+ edge_program = _transform (edge_program , * post_op_replace_passes )
702
+
703
+ return edge_program
704
+
705
+
657
706
def to_edge (
658
707
programs : Union [ExportedProgram , Dict [str , ExportedProgram ]],
659
708
constant_methods : Optional [Dict [str , Any ]] = None ,
@@ -681,52 +730,12 @@ def to_edge(
681
730
aten_programs = programs
682
731
683
732
edge_programs : Dict [str , ExportedProgram ] = {}
733
+
684
734
for name , program in aten_programs .items ():
685
735
# Decompose to Core ATen
686
736
program = program .run_decompositions (_default_decomposition_table ())
737
+ edge_programs [name ] = _generate_edge_program (name , config , program )
687
738
688
- if config ._check_ir_validity :
689
- try :
690
- EXIRATenDialectVerifier ()(program .graph_module )
691
- except ExportError as e :
692
- logging .info (f"Input program { name } is not in ATen dialect." )
693
- raise e
694
-
695
- pre_op_replace_passes , post_op_replace_passes = _get_aten_to_edge_passes (config )
696
-
697
- passes = []
698
- passes .append (
699
- ReplaceViewOpsWithViewCopyOpsPass ()
700
- ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
701
- passes .extend (pre_op_replace_passes )
702
- if config ._use_edge_ops :
703
- passes .append (OpReplacePass ())
704
-
705
- gm = program .graph_module
706
- for p in passes :
707
- gm_res = p (gm )
708
- assert gm_res is not None
709
- gm = gm_res .graph_module
710
-
711
- edge_program = ExportedProgram (
712
- root = gm ,
713
- graph = gm .graph ,
714
- graph_signature = _get_updated_graph_signature (program .graph_signature , gm ),
715
- state_dict = program .state_dict ,
716
- range_constraints = program .range_constraints ,
717
- module_call_graph = program .module_call_graph ,
718
- example_inputs = program .example_inputs ,
719
- verifier = EXIREdgeDialectVerifier (
720
- check_edge_ops = config ._use_edge_ops ,
721
- enable = config ._check_ir_validity ,
722
- class_only = True ,
723
- ),
724
- constants = program .constants ,
725
- )
726
- # Lift the tensor constants created in ScalarToTensorPass
727
- edge_program = lift_constant_tensor_pass (edge_program )
728
- edge_program = _transform (edge_program , * post_op_replace_passes )
729
- edge_programs [name ] = edge_program
730
739
return EdgeProgramManager (edge_programs , constant_methods , config )
731
740
732
741
0 commit comments