Skip to content

Commit 6ae27f6

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
add size and rank of nightly facto (#7962)
Summary: x2 scale number of testcases Differential Revision: D68690963
1 parent 3f8c4d8 commit 6ae27f6

File tree

1 file changed

+12
-34
lines changed

1 file changed

+12
-34
lines changed

examples/cadence/operators/facto_util.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,58 +18,36 @@
1818

1919
def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> None:
2020
match op_name:
21-
case (
22-
"sigmoid.default"
23-
| "_softmax.default"
24-
| "rsqrt.default"
25-
| "exp.default"
26-
| "mul.Tensor"
27-
| "div.Tensor"
28-
):
21+
case "sigmoid.default" | "rsqrt.default":
2922
tensor_constraints.extend(
3023
[
3124
cp.Dtype.In(lambda deps: [torch.float]),
32-
cp.Size.Le(lambda deps, r, d: 2),
33-
cp.Rank.Le(lambda deps: 2),
25+
cp.Rank.Le(lambda deps: 2**3),
3426
]
3527
)
36-
case (
37-
"add.Tensor"
38-
| "sub.Tensor"
39-
| "add.Scalar"
40-
| "sub.Scalar"
41-
| "mul.Scalar"
42-
| "div.Scalar"
43-
):
28+
case "exp.default":
4429
tensor_constraints.extend(
4530
[
46-
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
47-
cp.Size.Le(lambda deps, r, d: 2),
48-
cp.Rank.Le(lambda deps: 2),
49-
]
50-
)
51-
case "native_layer_norm.default":
52-
tensor_constraints.extend(
53-
[
54-
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
55-
cp.Size.Le(lambda deps, r, d: 2**4),
56-
cp.Rank.Le(lambda deps: 2**4),
31+
cp.Rank.Le(lambda deps: 2**3),
32+
cp.Value.Ge(lambda deps, dtype, struct: -1),
33+
cp.Value.Le(lambda deps, dtype, struct: 1),
5734
]
5835
)
5936
case _:
6037
tensor_constraints.extend(
6138
[
62-
cp.Dtype.In(lambda deps: [torch.float, torch.int32]),
63-
cp.Size.Le(lambda deps, r, d: 2),
64-
cp.Rank.Le(lambda deps: 2),
39+
cp.Rank.Le(lambda deps: 2**2),
6540
]
6641
)
6742
tensor_constraints.extend(
6843
[
69-
cp.Value.Ge(lambda deps, dtype, struct: -(2**8)),
70-
cp.Value.Le(lambda deps, dtype, struct: 2**8),
44+
cp.Dtype.In(lambda deps: [torch.int, torch.float]),
45+
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
46+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
47+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
7148
cp.Rank.Ge(lambda deps: 1),
7249
cp.Size.Ge(lambda deps, r, d: 1),
50+
cp.Size.Le(lambda deps, r, d: 2**9),
7351
]
7452
)
7553

0 commit comments

Comments
 (0)