@@ -20,56 +20,39 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
20
20
match op_name :
21
21
case (
22
22
"sigmoid.default"
23
- | "_softmax.default"
24
23
| "rsqrt.default"
25
- | "exp.default"
26
- | "mul.Tensor"
27
- | "div.Tensor"
28
24
):
29
25
tensor_constraints .extend (
30
26
[
31
27
cp .Dtype .In (lambda deps : [torch .float ]),
32
- cp .Size .Le (lambda deps , r , d : 2 ),
33
- cp .Rank .Le (lambda deps : 2 ),
28
+ cp .Rank .Le (lambda deps : 2 ** 3 ),
34
29
]
35
30
)
36
31
case (
37
- "add.Tensor"
38
- | "sub.Tensor"
39
- | "add.Scalar"
40
- | "sub.Scalar"
41
- | "mul.Scalar"
42
- | "div.Scalar"
32
+ "exp.default"
43
33
):
44
34
tensor_constraints .extend (
45
35
[
46
- cp .Dtype .In (lambda deps : [torch .float , torch .int32 ]),
47
- cp .Size .Le (lambda deps , r , d : 2 ),
48
- cp .Rank .Le (lambda deps : 2 ),
49
- ]
50
- )
51
- case "native_layer_norm.default" :
52
- tensor_constraints .extend (
53
- [
54
- cp .Dtype .In (lambda deps : [torch .float , torch .int32 ]),
55
- cp .Size .Le (lambda deps , r , d : 2 ** 4 ),
56
- cp .Rank .Le (lambda deps : 2 ** 4 ),
36
+ cp .Rank .Le (lambda deps : 2 ** 3 ),
37
+ cp .Value .Ge (lambda deps , dtype , struct : - 1 ),
38
+ cp .Value .Le (lambda deps , dtype , struct : 1 ),
57
39
]
58
40
)
59
41
case _:
60
42
tensor_constraints .extend (
61
43
[
62
- cp .Dtype .In (lambda deps : [torch .float , torch .int32 ]),
63
- cp .Size .Le (lambda deps , r , d : 2 ),
64
- cp .Rank .Le (lambda deps : 2 ),
44
+ cp .Rank .Le (lambda deps : 2 ** 2 ),
65
45
]
66
46
)
67
47
tensor_constraints .extend (
68
48
[
69
- cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 8 )),
70
- cp .Value .Le (lambda deps , dtype , struct : 2 ** 8 ),
49
+ cp .Dtype .In (lambda deps : [torch .int , torch .float ]),
50
+ cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
51
+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
52
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
71
53
cp .Rank .Ge (lambda deps : 1 ),
72
54
cp .Size .Ge (lambda deps , r , d : 1 ),
55
+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
73
56
]
74
57
)
75
58
0 commit comments