Skip to content

Commit 7394db0

Browse files
Zonglin Pengfacebook-github-bot
authored andcommitted
import test graph builder to oss, import im2row
Summary: titled Differential Revision: D65913991
1 parent 20e2db0 commit 7394db0

File tree

3 files changed

+95
-0
lines changed

3 files changed

+95
-0
lines changed

backends/cadence/aot/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,20 @@ python_library(
162162
"//executorch/exir/passes:spec_prop_pass",
163163
],
164164
)
165+
166+
python_unittest(
167+
name = "test_graph_builder",
168+
srcs = [
169+
"tests/test_graph_builder.py",
170+
],
171+
typing = True,
172+
deps = [
173+
"//caffe2:torch",
174+
"//executorch/backends/cadence/aot:graph_builder",
175+
"//executorch/backends/cadence/aot:pass_utils",
176+
"//executorch/exir:pass_base",
177+
"//executorch/exir/dialects:lib",
178+
"//later:lib",
179+
":ops_registrations"
180+
],
181+
)

backends/cadence/aot/pass_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,12 @@ def get_node_names_list_from_gm(
8989
continue
9090
graph_nodes.append(node.name)
9191
return graph_nodes
92+
93+
94+
def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int:
95+
"""Count the number of nodes with target `target` in the graph."""
96+
total = 0
97+
for node in graph_module.graph.nodes:
98+
if node.op == "call_function" and node.target == target:
99+
total += 1
100+
return total
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
4+
import torch
5+
import executorch.backends.cadence.aot.ops_registrations # noqa
6+
from executorch.backends.cadence.aot.graph_builder import (
7+
GraphBuilder,
8+
single_op_builder,
9+
)
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass
12+
from later.unittest import TestCase
13+
from executorch.backends.cadence.aot.pass_utils import count_node
14+
15+
class TestGraphBuilder(TestCase):
16+
def test_graph_with_single_im2row(self) -> None:
17+
# Create a graph with a single im2row node.
18+
builder = GraphBuilder()
19+
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
20+
pad_value = builder.placeholder("pad", torch.randn(1))
21+
channels_last = False
22+
im2row = builder.call_operator(
23+
exir_ops.edge.cadence.im2row.default,
24+
# pyre-ignore
25+
(
26+
x,
27+
(2, 2),
28+
(1, 1),
29+
(0, 0),
30+
(1, 1),
31+
pad_value,
32+
channels_last,
33+
),
34+
)
35+
builder.output([im2row])
36+
gm = builder.get_graph_module()
37+
# Check if graph module is valid by running exportpass on it.
38+
gm = ExportPass().call(gm).graph_module
39+
40+
# Check graph has a single im2row node.
41+
self.assertEqual(len([gm.graph.nodes]), 1)
42+
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
43+
44+
45+
class TestSingleOpBuilderUtility(TestCase):
46+
def test_graph_with_single_im2row(self) -> None:
47+
# Create a graph with a single im2row node.
48+
x = torch.randn(1, 3, 224, 224)
49+
pad_value = torch.randn(1)
50+
channels_last = False
51+
gm = single_op_builder(
52+
(x, pad_value),
53+
exir_ops.edge.cadence.im2row.default,
54+
(
55+
x,
56+
(2, 2),
57+
(1, 1),
58+
(0, 0),
59+
(1, 1),
60+
pad_value,
61+
channels_last,
62+
),
63+
)
64+
# Check if graph module is valid by running exportpass on it.
65+
gm = ExportPass().call(gm).graph_module
66+
67+
# Check graph has a single im2row node.
68+
self.assertEqual(len([gm.graph.nodes]), 1)
69+
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)

0 commit comments

Comments
 (0)