Skip to content

Commit 2a4fcb4

Browse files
tarun292facebook-github-bot
authored andcommitted
Refactor to_edge into smaller functions that can be reused
Differential Revision: D56401457
1 parent cc12d9b commit 2a4fcb4

File tree

1 file changed

+51
-42
lines changed

1 file changed

+51
-42
lines changed

exir/program/_program.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,55 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]
654654
return passes
655655

656656

657+
def _generate_edge_program(
658+
name: str, config: EdgeCompileConfig, program: ExportedProgram
659+
) -> ExportedProgram:
660+
661+
if config._check_ir_validity:
662+
try:
663+
EXIRATenDialectVerifier()(program.graph_module)
664+
except ExportError as e:
665+
logging.info(f"Input program {name} is not in ATen dialect.")
666+
raise e
667+
668+
pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
669+
670+
passes = []
671+
passes.append(
672+
ReplaceViewOpsWithViewCopyOpsPass()
673+
) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
674+
passes.extend(pre_op_replace_passes)
675+
if config._use_edge_ops:
676+
passes.append(OpReplacePass())
677+
678+
gm = program.graph_module
679+
for p in passes:
680+
gm_res = p(gm)
681+
assert gm_res is not None
682+
gm = gm_res.graph_module
683+
684+
edge_program = ExportedProgram(
685+
root=gm,
686+
graph=gm.graph,
687+
graph_signature=_get_updated_graph_signature(program.graph_signature, gm),
688+
state_dict=program.state_dict,
689+
range_constraints=program.range_constraints,
690+
module_call_graph=program.module_call_graph,
691+
example_inputs=program.example_inputs,
692+
verifier=EXIREdgeDialectVerifier(
693+
check_edge_ops=config._use_edge_ops,
694+
enable=config._check_ir_validity,
695+
class_only=True,
696+
),
697+
constants=program.constants,
698+
)
699+
# Lift the tensor constants created in ScalarToTensorPass
700+
edge_program = lift_constant_tensor_pass(edge_program)
701+
edge_program = _transform(edge_program, *post_op_replace_passes)
702+
703+
return edge_program
704+
705+
657706
def to_edge(
658707
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
659708
constant_methods: Optional[Dict[str, Any]] = None,
@@ -681,52 +730,12 @@ def to_edge(
681730
aten_programs = programs
682731

683732
edge_programs: Dict[str, ExportedProgram] = {}
733+
684734
for name, program in aten_programs.items():
685735
# Decompose to Core ATen
686736
program = program.run_decompositions(_default_decomposition_table())
737+
edge_programs[name] = _generate_edge_program(name, config, program)
687738

688-
if config._check_ir_validity:
689-
try:
690-
EXIRATenDialectVerifier()(program.graph_module)
691-
except ExportError as e:
692-
logging.info(f"Input program {name} is not in ATen dialect.")
693-
raise e
694-
695-
pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
696-
697-
passes = []
698-
passes.append(
699-
ReplaceViewOpsWithViewCopyOpsPass()
700-
) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
701-
passes.extend(pre_op_replace_passes)
702-
if config._use_edge_ops:
703-
passes.append(OpReplacePass())
704-
705-
gm = program.graph_module
706-
for p in passes:
707-
gm_res = p(gm)
708-
assert gm_res is not None
709-
gm = gm_res.graph_module
710-
711-
edge_program = ExportedProgram(
712-
root=gm,
713-
graph=gm.graph,
714-
graph_signature=_get_updated_graph_signature(program.graph_signature, gm),
715-
state_dict=program.state_dict,
716-
range_constraints=program.range_constraints,
717-
module_call_graph=program.module_call_graph,
718-
example_inputs=program.example_inputs,
719-
verifier=EXIREdgeDialectVerifier(
720-
check_edge_ops=config._use_edge_ops,
721-
enable=config._check_ir_validity,
722-
class_only=True,
723-
),
724-
constants=program.constants,
725-
)
726-
# Lift the tensor constants created in ScalarToTensorPass
727-
edge_program = lift_constant_tensor_pass(edge_program)
728-
edge_program = _transform(edge_program, *post_op_replace_passes)
729-
edge_programs[name] = edge_program
730739
return EdgeProgramManager(edge_programs, constant_methods, config)
731740

732741

0 commit comments

Comments
 (0)