Skip to content

Commit ac204d7

Browse files
committed
[exir] Allow verifiers in _transform
Pull Request resolved: #10274 This is to allow users to transform an ExportedProgram using passes in places where it may result in a dialect that is not compliant with the original creation context. For example, if an ExportedProgram was created in an edge dialect and now needs to be run and transformed in a way that is not compliant with the EdgeDialectVerifier, such as in a delegate preprocess() function, then the user may want to override the verifier with their own or simply disable it. ghstack-source-id: 278881821 @exported-using-ghexport Differential Revision: [D73205727](https://our.internmc.facebook.com/intern/diff/D73205727/)
1 parent ef99fff commit ac204d7

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

exir/program/_program.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,30 @@ def _get_updated_graph_signature(
212212
return new_signature
213213

214214

215-
def _transform(self, *passes: PassType) -> "ExportedProgram":
215+
def _transform(
216+
self,
217+
*passes: PassType,
218+
override_verifiers: None | list[Type[Verifier]] = None,
219+
) -> "ExportedProgram":
220+
"""
221+
Transforms the program according to the provided passes.
222+
223+
Args:
224+
self: The ExportedProgram instance to transform
225+
*passes: A sequence of passes to apply to the program
226+
override_verifiers: Optional list of verifier classes to use instead of the default verifiers.
227+
This is needed if the transforms yields illegal graph that the default verifier cannot handle.
228+
229+
Returns:
230+
ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made
231+
"""
232+
# A user friendly check to avoid vararg surprises, PEP 3102
233+
assert not any(
234+
isinstance(p, (list, Verifier)) for p in passes
235+
), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}"
236+
237+
for p in list(passes):
238+
print(type(p))
216239
pm = PassManager(list(passes))
217240
res = pm(self.graph_module)
218241
transformed_gm = res.graph_module if res is not None else self.graph_module
@@ -221,7 +244,9 @@ def _transform(self, *passes: PassType) -> "ExportedProgram":
221244
if transformed_gm is self.graph_module and not res.modified:
222245
return self
223246

224-
return _update_exported_program_graph_module(self, transformed_gm)
247+
return _update_exported_program_graph_module(
248+
self, transformed_gm, override_verifiers
249+
)
225250

226251

227252
def _update_exported_program_graph_module(

0 commit comments

Comments
 (0)