Skip to content

Commit 98524e0

Browse files
committed
feat: support adaptive avg pool2d and pool3d dynamo converters
1 parent 922fd11 commit 98524e0

File tree

2 files changed

+245
-101
lines changed

2 files changed

+245
-101
lines changed

py/torch_tensorrt/dynamo/conversion/impl/pool.py

Lines changed: 149 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from torch.fx.node import Target
66
from torch_tensorrt.dynamo._SourceIR import SourceIR
77
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8-
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
extend_attr_to_tuple,
10+
get_positive_dim,
11+
)
912
from torch_tensorrt.fx.converters.converter_utils import (
1013
has_dynamic_shape,
1114
set_layer_name,
@@ -116,37 +119,69 @@ def adaptive_avg_poolNd(
116119
output_size: Sequence[int],
117120
) -> TRTTensor:
118121
input_rank = len(input.shape)
119-
if input_rank == 3:
120-
input = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape))
122+
123+
if input_rank == 3: # TRT doesn't support 3D pooling
124+
input = impl.shuffle.reshape(
125+
ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape)
126+
)
121127

122128
extend_len = len(output_size)
129+
output_size = list(output_size)
130+
original_input = input
123131

124-
# pad the input based on output_size if the dim of output is larger than input
125-
pad = []
132+
# repeat_interleave the input if the dim of output is larger than input
126133
input_shape = input.shape
127-
for i in range(1, extend_len + 1):
128-
input_dim = input_shape[-i]
129-
output_dim = output_size[-i]
134+
insert_axises = []
135+
for axis in range(1, extend_len + 1):
136+
axis = -axis
137+
positive_axis = get_positive_dim(
138+
axis, input_rank
139+
) # this is for calculating new shapes below
140+
input_dim = input_shape[axis]
141+
output_dim = output_size[axis]
130142
diff = output_dim - input_dim
131-
if diff > 0:
132-
if diff % 2 == 0:
133-
pad.append(diff // 2)
134-
pad.append(diff // 2)
135-
else:
136-
pad.append(diff // 2 + 1)
137-
pad.append(diff // 2 + 1)
138-
else:
139-
pad.append(0)
140-
pad.append(0)
141-
142-
input = impl.pad.replication_padNd(
143-
ctx,
144-
target,
145-
source_ir,
146-
f"{name}_replication_padNd",
147-
input,
148-
pad,
149-
)
143+
if diff > 0: # the dim of output is larger than input
144+
times = output_dim // input_dim
145+
remainder = output_dim % input_dim
146+
if (
147+
diff == 2 and remainder == 2
148+
): # case 1: output_dim - input_dim == 2 and is not an integral multiple
149+
insert_axises.append(axis)
150+
remainder -= 1
151+
output_size[axis] -= 1
152+
153+
if (
154+
remainder + 1 == input_dim
155+
): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input
156+
remainder = 0
157+
times += 1
158+
159+
flags = []
160+
concat_list = []
161+
for j in range(input_dim):
162+
single_elem = impl.select.select(
163+
ctx, target, source_ir, f"{name}_select_{axis}_{j}", input, axis, j
164+
)
165+
new_shape = list(single_elem.shape)
166+
new_shape.insert(positive_axis, 1)
167+
single_elem = impl.shuffle.reshape(
168+
ctx,
169+
target,
170+
source_ir,
171+
f"{name}_reshape_{axis}_{j}",
172+
single_elem,
173+
new_shape,
174+
)
175+
if remainder > 0 or j in flags:
176+
concat_list.extend([single_elem] * (times + 1))
177+
remainder -= 2
178+
flags.append(input_dim - j - 1)
179+
else:
180+
concat_list.extend([single_elem] * times)
181+
out = impl.cat.cat(
182+
ctx, target, source_ir, f"{name}_cat_{axis}", concat_list, axis
183+
)
184+
input = out
150185

151186
stride = tuple(
152187
input.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
@@ -155,6 +190,20 @@ def adaptive_avg_poolNd(
155190
input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
156191
for i in range(extend_len)
157192
)
193+
194+
# Don't have to pool, directly return
195+
if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size):
196+
if input_rank == 3: # reshape back to 3D
197+
input = impl.shuffle.reshape(
198+
ctx,
199+
target,
200+
source_ir,
201+
f"{name}_reshape_back",
202+
input,
203+
(*input.shape[1:],),
204+
)
205+
return input
206+
158207
layer = ctx.net.add_pooling_nd(
159208
input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size
160209
)
@@ -163,7 +212,78 @@ def adaptive_avg_poolNd(
163212

164213
output = layer.get_output(0)
165214

166-
if input_rank == 3:
167-
output = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],))
215+
# For case 1, we need to split the output and insert the mid of input
216+
for axis in insert_axises:
217+
positive_axis = get_positive_dim(axis, input_rank)
218+
input_dim = input_shape[axis]
219+
output_dim = output_size[axis]
220+
if input_dim % 2 == 1:
221+
mid = impl.select.select(
222+
ctx,
223+
target,
224+
source_ir,
225+
f"{name}_select_{axis}",
226+
original_input,
227+
axis,
228+
input_dim // 2,
229+
)
230+
new_shape = list(mid.shape)
231+
new_shape.insert(positive_axis, 1)
232+
mid = impl.shuffle.reshape(
233+
ctx, target, source_ir, f"{name}_reshape_{axis}", mid, new_shape
234+
)
235+
split_output = impl.split.split(
236+
ctx, target, source_ir, f"{name}_split_{axis}", output, 2, axis
237+
)
238+
split_output.insert(1, mid)
239+
output = impl.cat.cat(
240+
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
241+
)
242+
else:
243+
mid1 = impl.select.select(
244+
ctx,
245+
target,
246+
source_ir,
247+
f"{name}_select_{axis}",
248+
original_input,
249+
axis,
250+
input_dim // 2 - 1,
251+
)
252+
new_shape = list(mid1.shape)
253+
new_shape.insert(positive_axis, 1)
254+
mid1 = impl.shuffle.reshape(
255+
ctx, target, source_ir, f"{name}_reshape_{axis}", mid1, new_shape
256+
)
257+
mid2 = impl.select.select(
258+
ctx,
259+
target,
260+
source_ir,
261+
f"{name}_select_{axis}",
262+
original_input,
263+
axis,
264+
input_dim // 2,
265+
)
266+
mid2 = impl.shuffle.reshape(
267+
ctx, target, source_ir, f"{name}_reshape_{axis}", mid2, new_shape
268+
)
269+
split_output = impl.split.split(
270+
ctx,
271+
target,
272+
source_ir,
273+
f"{name}_split_{axis}",
274+
output,
275+
[output_dim // 2, 1, output_dim // 2],
276+
axis,
277+
)
278+
split_output[1] = mid1
279+
split_output.insert(2, mid2)
280+
output = impl.cat.cat(
281+
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
282+
)
283+
284+
if input_rank == 3: # reshape back to 3D
285+
output = impl.shuffle.reshape(
286+
ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],)
287+
)
168288

169289
return output

0 commit comments

Comments
 (0)