Skip to content

Commit 5785fc3

Browse files
authored
add unit test for op_add (#7087)
add op_add shapes to generate as binaries (#7087) Summary: generates the add model pte’s for cadence to execute on. will use graph builder in later diffs Test Plan: Imported from GitHub, without a `Test Plan:` line. {F1968254537} Reviewed By: hsharma35 Differential Revision: D66510372 Pulled By: zonglinpeng
1 parent dedf77b commit 5785fc3

File tree

6 files changed

+173
-7
lines changed

6 files changed

+173
-7
lines changed

backends/cadence/aot/TARGETS

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,26 @@ python_library(
5050
],
5151
)
5252

53+
python_library(
54+
name = "export_example",
55+
srcs = [
56+
"export_example.py",
57+
],
58+
deps = [
59+
":passes",
60+
":utils",
61+
":ops_registrations",
62+
":replace_ops",
63+
"//caffe2:torch",
64+
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
65+
"//executorch/backends/cadence/runtime:runtime",
66+
"//executorch/backends/cadence/aot/quantizer:quantizer",
67+
"//executorch/backends/transforms:decompose_sdpa",
68+
"//executorch/backends/transforms:remove_clone_ops",
69+
"//executorch/exir:lib",
70+
"//executorch/devtools:lib",
71+
],
72+
)
5373

5474
python_library(
5575
name = "pass_utils",

backends/cadence/aot/export_example.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def export_model(
6060
model: nn.Module,
6161
example_inputs: Tuple[Any, ...],
6262
file_name: str = "CadenceDemoModel",
63+
run_and_compare: bool = True,
6364
):
6465
# create work directory for outputs and model binary
6566
working_dir = tempfile.mkdtemp(dir="/tmp")
@@ -112,9 +113,10 @@ def export_model(
112113
)
113114

114115
# TODO: move to test infra
115-
runtime.run_and_compare(
116-
executorch_prog=exec_prog,
117-
inputs=example_inputs,
118-
ref_outputs=ref_outputs,
119-
working_dir=working_dir,
120-
)
116+
if run_and_compare:
117+
runtime.run_and_compare(
118+
executorch_prog=exec_prog,
119+
inputs=example_inputs,
120+
ref_outputs=ref_outputs,
121+
working_dir=working_dir,
122+
)

backends/cadence/aot/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def print_ops_info(
162162

163163
# Print the final ops and their counts in a tabular format
164164
logging.info(
165-
tabulate(
165+
"\n"
166+
+ tabulate(
166167
sorted_ops_count,
167168
headers=[
168169
"Final Operators ", # one character longer than the longest op name

backends/cadence/runtime/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ python_library(
77
srcs = [
88
"__init__.py",
99
"executor.py",
10+
"runtime.py",
11+
"utils.py"
1012
] + glob([
1113
"xtsc-cfg/**/*",
1214
]),

examples/cadence/operators/TARGETS

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
8+
9+
oncall("odai_jarvis")
10+
11+
12+
python_unittest(
13+
name = "test_add_op",
14+
srcs = [
15+
"test_add_op.py",
16+
],
17+
typing = True,
18+
supports_static_listing = False,
19+
deps = [
20+
"fbsource//third-party/pypi/parameterized:parameterized",
21+
"//caffe2:torch",
22+
"//executorch/backends/cadence/aot:ops_registrations",
23+
"//executorch/backends/cadence/aot:export_example",
24+
"//executorch/backends/cadence/aot:compiler",
25+
],
26+
)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import unittest
4+
from typing import Tuple
5+
6+
from parameterized import parameterized
7+
8+
from executorch.backends.cadence.aot.ops_registrations import * # noqa
9+
10+
import torch
11+
import torch.nn as nn
12+
from executorch.backends.cadence.aot.export_example import export_model
13+
14+
15+
class ATenOpTestCases(unittest.TestCase):
16+
@parameterized.expand(
17+
[
18+
[(7, 5, 6), (7, 5, 6)],
19+
[(7, 5, 6), (1)],
20+
[(1), (7, 5, 6)],
21+
[(1), (7, 5, 6), 2.23],
22+
[(1), (7, 5, 6), -1.0],
23+
[(1), (7, 5, 6), -2.23],
24+
[(7, 5, 6), (7, 5, 6), 1.23],
25+
[(6, 7), (6, 7)],
26+
[(6, 7), (6, 7), 2],
27+
# Broadcast tests (should be optimized on G3)
28+
[(1, 32, 64), (1, 1, 64)],
29+
[(1, 32, 64), (64)],
30+
[(1, 1, 32), (32)],
31+
[(16, 1, 16), (1, 1, 16)],
32+
[(16, 1, 16), (16)],
33+
[(1, 4, 8, 8), (1, 1, 8, 8)],
34+
[(1, 4, 8, 8), (8, 8)],
35+
# Broadcast tests (should go to portable ops)
36+
[(1, 10, 1, 8), (4, 1, 4, 1)],
37+
[(1, 1, 16), (1, 8, 1), 2.5],
38+
# # aten.upsample_nearest2d tests
39+
[(5, 6, 6, 8), (5, 6, 6, 8)],
40+
[(1, 1, 12, 16), (1, 1, 12, 16)],
41+
]
42+
)
43+
def test_aten_add_out(
44+
self, Xshape: Tuple[int], Yshape: Tuple[int], alpha: float = 1
45+
) -> None:
46+
class AddTensor(nn.Module):
47+
def __init__(self, alpha: float):
48+
super().__init__()
49+
self.alpha = alpha
50+
51+
def forward(self, x: torch.Tensor, y: torch.Tensor):
52+
return torch.add(x, y, alpha=self.alpha)
53+
54+
model = AddTensor(alpha)
55+
56+
X = torch.randn(Xshape)
57+
Y = torch.randn(Yshape)
58+
59+
model.eval()
60+
export_model(
61+
model, (X, Y), file_name=self._testMethodName, run_and_compare=False
62+
)
63+
64+
@parameterized.expand(
65+
[
66+
[(7, 5, 6), (7, 5, 6)],
67+
[(7, 5, 6), (1)],
68+
[(1), (7, 5, 6)],
69+
[(1), (7, 5, 6), 2.23],
70+
[(1), (7, 5, 6), -1.0],
71+
[(1), (7, 5, 6), -2.23],
72+
[(7, 5, 6), (7, 5, 6), 1.23],
73+
[(6, 7), (6, 7)],
74+
[(6, 7), (6, 7), 2],
75+
# Broadcast tests (should be optimized on G3)
76+
[(1, 32, 64), (1, 1, 64)],
77+
[(1, 32, 64), (64)],
78+
[(1, 1, 32), (32)],
79+
[(16, 1, 16), (1, 1, 16)],
80+
[(16, 1, 16), (16)],
81+
[(1, 4, 8, 8), (1, 1, 8, 8)],
82+
[(1, 4, 8, 8), (8, 8)],
83+
# Broadcast tests (should go to portable ops)
84+
[(1, 10, 1, 8), (4, 1, 4, 1)],
85+
[(1, 1, 16), (1, 8, 1), 2.5],
86+
# # aten.upsample_nearest2d tests
87+
[(5, 6, 6, 8), (5, 6, 6, 8)],
88+
[(1, 1, 12, 16), (1, 1, 12, 16)],
89+
]
90+
)
91+
def test_aten_add_scalar_out(
92+
self, Xshape: Tuple[int], Yshape: Tuple[int], alpha: float = 1
93+
) -> None:
94+
# Tensor-Scalar addition
95+
class AddScalar(nn.Module):
96+
def __init__(self, alpha: float):
97+
super().__init__()
98+
self.alpha = alpha
99+
100+
def forward(self, x: torch.Tensor, y: float):
101+
return torch.add(x, y, alpha=self.alpha)
102+
103+
model = AddScalar(alpha)
104+
105+
X = torch.randn(Xshape)
106+
Y = 2.34
107+
108+
model.eval()
109+
export_model(
110+
model, (X, Y), file_name=self._testMethodName, run_and_compare=False
111+
)
112+
113+
114+
if __name__ == "__main__":
115+
unittest.main()

0 commit comments

Comments
 (0)