Skip to content

Commit 251405d

Browse files
authored
feat: support deconv (1d, 2d, and Nd) dynamo converter (#2337)
1 parent 117161a commit 251405d

File tree

4 files changed

+397
-15
lines changed

4 files changed

+397
-15
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,7 +1348,7 @@ def aten_ops_less(
13481348

13491349

13501350
def conv_param_validator(conv_node: Node) -> bool:
1351-
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))
1351+
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])
13521352

13531353

13541354
@dynamo_tensorrt_converter(
@@ -1361,20 +1361,37 @@ def aten_ops_convolution(
13611361
kwargs: Dict[str, Argument],
13621362
name: str,
13631363
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1364-
return impl.conv.convNd(
1365-
network,
1366-
target,
1367-
source_ir=SourceIR.ATEN,
1368-
name=name,
1369-
is_conv1d=len(args[3]) == 1,
1370-
input=args[0],
1371-
weight=args[1],
1372-
bias=args[2],
1373-
stride=args[3],
1374-
padding=args[4],
1375-
dilation=args[5],
1376-
groups=args[8],
1377-
)
1364+
is_transposed = args[6]
1365+
if not is_transposed:
1366+
return impl.conv.convNd(
1367+
network,
1368+
target,
1369+
source_ir=SourceIR.ATEN,
1370+
name=name,
1371+
is_conv1d=len(args[3]) == 1,
1372+
input=args[0],
1373+
weight=args[1],
1374+
bias=args[2],
1375+
stride=args[3],
1376+
padding=args[4],
1377+
dilation=args[5],
1378+
groups=args[8],
1379+
)
1380+
else:
1381+
return impl.deconv.deconvNd(
1382+
network,
1383+
target,
1384+
source_ir=SourceIR.ATEN,
1385+
name=name,
1386+
is_deconv1d=len(args[3]) == 1,
1387+
input=args[0],
1388+
weight=args[1],
1389+
bias=args[2],
1390+
stride=args[3],
1391+
padding=args[4],
1392+
dilation=args[5],
1393+
groups=args[8],
1394+
)
13781395

13791396

13801397
@dynamo_tensorrt_converter(torch.ops.aten.linear.default) # type: ignore[misc]

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
cast,
66
condition,
77
conv,
8+
deconv,
89
elementwise,
910
embedding,
1011
linear,
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from typing import Optional, Sequence, Union
2+
3+
import numpy as np
4+
5+
# @manual=//deeplearning/trt/python:py_tensorrt
6+
import tensorrt as trt
7+
import torch
8+
from torch.fx.node import Target
9+
from torch_tensorrt.dynamo.conversion import impl
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
extend_attr_to_tuple,
12+
get_trt_tensor,
13+
)
14+
from torch_tensorrt.fx.converters.converter_utils import (
15+
SourceIR,
16+
get_dyn_range,
17+
has_dynamic_shape,
18+
mark_as_int8_layer,
19+
set_layer_name,
20+
to_numpy,
21+
)
22+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
23+
24+
25+
def deconvNd(
26+
network: TRTNetwork,
27+
target: Union[Target, str],
28+
source_ir: Optional[SourceIR],
29+
name: str,
30+
is_deconv1d: bool,
31+
input: TRTTensor,
32+
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
33+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
34+
stride: Optional[Union[int, Sequence[int]]],
35+
padding: Optional[Union[int, Sequence[int]]],
36+
groups: Optional[int],
37+
dilation: Optional[Union[int, Sequence[int]]],
38+
scale: Optional[Union[torch.Tensor, float]] = None,
39+
zero_point: Optional[Union[torch.Tensor, float]] = None,
40+
) -> TRTTensor:
41+
if has_dynamic_shape(input.shape):
42+
assert input.shape[1] != -1, "Channel dim can't be dynamic for deconvolution."
43+
44+
if is_deconv1d:
45+
# Apply an unsqueeze operation to transform the deconv1d problem into deconv2d
46+
input = impl.unsqueeze.unsqueeze(
47+
network, target, source_ir, name + "_unsqueeze_deconv1d", input, -1
48+
)
49+
50+
# Process bias terms
51+
if isinstance(bias, (torch.Tensor, np.ndarray)):
52+
# Transform the bias constant into a Numpy array
53+
bias = to_numpy(bias)
54+
55+
elif isinstance(bias, TRTTensor):
56+
bias = get_trt_tensor(network, bias, f"{name}_bias")
57+
58+
elif bias is not None:
59+
raise RuntimeError(
60+
f"Deconvolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor"
61+
)
62+
63+
# Process weight terms
64+
if network.has_explicit_precision or isinstance(weight, TRTTensor):
65+
weight = get_trt_tensor(network, weight, f"{name}_weight")
66+
# Append new dimension (unsqueeze) if the deconvolution is 1d
67+
if is_deconv1d:
68+
input = impl.unsqueeze.unsqueeze(
69+
network, target, source_ir, name + "_unsqueeze_weight", weight, -1
70+
)
71+
72+
elif isinstance(weight, (torch.Tensor, np.ndarray)):
73+
# Transform the weight constant into a Numpy array
74+
weight = to_numpy(weight)
75+
76+
# Append new dimension (unsqueeze) if the deconvolution is 1d
77+
if is_deconv1d:
78+
weight = np.expand_dims(weight, axis=-1)
79+
80+
else:
81+
raise RuntimeError(
82+
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
83+
)
84+
85+
# add deconv layer
86+
deconv_layer = network.add_deconvolution_nd(
87+
input=input,
88+
num_output_maps=weight.shape[0],
89+
kernel_shape=weight.shape[2:],
90+
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
91+
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
92+
)
93+
94+
# If the weight is a TRTTensor, set it as an input of the layer
95+
if isinstance(weight, TRTTensor):
96+
deconv_layer.set_input(1, weight)
97+
98+
# If the bias is a TRTTensor, set it as an input of the layer
99+
if isinstance(bias, TRTTensor):
100+
deconv_layer.set_input(2, bias)
101+
102+
# Cast certain fields to tuples, in accordance with TRT requirements
103+
padding = (padding,) if isinstance(padding, int) else padding
104+
stride = (stride,) if isinstance(stride, int) else stride
105+
dilation = (dilation,) if isinstance(dilation, int) else dilation
106+
107+
# Expand parameters manually for Conv1D computations
108+
if is_deconv1d:
109+
padding = (tuple(padding) + (0,)) if padding is not None else padding
110+
stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride
111+
dilation = (
112+
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
113+
)
114+
115+
set_layer_name(deconv_layer, target, name, source_ir)
116+
117+
# Set relevant attributes of deconvolution layer
118+
if padding is not None:
119+
deconv_layer.padding_nd = padding
120+
if stride is not None:
121+
deconv_layer.stride_nd = stride
122+
if dilation is not None:
123+
deconv_layer.dilation_nd = dilation
124+
if groups is not None:
125+
deconv_layer.num_groups = groups
126+
127+
# Handle quantization cases
128+
if scale is not None and zero_point is not None:
129+
# Assume the dtype of activation is torch.quint8
130+
mark_as_int8_layer(deconv_layer, get_dyn_range(scale, zero_point, torch.quint8))
131+
132+
result = deconv_layer.get_output(0)
133+
134+
if is_deconv1d:
135+
# Apply a squeeze operation to transform the deconv2d problem back into deconv1d
136+
result = impl.squeeze.squeeze(
137+
network, target, source_ir, name + "_squeeze_deconv1d", result, -1
138+
)
139+
140+
return result

0 commit comments

Comments
 (0)