|
7 | 7 |
|
8 | 8 |
|
9 | 9 | class TestNegConverter(DispatchTestCase):
|
10 |
| - def test_neg(self): |
| 10 | + @parameterized.expand( |
| 11 | + [ |
| 12 | + ("2d_dim_dtype_float", (2, 2), torch.float), |
| 13 | + ("3d_dim_dtype_float", (2, 2, 2), torch.float), |
| 14 | + |
| 15 | + ] |
| 16 | + ) |
| 17 | + def test_neg_float(self, _, x, type): |
11 | 18 | class neg(nn.Module):
|
12 | 19 | def forward(self, input):
|
13 | 20 | return torch.neg(input)
|
14 |
| - |
15 |
| - inputs = [torch.randn(1, 10)] |
| 21 | + |
| 22 | + inputs = [torch.randn(x, dtype=type)] |
16 | 23 | self.run_test(
|
17 | 24 | neg(),
|
18 | 25 | inputs,
|
19 | 26 | expected_ops={torch.ops.aten.neg.default},
|
20 | 27 | )
|
21 | 28 |
|
| 29 | + @parameterized.expand( |
| 30 | + [ |
| 31 | + ("2d_dim_dtype_int", (2, 2), torch.int32, 0, 5), |
| 32 | + ("3d_dim_dtype_int", (2, 2, 2), torch.int32, 0, 5), |
| 33 | + ] |
| 34 | + ) |
| 35 | + |
| 36 | + def test_neg_int(self, _, x, type, min, max): |
| 37 | + class neg(nn.Module): |
| 38 | + def forward(self, input): |
| 39 | + return torch.neg(input) |
| 40 | + |
| 41 | + inputs = [torch.randint(min, max, (x), dtype=type)] |
| 42 | + |
| 43 | + self.run_test( |
| 44 | + neg(), |
| 45 | + inputs, |
| 46 | + expected_ops={torch.ops.aten.neg.default}, |
| 47 | + ) |
22 | 48 |
|
23 | 49 | if __name__ == "__main__":
|
24 | 50 | run_tests()
|
0 commit comments