Skip to content

Commit a57c97f

Browse files
committed
fix: Add special cases where input of graph is output
- TRT does not allow inputs of graphs to be outputs as well, however many of the scenarios encountered in real models can have this situation come up, especially in cases where the input is cloned or copied and then returned - The current converters will register these operators as a no-op, causing TRT engine building to fail on such inputs - Instead of requiring creation of an identity layer for every case of a clone or copy node, we instead check if that node is the only operator on a placeholder (input) and then insert the identity layer or not, accordingly - Coalesce implementations of clone and to_copy, which are effectively the same operator - Add test cases to validate new behavior - Add new boilerplate converter validator utility to support this case
1 parent f8bcbdd commit a57c97f

File tree

5 files changed

+151
-63
lines changed

5 files changed

+151
-63
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Dict, Optional, Sequence, Tuple, Union
2+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
33

44
import tensorrt as trt
55
import torch
@@ -9,6 +9,7 @@
99
from torch_tensorrt.dynamo.conversion.converter_utils import (
1010
cast_int_int_div_trt_tensor,
1111
cast_trt_tensor,
12+
is_only_operator_on_placeholder,
1213
)
1314
from torch_tensorrt.fx.converters import acc_ops_converters
1415
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
@@ -505,29 +506,59 @@ def aten_ops_permute(
505506
)
506507

507508

508-
def to_copy_dtype_validator(to_copy_node: Node) -> bool:
509-
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}
510-
511-
# Validate input node has convertible kwargs
512-
if "dtype" in to_copy_node.kwargs:
513-
if to_copy_node.kwargs["dtype"] in allowed_casts:
514-
return True
509+
def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]:
510+
"""Return validator for to_copy node with placeholder restrictions"""
511+
512+
def validate_dtype(to_copy_node: Node) -> bool:
513+
"""Returns true if the to_copy node can be converted to TRT
514+
515+
Based on data type being casted to
516+
"""
517+
allowed_casts = {
518+
torch.float,
519+
torch.int32,
520+
torch.bool,
521+
torch.int8,
522+
torch.float16,
523+
}
524+
525+
# Validate input node has convertible kwargs
526+
if "dtype" in to_copy_node.kwargs:
527+
if to_copy_node.kwargs["dtype"] in allowed_casts:
528+
return True
529+
else:
530+
_LOGGER.debug(
531+
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
532+
)
533+
return False
515534
else:
516535
_LOGGER.debug(
517-
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
536+
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
518537
)
519538
return False
520-
else:
521-
_LOGGER.debug(
522-
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
539+
540+
def validator(to_copy_node: Node) -> bool:
541+
"""Returns true if the to_copy node can be converted to TRT
542+
and the placeholder restriction is satisfied
543+
"""
544+
# The placeholder restriction is satsfied if placeholder_only is the same
545+
# truth value as is_only_operator_on_placeholder(to_copy_node)
546+
return validate_dtype(to_copy_node) and (
547+
(not placeholder_only) ^ is_only_operator_on_placeholder(to_copy_node)
523548
)
524-
return False
549+
550+
return validator
525551

526552

527553
@dynamo_tensorrt_converter(
528-
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
554+
torch.ops.aten.clone.default,
555+
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
529556
) # type: ignore[misc]
530-
def aten_ops_to_copy_dtype(
557+
@dynamo_tensorrt_converter(
558+
torch.ops.aten._to_copy.default,
559+
capability_validator=to_copy_dtype_validator(placeholder_only=False),
560+
) # type: ignore[misc]
561+
def aten_ops_clone_copy_dtype(
531562
network: TRTNetwork,
532563
target: Target,
533564
args: Tuple[Argument, ...],
@@ -540,28 +571,41 @@ def aten_ops_to_copy_dtype(
540571
SourceIR.ATEN,
541572
name,
542573
args[0],
543-
kwargs["dtype"],
574+
kwargs.get("dtype", args[0].dtype),
575+
force_layer=False,
544576
)
545577

546578

547-
@dynamo_tensorrt_converter(torch.ops.aten.clone.default) # type: ignore[misc]
548-
def aten_ops_clone(
579+
@dynamo_tensorrt_converter(
580+
torch.ops.aten.clone.default,
581+
capability_validator=is_only_operator_on_placeholder,
582+
) # type: ignore[misc]
583+
@dynamo_tensorrt_converter(
584+
torch.ops.aten._to_copy.default,
585+
capability_validator=to_copy_dtype_validator(placeholder_only=True),
586+
) # type: ignore[misc]
587+
def aten_ops_clone_copy_placeholder(
549588
network: TRTNetwork,
550589
target: Target,
551590
args: Tuple[Argument, ...],
552591
kwargs: Dict[str, Argument],
553592
name: str,
554593
) -> Union[TRTTensor, Sequence[TRTTensor]]:
555-
return impl.cast.clone(
594+
# For clone or copy nodes where the input is also the output,
595+
# we need to force cast to ensure a layer is added to the TRT engine
596+
# since TRT engine inputs cannot also be TRT engine outputs
597+
return impl.cast.to_copy(
556598
network,
557599
target,
558600
SourceIR.ATEN,
559601
name,
560602
args[0],
603+
kwargs.get("dtype", args[0].dtype),
604+
force_layer=True,
561605
)
562606

563607

564-
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
608+
@dynamo_tensorrt_converter(torch.ops.aten.expand.default) # type: ignore[misc]
565609
def aten_ops_expand(
566610
network: TRTNetwork,
567611
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,49 @@ def get_node_name(node: torch.fx.Node) -> str:
4343
return node_name
4444

4545

46+
def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
47+
"""Detects whether a call_function node is the only operator on a placeholder"""
48+
# Returns true if the node operates on a placeholder and is a direct output
49+
return (
50+
node.op == "call_function"
51+
and any(
52+
arg.op == "placeholder"
53+
for arg in node.args
54+
if isinstance(arg, torch.fx.Node)
55+
)
56+
and any(user.op == "output" for user in list(node.users.keys()))
57+
)
58+
59+
4660
def dynamic_unsupported(node: torch.fx.Node) -> bool:
4761
# Validate that none of the inputs to the node have Dynamic shapes
4862
assert isinstance(
4963
node, torch.fx.Node
5064
), "Inputs to validator functions must be FX Nodes"
5165

5266
# Check node value itself
53-
if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False):
67+
if ("val" in node.meta) and getattr(
68+
node.meta["val"], "_has_symbolic_sizes_strides", False
69+
):
5470
return False
5571

5672
# Check node arguments individually
5773
if any(
58-
getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
74+
(
75+
("val" in arg.meta)
76+
and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
77+
)
5978
for arg in node.args
6079
if isinstance(arg, torch.fx.Node)
6180
):
6281
return False
6382

6483
# Check node keyword arguments individually
6584
if any(
66-
getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
85+
(
86+
("val" in kwarg.meta)
87+
and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
88+
)
6789
for kwarg in node.kwargs.values()
6890
if isinstance(kwarg, torch.fx.Node)
6991
):
@@ -80,9 +102,12 @@ def cast_trt_tensor(
80102
target: Target = "",
81103
source_ir: Optional[SourceIR] = None,
82104
) -> TRTTensor:
83-
"""
84-
Given a TRT Tensor, convert that Tensor to the specified dtype
105+
"""Given a TRT Tensor, convert that Tensor to the specified dtype
106+
85107
Adds an Identity layer to the network which performs the conversion
108+
if the input's dtype is different from the cast type. Otherwise returns
109+
input unchanged
110+
86111
Args:
87112
network (TRTNetwork): A TensorRT network
88113
input_val (TRTTensor): A TRT Tensor to cast to a new data type

py/torch_tensorrt/dynamo/conversion/impl/cast.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
67
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
8+
from torch_tensorrt.fx.converters.converter_utils import (
9+
Frameworks,
10+
unified_dtype_converter,
11+
)
712
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
813

914
LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -16,28 +21,25 @@ def to_copy(
1621
name: str,
1722
input: TRTTensor,
1823
dtype: TRTDataType,
24+
force_layer: bool = False,
1925
) -> TRTTensor:
2026
if not isinstance(input, TRTTensor):
2127
raise RuntimeError(
2228
f"to_copy received input {input} that is not a TensorRT ITensor"
2329
)
2430

25-
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
26-
return casted_tensor
27-
28-
29-
def clone(
30-
network: TRTNetwork,
31-
target: Target,
32-
source_ir: Optional[SourceIR],
33-
name: str,
34-
input: TRTTensor,
35-
) -> TRTTensor:
36-
if not isinstance(input, TRTTensor):
37-
raise RuntimeError(
38-
f"clone received input {input} that is not a TensorRT ITensor"
39-
)
40-
41-
LOGGER.debug(f"Evaluating clone on object with name: {name}")
42-
43-
return input
31+
# If cast is forced, insert identity layer regardless of whether the dtype
32+
# doesn't change
33+
if force_layer:
34+
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)
35+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
36+
target_str = ConverterRegistry.qualified_name_or_str(target)
37+
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"
38+
39+
identity_layer = network.add_identity(input)
40+
identity_layer.set_output_type(0, trt_dtype)
41+
identity_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
42+
return identity_layer.get_output(0)
43+
else:
44+
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
45+
return casted_tensor

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Any, Callable, Optional, Union
44

5+
import numpy as np
56
import tensorrt as trt
67
import torch
78
from torch.fx.node import Target
@@ -11,7 +12,6 @@
1112
broadcast,
1213
get_trt_tensor,
1314
set_layer_name,
14-
squeeze_left,
1515
)
1616
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor
1717
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
@@ -75,10 +75,10 @@ def convert_binary_elementwise(
7575
is_rhs_trt_tensor = False
7676

7777
if isinstance(lhs_val, TRTTensor):
78-
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH)
78+
lhs_dtype = lhs_val.dtype
7979
is_lhs_trt_tensor = True
8080
if isinstance(rhs_val, TRTTensor):
81-
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH)
81+
rhs_dtype = rhs_val.dtype
8282
is_rhs_trt_tensor = True
8383

8484
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
@@ -103,23 +103,13 @@ def convert_binary_elementwise(
103103
# dtype but we don't have a way to detect whether it makes sense for the
104104
# scalar to be float or half. Hence we go with the lhs dtype.
105105
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
106-
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
106+
rhs_val = np.array(
107+
[rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY)
108+
)
107109
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
108-
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)
109-
110-
# When lhs is scalar, and rhs has shape [1,], then currently the assert
111-
# will fail because lhs shape has fewer dimensions than rhs shape. This
112-
# happens when using implicit batch dimension, when we removed the 1st
113-
# dimension from input tensor, causing it to have shape [] - a scalar. We
114-
# fix it by reducing the rhs constant with a squeeze_left, so it becomes a
115-
# scalar too. More generally, we squeeze_left on input if it's a constant
116-
# tensor. This is safe because broadcast will pad dimensions on the left
117-
# (prepend) to make lhs and rhs shape compatible.
118-
if network.has_implicit_batch_dimension:
119-
if isinstance(lhs_val, torch.Tensor):
120-
lhs_val = squeeze_left(lhs_val)
121-
if isinstance(rhs_val, torch.Tensor):
122-
rhs_val = squeeze_left(rhs_val)
110+
lhs_val = np.array(
111+
[lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY)
112+
)
123113

124114
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
125115
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)

tests/py/dynamo/conversion/test_casts.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def forward(self, x):
3535
disable_passes=True,
3636
)
3737

38+
def test_clone_direct(self):
39+
class Clone(nn.Module):
40+
def forward(self, x):
41+
return x.clone()
42+
43+
inputs = [torch.randn((8, 2, 10))]
44+
self.run_test(
45+
Clone(),
46+
inputs,
47+
expected_ops={torch.ops.aten.clone.default},
48+
disable_passes=True,
49+
)
50+
3851

3952
class TestToCopyConverter(DispatchTestCase):
4053
def test_to_copy_half(self):
@@ -83,6 +96,20 @@ def forward(self, x):
8396
disable_passes=True,
8497
)
8598

99+
def test_to_copy_direct(self):
100+
class ToCopyFloat(nn.Module):
101+
def forward(self, x):
102+
return x.to(dtype=torch.float, copy=True)
103+
104+
inputs = [torch.rand((1, 3, 10)).float()]
105+
self.run_test(
106+
ToCopyFloat(),
107+
inputs,
108+
expected_ops={torch.ops.aten._to_copy.default},
109+
precision=torch.float,
110+
disable_passes=True,
111+
)
112+
86113

87114
if __name__ == "__main__":
88115
run_tests()

0 commit comments

Comments
 (0)