Skip to content

Commit a055858

Browse files
committed
fix: Add special cases where input of graph is output
- TRT does not allow inputs of graphs to be outputs as well, however many of the scenarios encountered in real models can have this situation come up, especially in cases where the input is cloned or copied and then returned - The current converters will register these operators as a no-op, causing TRT engine building to fail on such inputs - Instead of requiring creation of an identity layer for every case of a clone or copy node, we instead check if that node is the only operator on a placeholder (input) and then insert the identity layer or not, accordingly - Coalesce implementations of clone and to_copy, which are effectively the same operator - Add test cases to validate new behavior - Add new boilerplate converter validator utility to support this case
1 parent 40f8064 commit a055858

File tree

5 files changed

+161
-65
lines changed

5 files changed

+161
-65
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import logging
2-
from typing import Any, Dict, Optional, Sequence, Tuple, Union
2+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
33

44
import torch
55
from torch.fx.node import Argument, Node, Target
66
from torch_tensorrt.dynamo._SourceIR import SourceIR
77
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
cast_trt_tensor,
10+
is_only_operator_on_placeholder,
11+
)
812
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
913

1014
from .converter_registry import dynamo_tensorrt_converter
@@ -447,29 +451,59 @@ def aten_ops_permute(
447451
)
448452

449453

450-
def to_copy_dtype_validator(to_copy_node: Node) -> bool:
451-
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}
452-
453-
# Validate input node has convertible kwargs
454-
if "dtype" in to_copy_node.kwargs:
455-
if to_copy_node.kwargs["dtype"] in allowed_casts:
456-
return True
454+
def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]:
455+
"""Return validator for to_copy node with placeholder restrictions"""
456+
457+
def validate_dtype(to_copy_node: Node) -> bool:
458+
"""Returns true if the to_copy node can be converted to TRT
459+
460+
Based on data type being casted to
461+
"""
462+
allowed_casts = {
463+
torch.float,
464+
torch.int32,
465+
torch.bool,
466+
torch.int8,
467+
torch.float16,
468+
}
469+
470+
# Validate input node has convertible kwargs
471+
if "dtype" in to_copy_node.kwargs:
472+
if to_copy_node.kwargs["dtype"] in allowed_casts:
473+
return True
474+
else:
475+
_LOGGER.debug(
476+
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
477+
)
478+
return False
457479
else:
458480
_LOGGER.debug(
459-
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
481+
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
460482
)
461483
return False
462-
else:
463-
_LOGGER.debug(
464-
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
484+
485+
def validator(to_copy_node: Node) -> bool:
486+
"""Returns true if the to_copy node can be converted to TRT
487+
and the placeholder restriction is satisfied
488+
"""
489+
# The placeholder restriction is satsfied if placeholder_only is the same
490+
# truth value as is_only_operator_on_placeholder(to_copy_node)
491+
return validate_dtype(to_copy_node) and (
492+
(not placeholder_only) ^ is_only_operator_on_placeholder(to_copy_node)
465493
)
466-
return False
494+
495+
return validator
467496

468497

469498
@dynamo_tensorrt_converter(
470-
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
499+
torch.ops.aten.clone.default,
500+
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
471501
) # type: ignore[misc]
472-
def aten_ops_to_copy_dtype(
502+
@dynamo_tensorrt_converter(
503+
torch.ops.aten._to_copy.default,
504+
capability_validator=to_copy_dtype_validator(placeholder_only=False),
505+
) # type: ignore[misc]
506+
def aten_ops_clone_copy_dtype(
473507
network: TRTNetwork,
474508
target: Target,
475509
args: Tuple[Argument, ...],
@@ -482,28 +516,41 @@ def aten_ops_to_copy_dtype(
482516
SourceIR.ATEN,
483517
name,
484518
args[0],
485-
kwargs["dtype"],
519+
kwargs.get("dtype", args[0].dtype),
520+
force_layer=False,
486521
)
487522

488523

489-
@dynamo_tensorrt_converter(torch.ops.aten.clone.default) # type: ignore[misc]
490-
def aten_ops_clone(
524+
@dynamo_tensorrt_converter(
525+
torch.ops.aten.clone.default,
526+
capability_validator=is_only_operator_on_placeholder,
527+
) # type: ignore[misc]
528+
@dynamo_tensorrt_converter(
529+
torch.ops.aten._to_copy.default,
530+
capability_validator=to_copy_dtype_validator(placeholder_only=True),
531+
) # type: ignore[misc]
532+
def aten_ops_clone_copy_placeholder(
491533
network: TRTNetwork,
492534
target: Target,
493535
args: Tuple[Argument, ...],
494536
kwargs: Dict[str, Argument],
495537
name: str,
496538
) -> Union[TRTTensor, Sequence[TRTTensor]]:
497-
return impl.cast.clone(
539+
# For clone or copy nodes where the input is also the output,
540+
# we need to force cast to ensure a layer is added to the TRT engine
541+
# since TRT engine inputs cannot also be TRT engine outputs
542+
return impl.cast.to_copy(
498543
network,
499544
target,
500545
SourceIR.ATEN,
501546
name,
502547
args[0],
548+
kwargs.get("dtype", args[0].dtype),
549+
force_layer=True,
503550
)
504551

505552

506-
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
553+
@dynamo_tensorrt_converter(torch.ops.aten.expand.default) # type: ignore[misc]
507554
def aten_ops_expand(
508555
network: TRTNetwork,
509556
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,49 @@ def get_node_name(node: torch.fx.Node) -> str:
4343
return node_name
4444

4545

46+
def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
47+
"""Detects whether a call_function node is the only operator on a placeholder"""
48+
# Returns true if the node operates on a placeholder and is a direct output
49+
return (
50+
node.op == "call_function"
51+
and any(
52+
arg.op == "placeholder"
53+
for arg in node.args
54+
if isinstance(arg, torch.fx.Node)
55+
)
56+
and any(user.op == "output" for user in list(node.users.keys()))
57+
)
58+
59+
4660
def dynamic_unsupported(node: torch.fx.Node) -> bool:
4761
# Validate that none of the inputs to the node have Dynamic shapes
4862
assert isinstance(
4963
node, torch.fx.Node
5064
), "Inputs to validator functions must be FX Nodes"
5165

5266
# Check node value itself
53-
if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False):
67+
if ("val" in node.meta) and getattr(
68+
node.meta["val"], "_has_symbolic_sizes_strides", False
69+
):
5470
return False
5571

5672
# Check node arguments individually
5773
if any(
58-
getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
74+
(
75+
("val" in arg.meta)
76+
and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
77+
)
5978
for arg in node.args
6079
if isinstance(arg, torch.fx.Node)
6180
):
6281
return False
6382

6483
# Check node keyword arguments individually
6584
if any(
66-
getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
85+
(
86+
("val" in kwarg.meta)
87+
and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
88+
)
6789
for kwarg in node.kwargs.values()
6890
if isinstance(kwarg, torch.fx.Node)
6991
):
@@ -80,9 +102,12 @@ def cast_trt_tensor(
80102
target: Target = "",
81103
source_ir: Optional[SourceIR] = None,
82104
) -> TRTTensor:
83-
"""
84-
Given a TRT Tensor, convert that Tensor to the specified dtype
105+
"""Given a TRT Tensor, convert that Tensor to the specified dtype
106+
85107
Adds an Identity layer to the network which performs the conversion
108+
if the input's dtype is different from the cast type. Otherwise returns
109+
input unchanged
110+
86111
Args:
87112
network (TRTNetwork): A TensorRT network
88113
input_val (TRTTensor): A TRT Tensor to cast to a new data type
@@ -185,10 +210,16 @@ def extend_attr_to_tuple(
185210

186211
if isinstance(val, list):
187212
val = tuple(val)
188-
return val
213+
214+
if isinstance(val, tuple):
215+
return val
216+
else:
217+
raise AssertionError(f"Object {val} could not be extended to tuple")
189218

190219

191-
def cast_int_or_float_to_bool(network: TRTNetwork, name: str, tensor: TRTTensor):
220+
def cast_int_or_float_to_bool(
221+
network: TRTNetwork, name: str, tensor: TRTTensor
222+
) -> TRTTensor:
192223
if tensor.dtype != trt.bool:
193224
return cast_trt_tensor(network, tensor, trt.bool, name)
194225

py/torch_tensorrt/dynamo/conversion/impl/cast.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
67
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
8+
from torch_tensorrt.fx.converters.converter_utils import (
9+
Frameworks,
10+
unified_dtype_converter,
11+
)
712
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
813

914
LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -16,28 +21,25 @@ def to_copy(
1621
name: str,
1722
input: TRTTensor,
1823
dtype: TRTDataType,
24+
force_layer: bool = False,
1925
) -> TRTTensor:
2026
if not isinstance(input, TRTTensor):
2127
raise RuntimeError(
2228
f"to_copy received input {input} that is not a TensorRT ITensor"
2329
)
2430

25-
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
26-
return casted_tensor
27-
28-
29-
def clone(
30-
network: TRTNetwork,
31-
target: Target,
32-
source_ir: Optional[SourceIR],
33-
name: str,
34-
input: TRTTensor,
35-
) -> TRTTensor:
36-
if not isinstance(input, TRTTensor):
37-
raise RuntimeError(
38-
f"clone received input {input} that is not a TensorRT ITensor"
39-
)
40-
41-
LOGGER.debug(f"Evaluating clone on object with name: {name}")
42-
43-
return input
31+
# If cast is forced, insert identity layer regardless of whether the dtype
32+
# doesn't change
33+
if force_layer:
34+
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)
35+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
36+
target_str = ConverterRegistry.qualified_name_or_str(target)
37+
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"
38+
39+
identity_layer = network.add_identity(input)
40+
identity_layer.set_output_type(0, trt_dtype)
41+
identity_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
42+
return identity_layer.get_output(0)
43+
else:
44+
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
45+
return casted_tensor

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
broadcast,
1313
get_trt_tensor,
1414
set_layer_name,
15-
squeeze_left,
1615
)
1716
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor
1817
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
@@ -94,10 +93,10 @@ def convert_binary_elementwise(
9493
is_rhs_trt_tensor = False
9594

9695
if isinstance(lhs_val, TRTTensor):
97-
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY)
96+
lhs_dtype = lhs_val.dtype
9897
is_lhs_trt_tensor = True
9998
if isinstance(rhs_val, TRTTensor):
100-
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY)
99+
rhs_dtype = rhs_val.dtype
101100
is_rhs_trt_tensor = True
102101

103102
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
@@ -122,23 +121,13 @@ def convert_binary_elementwise(
122121
# dtype but we don't have a way to detect whether it makes sense for the
123122
# scalar to be float or half. Hence we go with the lhs dtype.
124123
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
125-
rhs_val = np.array([rhs_val], dtype=lhs_dtype)
124+
rhs_val = np.array(
125+
[rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY)
126+
)
126127
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
127-
lhs_val = np.array([lhs_val], dtype=rhs_dtype)
128-
129-
# When lhs is scalar, and rhs has shape [1,], then currently the assert
130-
# will fail because lhs shape has fewer dimensions than rhs shape. This
131-
# happens when using implicit batch dimension, when we removed the 1st
132-
# dimension from input tensor, causing it to have shape [] - a scalar. We
133-
# fix it by reducing the rhs constant with a squeeze_left, so it becomes a
134-
# scalar too. More generally, we squeeze_left on input if it's a constant
135-
# tensor. This is safe because broadcast will pad dimensions on the left
136-
# (prepend) to make lhs and rhs shape compatible.
137-
if network.has_implicit_batch_dimension:
138-
if isinstance(lhs_val, torch.Tensor):
139-
lhs_val = squeeze_left(lhs_val)
140-
if isinstance(rhs_val, torch.Tensor):
141-
rhs_val = squeeze_left(rhs_val)
128+
lhs_val = np.array(
129+
[lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY)
130+
)
142131

143132
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
144133
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)

tests/py/dynamo/conversion/test_casts.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def forward(self, x):
3535
disable_passes=True,
3636
)
3737

38+
def test_clone_direct(self):
39+
class Clone(nn.Module):
40+
def forward(self, x):
41+
return x.clone()
42+
43+
inputs = [torch.randn((8, 2, 10))]
44+
self.run_test(
45+
Clone(),
46+
inputs,
47+
expected_ops={torch.ops.aten.clone.default},
48+
disable_passes=True,
49+
)
50+
3851

3952
class TestToCopyConverter(DispatchTestCase):
4053
def test_to_copy_half(self):
@@ -83,6 +96,20 @@ def forward(self, x):
8396
disable_passes=True,
8497
)
8598

99+
def test_to_copy_direct(self):
100+
class ToCopyFloat(nn.Module):
101+
def forward(self, x):
102+
return x.to(dtype=torch.float, copy=True)
103+
104+
inputs = [torch.rand((1, 3, 10)).float()]
105+
self.run_test(
106+
ToCopyFloat(),
107+
inputs,
108+
expected_ops={torch.ops.aten._to_copy.default},
109+
precision=torch.float,
110+
disable_passes=True,
111+
)
112+
86113

87114
if __name__ == "__main__":
88115
run_tests()

0 commit comments

Comments
 (0)