Skip to content

Commit 37223a7

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
add mean to g3 nightly (#8059)
Summary: titled expect loss in some cases Reviewed By: hsharma35 Differential Revision: D68845592
1 parent e55a3b0 commit 37223a7

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

examples/cadence/operators/facto_util.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import copy
66
from typing import List, OrderedDict, Tuple
77

8+
import facto.specdb.function as fn
9+
810
import torch
911
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
1012
from facto.inputgen.specs.model import ConstraintProducer as cp
@@ -22,7 +24,16 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
2224
tensor_constraints.extend(
2325
[
2426
cp.Dtype.In(lambda deps: [torch.float]),
25-
cp.Rank.Le(lambda deps: 2**3),
27+
cp.Rank.Le(lambda deps: 2**2),
28+
cp.Value.Ge(lambda deps, dtype, struct: -2),
29+
cp.Value.Le(lambda deps, dtype, struct: 2),
30+
]
31+
)
32+
case "mean.dim":
33+
tensor_constraints.extend(
34+
[
35+
cp.Dtype.In(lambda deps: [torch.float]),
36+
cp.Rank.Le(lambda deps: 2**2),
2637
]
2738
)
2839
case "exp.default":
@@ -86,8 +97,27 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
8697
cp.Value.Le(lambda deps, dtype: 2),
8798
]
8899
)
100+
elif in_spec.type.is_scalar_type():
101+
spec.inspec[index].constraints.extend(
102+
[
103+
cp.Dtype.In(lambda deps: apply_scalar_contraints(op_name)),
104+
]
105+
)
89106
elif in_spec.type.is_tensor():
90107
spec.inspec[index].constraints.extend(tensor_constraints)
108+
elif in_spec.type.is_dim_list():
109+
spec.inspec[index].constraints.extend(
110+
[
111+
cp.Length.Ge(lambda deps: 1),
112+
cp.Optional.Eq(lambda deps: False),
113+
]
114+
)
115+
elif in_spec.type.is_bool():
116+
spec.inspec[index].constraints.extend(
117+
[
118+
cp.Dtype.In(lambda deps: [torch.bool]),
119+
]
120+
)
91121

92122
return [
93123
(posargs, inkwargs)

examples/cadence/operators/test_g3_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,35 @@ def test_g3__softmax_out(
259259

260260
self.run_and_verify(model, (inputs,))
261261

262+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
263+
@parameterized.expand([*facto_util.facto_testcase_gen("mean.dim")])
264+
def test_g3_mean_dim_out(
265+
self,
266+
posargs: List[int],
267+
inkwargs: OrderedDict[str, str],
268+
) -> None:
269+
class Meandim(nn.Module):
270+
def forward(
271+
self,
272+
x: torch.Tensor,
273+
dim_list: Tuple[int],
274+
keepdim: bool,
275+
dtype: torch.dtype = torch.float32,
276+
) -> torch.Tensor:
277+
return torch.ops.aten.mean.dim(
278+
x,
279+
dim_list,
280+
keepdim,
281+
dtype=dtype,
282+
)
283+
284+
model = Meandim()
285+
286+
self.run_and_verify(
287+
model,
288+
inputs=tuple(posargs),
289+
)
290+
262291

263292
if __name__ == "__main__":
264293
unittest.main()

0 commit comments

Comments
 (0)