Skip to content

Commit d6d68d5

Browse files
committed
[exir] Allow verifiers in _transform
This is to allow users to transform an ExportedProgram using passes in places where it may result in a dialect that is not compliant with the original creation context. For example, if an ExportedProgram was created in an edge dialect and now needs to be run and transformed in a way that is not compliant with the EdgeDialectVerifier, such as in a delegate preprocess() function, then the user may want to override the verifier with their own or simply disable it. Differential Revision: [D73205727](https://our.internmc.facebook.com/intern/diff/D73205727/) ghstack-source-id: 278776368 Pull Request resolved: #10274
1 parent e42d013 commit d6d68d5

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

exir/program/_program.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,23 @@ def _get_updated_graph_signature(
212212
return new_signature
213213

214214

215-
def _transform(self, *passes: PassType) -> "ExportedProgram":
215+
def _transform(
216+
self,
217+
*passes: PassType,
218+
override_verifiers: None | list[Type[Verifier]] = None,
219+
) -> "ExportedProgram":
220+
"""
221+
Transforms the program according to the provided passes.
222+
223+
Args:
224+
self: The ExportedProgram instance to transform
225+
*passes: A sequence of passes to apply to the program
226+
override_verifiers: Optional list of verifier classes to use instead of the default verifiers.
227+
This is needed if the transforms yields illegal graph that the default verifier cannot handle.
228+
229+
Returns:
230+
ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made
231+
"""
216232
pm = PassManager(list(passes))
217233
res = pm(self.graph_module)
218234
transformed_gm = res.graph_module if res is not None else self.graph_module
@@ -221,7 +237,9 @@ def _transform(self, *passes: PassType) -> "ExportedProgram":
221237
if transformed_gm is self.graph_module and not res.modified:
222238
return self
223239

224-
return _update_exported_program_graph_module(self, transformed_gm)
240+
return _update_exported_program_graph_module(
241+
self, transformed_gm, override_verifiers
242+
)
225243

226244

227245
def _update_exported_program_graph_module(

0 commit comments

Comments
 (0)