|
20 | 20 | from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
|
21 | 21 | from executorch.exir.dialects.edge._ops import EdgeOpOverload
|
22 | 22 | from executorch.exir.emit import emit_program
|
| 23 | +from executorch.exir.graph_module import get_control_flow_submodules |
23 | 24 | from executorch.exir.pass_base import ExportPass, PassResult
|
24 | 25 | from executorch.exir.pass_manager import PassManager
|
25 | 26 | from executorch.exir.passes import (
|
|
45 | 46 | from executorch.exir.tests.common import register_additional_test_aten_ops
|
46 | 47 | from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
|
47 | 48 | from executorch.exir.tests.models import MLP, Mul
|
| 49 | +from functorch.experimental import control_flow |
48 | 50 |
|
49 | 51 | from torch import nn
|
50 | 52 | from torch.fx import GraphModule, subgraph_rewriter
|
@@ -826,6 +828,70 @@ def test_debug_handle_generator_pass(self) -> None:
|
826 | 828 | for node in graph_module.graph.nodes:
|
827 | 829 | self.assertIn("debug_handle", node.meta)
|
828 | 830 |
|
| 831 | + def test_debug_handle_generator_pass_with_control_flow(self) -> None: |
| 832 | + def true_nested(y: torch.Tensor) -> torch.Tensor: |
| 833 | + y = y + y |
| 834 | + y = torch.mm(y, y) |
| 835 | + return y |
| 836 | + |
| 837 | + def false_nested(y: torch.Tensor) -> torch.Tensor: |
| 838 | + return torch.mm(y, y) |
| 839 | + |
| 840 | + def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor: |
| 841 | + z = control_flow.cond(pred2, true_nested, false_nested, [x]) |
| 842 | + return x + z |
| 843 | + |
| 844 | + def false_fn(x: torch.Tensor, _) -> torch.Tensor: |
| 845 | + return x.cos() |
| 846 | + |
| 847 | + def map_fn( |
| 848 | + x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor |
| 849 | + ) -> torch.Tensor: |
| 850 | + x = x.cos() |
| 851 | + y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) |
| 852 | + x = x + y |
| 853 | + return x.sin() |
| 854 | + |
| 855 | + def f( |
| 856 | + xs: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor |
| 857 | + ) -> torch.Tensor: |
| 858 | + y = torch.mm(y, y) |
| 859 | + return control_flow.map(map_fn, xs, pred1, pred2, y) |
| 860 | + |
| 861 | + inputs = ( |
| 862 | + torch.ones(2, 2), |
| 863 | + torch.tensor([False]), |
| 864 | + torch.tensor([False]), |
| 865 | + torch.ones(2, 2), |
| 866 | + ) |
| 867 | + |
| 868 | + graph_module = exir.capture( |
| 869 | + f, |
| 870 | + inputs, |
| 871 | + exir.CaptureConfig(), |
| 872 | + ).exported_program.graph_module |
| 873 | + |
| 874 | + def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: |
| 875 | + queue = [graph_module] |
| 876 | + while queue: |
| 877 | + current_graph_module = queue.pop(0) |
| 878 | + for node in current_graph_module.graph.nodes: |
| 879 | + self.assertIn("debug_handle", node.meta) |
| 880 | + control_flow_submodules = [ |
| 881 | + submodule |
| 882 | + for _, submodule, _ in get_control_flow_submodules( |
| 883 | + current_graph_module |
| 884 | + ) |
| 885 | + ] |
| 886 | + queue.extend(control_flow_submodules) |
| 887 | + |
| 888 | + DebugHandleGeneratorPass()(graph_module) |
| 889 | + check_debug_handle_metadata(graph_module) |
| 890 | + |
| 891 | + # Check debug handle still preserved after ScalarToTensorPass |
| 892 | + ScalarToTensorPass()(graph_module) |
| 893 | + check_debug_handle_metadata(graph_module) |
| 894 | + |
829 | 895 | def test_symint_conversion(self) -> None:
|
830 | 896 | def f(x: torch.Tensor) -> torch.Tensor:
|
831 | 897 | return x + x.shape[0]
|
|
0 commit comments