20
20
MAX_CASES = 50
21
21
22
22
23
- def apply_tensor_contraints (op_name : str , tensor_constraints : list [ object ] ) -> None :
24
- additional_tensor_constraints = [
23
+ def apply_tensor_contraints (op_name : str , index : int ) -> list [ object ] :
24
+ tensor_constraints = [
25
25
cp .Dtype .In (lambda deps : [torch .int , torch .float ]),
26
26
cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
27
27
cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
@@ -33,17 +33,28 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
33
33
34
34
match op_name :
35
35
case "where.self" :
36
- additional_tensor_constraints = [
37
- cp .Dtype .In (lambda deps : [torch .float , torch .int , torch .bool ]),
38
- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
39
- cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
40
- cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
41
- cp .Rank .Ge (lambda deps : 1 ),
42
- cp .Size .Ge (lambda deps , r , d : 1 ),
43
- cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
44
- ]
36
+ if index == 0 : # condition
37
+ tensor_constraints = [
38
+ cp .Dtype .In (lambda deps : [torch .bool ]),
39
+ cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
40
+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
41
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
42
+ cp .Rank .Ge (lambda deps : 1 ),
43
+ cp .Size .Ge (lambda deps , r , d : 1 ),
44
+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
45
+ ]
46
+ else :
47
+ tensor_constraints = [
48
+ cp .Dtype .In (lambda deps : [torch .float , torch .int ]),
49
+ cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
50
+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
51
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
52
+ cp .Rank .Ge (lambda deps : 1 ),
53
+ cp .Size .Ge (lambda deps , r , d : 1 ),
54
+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
55
+ ]
45
56
case "sigmoid.default" :
46
- additional_tensor_constraints .extend (
57
+ tensor_constraints .extend (
47
58
[
48
59
cp .Dtype .In (lambda deps : [torch .float ]),
49
60
cp .Rank .Le (lambda deps : 2 ** 2 ),
@@ -52,7 +63,7 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
52
63
]
53
64
)
54
65
case "rsqrt.default" :
55
- additional_tensor_constraints .extend (
66
+ tensor_constraints .extend (
56
67
[
57
68
cp .Dtype .In (lambda deps : [torch .float ]),
58
69
cp .Rank .Le (lambda deps : 2 ** 2 ),
@@ -63,35 +74,35 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
63
74
]
64
75
)
65
76
case "mean.dim" :
66
- additional_tensor_constraints .extend (
77
+ tensor_constraints .extend (
67
78
[
68
79
cp .Dtype .In (lambda deps : [torch .float ]),
69
80
cp .Rank .Le (lambda deps : 2 ** 2 ),
70
81
]
71
82
)
72
83
case "exp.default" :
73
- additional_tensor_constraints .extend (
84
+ tensor_constraints .extend (
74
85
[
75
86
cp .Rank .Le (lambda deps : 2 ** 3 ),
76
87
cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 2 )),
77
88
cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
78
89
]
79
90
)
80
91
case "slice_copy.Tensor" :
81
- additional_tensor_constraints .extend (
92
+ tensor_constraints .extend (
82
93
[
83
94
cp .Rank .Le (lambda deps : 2 ),
84
95
cp .Value .Ge (lambda deps , dtype , struct : 1 ),
85
96
cp .Value .Le (lambda deps , dtype , struct : 2 ),
86
97
]
87
98
)
88
99
case _:
89
- additional_tensor_constraints .extend (
100
+ tensor_constraints .extend (
90
101
[
91
102
cp .Rank .Le (lambda deps : 2 ** 2 ),
92
103
]
93
104
)
94
- tensor_constraints . extend ( additional_tensor_constraints )
105
+ return tensor_constraints
95
106
96
107
97
108
def apply_scalar_contraints (op_name : str ) -> list [ScalarDtype ]:
@@ -107,9 +118,6 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
107
118
def facto_testcase_gen (op_name : str ) -> List [Tuple [List [str ], OrderedDict [str , str ]]]:
108
119
# minimal example to test add.Tensor using FACTO
109
120
spec = SpecDictDB [op_name ]
110
- tensor_constraints = []
111
- # common tensor constraints
112
- apply_tensor_contraints (op_name , tensor_constraints )
113
121
114
122
for index , in_spec in enumerate (copy .deepcopy (spec .inspec )):
115
123
if in_spec .type .is_scalar ():
@@ -142,7 +150,9 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
142
150
]
143
151
)
144
152
elif in_spec .type .is_tensor ():
145
- spec .inspec [index ].constraints .extend (tensor_constraints )
153
+ spec .inspec [index ].constraints .extend (
154
+ apply_tensor_contraints (op_name , index )
155
+ )
146
156
elif in_spec .type .is_dim_list ():
147
157
spec .inspec [index ].constraints .extend (
148
158
[
0 commit comments