Skip to content

Commit f354e16

Browse files
committed
[exir] Refactor EdgeProgramManager.transform
Pull Request resolved: #10275 Mainly refactor, but also update the dialect verifier in the EP created by the `_transform` when the edge_config has been updated. ghstack-source-id: 278881823 @exported-using-ghexport Differential Revision: [D73205728](https://our.internmc.facebook.com/intern/diff/D73205728/)
1 parent ac204d7 commit f354e16

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

exir/program/_program.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,6 @@ def _transform(
234234
isinstance(p, (list, Verifier)) for p in passes
235235
), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}"
236236

237-
for p in list(passes):
238-
print(type(p))
239237
pm = PassManager(list(passes))
240238
res = pm(self.graph_module)
241239
transformed_gm = res.graph_module if res is not None else self.graph_module
@@ -1442,22 +1440,34 @@ def transform(
14421440
"""
14431441
compile_config = compile_config or self.compile_config
14441442
new_programs: Dict[str, ExportedProgram] = {}
1443+
1444+
def _transform_and_verify(
1445+
program: ExportedProgram,
1446+
passes: Sequence[PassType],
1447+
verifier: EXIREdgeDialectVerifier,
1448+
) -> ExportedProgram:
1449+
# Overwrite the original verifier with the new one
1450+
# This should be a no-op for the most cases where compile_config is none.
1451+
new_program = _transform(program, *passes, override_verifiers=[verifier])
1452+
# ExportedProgram constructor should call the verifier, but
1453+
# the validate() function in the constructor is marked for deprecation.
1454+
verifier()(new_program.graph_module)
1455+
return new_program
1456+
1457+
verifier = EXIREdgeDialectVerifier(
1458+
edge_compile_config=compile_config, class_only=True
1459+
)
14451460
if isinstance(passes, dict):
14461461
for name, program in self._edge_programs.items():
14471462
if name in passes.keys():
1448-
new_programs[name] = _transform(program, *passes[name])
1449-
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1450-
new_programs[name].graph_module
1463+
new_programs[name] = _transform_and_verify(
1464+
program, passes[name], verifier
14511465
)
14521466
else:
14531467
new_programs[name] = copy.deepcopy(program)
1454-
14551468
else: # apply passes to every method
14561469
for name, program in self._edge_programs.items():
1457-
new_programs[name] = _transform(program, *passes)
1458-
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1459-
new_programs[name].graph_module
1460-
)
1470+
new_programs[name] = _transform_and_verify(program, passes, verifier)
14611471

14621472
return EdgeProgramManager(
14631473
new_programs, copy.deepcopy(self._config_methods), compile_config

0 commit comments

Comments
 (0)