Skip to content

Commit b6282dc

Browse files
angelayifacebook-github-bot
authored andcommitted
Move SpecPropPass after backend passes
Summary: SpecPropPass is used to construct the `meta["spec"]` metadata which is later used for memory planning. Specifically, it stores the shapes/dtypes used for memory planning (symbolic shapes are converted to their upper bound value in the SymShapeEvalPass), and in the memory planning pass, it stores the lifetimes of tensors and their buffer sizes. This diff moves SpecPropPass to after user-defined passes are run so that user passes do not affect the `spec` metadata (ex. If users insert a node then they would need to update meta["spec"] because we do not recalculate it). Originally this was put before user-defined passes because Turing team has a hacky pass to change `spec.dtype` due to the CPU kernel not matching the hardware kernel, but from a high level I think this should still go at the end. Turing team will fix + remove their "legalize_dtype" pass in D47640453 so that it is not dependent on meta["spec"]. Reviewed By: zhxchen17 Differential Revision: D47572569 fbshipit-source-id: efb77df1c867925c0b22a53d8435dae515e033c8
1 parent a1fea89 commit b6282dc

File tree

2 files changed

+1
-24
lines changed

2 files changed

+1
-24
lines changed

exir/emit/test/test_emit.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -641,29 +641,6 @@ def forward(self, x):
641641
program.execution_plan[0].values[4].val, schema.OptionalTensorList
642642
)
643643

644-
def test_broken_call_function(self) -> None:
645-
def f(x: torch.Tensor) -> torch.Tensor:
646-
number = x.argmax()
647-
x.add(number)
648-
return x.sum()
649-
650-
module = exir.capture(
651-
f, (torch.tensor([1.0]),), exir.CaptureConfig(pt2_mode=True)
652-
).to_edge()
653-
654-
def bad_pass(module: torch.fx.GraphModule) -> PassResult:
655-
for node in module.graph.nodes:
656-
if node.op == "call_function":
657-
node.meta["spec"] = f
658-
return PassResult(module, True)
659-
660-
with self.assertRaisesRegex(
661-
InternalError, ".*Here is the failing node in the graph module:\ngraph().*"
662-
):
663-
module.to_executorch(
664-
exir.ExecutorchBackendConfig(passes=[bad_pass])
665-
).program
666-
667644
def test_emit_map(self) -> None:
668645
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
669646
def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

exir/program/_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
235235
def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]:
236236
# pyre-ignore
237237
passes: List[PassType] = [
238-
SpecPropPass(),
239238
*config.passes,
239+
SpecPropPass(),
240240
EdgeToBackendOpsPass(),
241241
RemoveAssertAsyncPass(),
242242
SymShapeEvalPass(),

0 commit comments

Comments
 (0)