Skip to content

[exir] Refactor EdgeProgramManager.transform #10275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/digantdesai/35/base
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,6 @@ def _transform(
isinstance(p, (list, Verifier)) for p in passes
), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}"

for p in list(passes):
print(type(p))
pm = PassManager(list(passes))
res = pm(self.graph_module)
transformed_gm = res.graph_module if res is not None else self.graph_module
Expand Down Expand Up @@ -1442,22 +1440,34 @@ def transform(
"""
compile_config = compile_config or self.compile_config
new_programs: Dict[str, ExportedProgram] = {}

def _transform_and_verify(
program: ExportedProgram,
passes: Sequence[PassType],
verifier: type[Verifier],
) -> ExportedProgram:
# Overwrite the original verifier with the new one
# This should be a no-op for the most cases where compile_config is none.
new_program = _transform(program, *passes, override_verifiers=[verifier])
# ExportedProgram constructor should call the verifier, but
# the validate() function in the constructor is marked for deprecation.
verifier()(new_program.graph_module)
return new_program

verifier = EXIREdgeDialectVerifier(
edge_compile_config=compile_config, class_only=True
)
if isinstance(passes, dict):
for name, program in self._edge_programs.items():
if name in passes.keys():
new_programs[name] = _transform(program, *passes[name])
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
new_programs[name].graph_module
new_programs[name] = _transform_and_verify(
program, passes[name], verifier
)
else:
new_programs[name] = copy.deepcopy(program)

else: # apply passes to every method
for name, program in self._edge_programs.items():
new_programs[name] = _transform(program, *passes)
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
new_programs[name].graph_module
)
new_programs[name] = _transform_and_verify(program, passes, verifier)

return EdgeProgramManager(
new_programs, copy.deepcopy(self._config_methods), compile_config
Expand Down
Loading