3
3
from parameterized import parameterized
4
4
from torch .testing ._internal .common_utils import run_tests
5
5
from torch_tensorrt import Input
6
+ from torch_tensorrt .dynamo .conversion import UnsupportedOperatorException
6
7
7
8
8
9
# FIXME: check about implicit and explicit batch
9
10
class TestSplitConverterNoDim (DispatchTestCase ):
10
11
@parameterized .expand (
11
12
[
12
13
("split_size_or_sections_no_dim" , 2 ),
13
- ("split_size_or_sections_list_no_dim" , [1 , 4 ]),
14
- ("split_size_or_sections_list_no_dim_not_full_split" , [1 , 3 ]),
15
14
]
16
15
)
17
16
def test_split (self , _ , split_size_or_tensor ):
@@ -23,20 +22,62 @@ def forward(self, input):
23
22
out = torch .split (input , split_size_or_tensor )
24
23
return out
25
24
26
- input = torch .arange (10 ).reshape (5 , 2 )
25
+ input = [ torch .randn (10 ).reshape (5 , 2 )]
27
26
self .run_test (
28
27
TestModule (),
29
28
input ,
30
- expected_ops = {torch .ops .aten .split .default },
29
+ expected_ops = {torch .ops .aten .split .Tensor },
30
+ disable_passes = True ,
31
31
)
32
32
33
+ @parameterized .expand (
34
+ [
35
+ ("split_size_or_sections_list_no_dim_list" , [1 , 4 ]),
36
+ ]
37
+ )
38
+ def test_split_list (self , _ , split_size_or_tensor ):
39
+ class TestModule (torch .nn .Module ):
40
+ def __init__ (self ):
41
+ super ().__init__ ()
42
+
43
+ def forward (self , input ):
44
+ out = torch .split (input , split_size_or_tensor )
45
+ return out
46
+
47
+ input = [torch .randn (10 ).reshape (5 , 2 )]
48
+ self .run_test (
49
+ TestModule (),
50
+ input ,
51
+ expected_ops = {torch .ops .aten .split_with_sizes .default },
52
+ disable_passes = True ,
53
+ )
33
54
34
- class TestSplitConverterWithDim (DispatchTestCase ):
35
55
@parameterized .expand (
36
56
[
37
- ("split_size_or_sections_dim" , 2 , 1 ),
38
- ("split_size_or_sections_list_dim" , [1 , 4 ], 1 ),
39
- ("split_size_or_sections_list_dim_not_full_split" , [1 , 3 ], 1 ),
57
+ ("split_size_or_sections_list_no_dim_not_full_list" , [1 , 3 ]),
58
+ ]
59
+ )
60
+ def test_split_not_full_list (self , _ , split_size_or_tensor ):
61
+ class TestModule (torch .nn .Module ):
62
+ def __init__ (self ):
63
+ super ().__init__ ()
64
+
65
+ def forward (self , input ):
66
+ out = torch .split (input , split_size_or_tensor )
67
+ return out
68
+
69
+ input = [torch .randn (10 ).reshape (5 , 2 )]
70
+ with self .assertRaises (RuntimeError ):
71
+ self .run_test (
72
+ TestModule (),
73
+ input ,
74
+ expected_ops = {torch .ops .aten .split_with_sizes .default },
75
+ disable_passes = True ,
76
+ )
77
+
78
+ @parameterized .expand (
79
+ [
80
+ ("split_size_or_sections_dims" , 2 , 1 ),
40
81
]
41
82
)
42
83
def test_split (self , _ , split_size_or_tensor , dim ):
@@ -45,25 +86,90 @@ def __init__(self):
45
86
super ().__init__ ()
46
87
47
88
def forward (self , input ):
48
- out = torch .split (split_size_or_tensor , dim )
89
+ out = torch .split (input , split_size_or_tensor , dim )
90
+ return out
91
+
92
+ input = [torch .randn (10 ).reshape (5 , 2 )]
93
+ self .run_test (
94
+ TestModule (),
95
+ input ,
96
+ expected_ops = {torch .ops .aten .split .Tensor },
97
+ disable_passes = True ,
98
+ )
99
+
100
+ @parameterized .expand (
101
+ [
102
+ ("split_size_or_sections_list_dims" , [1 , 1 ], 1 ),
103
+ ]
104
+ )
105
+ def test_split_dim (self , _ , split_size_or_tensor , dim ):
106
+ class TestModule (torch .nn .Module ):
107
+ def __init__ (self ):
108
+ super ().__init__ ()
109
+
110
+ def forward (self , input ):
111
+ out = torch .split (input , split_size_or_tensor , dim )
49
112
return out
50
113
51
- input = torch .arange (10 ).reshape (2 , 5 )
114
+ input = [ torch .randn (10 ).reshape (5 , 2 )]
52
115
self .run_test (
53
116
TestModule (),
54
117
input ,
55
- expected_ops = {torch .ops .aten .split .default },
118
+ expected_ops = {torch .ops .aten .split_with_sizes .default },
119
+ disable_passes = True ,
56
120
)
57
121
122
+ @parameterized .expand (
123
+ [
124
+ ("split_size_or_sections_list_dims" , [1 , 1 ], 1 ),
125
+ ]
126
+ )
127
+ def test_split_dim_list (self , _ , split_size_or_tensor , dim ):
128
+ class TestModule (torch .nn .Module ):
129
+ def __init__ (self ):
130
+ super ().__init__ ()
131
+
132
+ def forward (self , input ):
133
+ out = torch .split (input , split_size_or_tensor , dim )
134
+ return out
135
+
136
+ input = [torch .randn (10 ).reshape (5 , 2 )]
137
+ self .run_test (
138
+ TestModule (),
139
+ input ,
140
+ expected_ops = {torch .ops .aten .split_with_sizes .default },
141
+ disable_passes = True ,
142
+ )
58
143
59
- class TestSplitConverterDynamicShape (DispatchTestCase ):
60
144
@parameterized .expand (
61
145
[
62
- ("select_split_size_or_sections_dim" , 2 , 1 ),
63
- ("select_split_size_or_sections_list_dim" , [1 , 4 ], 1 ),
146
+ ("split_size_or_sections_list_dims_not_full_list" , [1 , 1 ], 1 ),
64
147
]
65
148
)
66
- def test_split (self , _ , split_size_or_tensor , dim ):
149
+ def test_split_dim_list (self , _ , split_size_or_tensor , dim ):
150
+ class TestModule (torch .nn .Module ):
151
+ def __init__ (self ):
152
+ super ().__init__ ()
153
+
154
+ def forward (self , input ):
155
+ out = torch .split (input , split_size_or_tensor , dim )
156
+ return out
157
+
158
+ input = [torch .randn (15 ).reshape (5 , 3 )]
159
+ with self .assertRaises (RuntimeError ):
160
+ self .run_test (
161
+ TestModule (),
162
+ input ,
163
+ expected_ops = {torch .ops .aten .split_with_sizes .default },
164
+ disable_passes = True ,
165
+ )
166
+
167
+ @parameterized .expand (
168
+ [
169
+ ("select_split_size_or_sections_dim_dynamic_shape" , 2 , 1 ),
170
+ ]
171
+ )
172
+ def test_split_dynamic (self , _ , split_size_or_tensor , dim ):
67
173
class TestModule (torch .nn .Module ):
68
174
def __init__ (self ):
69
175
super ().__init__ ()
@@ -82,17 +188,16 @@ def forward(self, input):
82
188
self .run_test_with_dynamic_shape (
83
189
TestModule (),
84
190
input_specs ,
85
- expected_ops = {torch .ops .aten .split .default },
191
+ expected_ops = {torch .ops .aten .split .Tensor },
192
+ disable_passes = True ,
86
193
)
87
194
88
-
89
- class TestSplitSymIntConverterImplicitBatch (DispatchTestCase ):
90
195
@parameterized .expand (
91
196
[
92
197
("select_chunk_dim" , 6 , 0 ),
93
198
]
94
199
)
95
- def test_chunk (self , _ , chunk , dim ):
200
+ def test_split_dynamic (self , _ , chunk , dim ):
96
201
class TestModule (torch .nn .Module ):
97
202
def __init__ (self ):
98
203
super ().__init__ ()
@@ -102,11 +207,13 @@ def forward(self, input):
102
207
return out
103
208
104
209
input = [torch .randn (11 )]
105
- self .run_test (
106
- TestModule (),
107
- input ,
108
- expected_ops = {torch .ops .aten .split .default },
109
- )
210
+ with self .assertRaises (UnsupportedOperatorException ):
211
+ self .run_test (
212
+ TestModule (),
213
+ input ,
214
+ expected_ops = {torch .ops .aten .split .Tensor },
215
+ disable_passes = True ,
216
+ )
110
217
111
218
112
219
if __name__ == "__main__" :
0 commit comments