|
18 | 18 |
|
19 | 19 | def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> None:
|
20 | 20 | match op_name:
|
21 |
| - case ( |
22 |
| - "sigmoid.default" |
23 |
| - | "_softmax.default" |
24 |
| - | "rsqrt.default" |
25 |
| - | "exp.default" |
26 |
| - | "mul.Tensor" |
27 |
| - | "div.Tensor" |
28 |
| - ): |
| 21 | + case "sigmoid.default" | "rsqrt.default": |
29 | 22 | tensor_constraints.extend(
|
30 | 23 | [
|
31 | 24 | cp.Dtype.In(lambda deps: [torch.float]),
|
32 |
| - cp.Size.Le(lambda deps, r, d: 2), |
33 |
| - cp.Rank.Le(lambda deps: 2), |
| 25 | + cp.Rank.Le(lambda deps: 2**3), |
34 | 26 | ]
|
35 | 27 | )
|
36 |
| - case ( |
37 |
| - "add.Tensor" |
38 |
| - | "sub.Tensor" |
39 |
| - | "add.Scalar" |
40 |
| - | "sub.Scalar" |
41 |
| - | "mul.Scalar" |
42 |
| - | "div.Scalar" |
43 |
| - ): |
| 28 | + case "exp.default": |
44 | 29 | tensor_constraints.extend(
|
45 | 30 | [
|
46 |
| - cp.Dtype.In(lambda deps: [torch.float, torch.int32]), |
47 |
| - cp.Size.Le(lambda deps, r, d: 2), |
48 |
| - cp.Rank.Le(lambda deps: 2), |
49 |
| - ] |
50 |
| - ) |
51 |
| - case "native_layer_norm.default": |
52 |
| - tensor_constraints.extend( |
53 |
| - [ |
54 |
| - cp.Dtype.In(lambda deps: [torch.float, torch.int32]), |
55 |
| - cp.Size.Le(lambda deps, r, d: 2**4), |
56 |
| - cp.Rank.Le(lambda deps: 2**4), |
| 31 | + cp.Rank.Le(lambda deps: 2**3), |
| 32 | + cp.Value.Ge(lambda deps, dtype, struct: -1), |
| 33 | + cp.Value.Le(lambda deps, dtype, struct: 1), |
57 | 34 | ]
|
58 | 35 | )
|
59 | 36 | case _:
|
60 | 37 | tensor_constraints.extend(
|
61 | 38 | [
|
62 |
| - cp.Dtype.In(lambda deps: [torch.float, torch.int32]), |
63 |
| - cp.Size.Le(lambda deps, r, d: 2), |
64 |
| - cp.Rank.Le(lambda deps: 2), |
| 39 | + cp.Rank.Le(lambda deps: 2**2), |
65 | 40 | ]
|
66 | 41 | )
|
67 | 42 | tensor_constraints.extend(
|
68 | 43 | [
|
69 |
| - cp.Value.Ge(lambda deps, dtype, struct: -(2**8)), |
70 |
| - cp.Value.Le(lambda deps, dtype, struct: 2**8), |
| 44 | + cp.Dtype.In(lambda deps: [torch.int, torch.float]), |
| 45 | + cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]), |
| 46 | + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), |
| 47 | + cp.Value.Le(lambda deps, dtype, struct: 2**4), |
71 | 48 | cp.Rank.Ge(lambda deps: 1),
|
72 | 49 | cp.Size.Ge(lambda deps, r, d: 1),
|
| 50 | + cp.Size.Le(lambda deps, r, d: 2**9), |
73 | 51 | ]
|
74 | 52 | )
|
75 | 53 |
|
|
0 commit comments