Skip to content

Commit e92bb7a

Browse files
authored
add mean to g3 nightly
Differential Revision: D68845592 Pull Request resolved: #8059
1 parent 682c636 commit e92bb7a

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

examples/cadence/operators/facto_util.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,16 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
2222
tensor_constraints.extend(
2323
[
2424
cp.Dtype.In(lambda deps: [torch.float]),
25-
cp.Rank.Le(lambda deps: 2**3),
25+
cp.Rank.Le(lambda deps: 2**2),
26+
cp.Value.Ge(lambda deps, dtype, struct: -2),
27+
cp.Value.Le(lambda deps, dtype, struct: 2),
28+
]
29+
)
30+
case "mean.dim":
31+
tensor_constraints.extend(
32+
[
33+
cp.Dtype.In(lambda deps: [torch.float]),
34+
cp.Rank.Le(lambda deps: 2**2),
2635
]
2736
)
2837
case "exp.default":
@@ -86,8 +95,27 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
8695
cp.Value.Le(lambda deps, dtype: 2),
8796
]
8897
)
98+
elif in_spec.type.is_scalar_type():
99+
spec.inspec[index].constraints.extend(
100+
[
101+
cp.Dtype.In(lambda deps: apply_scalar_contraints(op_name)),
102+
]
103+
)
89104
elif in_spec.type.is_tensor():
90105
spec.inspec[index].constraints.extend(tensor_constraints)
106+
elif in_spec.type.is_dim_list():
107+
spec.inspec[index].constraints.extend(
108+
[
109+
cp.Length.Ge(lambda deps: 1),
110+
cp.Optional.Eq(lambda deps: False),
111+
]
112+
)
113+
elif in_spec.type.is_bool():
114+
spec.inspec[index].constraints.extend(
115+
[
116+
cp.Dtype.In(lambda deps: [torch.bool]),
117+
]
118+
)
91119

92120
return [
93121
(posargs, inkwargs)

examples/cadence/operators/test_g3_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,35 @@ def test_g3__softmax_out(
259259

260260
self.run_and_verify(model, (inputs,))
261261

262+
# pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
263+
@parameterized.expand([*facto_util.facto_testcase_gen("mean.dim")])
264+
def test_g3_mean_dim_out(
265+
self,
266+
posargs: List[int],
267+
inkwargs: OrderedDict[str, str],
268+
) -> None:
269+
class Meandim(nn.Module):
270+
def forward(
271+
self,
272+
x: torch.Tensor,
273+
dim_list: Tuple[int],
274+
keepdim: bool,
275+
dtype: torch.dtype = torch.float32,
276+
) -> torch.Tensor:
277+
return torch.ops.aten.mean.dim(
278+
x,
279+
dim_list,
280+
keepdim,
281+
dtype=dtype,
282+
)
283+
284+
model = Meandim()
285+
286+
self.run_and_verify(
287+
model,
288+
inputs=tuple(posargs),
289+
)
290+
262291

263292
if __name__ == "__main__":
264293
unittest.main()

0 commit comments

Comments
 (0)