1
+ import unittest
2
+
1
3
import torch
2
4
from parameterized import parameterized
3
5
from torch .testing ._internal .common_utils import run_tests
6
+ from torch_tensorrt import Input
4
7
5
8
from .harness import DispatchTestCase
6
9
@@ -27,6 +30,7 @@ def forward(self, input):
27
30
self .run_test (
28
31
TestChunk (),
29
32
input ,
33
+ use_dynamo_tracer = True ,
30
34
)
31
35
32
36
@parameterized .expand (
@@ -51,6 +55,7 @@ def forward(self, input):
51
55
self .run_test (
52
56
TestChunk (),
53
57
input ,
58
+ use_dynamo_tracer = True ,
54
59
)
55
60
56
61
@parameterized .expand (
@@ -75,6 +80,104 @@ def forward(self, input):
75
80
self .run_test (
76
81
TestChunk (),
77
82
input ,
83
+ use_dynamo_tracer = True ,
84
+ )
85
+
86
+
87
+ #######################Dynamic cases################
88
+ ####The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed
89
+ @unittest .skip ("Pending aten.split converter. Currently tested by E2E" )
90
+ class TestChunkDynamicConverter (DispatchTestCase ):
91
+ @parameterized .expand (
92
+ [
93
+ ((1 ,), (1 ,), (3 ,), 3 , 0 ),
94
+ ((3 ,), (3 ,), (4 ,), 3 , 0 ),
95
+ ((4 ,), (4 ,), (6 ,), 3 , 0 ),
96
+ ((6 ,), (6 ,), (9 ,), 3 , 0 ),
97
+ ((3 ,), (3 ,), (4 ,), 1 , - 1 ),
98
+ ((3 ,), (3 ,), (4 ,), 3 , - 1 ),
99
+ ((3 ,), (3 ,), (4 ,), 4 , - 1 ),
100
+ ]
101
+ )
102
+ def test_chunk_1D (self , min_shape , opt_shape , max_shape , chunks , dim ):
103
+ class TestChunk (torch .nn .Module ):
104
+ def forward (self , input ):
105
+ out = torch .ops .aten .chunk .default (input , chunks , dim )
106
+ return out
107
+
108
+ input_specs = [
109
+ Input (
110
+ min_shape = min_shape ,
111
+ opt_shape = opt_shape ,
112
+ max_shape = max_shape ,
113
+ ),
114
+ ]
115
+ self .run_test_with_dynamic_shape (
116
+ TestChunk (),
117
+ input_specs ,
118
+ use_dynamo_tracer = True ,
119
+ )
120
+
121
+ @parameterized .expand (
122
+ [
123
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 1 , 0 ),
124
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 3 , 0 ),
125
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 4 , 0 ),
126
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 2 , - 2 ),
127
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 6 , - 2 ),
128
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 3 , 1 ),
129
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 4 , 1 ),
130
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 5 , - 1 ),
131
+ ]
132
+ )
133
+ def test_chunk_2D (self , min_shape , opt_shape , max_shape , chunks , dim ):
134
+ class TestChunk (torch .nn .Module ):
135
+ def forward (self , input ):
136
+ out = torch .ops .aten .chunk .default (input , chunks , dim )
137
+ return out
138
+
139
+ input_specs = [
140
+ Input (
141
+ min_shape = min_shape ,
142
+ opt_shape = opt_shape ,
143
+ max_shape = max_shape ,
144
+ ),
145
+ ]
146
+ self .run_test_with_dynamic_shape (
147
+ TestChunk (),
148
+ input_specs ,
149
+ use_dynamo_tracer = True ,
150
+ )
151
+
152
+ @parameterized .expand (
153
+ [
154
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 1 , 0 ),
155
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 3 , - 3 ),
156
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 3 , 1 ),
157
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 4 , 1 ),
158
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 6 , - 2 ),
159
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 1 , 2 ),
160
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 3 , - 1 ),
161
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 4 , - 1 ),
162
+ ]
163
+ )
164
+ def test_chunk_3D (self , min_shape , opt_shape , max_shape , chunks , dim ):
165
+ class TestChunk (torch .nn .Module ):
166
+ def forward (self , input ):
167
+ out = torch .ops .aten .chunk .default (input , chunks , dim )
168
+ return out
169
+
170
+ input_specs = [
171
+ Input (
172
+ min_shape = min_shape ,
173
+ opt_shape = opt_shape ,
174
+ max_shape = max_shape ,
175
+ ),
176
+ ]
177
+ self .run_test_with_dynamic_shape (
178
+ TestChunk (),
179
+ input_specs ,
180
+ use_dynamo_tracer = True ,
78
181
)
79
182
80
183
0 commit comments