|
1 | 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
2 | 2 |
|
| 3 | +# pyre-strict |
| 4 | + |
| 5 | + |
| 6 | +from typing import Sequence |
3 | 7 |
|
4 | 8 | import executorch.backends.cadence.aot.ops_registrations # noqa
|
5 | 9 | import torch
|
|
9 | 13 | )
|
10 | 14 | from executorch.backends.cadence.aot.pass_utils import count_node
|
11 | 15 | from executorch.exir.dialects._ops import ops as exir_ops
|
12 |
| -from executorch.exir.pass_base import ExportPass |
| 16 | +from executorch.exir.pass_base import ExportPass, NodeMetadata |
13 | 17 | from later.unittest import TestCase
|
14 | 18 |
|
15 | 19 |
|
@@ -68,3 +72,30 @@ def test_graph_with_single_im2row(self) -> None:
|
68 | 72 | # Check graph has a single im2row node.
|
69 | 73 | self.assertEqual(len([gm.graph.nodes]), 1)
|
70 | 74 | self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
|
| 75 | + |
| 76 | + |
| 77 | +class TestHigherOrderOps(TestCase): |
| 78 | + def _get_inner_graph(self, x_shape: Sequence[int]) -> torch.fx.GraphModule: |
| 79 | + builder = GraphBuilder() |
| 80 | + x = builder.placeholder("x", torch.randn(*x_shape)) |
| 81 | + add = builder.call_operator( |
| 82 | + exir_ops.edge.aten.add.Tensor, |
| 83 | + (x, x), # pyre-ignore |
| 84 | + ) |
| 85 | + builder.output([x, add]) |
| 86 | + gm = builder.get_graph_module() |
| 87 | + # Check if graph module is valid by running exportpass on it. |
| 88 | + gm = ExportPass().call(gm).graph_module |
| 89 | + return gm |
| 90 | + |
| 91 | + def test_call_map(self) -> None: |
| 92 | + builder = GraphBuilder() |
| 93 | + x_shape = (4, 8, 8) |
| 94 | + x = builder.placeholder("x", torch.randn(*x_shape)) |
| 95 | + map_node = builder.call_map( |
| 96 | + self._get_inner_graph(x_shape[1:]), [x], [], NodeMetadata({}) |
| 97 | + ) |
| 98 | + builder.output([map_node]) |
| 99 | + gm = builder.get_graph_module() |
| 100 | + # Check if graph module is valid by running exportpass on it. |
| 101 | + ExportPass().call(gm).graph_module |
0 commit comments