Skip to content

Commit 71a452b

Browse files
committed
combining split tests
1 parent e77c7d3 commit 71a452b

File tree

1 file changed

+1
-45
lines changed

1 file changed

+1
-45
lines changed

tests/py/dynamo/converters/test_split_aten.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def forward(self, input):
3333
@parameterized.expand(
3434
[
3535
("split_size_or_sections_list_no_dim_list", [1, 4]),
36+
("split_size_or_sections_list_no_dim_not_full_list", [1, 3]),
3637
]
3738
)
3839
def test_split_list(self, _, split_size_or_tensor):
@@ -52,29 +53,6 @@ def forward(self, input):
5253
disable_passes=True,
5354
)
5455

55-
@parameterized.expand(
56-
[
57-
("split_size_or_sections_list_no_dim_not_full_list", [1, 3]),
58-
]
59-
)
60-
def test_split_not_full_list(self, _, split_size_or_tensor):
61-
class TestModule(torch.nn.Module):
62-
def __init__(self):
63-
super().__init__()
64-
65-
def forward(self, input):
66-
out = torch.split(input, split_size_or_tensor)
67-
return out
68-
69-
input = [torch.randn(10).reshape(5, 2)]
70-
with self.assertRaises(RuntimeError):
71-
self.run_test(
72-
TestModule(),
73-
input,
74-
expected_ops={torch.ops.aten.split_with_sizes.default},
75-
disable_passes=True,
76-
)
77-
7856
@parameterized.expand(
7957
[
8058
("split_size_or_sections_dims", 2, 1),
@@ -97,28 +75,6 @@ def forward(self, input):
9775
disable_passes=True,
9876
)
9977

100-
@parameterized.expand(
101-
[
102-
("split_size_or_sections_list_dims", [1, 1], 1),
103-
]
104-
)
105-
def test_split_dim(self, _, split_size_or_tensor, dim):
106-
class TestModule(torch.nn.Module):
107-
def __init__(self):
108-
super().__init__()
109-
110-
def forward(self, input):
111-
out = torch.split(input, split_size_or_tensor, dim)
112-
return out
113-
114-
input = [torch.randn(10).reshape(5, 2)]
115-
self.run_test(
116-
TestModule(),
117-
input,
118-
expected_ops={torch.ops.aten.split_with_sizes.default},
119-
disable_passes=True,
120-
)
121-
12278
@parameterized.expand(
12379
[
12480
("split_size_or_sections_list_dims", [1, 1], 1),

0 commit comments

Comments
 (0)