@@ -33,6 +33,7 @@ def forward(self, input):
33
33
@parameterized .expand (
34
34
[
35
35
("split_size_or_sections_list_no_dim_list" , [1 , 4 ]),
36
+ ("split_size_or_sections_list_no_dim_not_full_list" , [1 , 3 ]),
36
37
]
37
38
)
38
39
def test_split_list (self , _ , split_size_or_tensor ):
@@ -52,29 +53,6 @@ def forward(self, input):
52
53
disable_passes = True ,
53
54
)
54
55
55
- @parameterized .expand (
56
- [
57
- ("split_size_or_sections_list_no_dim_not_full_list" , [1 , 3 ]),
58
- ]
59
- )
60
- def test_split_not_full_list (self , _ , split_size_or_tensor ):
61
- class TestModule (torch .nn .Module ):
62
- def __init__ (self ):
63
- super ().__init__ ()
64
-
65
- def forward (self , input ):
66
- out = torch .split (input , split_size_or_tensor )
67
- return out
68
-
69
- input = [torch .randn (10 ).reshape (5 , 2 )]
70
- with self .assertRaises (RuntimeError ):
71
- self .run_test (
72
- TestModule (),
73
- input ,
74
- expected_ops = {torch .ops .aten .split_with_sizes .default },
75
- disable_passes = True ,
76
- )
77
-
78
56
@parameterized .expand (
79
57
[
80
58
("split_size_or_sections_dims" , 2 , 1 ),
@@ -97,28 +75,6 @@ def forward(self, input):
97
75
disable_passes = True ,
98
76
)
99
77
100
- @parameterized .expand (
101
- [
102
- ("split_size_or_sections_list_dims" , [1 , 1 ], 1 ),
103
- ]
104
- )
105
- def test_split_dim (self , _ , split_size_or_tensor , dim ):
106
- class TestModule (torch .nn .Module ):
107
- def __init__ (self ):
108
- super ().__init__ ()
109
-
110
- def forward (self , input ):
111
- out = torch .split (input , split_size_or_tensor , dim )
112
- return out
113
-
114
- input = [torch .randn (10 ).reshape (5 , 2 )]
115
- self .run_test (
116
- TestModule (),
117
- input ,
118
- expected_ops = {torch .ops .aten .split_with_sizes .default },
119
- disable_passes = True ,
120
- )
121
-
122
78
@parameterized .expand (
123
79
[
124
80
("split_size_or_sections_list_dims" , [1 , 1 ], 1 ),
0 commit comments