Skip to content

Commit 37a3301

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
add all g3 tests using facto, fix sub scalar model, shrink input rank and size (#7707)
Summary: Pull Request resolved: #7707 It has to be in the whole file to retain the same order. Will dedup in following diffs fixed FACTO and sub scalar test cases Reproducing testcases from internal testing. ``` ✗ Fail: on_device_ai/Assistant/Jarvis/nightly:test_g3_nightly - test_g3_sub_tensor_out_11 (on_device_ai.Assistant.Jarvis.nightly.test_g3_nightly.TestOperators) (25.3s) /data/users/zonglinpeng/fbsource 7c1566d4aa2e+ **************************************************************************************************** OrderedDict([('alpha', 1.6130766369937761)]) [tensor([[ 254, -199]], dtype=torch.int32), tensor([[-22.2500, 168.7500], [147.8750, 247.8750]])] ``` VS ``` ✓ Pass: executorch/examples/cadence/operators:test_g3_ops - test_g3_sub_tensor_out_11 (executorch.examples.cadence.operators.test_g3_ops.ATenOpTestCases) (1.0s) **************************************************************************************************** [tensor([[ 254, -199]], dtype=torch.int32), tensor([[-22.2500, 168.7500], [147.8750, 247.8750]])] OrderedDict([('alpha', 1.6130766369937761)]) ``` Differential Revision: D68195603
1 parent 4e80560 commit 37a3301

File tree

3 files changed

+304
-12
lines changed

3 files changed

+304
-12
lines changed

examples/cadence/operators/facto_util.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
2626
| "mul.Tensor"
2727
| "div.Tensor"
2828
):
29-
tensor_constraints.append(
30-
cp.Dtype.In(lambda deps: [torch.float]),
29+
tensor_constraints.extend(
30+
[
31+
cp.Dtype.In(lambda deps: [torch.float]),
32+
cp.Size.Le(lambda deps, r, d: 2),
33+
cp.Rank.Le(lambda deps: 2),
34+
]
3135
)
3236
case (
3337
"add.Tensor"
@@ -37,35 +41,60 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
3741
| "mul.Scalar"
3842
| "div.Scalar"
3943
):
40-
tensor_constraints.append(
41-
cp.Dtype.In(lambda deps: [torch.float, torch.int]),
44+
tensor_constraints.extend(
45+
[
46+
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
47+
cp.Size.Le(lambda deps, r, d: 2),
48+
cp.Rank.Le(lambda deps: 2),
49+
]
50+
)
51+
case "native_layer_norm.default":
52+
tensor_constraints.extend(
53+
[
54+
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
55+
cp.Size.Le(lambda deps, r, d: 2**4),
56+
cp.Rank.Le(lambda deps: 2**4),
57+
]
4258
)
4359
case _:
44-
tensor_constraints.append(
45-
cp.Dtype.In(lambda deps: [torch.float, torch.int]),
60+
tensor_constraints.extend(
61+
[
62+
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
63+
cp.Size.Le(lambda deps, r, d: 2),
64+
cp.Rank.Le(lambda deps: 2),
65+
]
4666
)
4767
tensor_constraints.extend(
4868
[
4969
cp.Value.Ge(lambda deps, dtype, struct: -(2**8)),
5070
cp.Value.Le(lambda deps, dtype, struct: 2**8),
5171
cp.Rank.Ge(lambda deps: 1),
52-
cp.Rank.Le(lambda deps: 2**2),
5372
cp.Size.Ge(lambda deps, r, d: 1),
54-
cp.Size.Le(lambda deps, r, d: 2**2),
5573
]
5674
)
5775

5876

77+
def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
78+
match op_name:
79+
case "add.Scalar" | "sub.Scalar" | "mul.Scalar" | "div.Scalar":
80+
return [ScalarDtype.int]
81+
case _:
82+
return [ScalarDtype.float, ScalarDtype.int]
83+
84+
5985
def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, str]]]:
6086
# minimal example to test add.Tensor using FACTO
6187
spec = SpecDictDB[op_name]
88+
tensor_constraints = []
89+
# common tensor constraints
90+
apply_tensor_contraints(op_name, tensor_constraints)
6291

6392
for index, in_spec in enumerate(copy.deepcopy(spec.inspec)):
6493
if in_spec.type.is_scalar():
6594
if in_spec.name != "alpha":
6695
spec.inspec[index].constraints.extend(
6796
[
68-
cp.Dtype.In(lambda deps: [ScalarDtype.float, ScalarDtype.int]),
97+
cp.Dtype.In(lambda deps: apply_scalar_contraints(op_name)),
6998
cp.Value.Ge(lambda deps, dtype: -(2**8)),
7099
cp.Value.Le(lambda deps, dtype: 2**2),
71100
cp.Size.Ge(lambda deps, r, d: 1),
@@ -80,9 +109,6 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
80109
]
81110
)
82111
elif in_spec.type.is_tensor():
83-
tensor_constraints = []
84-
# common tensor constraints
85-
apply_tensor_contraints(op_name, tensor_constraints)
86112
spec.inspec[index].constraints.extend(tensor_constraints)
87113

88114
return [

examples/cadence/operators/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
99

1010
TESTS_LIST = [
1111
"add_op",
12+
"g3_ops",
1213
"quantized_conv1d_op",
1314
"quantized_linear_op",
1415
]
@@ -46,5 +47,6 @@ def _define_test_target(test_name):
4647
"fbcode//executorch/backends/cadence/aot:ops_registrations",
4748
"fbcode//executorch/backends/cadence/aot:export_example",
4849
"fbcode//executorch/backends/cadence/aot:compiler",
50+
"fbcode//executorch/examples/cadence/operators:facto_util",
4951
],
5052
)
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
import unittest
2+
from typing import Any, cast, List, OrderedDict, Tuple
3+
4+
from executorch.examples.cadence.operators import facto_util
5+
6+
from parameterized import parameterized
7+
8+
from executorch.backends.cadence.aot.ops_registrations import * # noqa
9+
10+
import torch
11+
import torch.nn as nn
12+
from executorch.backends.cadence.aot.export_example import export_model
13+
14+
15+
class ATenOpTestCases(unittest.TestCase):
16+
def run_and_verify(self, model: nn.Module, inputs: Tuple[Any, ...]) -> None:
17+
model.eval()
18+
export_model(
19+
model, inputs, file_name=self._testMethodName, run_and_compare=False
20+
)
21+
22+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
23+
@parameterized.expand([*facto_util.facto_testcase_gen("add.Tensor")])
24+
@torch.no_grad()
25+
def test_g3_add_tensor_out(
26+
self,
27+
posargs: List[str],
28+
inkwargs: OrderedDict[str, str],
29+
) -> None:
30+
class AddTensor(nn.Module):
31+
def __init__(self, alpha: float):
32+
super().__init__()
33+
self.alpha = alpha
34+
35+
def forward(self, x: torch.Tensor, y: torch.Tensor):
36+
return torch.add(x, y, alpha=self.alpha)
37+
38+
model = AddTensor(**inkwargs)
39+
40+
self.run_and_verify(model, tuple(posargs))
41+
42+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
43+
@parameterized.expand([*facto_util.facto_testcase_gen("add.Scalar")])
44+
@torch.no_grad()
45+
def test_aten_add_Scalar_out(
46+
self,
47+
posargs: List[str],
48+
inkwargs: OrderedDict[str, str],
49+
) -> None:
50+
class AddScalar(nn.Module):
51+
def __init__(self, alpha: float):
52+
super().__init__()
53+
self.alpha = alpha
54+
55+
def forward(self, x: torch.Tensor, y: float):
56+
return torch.add(x, y, alpha=self.alpha)
57+
58+
inputs = posargs[:-1] # posargs = [x_tensor, y_scalar, alpha_scalar]
59+
alpha = posargs[-1]
60+
model = AddScalar(alpha)
61+
62+
self.run_and_verify(model, tuple(inputs))
63+
64+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
65+
@parameterized.expand([*facto_util.facto_testcase_gen("sub.Tensor")])
66+
@torch.no_grad()
67+
def test_g3_sub_tensor_out(
68+
self,
69+
posargs: List[str],
70+
inkwargs: OrderedDict[str, str],
71+
) -> None:
72+
class SubTensor(nn.Module):
73+
def __init__(self, alpha: float):
74+
super().__init__()
75+
self.alpha = alpha
76+
77+
def forward(self, x: torch.Tensor, y: torch.Tensor):
78+
return torch.sub(x, y, alpha=self.alpha)
79+
80+
model = SubTensor(**inkwargs)
81+
82+
self.run_and_verify(model, tuple(posargs))
83+
84+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
85+
@parameterized.expand([*facto_util.facto_testcase_gen("sub.Scalar")])
86+
@torch.no_grad()
87+
def test_g3_sub_scalar_out(
88+
self,
89+
posargs: List[str],
90+
inkwargs: OrderedDict[str, str],
91+
) -> None:
92+
# Tensor-Scalar subtraction
93+
class SubScalar(torch.nn.Module):
94+
def __init__(self, other):
95+
super().__init__()
96+
self.other = other
97+
98+
def forward(self, x):
99+
return torch.ops.aten.sub.Scalar(x, self.other)
100+
101+
inputs = posargs[0] # posargs = [x_tensor, y_scalar, alpha_scalar]
102+
model = SubScalar(posargs[1])
103+
104+
self.run_and_verify(model, (inputs,))
105+
106+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
107+
@parameterized.expand([*facto_util.facto_testcase_gen("div.Tensor")])
108+
@torch.no_grad()
109+
def test_g3_div_tensor_out(
110+
self,
111+
posargs: List[str],
112+
inkwargs: OrderedDict[str, str],
113+
) -> None:
114+
class DivTensor(nn.Module):
115+
def forward(self, x: torch.Tensor, y: torch.Tensor):
116+
return torch.div(x, y + 1)
117+
118+
model = DivTensor(**inkwargs)
119+
120+
self.run_and_verify(model, tuple(posargs))
121+
122+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
123+
@parameterized.expand([*facto_util.facto_testcase_gen("div.Scalar")])
124+
@torch.no_grad()
125+
def test_g3_div_scalar_out(
126+
self,
127+
posargs: List[str],
128+
inkwargs: OrderedDict[str, str],
129+
) -> None:
130+
class DivScalar(nn.Module):
131+
def forward(self, x: torch.Tensor, y: torch.Tensor):
132+
return torch.div(x, y + 1)
133+
134+
model = DivScalar(**inkwargs)
135+
136+
self.run_and_verify(model, tuple(posargs))
137+
138+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
139+
@parameterized.expand([*facto_util.facto_testcase_gen("exp.default")])
140+
@torch.no_grad()
141+
def test_g3_exp_out(
142+
self,
143+
posargs: List[str],
144+
inkwargs: OrderedDict[str, str],
145+
) -> None:
146+
class Exp(nn.Module):
147+
def forward(self, x: torch.Tensor):
148+
return torch.exp(x)
149+
150+
model = Exp(**inkwargs)
151+
152+
self.run_and_verify(model, tuple(posargs))
153+
154+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
155+
@parameterized.expand([*facto_util.facto_testcase_gen("mul.Tensor")])
156+
@torch.no_grad()
157+
def test_g3_mul_tensor_out(
158+
self,
159+
posargs: List[str],
160+
inkwargs: OrderedDict[str, str],
161+
) -> None:
162+
class MulTensor(nn.Module):
163+
def forward(self, x: torch.Tensor, y: torch.Tensor):
164+
return x * y
165+
166+
model = MulTensor(**inkwargs)
167+
168+
self.run_and_verify(model, tuple(posargs))
169+
170+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
171+
@parameterized.expand([*facto_util.facto_testcase_gen("mul.Scalar")])
172+
@torch.no_grad()
173+
def test_g3_mul_scalar_out(
174+
self,
175+
posargs: List[str],
176+
inkwargs: OrderedDict[str, str],
177+
) -> None:
178+
class MulScalar(nn.Module):
179+
def forward(self, x: torch.Tensor, y: torch.Tensor):
180+
return x * y
181+
182+
model = MulScalar(**inkwargs)
183+
184+
self.run_and_verify(model, tuple(posargs))
185+
186+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
187+
@parameterized.expand([*facto_util.facto_testcase_gen("native_layer_norm.default")])
188+
@torch.no_grad()
189+
def test_g3_native_layer_norm_out(
190+
self,
191+
posargs: List[int],
192+
inkwargs: OrderedDict[str, str],
193+
) -> None:
194+
inputs, normalized_shape, weight, bias, _ = posargs
195+
model = nn.LayerNorm(normalized_shape, eps=1e-5)
196+
if weight is not None:
197+
weight = cast(torch.Tensor, weight)
198+
model.weight = nn.Parameter(torch.rand_like(weight))
199+
if bias is not None:
200+
bias = cast(torch.Tensor, bias)
201+
model.bias = nn.Parameter(torch.rand_like(bias))
202+
203+
self.run_and_verify(model, (inputs,))
204+
205+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
206+
@parameterized.expand([*facto_util.facto_testcase_gen("neg.default")])
207+
@torch.no_grad()
208+
def test_g3_neg_out(
209+
self,
210+
posargs: List[int],
211+
inkwargs: OrderedDict[str, str],
212+
) -> None:
213+
class Neg(nn.Module):
214+
def forward(self, x: torch.Tensor) -> torch.Tensor:
215+
return torch.neg(x)
216+
217+
model = Neg(**inkwargs)
218+
219+
self.run_and_verify(model, tuple(posargs))
220+
221+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
222+
@parameterized.expand([*facto_util.facto_testcase_gen("rsqrt.default")])
223+
@torch.no_grad()
224+
def test_g3_rsqrt_out(
225+
self,
226+
posargs: List[int],
227+
inkwargs: OrderedDict[str, str],
228+
) -> None:
229+
class Rsqrt(nn.Module):
230+
def forward(self, x: torch.Tensor):
231+
return torch.ops.aten.rsqrt(x)
232+
233+
model = Rsqrt(**inkwargs)
234+
235+
self.run_and_verify(model, tuple(posargs))
236+
237+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
238+
@parameterized.expand([*facto_util.facto_testcase_gen("sigmoid.default")])
239+
@torch.no_grad()
240+
def test_g3_sigmoid_out(
241+
self,
242+
posargs: List[int],
243+
inkwargs: OrderedDict[str, str],
244+
) -> None:
245+
model = nn.Sigmoid(**inkwargs)
246+
247+
self.run_and_verify(model, tuple(posargs))
248+
249+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
250+
@parameterized.expand([*facto_util.facto_testcase_gen("_softmax.default")])
251+
@torch.no_grad()
252+
def test_g3__softmax_out(
253+
self,
254+
posargs: List[int],
255+
inkwargs: OrderedDict[str, str],
256+
) -> None:
257+
inputs, _, _ = posargs
258+
model = nn.Softmax(dim=-1)
259+
260+
self.run_and_verify(model, (inputs,))
261+
262+
263+
if __name__ == "__main__":
264+
unittest.main()

0 commit comments

Comments
 (0)