Skip to content

Commit 6ebec00

Browse files
kparichayWei Wei
authored andcommitted
[fx2trt] Support conv1d converter for module (#45)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/45 Extend the support for conv1d converter for the module for fx2trt As 2d convolution is not supported by TensorRT, this patch maps conv1d operation to conv2d with a pair of unsqueeze and squeeze. Reviewed By: wushirong Differential Revision: D35395174 fbshipit-source-id: 5b564609998f103d083d1451e5d64f6224eaf695
1 parent 5da146d commit 6ebec00

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

fx/converters/convolution.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# @manual=//deeplearning/trt/python:py_tensorrt
2+
import numpy as np
3+
24
import tensorrt as trt
35
import torch
46
from fx2trt_oss.fx.converter_registry import tensorrt_converter
@@ -22,6 +24,20 @@ def common_conv(network, mod, dimension, input_val, layer_name, is_quantized):
2224
kernel = to_numpy(mod.weight() if is_quantized else mod.weight)
2325
bias = to_numpy(mod.bias() if is_quantized else mod.bias)
2426

27+
if dimension == 1:
28+
# Append unsqueeze before conv2d to calculate conv1d
29+
unsqueeze_layer = network.add_shuffle(input=input_val)
30+
unsqueeze_layer.reshape_dims = (*input_val.shape, 1)
31+
unsqueeze_layer.name = f"{layer_name}_unsqueeze"
32+
input_val = unsqueeze_layer.get_output(0)
33+
34+
padding = padding + (0,)
35+
kernel = np.expand_dims(kernel, -1)
36+
kernel_size = kernel.shape[2:]
37+
if bias is not None:
38+
bias = bias[None]
39+
# bias = np.expand_dims(bias, -1)
40+
2541
layer = network.add_convolution_nd(
2642
input=input_val,
2743
num_output_maps=mod.out_channels,
@@ -39,7 +55,15 @@ def common_conv(network, mod, dimension, input_val, layer_name, is_quantized):
3955
# Assume the dtype of activation is torch.quint8
4056
mark_as_int8_layer(layer, get_dyn_range(mod.scale, mod.zero_point, torch.quint8))
4157

42-
return layer.get_output(0)
58+
result = layer.get_output(0)
59+
if dimension == 1:
60+
# Append squeeze after conv2d to calculate conv1d
61+
squeeze_layer = network.add_shuffle(input=result)
62+
squeeze_layer.reshape_dims = tuple(result.shape[:-1])
63+
squeeze_layer.name = f"{layer_name}_squeeze"
64+
result = squeeze_layer.get_output(0)
65+
66+
return result
4367

4468

4569
def common_conv_relu(network, mod, dimension, input_val, layer_name, is_quantized):
@@ -62,6 +86,20 @@ def common_conv_relu(network, mod, dimension, input_val, layer_name, is_quantize
6286
return layer.get_output(0)
6387

6488

89+
@tensorrt_converter(torch.nn.modules.conv.Conv1d)
90+
def conv1d(network, submod, args, kwargs, layer_name):
91+
# args/kwargs should have already been normalized to kwargs
92+
assert len(args) == 0
93+
input_val = kwargs["input"]
94+
95+
if not isinstance(input_val, trt.tensorrt.ITensor):
96+
raise RuntimeError(f"Conv1d received input {input_val} that is not part "
97+
"of the TensorRT region!")
98+
99+
if layer_name is None:
100+
raise RuntimeError("layer name is none")
101+
return common_conv(network, submod, dimension=1, input_val=input_val, layer_name=layer_name, is_quantized=False)
102+
65103
@tensorrt_converter(torch.nn.modules.conv.Conv2d)
66104
def conv2d(network, submod, args, kwargs, layer_name):
67105
# args/kwargs should have already been normalized to kwargs

test/converters/vanilla/test_convolution.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,39 @@
88

99

1010
class TestConvolutionConverter(VanillaTestCase):
11+
@parameterized.expand(
12+
[
13+
("default", 1),
14+
param("no_bias", 1, bias=False),
15+
("tuple_parameters", 1, (1), (0)),
16+
param("non_zero_padding", 1, padding=1),
17+
param("dilation", 1, dilation=2),
18+
param("groups", 1, groups=3),
19+
]
20+
)
21+
def test_conv1d(
22+
self,
23+
test_name,
24+
kernel_size,
25+
stride=1,
26+
padding=0,
27+
dilation=1,
28+
groups=1,
29+
bias=True,
30+
):
31+
class TestModule(torch.nn.Module):
32+
def __init__(self):
33+
super().__init__()
34+
self.conv = torch.nn.Conv1d(
35+
3, 6, kernel_size, stride, padding, dilation, groups, bias
36+
)
37+
38+
def forward(self, x):
39+
return self.conv(x)
40+
41+
inputs = [torch.randn(1, 3, 224)]
42+
self.run_test(TestModule(), inputs, expected_ops={torch.nn.modules.conv.Conv1d})
43+
1144
@parameterized.expand(
1245
[
1346
("default", 1),

0 commit comments

Comments
 (0)