Skip to content

Commit ef62a7a

Browse files
committed
add test case for other ops
1 parent 4a32b90 commit ef62a7a

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

tests/py/dynamo/converters/test_neg_aten.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,44 @@
77

88

99
class TestNegConverter(DispatchTestCase):
10-
def test_neg(self):
10+
@parameterized.expand(
11+
[
12+
("2d_dim_dtype_float", (2, 2), torch.float),
13+
("3d_dim_dtype_float", (2, 2, 2), torch.float),
14+
15+
]
16+
)
17+
def test_neg_float(self, _, x, type):
1118
class neg(nn.Module):
1219
def forward(self, input):
1320
return torch.neg(input)
14-
15-
inputs = [torch.randn(1, 10)]
21+
22+
inputs = [torch.randn(x, dtype=type)]
1623
self.run_test(
1724
neg(),
1825
inputs,
1926
expected_ops={torch.ops.aten.neg.default},
2027
)
2128

29+
@parameterized.expand(
30+
[
31+
("2d_dim_dtype_int", (2, 2), torch.int32, 0, 5),
32+
("3d_dim_dtype_int", (2, 2, 2), torch.int32, 0, 5),
33+
]
34+
)
35+
36+
def test_neg_int(self, _, x, type, min, max):
37+
class neg(nn.Module):
38+
def forward(self, input):
39+
return torch.neg(input)
40+
41+
inputs = [torch.randint(min, max, (x), dtype=type)]
42+
43+
self.run_test(
44+
neg(),
45+
inputs,
46+
expected_ops={torch.ops.aten.neg.default},
47+
)
2248

2349
if __name__ == "__main__":
2450
run_tests()

0 commit comments

Comments
 (0)