Skip to content

Commit 8470cb9

Browse files
authored
import test graph builder to oss, import im2row
Differential Revision: D65913991 Pull Request resolved: #6878
1 parent 97cb892 commit 8470cb9

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

backends/cadence/aot/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,20 @@ python_library(
162162
"//executorch/exir/passes:spec_prop_pass",
163163
],
164164
)
165+
166+
python_unittest(
167+
name = "test_graph_builder",
168+
srcs = [
169+
"tests/test_graph_builder.py",
170+
],
171+
typing = True,
172+
deps = [
173+
"//caffe2:torch",
174+
"//executorch/backends/cadence/aot:graph_builder",
175+
"//executorch/backends/cadence/aot:pass_utils",
176+
"//executorch/exir:pass_base",
177+
"//executorch/exir/dialects:lib",
178+
"//later:lib",
179+
":ops_registrations"
180+
],
181+
)

backends/cadence/aot/pass_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,12 @@ def get_node_names_list_from_gm(
8989
continue
9090
graph_nodes.append(node.name)
9191
return graph_nodes
92+
93+
94+
def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int:
95+
"""Count the number of nodes with target `target` in the graph."""
96+
total = 0
97+
for node in graph_module.graph.nodes:
98+
if node.op == "call_function" and node.target == target:
99+
total += 1
100+
return total
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
4+
import executorch.backends.cadence.aot.ops_registrations # noqa
5+
import torch
6+
from executorch.backends.cadence.aot.graph_builder import (
7+
GraphBuilder,
8+
single_op_builder,
9+
)
10+
from executorch.backends.cadence.aot.pass_utils import count_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass
13+
from later.unittest import TestCase
14+
15+
16+
class TestGraphBuilder(TestCase):
17+
def test_graph_with_single_im2row(self) -> None:
18+
# Create a graph with a single im2row node.
19+
builder = GraphBuilder()
20+
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
21+
pad_value = builder.placeholder("pad", torch.randn(1))
22+
channels_last = False
23+
im2row = builder.call_operator(
24+
exir_ops.edge.cadence.im2row.default,
25+
# pyre-ignore
26+
(
27+
x,
28+
(2, 2),
29+
(1, 1),
30+
(0, 0),
31+
(1, 1),
32+
pad_value,
33+
channels_last,
34+
),
35+
)
36+
builder.output([im2row])
37+
gm = builder.get_graph_module()
38+
# Check if graph module is valid by running exportpass on it.
39+
gm = ExportPass().call(gm).graph_module
40+
41+
# Check graph has a single im2row node.
42+
self.assertEqual(len([gm.graph.nodes]), 1)
43+
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
44+
45+
46+
class TestSingleOpBuilderUtility(TestCase):
47+
def test_graph_with_single_im2row(self) -> None:
48+
# Create a graph with a single im2row node.
49+
x = torch.randn(1, 3, 224, 224)
50+
pad_value = torch.randn(1)
51+
channels_last = False
52+
gm = single_op_builder(
53+
(x, pad_value),
54+
exir_ops.edge.cadence.im2row.default,
55+
(
56+
x,
57+
(2, 2),
58+
(1, 1),
59+
(0, 0),
60+
(1, 1),
61+
pad_value,
62+
channels_last,
63+
),
64+
)
65+
# Check if graph module is valid by running exportpass on it.
66+
gm = ExportPass().call(gm).graph_module
67+
68+
# Check graph has a single im2row node.
69+
self.assertEqual(len([gm.graph.nodes]), 1)
70+
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)

0 commit comments

Comments
 (0)