Skip to content

Commit 5760875

Browse files
qxy11Wei Wei
authored andcommitted
[fx2trt] Add ops needed to lower XrayVideo 2022a model to fx2trt (#79)
Summary: Pull Request resolved: pytorch/fx2trt#79 Add support for - repeat_interleave - adaptive_avg_pool3d - max_pool3d Modify tile converter to support dims that are not all int Reviewed By: frank-wei Differential Revision: D35823002 fbshipit-source-id: ec428c97c5832e62847a96176f875204e05f94d1
1 parent 4a0ca3e commit 5760875

File tree

8 files changed

+459
-40
lines changed

8 files changed

+459
-40
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 74 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,17 @@ def acc_ops_size(
506506
kwargs: Dict[str, Argument],
507507
name: str,
508508
) -> Union[TRTTensor, Sequence[TRTTensor]]:
509-
input_val = kwargs["input"]
510-
509+
input_t = kwargs["input"]
510+
if type(input_t) == torch.nn.Parameter or type(input_t) == torch.Tensor:
511+
if (
512+
not has_dynamic_shape(input_t.shape)
513+
and network.has_implicit_batch_dimension
514+
):
515+
return torch.Size((IMPLICIT_BATCH_DIM,) + tuple(input_t.shape))
516+
return input_t.shape
517+
518+
# input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
519+
input_val = input_t
511520
if not isinstance(input_val, TRTTensor):
512521
raise RuntimeError(
513522
f"size received input {input_val} that is not part "
@@ -779,13 +788,8 @@ def acc_ops_tile(
779788
kwargs: Dict[str, Argument],
780789
name: str,
781790
) -> Union[TRTTensor, Sequence[TRTTensor]]:
782-
input_val = kwargs["input"]
783-
784-
if not isinstance(input_val, TRTTensor):
785-
raise RuntimeError(
786-
f"tile received input {input_val} that is not part "
787-
"of the TensorRT region!"
788-
)
791+
input_t = kwargs["input"]
792+
input_val = get_trt_tensor(network, input_t, f"{name}_input")
789793

790794
dims = tuple(cast(Sequence[int], kwargs["dims"]))
791795
n_input_dims = len(input_val.shape) + (
@@ -822,9 +826,28 @@ def acc_ops_tile(
822826
if network.has_implicit_batch_dimension:
823827
assert dims[0] == 1, "Can't tile the batch dim when it's implicit."
824828
dims = dims[1:]
825-
826829
starts = [0] * len(dims)
827-
shapes = [i * j for i, j in zip(input_val.shape, dims)] # type: ignore[union-attr]
830+
shapes = []
831+
if all(isinstance(d, int) for d in dims):
832+
shapes = [i * j for i, j in zip(input_val.shape, dims)] # type: ignore[union-attr]
833+
else:
834+
shape = []
835+
for i, (s, d) in enumerate(zip(input_val.shape, dims)):
836+
if isinstance(d, TRTTensor) and len(d.shape) == 0:
837+
d = prepend_ones(network, d, f"{name}_{i}", 1)
838+
else:
839+
d = get_trt_tensor(network, d, f"{name}_{i}")
840+
shape.append(d)
841+
mul = add_binary_elementwise_layer(
842+
network,
843+
s,
844+
d,
845+
trt.ElementWiseOperation.PROD,
846+
target,
847+
f"{name}_mul_{i}",
848+
)
849+
shapes.append(mul)
850+
dims = shape
828851
# If there's dynmaic dim then there would be negative dims in shapes which is not allowed.
829852
# Here we build a dummy shapes array.
830853
if has_dynamic_shape(input_val.shape): # type: ignore[union-attr]
@@ -838,9 +861,16 @@ def acc_ops_tile(
838861
starts_tensor = network.add_constant(
839862
(len(dims),), np.ascontiguousarray([0] * len(dims), np.int32)
840863
).get_output(0)
841-
dims_tensor = network.add_constant(
842-
(len(dims),), np.ascontiguousarray(dims, np.int32)
843-
).get_output(0)
864+
if all(isinstance(d, int) for d in dims):
865+
dims_tensor = network.add_constant(
866+
(len(dims),), np.ascontiguousarray(dims, np.int32)
867+
).get_output(0)
868+
else:
869+
assert all(isinstance(d, TRTTensor) for d in dims)
870+
concat_dims_layer = network.add_concatenation(inputs=dims)
871+
concat_dims_layer.axis = 0
872+
concat_dims_layer.name = f"{name}_tile_dim"
873+
dims_tensor = concat_dims_layer.get_output(0)
844874
input_shape_layer = network.add_shape(input_val)
845875
input_shape_layer.name = f"{name}_slice_input_shape"
846876
slice_shapes_tensor = add_binary_elementwise_layer(
@@ -1880,7 +1910,8 @@ def acc_ops_max_pool1d(
18801910

18811911

18821912
@tensorrt_converter(acc_ops.max_pool2d)
1883-
def acc_ops_max_pool2d(
1913+
@tensorrt_converter(acc_ops.max_pool3d)
1914+
def acc_ops_max_poolnd(
18841915
network: TRTNetwork,
18851916
target: Target,
18861917
args: Tuple[Argument, ...],
@@ -1894,26 +1925,27 @@ def acc_ops_max_pool2d(
18941925
f"MaxPool2d received input {input_val} that is not part "
18951926
"of the TensorRT region!"
18961927
)
1897-
1898-
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], 2)
1899-
stride = extend_attr_to_tuple(kwargs["stride"], 2)
1900-
padding = extend_attr_to_tuple(kwargs["padding"], 2)
1901-
dilation = extend_attr_to_tuple(kwargs["dilation"], 2)
1928+
extend_len = 2 if target == acc_ops.max_pool2d else 3
1929+
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], extend_len)
1930+
stride = extend_attr_to_tuple(kwargs["stride"], extend_len)
1931+
padding = extend_attr_to_tuple(kwargs["padding"], extend_len)
1932+
dilation = extend_attr_to_tuple(kwargs["dilation"], extend_len)
19021933
ceil_mode = kwargs["ceil_mode"]
19031934

19041935
if len(stride) == 0 or stride[0] == None:
19051936
stride = kernel_size
19061937

1907-
if dilation != (1, 1):
1938+
ones = (1,) * extend_len
1939+
if dilation != ones:
19081940
raise RuntimeError(
19091941
f"Only support dilation=(1, 1) for maxpool, but got {dilation}"
19101942
)
19111943

1912-
layer = network.add_pooling(
1944+
layer = network.add_pooling_nd(
19131945
input=input_val, type=trt.PoolingType.MAX, window_size=kernel_size
19141946
)
1915-
layer.stride = stride
1916-
layer.padding = padding
1947+
layer.stride_nd = stride
1948+
layer.padding_nd = padding
19171949
set_layer_name(layer, target, name)
19181950

19191951
if ceil_mode:
@@ -2093,8 +2125,8 @@ def acc_ops_unsqueeze(
20932125
kwargs: Dict[str, Argument],
20942126
name: str,
20952127
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2096-
input_val = kwargs["input"]
2097-
2128+
input_t = kwargs["input"]
2129+
input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
20982130
if not isinstance(input_val, TRTTensor):
20992131
raise RuntimeError(
21002132
f"unsqueeze received input {input_val} that is not part "
@@ -2161,8 +2193,9 @@ def acc_ops_topk(
21612193
return layer.get_output(0), layer.get_output(1)
21622194

21632195

2196+
@tensorrt_converter(acc_ops.adaptive_avg_pool3d)
21642197
@tensorrt_converter(acc_ops.adaptive_avg_pool2d)
2165-
def acc_ops_adaptive_avg_pool2d(
2198+
def acc_ops_adaptive_avg_poolnd(
21662199
network: TRTNetwork,
21672200
target: Target,
21682201
args: Tuple[Argument, ...],
@@ -2177,30 +2210,32 @@ def acc_ops_adaptive_avg_pool2d(
21772210
"of the TensorRT region!"
21782211
)
21792212

2180-
assert (
2181-
input_val.shape[-1] != -1 and input_val.shape[-1] != -1
2213+
extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
2214+
assert all(
2215+
input_val.shape[-(i + 1)] != -1 for i in range(extend_len)
21822216
), "AdaptiveAvgPool2d currently doesn't support dynamic shapes for last two dims."
21832217

2184-
output_size = cast(Sequence[int], extend_attr_to_tuple(kwargs["output_size"], 2))
2185-
for input_dim, output_dim in zip(input_val.shape[-2:], output_size):
2218+
output_size = cast(
2219+
Sequence[int], extend_attr_to_tuple(kwargs["output_size"], extend_len)
2220+
)
2221+
for input_dim, output_dim in zip(input_val.shape[-extend_len:], output_size):
21862222
if input_dim % output_dim != 0:
21872223
raise RuntimeError(
21882224
"For AdaptiveAvgPool, input dim has to be integer multiple of output dim."
21892225
f"Got input dim {input_dim}, output dim {output_dim}"
21902226
)
21912227

2192-
stride = (
2193-
input_val.shape[-2] // output_size[0],
2194-
input_val.shape[-1] // output_size[1],
2228+
stride = tuple(
2229+
input_val.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
21952230
)
2196-
kernel_size = (
2197-
input_val.shape[-2] - (output_size[0] - 1) * stride[0],
2198-
input_val.shape[-1] - (output_size[1] - 1) * stride[1],
2231+
kernel_size = tuple(
2232+
input_val.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
2233+
for i in range(extend_len)
21992234
)
2200-
layer = network.add_pooling(
2235+
layer = network.add_pooling_nd(
22012236
input=input_val, type=trt.PoolingType.AVERAGE, window_size=kernel_size
22022237
)
2203-
layer.stride = stride
2238+
layer.stride_nd = stride
22042239
set_layer_name(layer, target, name)
22052240

22062241
return layer.get_output(0)
@@ -2781,7 +2816,6 @@ def acc_ops_getitem(
27812816
) -> Union[TRTTensor, Sequence[TRTTensor]]:
27822817
input_val = kwargs["input"]
27832818
slices = kwargs["idx"]
2784-
27852819
if not isinstance(input_val, TRTTensor):
27862820
return operator.getitem(input_val, slices) # type: ignore[arg-type]
27872821

test/converters/acc_op/test_adaptive_avgpool.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,50 @@ def forward(self, x):
4848
TestModule(), input_specs, expected_ops={acc_ops.adaptive_avg_pool2d}
4949
)
5050

51+
@parameterized.expand(
52+
[
53+
((16, 16, 16),),
54+
((32, 16, 4),),
55+
(32,),
56+
]
57+
)
58+
def test_adaptive_avgpool3d(
59+
self,
60+
output_size,
61+
):
62+
class TestModule(torch.nn.Module):
63+
def __init__(self):
64+
super().__init__()
65+
self.pool = torch.nn.AdaptiveAvgPool3d(output_size)
66+
67+
def forward(self, x):
68+
return self.pool(x)
69+
70+
inputs = [torch.randn(1, 3, 32, 64, 64)]
71+
self.run_test(TestModule(), inputs, expected_ops={acc_ops.adaptive_avg_pool3d})
72+
73+
def test_adaptive_avgpool3d_with_dynamic_shape(self):
74+
class TestModule(torch.nn.Module):
75+
def __init__(self):
76+
super().__init__()
77+
self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16))
78+
79+
def forward(self, x):
80+
return self.pool(x)
81+
82+
input_specs = [
83+
InputTensorSpec(
84+
shape=(-1, -1, 32, 64, 64),
85+
dtype=torch.float32,
86+
shape_ranges=[
87+
((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64))
88+
],
89+
),
90+
]
91+
self.run_test_with_dynamic_shape(
92+
TestModule(), input_specs, expected_ops={acc_ops.adaptive_avg_pool3d}
93+
)
94+
5195

5296
if __name__ == "__main__":
5397
run_tests()

test/converters/acc_op/test_maxpool.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,56 @@ def forward(self, x):
9595
TestModule(), input_specs, expected_ops={acc_ops.max_pool2d}
9696
)
9797

98+
@parameterized.expand(
99+
[
100+
("default", 1),
101+
("stride", 1, 2),
102+
("tuple_parameters", 2, (1, 1, 1), (1, 1, 1)),
103+
param("padding", 2, padding=1),
104+
param("ceil_mode", 1, ceil_mode=True),
105+
]
106+
)
107+
def test_max_pool3d(
108+
self,
109+
test_name,
110+
kernel_size,
111+
stride=1,
112+
padding=0,
113+
ceil_mode=False,
114+
):
115+
class TestModule(torch.nn.Module):
116+
def __init__(self):
117+
super().__init__()
118+
self.max_pool = torch.nn.MaxPool3d(
119+
kernel_size, stride, padding, ceil_mode=ceil_mode
120+
)
121+
122+
def forward(self, x):
123+
return self.max_pool(x)
124+
125+
inputs = [torch.randn(1, 3, 32, 32, 32)]
126+
self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool3d})
127+
128+
def test_max_pool3d_with_dynamic_shape(self):
129+
class TestModule(torch.nn.Module):
130+
def __init__(self):
131+
super().__init__()
132+
self.max_pool = torch.nn.MaxPool3d(1, 1)
133+
134+
def forward(self, x):
135+
return self.max_pool(x)
136+
137+
input_specs = [
138+
InputTensorSpec(
139+
shape=(-1, -1, -1, -1, -1),
140+
dtype=torch.float32,
141+
shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))],
142+
),
143+
]
144+
self.run_test_with_dynamic_shape(
145+
TestModule(), input_specs, expected_ops={acc_ops.max_pool3d}
146+
)
147+
98148
@parameterized.expand(
99149
[
100150
("default", 1),
@@ -158,6 +208,32 @@ def forward(self, x):
158208
inputs = [torch.randn(1, 3, 224, 224)]
159209
self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool2d})
160210

211+
@parameterized.expand(
212+
[
213+
("default", 1),
214+
param("stride", 2, stride=()),
215+
]
216+
)
217+
def test_stride_none_max_pool3d(
218+
self,
219+
test_name,
220+
kernel_size,
221+
stride=None,
222+
padding=0,
223+
ceil_mode=False,
224+
):
225+
class TestModule(torch.nn.Module):
226+
def __init__(self):
227+
super().__init__()
228+
229+
def forward(self, x):
230+
return torch.nn.functional.max_pool3d(
231+
x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode
232+
)
233+
234+
inputs = [torch.randn(1, 3, 32, 32, 32)]
235+
self.run_test(TestModule(), inputs, expected_ops={acc_ops.max_pool3d})
236+
161237

162238
if __name__ == "__main__":
163239
run_tests()

0 commit comments

Comments
 (0)