5
5
from torch .fx .node import Target
6
6
from torch_tensorrt .dynamo ._SourceIR import SourceIR
7
7
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
8
- from torch_tensorrt .dynamo .conversion .converter_utils import extend_attr_to_tuple
8
+ from torch_tensorrt .dynamo .conversion .converter_utils import (
9
+ extend_attr_to_tuple ,
10
+ get_positive_dim ,
11
+ )
9
12
from torch_tensorrt .fx .converters .converter_utils import (
10
13
has_dynamic_shape ,
11
14
set_layer_name ,
@@ -116,37 +119,69 @@ def adaptive_avg_poolNd(
116
119
output_size : Sequence [int ],
117
120
) -> TRTTensor :
118
121
input_rank = len (input .shape )
119
- if input_rank == 3 :
120
- input = impl .shuffle .reshape (ctx , target , source_ir , f"{ name } _reshape" , input , (1 , * input .shape ))
122
+
123
+ if input_rank == 3 : # TRT doesn't support 3D pooling
124
+ input = impl .shuffle .reshape (
125
+ ctx , target , source_ir , f"{ name } _reshape" , input , (1 , * input .shape )
126
+ )
121
127
122
128
extend_len = len (output_size )
129
+ output_size = list (output_size )
130
+ original_input = input
123
131
124
- # pad the input based on output_size if the dim of output is larger than input
125
- pad = []
132
+ # repeat_interleave the input if the dim of output is larger than input
126
133
input_shape = input .shape
127
- for i in range (1 , extend_len + 1 ):
128
- input_dim = input_shape [- i ]
129
- output_dim = output_size [- i ]
134
+ insert_axises = []
135
+ for axis in range (1 , extend_len + 1 ):
136
+ axis = - axis
137
+ positive_axis = get_positive_dim (
138
+ axis , input_rank
139
+ ) # this is for calculating new shapes below
140
+ input_dim = input_shape [axis ]
141
+ output_dim = output_size [axis ]
130
142
diff = output_dim - input_dim
131
- if diff > 0 :
132
- if diff % 2 == 0 :
133
- pad .append (diff // 2 )
134
- pad .append (diff // 2 )
135
- else :
136
- pad .append (diff // 2 + 1 )
137
- pad .append (diff // 2 + 1 )
138
- else :
139
- pad .append (0 )
140
- pad .append (0 )
141
-
142
- input = impl .pad .replication_padNd (
143
- ctx ,
144
- target ,
145
- source_ir ,
146
- f"{ name } _replication_padNd" ,
147
- input ,
148
- pad ,
149
- )
143
+ if diff > 0 : # the dim of output is larger than input
144
+ times = output_dim // input_dim
145
+ remainder = output_dim % input_dim
146
+ if (
147
+ diff == 2 and remainder == 2
148
+ ): # case 1: output_dim - input_dim == 2 and is not an integral multiple
149
+ insert_axises .append (axis )
150
+ remainder -= 1
151
+ output_size [axis ] -= 1
152
+
153
+ if (
154
+ remainder + 1 == input_dim
155
+ ): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input
156
+ remainder = 0
157
+ times += 1
158
+
159
+ flags = []
160
+ concat_list = []
161
+ for j in range (input_dim ):
162
+ single_elem = impl .select .select (
163
+ ctx , target , source_ir , f"{ name } _select_{ axis } _{ j } " , input , axis , j
164
+ )
165
+ new_shape = list (single_elem .shape )
166
+ new_shape .insert (positive_axis , 1 )
167
+ single_elem = impl .shuffle .reshape (
168
+ ctx ,
169
+ target ,
170
+ source_ir ,
171
+ f"{ name } _reshape_{ axis } _{ j } " ,
172
+ single_elem ,
173
+ new_shape ,
174
+ )
175
+ if remainder > 0 or j in flags :
176
+ concat_list .extend ([single_elem ] * (times + 1 ))
177
+ remainder -= 2
178
+ flags .append (input_dim - j - 1 )
179
+ else :
180
+ concat_list .extend ([single_elem ] * times )
181
+ out = impl .cat .cat (
182
+ ctx , target , source_ir , f"{ name } _cat_{ axis } " , concat_list , axis
183
+ )
184
+ input = out
150
185
151
186
stride = tuple (
152
187
input .shape [- extend_len + i ] // output_size [i ] for i in range (extend_len )
@@ -155,6 +190,20 @@ def adaptive_avg_poolNd(
155
190
input .shape [- extend_len + i ] - (output_size [i ] - 1 ) * stride [i ]
156
191
for i in range (extend_len )
157
192
)
193
+
194
+ # Don't have to pool, directly return
195
+ if all (s == 1 for s in stride ) and all (k == 1 for k in kernel_size ):
196
+ if input_rank == 3 : # reshape back to 3D
197
+ input = impl .shuffle .reshape (
198
+ ctx ,
199
+ target ,
200
+ source_ir ,
201
+ f"{ name } _reshape_back" ,
202
+ input ,
203
+ (* input .shape [1 :],),
204
+ )
205
+ return input
206
+
158
207
layer = ctx .net .add_pooling_nd (
159
208
input = input , type = trt .PoolingType .AVERAGE , window_size = kernel_size
160
209
)
@@ -163,7 +212,78 @@ def adaptive_avg_poolNd(
163
212
164
213
output = layer .get_output (0 )
165
214
166
- if input_rank == 3 :
167
- output = impl .shuffle .reshape (ctx , target , source_ir , f"{ name } _reshape_back" , output , (* output .shape [1 :],))
215
+ # For case 1, we need to split the output and insert the mid of input
216
+ for axis in insert_axises :
217
+ positive_axis = get_positive_dim (axis , input_rank )
218
+ input_dim = input_shape [axis ]
219
+ output_dim = output_size [axis ]
220
+ if input_dim % 2 == 1 :
221
+ mid = impl .select .select (
222
+ ctx ,
223
+ target ,
224
+ source_ir ,
225
+ f"{ name } _select_{ axis } " ,
226
+ original_input ,
227
+ axis ,
228
+ input_dim // 2 ,
229
+ )
230
+ new_shape = list (mid .shape )
231
+ new_shape .insert (positive_axis , 1 )
232
+ mid = impl .shuffle .reshape (
233
+ ctx , target , source_ir , f"{ name } _reshape_{ axis } " , mid , new_shape
234
+ )
235
+ split_output = impl .split .split (
236
+ ctx , target , source_ir , f"{ name } _split_{ axis } " , output , 2 , axis
237
+ )
238
+ split_output .insert (1 , mid )
239
+ output = impl .cat .cat (
240
+ ctx , target , source_ir , f"{ name } _cat_{ axis } " , split_output , axis
241
+ )
242
+ else :
243
+ mid1 = impl .select .select (
244
+ ctx ,
245
+ target ,
246
+ source_ir ,
247
+ f"{ name } _select_{ axis } " ,
248
+ original_input ,
249
+ axis ,
250
+ input_dim // 2 - 1 ,
251
+ )
252
+ new_shape = list (mid1 .shape )
253
+ new_shape .insert (positive_axis , 1 )
254
+ mid1 = impl .shuffle .reshape (
255
+ ctx , target , source_ir , f"{ name } _reshape_{ axis } " , mid1 , new_shape
256
+ )
257
+ mid2 = impl .select .select (
258
+ ctx ,
259
+ target ,
260
+ source_ir ,
261
+ f"{ name } _select_{ axis } " ,
262
+ original_input ,
263
+ axis ,
264
+ input_dim // 2 ,
265
+ )
266
+ mid2 = impl .shuffle .reshape (
267
+ ctx , target , source_ir , f"{ name } _reshape_{ axis } " , mid2 , new_shape
268
+ )
269
+ split_output = impl .split .split (
270
+ ctx ,
271
+ target ,
272
+ source_ir ,
273
+ f"{ name } _split_{ axis } " ,
274
+ output ,
275
+ [output_dim // 2 , 1 , output_dim // 2 ],
276
+ axis ,
277
+ )
278
+ split_output [1 ] = mid1
279
+ split_output .insert (2 , mid2 )
280
+ output = impl .cat .cat (
281
+ ctx , target , source_ir , f"{ name } _cat_{ axis } " , split_output , axis
282
+ )
283
+
284
+ if input_rank == 3 : # reshape back to 3D
285
+ output = impl .shuffle .reshape (
286
+ ctx , target , source_ir , f"{ name } _reshape_back" , output , (* output .shape [1 :],)
287
+ )
168
288
169
289
return output
0 commit comments