Skip to content

Commit 851f5fc

Browse files
authored
fix where self out contraint to make a, b numerical
Differential Revision: D75644590 Pull Request resolved: #11240
1 parent 456ed6a commit 851f5fc

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
MAX_CASES = 50
2121

2222

23-
def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> None:
24-
additional_tensor_constraints = [
23+
def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
24+
tensor_constraints = [
2525
cp.Dtype.In(lambda deps: [torch.int, torch.float]),
2626
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
2727
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
@@ -33,17 +33,28 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
3333

3434
match op_name:
3535
case "where.self":
36-
additional_tensor_constraints = [
37-
cp.Dtype.In(lambda deps: [torch.float, torch.int, torch.bool]),
38-
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
39-
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
40-
cp.Value.Le(lambda deps, dtype, struct: 2**4),
41-
cp.Rank.Ge(lambda deps: 1),
42-
cp.Size.Ge(lambda deps, r, d: 1),
43-
cp.Size.Le(lambda deps, r, d: 2**9),
44-
]
36+
if index == 0: # condition
37+
tensor_constraints = [
38+
cp.Dtype.In(lambda deps: [torch.bool]),
39+
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
40+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
41+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
42+
cp.Rank.Ge(lambda deps: 1),
43+
cp.Size.Ge(lambda deps, r, d: 1),
44+
cp.Size.Le(lambda deps, r, d: 2**9),
45+
]
46+
else:
47+
tensor_constraints = [
48+
cp.Dtype.In(lambda deps: [torch.float, torch.int]),
49+
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
50+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
51+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
52+
cp.Rank.Ge(lambda deps: 1),
53+
cp.Size.Ge(lambda deps, r, d: 1),
54+
cp.Size.Le(lambda deps, r, d: 2**9),
55+
]
4556
case "sigmoid.default":
46-
additional_tensor_constraints.extend(
57+
tensor_constraints.extend(
4758
[
4859
cp.Dtype.In(lambda deps: [torch.float]),
4960
cp.Rank.Le(lambda deps: 2**2),
@@ -52,7 +63,7 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
5263
]
5364
)
5465
case "rsqrt.default":
55-
additional_tensor_constraints.extend(
66+
tensor_constraints.extend(
5667
[
5768
cp.Dtype.In(lambda deps: [torch.float]),
5869
cp.Rank.Le(lambda deps: 2**2),
@@ -63,35 +74,35 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
6374
]
6475
)
6576
case "mean.dim":
66-
additional_tensor_constraints.extend(
77+
tensor_constraints.extend(
6778
[
6879
cp.Dtype.In(lambda deps: [torch.float]),
6980
cp.Rank.Le(lambda deps: 2**2),
7081
]
7182
)
7283
case "exp.default":
73-
additional_tensor_constraints.extend(
84+
tensor_constraints.extend(
7485
[
7586
cp.Rank.Le(lambda deps: 2**3),
7687
cp.Value.Ge(lambda deps, dtype, struct: -(2**2)),
7788
cp.Value.Le(lambda deps, dtype, struct: 2**2),
7889
]
7990
)
8091
case "slice_copy.Tensor":
81-
additional_tensor_constraints.extend(
92+
tensor_constraints.extend(
8293
[
8394
cp.Rank.Le(lambda deps: 2),
8495
cp.Value.Ge(lambda deps, dtype, struct: 1),
8596
cp.Value.Le(lambda deps, dtype, struct: 2),
8697
]
8798
)
8899
case _:
89-
additional_tensor_constraints.extend(
100+
tensor_constraints.extend(
90101
[
91102
cp.Rank.Le(lambda deps: 2**2),
92103
]
93104
)
94-
tensor_constraints.extend(additional_tensor_constraints)
105+
return tensor_constraints
95106

96107

97108
def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
@@ -107,9 +118,6 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
107118
def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, str]]]:
108119
# minimal example to test add.Tensor using FACTO
109120
spec = SpecDictDB[op_name]
110-
tensor_constraints = []
111-
# common tensor constraints
112-
apply_tensor_contraints(op_name, tensor_constraints)
113121

114122
for index, in_spec in enumerate(copy.deepcopy(spec.inspec)):
115123
if in_spec.type.is_scalar():
@@ -142,7 +150,9 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
142150
]
143151
)
144152
elif in_spec.type.is_tensor():
145-
spec.inspec[index].constraints.extend(tensor_constraints)
153+
spec.inspec[index].constraints.extend(
154+
apply_tensor_contraints(op_name, index)
155+
)
146156
elif in_spec.type.is_dim_list():
147157
spec.inspec[index].constraints.extend(
148158
[

0 commit comments

Comments
 (0)