Skip to content

Commit 692921e

Browse files
committed
Move fixes into Dynamo directory
1 parent 7ff9309 commit 692921e

File tree

16 files changed

+162
-63
lines changed

16 files changed

+162
-63
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.utils._pytree as pytree
1010
from torch._dynamo.utils import detect_fake_mode
1111
from torch._functorch.aot_autograd import _aot_export_function
12-
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
12+
from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant
1313
from torch._ops import OpOverload
1414
from torch_tensorrt.dynamo import CompilationSettings
1515
from torch_tensorrt.dynamo.compile import compile_module
@@ -100,7 +100,7 @@ def _pretraced_backend(
100100
+ "Returning GraphModule forward instead.",
101101
exc_info=True,
102102
)
103-
return gm.forward
103+
return gm
104104
else:
105105
logger.critical(
106106
"Halting compilation on build failure since "
@@ -114,6 +114,13 @@ def _pretraced_backend(
114114

115115
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
116116
def constant_fold(gm: torch.fx.GraphModule) -> Any:
117+
"""Adapted from:
118+
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
119+
120+
Folds constants in the graph module, not skipping constructors
121+
122+
Modifies the graph in-place and replaces node with constants
123+
"""
117124
cf = ConstantFolder(gm, skip_constructors=False)
118125
cf.run()
119126

@@ -141,10 +148,13 @@ def aot_export_for_compile(
141148
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
142149
) -> torch.fx.GraphModule:
143150
"""Adapted from:
144-
https://github.com/pytorch/pytorch/blob/054f3f1d8f9eb63ef8437991eba5b8f2aeee920f/torch/_functorch/aot_autograd.py#L4133-L4134
151+
https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158
145152
146153
Removed check for input aliasing in resultant subgraph - TRT is functional-only
154+
155+
Exports the function to ATen for torch compile
147156
"""
157+
# Trace function with input arguments and decompositions
148158
with torch.no_grad():
149159
fx_g, metadata, in_spec, out_spec = _aot_export_function(
150160
func,

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
361361
outputs = (args[0],)
362362

363363
for output_idx in range(len(outputs)):
364-
from torch_tensorrt.fx.converters import get_trt_tensor
364+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
365365

366366
output = outputs[output_idx]
367367

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def aten_ops_clone(
538538
)
539539

540540

541-
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
541+
@dynamo_tensorrt_converter(torch.ops.aten.expand.default) # type: ignore[misc]
542542
def aten_ops_expand(
543543
network: TRTNetwork,
544544
target: Target,
@@ -568,7 +568,7 @@ def amax_param_validator(amax_node: Node) -> bool:
568568

569569
@dynamo_tensorrt_converter(
570570
torch.ops.aten.amax.default, capability_validator=amax_param_validator
571-
)
571+
) # type: ignore[misc]
572572
def aten_ops_amax(
573573
network: TRTNetwork,
574574
target: Target,
@@ -982,12 +982,13 @@ def aten_ops_isinf(
982982

983983

984984
def conv_param_validator(conv_node: Node) -> bool:
985+
# Output padding and transposed convolutions not supported currently
985986
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))
986987

987988

988989
@dynamo_tensorrt_converter(
989990
torch.ops.aten.convolution.default, capability_validator=conv_param_validator
990-
)
991+
) # type: ignore[misc]
991992
def aten_ops_convolution(
992993
network: TRTNetwork,
993994
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import functools
22
import logging
33
import re
4-
from typing import Any, List, Optional, Tuple
4+
from typing import Any, List, Optional, Tuple, Union
55

6+
import numpy as np
67
import tensorrt as trt
78
import torch
89
from torch.fx.node import Target
910
from torch_tensorrt.fx.converters.converter_utils import (
1011
Frameworks,
1112
get_axes_for_reduce_op,
13+
to_numpy,
1214
unified_dtype_converter,
1315
)
1416
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
@@ -187,4 +189,76 @@ def extend_attr_to_tuple(
187189

188190
if isinstance(val, list):
189191
val = tuple(val)
190-
return val
192+
193+
if isinstance(val, tuple):
194+
return val
195+
else:
196+
raise AssertionError(f"Could not extend attribute {val}")
197+
198+
199+
def create_constant(
200+
network: TRTNetwork,
201+
value: Union[int, float, np.ndarray, torch.Tensor],
202+
name: str,
203+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]],
204+
) -> TRTTensor:
205+
"""
206+
Add a TensorRT constant layer whose value is `value` to `network`.
207+
Args:
208+
network (TRTNetwork): A TensorRT network to which we want to add
209+
a constant layer.
210+
value (Union[int, float, np.ndarray, torch.Tensor]): A literal value, Numpy array,
211+
or a PyTorch tensor that will be used as value of the added TensorRT Constant layer.
212+
name (str): Name of the added TensorRT Constant layer.
213+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
214+
If a dtype is given, we will convert the type of the given `value` to this dtype.
215+
Returns:
216+
A TensorRT ITensor that represents the given value.
217+
"""
218+
constant = network.add_constant(
219+
(1,) if isinstance(value, (int, float)) else value.shape,
220+
to_numpy(value, dtype).copy(),
221+
)
222+
constant.name = name
223+
return constant.get_output(0)
224+
225+
226+
def get_trt_tensor(
227+
network: TRTNetwork,
228+
input_val: Any,
229+
name: str,
230+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None,
231+
) -> TRTTensor:
232+
"""
233+
Given a value of random type, we try to convert it to a TensorRT ITensor.
234+
An runtime error is raised if we're not able to do that.
235+
Args:
236+
network (TRTNetwork): A TensorRT network. If we want to
237+
add a TensorRT Constant layer, we will add it to this network.
238+
input_val (Any): An value that we want to convert to a TensorRT ITensor.
239+
name (str): The name of the created TensorRT Constant layer if there's
240+
one.
241+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
242+
If dtype is provided, the given value will be converted to this dtype.
243+
Returns:
244+
A TensorRT ITensor that represents the given value.
245+
"""
246+
# TRT can not add constant for bool type. We do a work around to 1) cast it to int and 2)cast to bool later
247+
# This is useful for logical operations which require input to be bool type
248+
if isinstance(input_val, bool):
249+
input_val = int(input_val)
250+
elif isinstance(input_val, torch.Tensor) and (
251+
input_val.dtype == torch.bool or input_val.dtype == torch.int64
252+
):
253+
input_val = input_val.to(torch.int32)
254+
elif isinstance(input_val, np.ndarray) and (
255+
input_val.dtype == np.bool_ or input_val.dtype == np.int64
256+
):
257+
input_val = input_val.astype(np.int32)
258+
259+
if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)):
260+
return create_constant(network, input_val, name, dtype)
261+
elif isinstance(input_val, TRTTensor):
262+
return input_val
263+
else:
264+
raise AssertionError(f"Cannot convert {input_val} to TRT constant")

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
from typing import Optional
22

3+
import tensorrt as trt
34
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.dynamo.conversion.converter_utils import broadcastable
7-
from torch_tensorrt.dynamo.conversion.impl.slice import expand
8-
from torch_tensorrt.fx.converters.converter_utils import (
9-
broadcast,
7+
from torch_tensorrt.dynamo.conversion.converter_utils import (
8+
broadcastable,
109
get_trt_tensor,
11-
set_layer_name,
1210
)
11+
from torch_tensorrt.dynamo.conversion.impl.slice import expand
12+
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
1313
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1414

15-
import tensorrt as trt
16-
1715

1816
def where(
1917
network: TRTNetwork,

py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
import torch
88
from torch.fx.node import Target
99
from torch_tensorrt.dynamo.conversion import impl
10-
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
extend_attr_to_tuple,
12+
get_trt_tensor,
13+
)
1114
from torch_tensorrt.fx.converters.converter_utils import (
1215
SourceIR,
1316
get_dyn_range,
14-
get_trt_tensor,
1517
has_dynamic_shape,
1618
mark_as_int8_layer,
1719
set_layer_name,
@@ -27,8 +29,8 @@ def convNd(
2729
name: str,
2830
is_conv1d: bool,
2931
input: TRTTensor,
30-
weight: Union[TRTTensor, torch.Tensor],
31-
bias: Optional[Union[TRTTensor, torch.Tensor]],
32+
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
33+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
3234
stride: Optional[Union[int, Sequence[int]]],
3335
padding: Optional[Union[int, Sequence[int]]],
3436
dilation: Optional[Union[int, Sequence[int]]],
@@ -97,19 +99,28 @@ def convNd(
9799
if isinstance(bias, TRTTensor):
98100
conv_layer.set_input(2, bias)
99101

102+
# Cast certain fields to tuples, in accordance with TRT requirements
103+
padding = (padding,) if isinstance(padding, int) else padding
104+
stride = (stride,) if isinstance(stride, int) else stride
105+
dilation = (dilation,) if isinstance(dilation, int) else dilation
106+
100107
# Expand parameters manually for Conv1D computations
101108
if is_conv1d:
102-
padding = tuple(padding) + (0,)
103-
stride = extend_attr_to_tuple(stride, 2)
104-
dilation = extend_attr_to_tuple(dilation, 2)
109+
padding = (tuple(padding) + (0,)) if padding is not None else padding
110+
stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride
111+
dilation = (
112+
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
113+
)
105114

106115
set_layer_name(conv_layer, target, name, source_ir)
107116

108117
# Set relevant attributes of convolution layer
109-
conv_layer.padding_nd = padding
110-
conv_layer.stride_nd = stride
111-
conv_layer.dilation_nd = dilation
112-
118+
if padding is not None:
119+
conv_layer.padding_nd = padding
120+
if stride is not None:
121+
conv_layer.stride_nd = stride
122+
if dilation is not None:
123+
conv_layer.dilation_nd = dilation
113124
if groups is not None:
114125
conv_layer.num_groups = groups
115126

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import torch
77
from torch.fx.node import Target
88
from torch_tensorrt.dynamo._SourceIR import SourceIR
9-
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
9+
from torch_tensorrt.dynamo.conversion.converter_utils import (
10+
cast_trt_tensor,
11+
get_trt_tensor,
12+
)
1013
from torch_tensorrt.fx.converters.converter_utils import (
1114
broadcast,
12-
get_trt_tensor,
1315
set_layer_name,
1416
squeeze_left,
1517
)

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
import tensorrt as trt
55
from torch.fx.node import Target
66
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
78
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
89
convert_binary_elementwise,
910
)
1011
from torch_tensorrt.dynamo.conversion.impl.unary import sign
1112
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
12-
from torch_tensorrt.fx.converters.converter_utils import (
13-
get_trt_tensor,
14-
set_layer_name,
15-
squeeze_left,
16-
)
13+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left
1714
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1815
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1916

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor, set_layer_name
6+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
7+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
78
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
89

910

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
import tensorrt as trt
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.fx.converters.converter_utils import (
7-
broadcast,
8-
get_trt_tensor,
9-
set_layer_name,
10-
)
6+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
7+
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
118
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
129
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1310

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
5+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
56
from torch_tensorrt.fx.converters.converter_utils import (
67
get_positive_dim,
7-
get_trt_tensor,
88
set_layer_name,
99
)
1010
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,14 +2711,8 @@ def acc_ops_linear(
27112711
"dim for linear and it can't be the last dim."
27122712
)
27132713

2714-
if isinstance(kwargs["weight"], (torch.Tensor, np.ndarray)):
2715-
weight = get_trt_tensor(
2716-
network,
2717-
kwargs["weight"].t()
2718-
if isinstance(kwargs["weight"], torch.Tensor)
2719-
else kwargs["weight"].T,
2720-
f"{name}_weight",
2721-
)
2714+
if isinstance(kwargs["weight"], torch.Tensor):
2715+
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
27222716
if target not in (acc_ops.linear, torch.ops.aten.linear):
27232717
weight_op = trt.MatrixOperation.TRANSPOSE
27242718
else:

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def create_constant(
271271
"""
272272
constant = network.add_constant(
273273
(1,) if isinstance(value, (int, float)) else value.shape,
274-
to_numpy(value, dtype).copy(),
274+
to_numpy(value, dtype),
275275
)
276276
constant.name = name
277277
return constant.get_output(0)
@@ -311,7 +311,7 @@ def get_trt_tensor(
311311
elif isinstance(input_val, np.ndarray) and (
312312
input_val.dtype == np.bool_ or input_val.dtype == np.int64
313313
):
314-
input_val = input_val.astype(np.int32)
314+
input_val = input_val.to(np.int32)
315315

316316
if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)):
317317
return create_constant(network, input_val, name, dtype)

py/torch_tensorrt/fx/converters/impl/convolution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def convNd(
5050
)
5151

5252
# Process bias terms
53-
if isinstance(bias, (torch.Tensor, np.ndarray)):
53+
if isinstance(bias, torch.Tensor):
5454
# Transform the bias constant into a Numpy array
5555
bias = to_numpy(bias)
5656

@@ -75,7 +75,7 @@ def convNd(
7575
network, target, tuple(), kwargs, name + "_unsqueeze_weight"
7676
)
7777

78-
elif isinstance(weight, (torch.Tensor, np.ndarray)):
78+
elif isinstance(weight, torch.Tensor):
7979
# Transform the weight constant into a Numpy array
8080
weight = to_numpy(weight)
8181

0 commit comments

Comments
 (0)