Skip to content

Commit 5348ac2

Browse files
committed
fix: Centralize FX conv impl, add feature
- Centralize convolution implementation in FX, similar across all source IRs, including aten, acc, nn - Enable pass-through of build errors in e2e tests to ensure errors are not being hidden - Allow conv layers to take bias inputs in FX, per new functionality from TRT
1 parent dd31c9a commit 5348ac2

File tree

6 files changed

+310
-314
lines changed

6 files changed

+310
-314
lines changed

py/torch_tensorrt/dynamo/test/test_dynamo_backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_resnet18(ir):
2727
"device": torchtrt.Device("cuda:0"),
2828
"enabled_precisions": {torch.float},
2929
"ir": ir,
30+
"pass_through_build_failures": True,
3031
}
3132

3233
trt_mod = torchtrt.compile(model, **compile_spec)
@@ -57,6 +58,7 @@ def test_mobilenet_v2(ir):
5758
"device": torchtrt.Device("cuda:0"),
5859
"enabled_precisions": {torch.float},
5960
"ir": ir,
61+
"pass_through_build_failures": True,
6062
}
6163

6264
trt_mod = torchtrt.compile(model, **compile_spec)
@@ -87,6 +89,7 @@ def test_efficientnet_b0(ir):
8789
"device": torchtrt.Device("cuda:0"),
8890
"enabled_precisions": {torch.float},
8991
"ir": ir,
92+
"pass_through_build_failures": True,
9093
}
9194

9295
trt_mod = torchtrt.compile(model, **compile_spec)
@@ -126,6 +129,7 @@ def test_bert_base_uncased(ir):
126129
"enabled_precisions": {torch.float},
127130
"truncate_long_and_double": True,
128131
"ir": ir,
132+
"pass_through_build_failures": True,
129133
}
130134
trt_mod = torchtrt.compile(model, **compile_spec)
131135

@@ -160,6 +164,7 @@ def test_resnet18_half(ir):
160164
"device": torchtrt.Device("cuda:0"),
161165
"enabled_precisions": {torch.half},
162166
"ir": ir,
167+
"pass_through_build_failures": True,
163168
}
164169

165170
trt_mod = torchtrt.compile(model, **compile_spec)

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 44 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
trt_transposed_matmul,
2727
)
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29-
from torch_tensorrt.fx.converters.impl import activation
29+
from torch_tensorrt.fx.converters.impl import activation, convolution
3030

3131
_LOGGER: logging.Logger = logging.getLogger(__name__)
3232

@@ -96,86 +96,20 @@ def acc_ops_conv1d(
9696
kwargs: Dict[str, Argument],
9797
name: str,
9898
) -> Union[TRTTensor, Sequence[TRTTensor]]:
99-
input_val = kwargs["input"]
100-
if not isinstance(input_val, TRTTensor):
101-
raise RuntimeError(
102-
f"Conv received input {input_val} that is not part "
103-
"of the TensorRT region!"
104-
)
105-
106-
# Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d
107-
unsqueeze_layer = network.add_shuffle(input=input_val)
108-
unsqueeze_layer.reshape_dims = tuple([*input_val.shape, 1])
109-
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
110-
input_val = unsqueeze_layer.get_output(0)
111-
112-
if has_dynamic_shape(input_val.shape):
113-
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
114-
115-
# for now we'll assume bias is constant Tensor or None,
116-
# and bias being ITensor is not supported in TensorRT api
117-
# right now
118-
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
119-
raise RuntimeError(
120-
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
121-
)
122-
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
123-
if bias is not None:
124-
bias = bias[None]
125-
weight = kwargs["weight"]
126-
127-
if network.has_explicit_precision or isinstance(weight, TRTTensor):
128-
weight = get_trt_tensor(network, weight, f"{name}_weight")
129-
# Expand 1d weight with unsqueeze for calculation
130-
unsqueeze_weight_layer = network.add_shuffle(input=weight)
131-
unsqueeze_weight_layer.reshape_dims = tuple([*weight.shape, 1])
132-
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze_weight")
133-
weight = unsqueeze_weight_layer.get_output(0)
134-
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
135-
# will need to use uninitialized weight and set it later to support
136-
# ITensor weights
137-
dummy_weight = trt.Weights()
138-
layer = network.add_convolution_nd(
139-
input=input_val,
140-
num_output_maps=weight.shape[0],
141-
kernel_shape=weight.shape[2:],
142-
kernel=dummy_weight,
143-
bias=bias,
144-
)
145-
146-
layer.set_input(1, weight)
147-
else:
148-
if not isinstance(kwargs["weight"], torch.Tensor):
149-
raise RuntimeError(
150-
f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
151-
)
152-
weight = to_numpy(weight)
153-
weight = np.expand_dims(weight, -1)
154-
layer = network.add_convolution_nd(
155-
input=input_val,
156-
num_output_maps=weight.shape[0],
157-
kernel_shape=weight.shape[2:],
158-
kernel=weight,
159-
bias=bias,
160-
)
161-
# expand params to 2d for computation
162-
padding = list(kwargs["padding"])
163-
padding.append(0)
164-
stride = extend_attr_to_tuple(kwargs["stride"], 2)
165-
dilation = extend_attr_to_tuple(kwargs["dilation"], 2)
166-
167-
set_layer_name(layer, target, name)
168-
layer.stride_nd = stride
169-
layer.padding_nd = padding
170-
layer.dilation_nd = dilation
171-
if kwargs["groups"] is not None:
172-
layer.num_groups = kwargs["groups"]
173-
174-
result = layer.get_output(0)
175-
squeeze_layer = network.add_shuffle(input=result)
176-
squeeze_layer.reshape_dims = tuple(result.shape[:-1])
177-
set_layer_name(squeeze_layer, target, name + "_squeeze")
178-
return squeeze_layer.get_output(0)
99+
return convolution.convNd(
100+
network,
101+
target,
102+
source_ir=SourceIR.ACC,
103+
name=name,
104+
is_conv1d=True,
105+
input_val=kwargs["input"],
106+
weight=kwargs["weight"],
107+
bias=kwargs["bias"],
108+
stride=kwargs["stride"],
109+
padding=kwargs["padding"],
110+
dilation=kwargs["dilation"],
111+
groups=kwargs["groups"],
112+
)
179113

180114

181115
@tensorrt_converter(acc_ops.conv3d)
@@ -187,63 +121,20 @@ def acc_ops_convnd(
187121
kwargs: Dict[str, Argument],
188122
name: str,
189123
) -> Union[TRTTensor, Sequence[TRTTensor]]:
190-
input_val = kwargs["input"]
191-
192-
if not isinstance(input_val, TRTTensor):
193-
raise RuntimeError(
194-
f"Conv received input {input_val} that is not part "
195-
"of the TensorRT region!"
196-
)
197-
198-
if has_dynamic_shape(input_val.shape):
199-
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
200-
201-
# for now we'll assume bias is constant Tensor or None,
202-
# and bias being ITensor is not supported in TensorRT api
203-
# right now
204-
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
205-
raise RuntimeError(
206-
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
207-
)
208-
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
209-
210-
if network.has_explicit_precision or isinstance(kwargs["weight"], TRTTensor):
211-
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
212-
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
213-
# will need to use uninitialized weight and set it later to support
214-
# ITensor weights
215-
dummy_weight = trt.Weights()
216-
layer = network.add_convolution_nd(
217-
input=input_val,
218-
num_output_maps=weight.shape[0],
219-
kernel_shape=weight.shape[2:],
220-
kernel=dummy_weight,
221-
bias=bias,
222-
)
223-
224-
layer.set_input(1, weight)
225-
else:
226-
if not isinstance(kwargs["weight"], torch.Tensor):
227-
raise RuntimeError(
228-
f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
229-
)
230-
weight = to_numpy(kwargs["weight"])
231-
layer = network.add_convolution_nd(
232-
input=input_val,
233-
num_output_maps=weight.shape[0],
234-
kernel_shape=weight.shape[2:],
235-
kernel=weight,
236-
bias=bias,
237-
)
238-
239-
set_layer_name(layer, target, name)
240-
layer.stride_nd = kwargs["stride"]
241-
layer.padding_nd = kwargs["padding"]
242-
layer.dilation_nd = kwargs["dilation"]
243-
if kwargs["groups"] is not None:
244-
layer.num_groups = kwargs["groups"]
245-
246-
return layer.get_output(0)
124+
return convolution.convNd(
125+
network,
126+
target,
127+
source_ir=SourceIR.ACC,
128+
name=name,
129+
is_conv1d=False,
130+
input_val=kwargs["input"],
131+
weight=kwargs["weight"],
132+
bias=kwargs["bias"],
133+
stride=kwargs["stride"],
134+
padding=kwargs["padding"],
135+
dilation=kwargs["dilation"],
136+
groups=kwargs["groups"],
137+
)
247138

248139

249140
@tensorrt_converter(acc_ops.conv_transpose2d)
@@ -268,32 +159,36 @@ def acc_ops_conv_transposend(
268159
input_val.shape[1] != -1
269160
), "Channel dim can't be dynamic for transpose convolution."
270161

271-
# for now we'll assume bias is constant Tensor or None,
272-
# and bias being ITensor is not supported in TensorRT api
273-
# right now
274-
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
275-
raise RuntimeError(
276-
f"ConvTranspose {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
277-
)
278-
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
162+
if not isinstance(kwargs["bias"], TRTTensor):
163+
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
164+
raise RuntimeError(
165+
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
166+
)
167+
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
168+
else:
169+
bias = kwargs["bias"]
279170

280171
if network.has_explicit_precision or isinstance(kwargs["weight"], TRTTensor):
281172
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
282173
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
283174
# will need to use uninitialized weight and set it later to support
284175
# ITensor weights
285-
dummy_weight = trt.Weights()
286176

287177
# nn.ConvTranspose2d/3d weight size is (in_channels, out_channels/groups, kernel_0, kernel_1, [kernel_2])
288178
layer = network.add_deconvolution_nd(
289179
input=input_val,
290180
num_output_maps=weight.shape[1] * kwargs["groups"],
291181
kernel_shape=weight.shape[2:],
292-
kernel=dummy_weight,
293-
bias=bias,
182+
kernel=trt.Weights(),
183+
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
294184
)
295185

296186
layer.set_input(1, weight)
187+
188+
# If the bias is a TRTTensor, set it as an input of the layer
189+
if isinstance(bias, TRTTensor):
190+
bias = get_trt_tensor(network, bias, f"{name}_bias")
191+
layer.set_input(2, bias)
297192
else:
298193
if not isinstance(kwargs["weight"], torch.Tensor):
299194
raise RuntimeError(

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from .converter_utils import * # noqa: F403
2424
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
25-
from torch_tensorrt.fx.converters.impl import activation
25+
from torch_tensorrt.fx.converters.impl import activation, convolution
2626

2727
_LOGGER: logging.Logger = logging.getLogger(__name__)
2828

@@ -129,13 +129,36 @@ def aten_ops_convolution(
129129
# we do not handle output_padding.
130130
if args[7] not in ([0], [0, 0], [0, 0, 0]):
131131
raise RuntimeError(f"Target {target} has non-0 output_padding")
132+
132133
if len(kwargs_new["stride"]) == 1:
133-
return acc_ops_converters.acc_ops_conv1d(
134-
network, target, None, kwargs_new, name
134+
return convolution.convNd(
135+
network,
136+
target,
137+
source_ir=SourceIR.ATEN,
138+
name=name,
139+
is_conv1d=True,
140+
input_val=kwargs_new["input"],
141+
weight=kwargs_new["weight"],
142+
bias=kwargs_new["bias"],
143+
stride=kwargs_new["stride"],
144+
padding=kwargs_new["padding"],
145+
dilation=kwargs_new["dilation"],
146+
groups=kwargs_new["groups"],
135147
)
136148
else:
137-
return acc_ops_converters.acc_ops_convnd(
138-
network, target, None, kwargs_new, name
149+
return convolution.convNd(
150+
network,
151+
target,
152+
source_ir=SourceIR.ATEN,
153+
name=name,
154+
is_conv1d=False,
155+
input_val=kwargs_new["input"],
156+
weight=kwargs_new["weight"],
157+
bias=kwargs_new["bias"],
158+
stride=kwargs_new["stride"],
159+
padding=kwargs_new["padding"],
160+
dilation=kwargs_new["dilation"],
161+
groups=kwargs_new["groups"],
139162
)
140163

141164

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,17 @@ def get_positive_dim(dim: int, dim_size: int) -> int:
9999

100100

101101
def set_layer_name(
102-
layer: TRTLayer, target: Target, name: str, source_ir: Optional[SourceIR] = None
102+
layer: TRTLayer,
103+
target: Union[Target, torch.nn.Module, str],
104+
name: str,
105+
source_ir: Optional[SourceIR] = None,
103106
) -> None:
104107
"""
105108
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
106109
107110
Args:
108111
layer (TRTLayer): A TensorRT layer of which we want to set the name.
109-
target (Target): A fx node.target. For call_function node, it's the function that
112+
target (Target): A fx node.target or submodule. For call_function node, it's the function that
110113
the node represents.
111114
name (str): Consists of fx node.name with optional suffix.
112115
source_ir: (Optional[SourceIR]): The IR producing the op.

0 commit comments

Comments
 (0)