Skip to content

Commit cbd22ac

Browse files
committed
fix: Add special cases where input of graph is output
- TRT does not allow inputs of graphs to be outputs as well, however many of the scenarios encountered in real models can have this situation come up, especially in cases where the input is cloned or copied and then returned - The current converters will register these operators as a no-op, causing TRT engine building to fail on such inputs - Instead of requiring creation of an identity layer for every case of a clone or copy node, we instead check if that node is the only operator on a placeholder (input) and then insert the identity layer or not, accordingly - Coalesce implementations of clone and to_copy, which are effectively the same operator - Add test cases to validate new behavior - Add new boilerplate converter validator utility to support this case
1 parent 0e5a497 commit cbd22ac

File tree

5 files changed

+153
-50
lines changed

5 files changed

+153
-50
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Dict, Optional, Sequence, Tuple, Union
2+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
33

44
import tensorrt as trt
55
import torch
@@ -9,6 +9,7 @@
99
from torch_tensorrt.dynamo.conversion.converter_utils import (
1010
cast_int_int_div_trt_tensor,
1111
cast_trt_tensor,
12+
is_only_operator_on_placeholder,
1213
)
1314
from torch_tensorrt.fx.converters import acc_ops_converters
1415
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
@@ -366,29 +367,59 @@ def aten_ops_permute(
366367
)
367368

368369

369-
def to_copy_dtype_validator(to_copy_node: Node) -> bool:
370-
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}
371-
372-
# Validate input node has convertible kwargs
373-
if "dtype" in to_copy_node.kwargs:
374-
if to_copy_node.kwargs["dtype"] in allowed_casts:
375-
return True
370+
def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]:
371+
"""Return validator for to_copy node with placeholder restrictions"""
372+
373+
def validate_dtype(to_copy_node: Node) -> bool:
374+
"""Returns true if the to_copy node can be converted to TRT
375+
376+
Based on data type being casted to
377+
"""
378+
allowed_casts = {
379+
torch.float,
380+
torch.int32,
381+
torch.bool,
382+
torch.int8,
383+
torch.float16,
384+
}
385+
386+
# Validate input node has convertible kwargs
387+
if "dtype" in to_copy_node.kwargs:
388+
if to_copy_node.kwargs["dtype"] in allowed_casts:
389+
return True
390+
else:
391+
_LOGGER.debug(
392+
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
393+
)
394+
return False
376395
else:
377396
_LOGGER.debug(
378-
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
397+
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
379398
)
380399
return False
381-
else:
382-
_LOGGER.debug(
383-
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
400+
401+
def validator(to_copy_node: Node) -> bool:
402+
"""Returns true if the to_copy node can be converted to TRT
403+
and the placeholder restriction is satisfied
404+
"""
405+
# The placeholder restriction is satsfied if placeholder_only is the same
406+
# truth value as is_only_operator_on_placeholder(to_copy_node)
407+
return validate_dtype(to_copy_node) and (
408+
(not placeholder_only) ^ is_only_operator_on_placeholder(to_copy_node)
384409
)
385-
return False
410+
411+
return validator
386412

387413

388414
@dynamo_tensorrt_converter(
389-
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
415+
torch.ops.aten.clone.default,
416+
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
390417
) # type: ignore[misc]
391-
def aten_ops_to_copy_dtype(
418+
@dynamo_tensorrt_converter(
419+
torch.ops.aten._to_copy.default,
420+
capability_validator=to_copy_dtype_validator(placeholder_only=False),
421+
) # type: ignore[misc]
422+
def aten_ops_clone_copy_dtype(
392423
network: TRTNetwork,
393424
target: Target,
394425
args: Tuple[Argument, ...],
@@ -401,28 +432,41 @@ def aten_ops_to_copy_dtype(
401432
SourceIR.ATEN,
402433
name,
403434
args[0],
404-
kwargs["dtype"],
435+
kwargs.get("dtype", args[0].dtype),
436+
force_cast=False,
405437
)
406438

407439

408-
@dynamo_tensorrt_converter(torch.ops.aten.clone.default) # type: ignore[misc]
409-
def aten_ops_clone(
440+
@dynamo_tensorrt_converter(
441+
torch.ops.aten.clone.default,
442+
capability_validator=is_only_operator_on_placeholder,
443+
) # type: ignore[misc]
444+
@dynamo_tensorrt_converter(
445+
torch.ops.aten._to_copy.default,
446+
capability_validator=to_copy_dtype_validator(placeholder_only=True),
447+
) # type: ignore[misc]
448+
def aten_ops_clone_copy_placeholder(
410449
network: TRTNetwork,
411450
target: Target,
412451
args: Tuple[Argument, ...],
413452
kwargs: Dict[str, Argument],
414453
name: str,
415454
) -> Union[TRTTensor, Sequence[TRTTensor]]:
416-
return impl.cast.clone(
455+
# For clone or copy nodes where the input is also the output,
456+
# we need to force cast to ensure a layer is added to the TRT engine
457+
# since TRT engine inputs cannot also be TRT engine outputs
458+
return impl.cast.to_copy(
417459
network,
418460
target,
419461
SourceIR.ATEN,
420462
name,
421463
args[0],
464+
kwargs.get("dtype", args[0].dtype),
465+
force_cast=True,
422466
)
423467

424468

425-
@dynamo_tensorrt_converter(torch.ops.aten.expand.default)
469+
@dynamo_tensorrt_converter(torch.ops.aten.expand.default) # type: ignore[misc]
426470
def aten_ops_expand(
427471
network: TRTNetwork,
428472
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,49 @@ def get_node_name(node: torch.fx.Node) -> str:
4343
return node_name
4444

4545

46+
def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
47+
"""Detects whether a call_function node is the only operator on a placeholder"""
48+
# Returns true if the node operates on a placeholder and is a direct output
49+
return (
50+
node.op == "call_function"
51+
and any(
52+
arg.op == "placeholder"
53+
for arg in node.args
54+
if isinstance(arg, torch.fx.Node)
55+
)
56+
and any(user.op == "output" for user in list(node.users.keys()))
57+
)
58+
59+
4660
def dynamic_unsupported(node: torch.fx.Node) -> bool:
4761
# Validate that none of the inputs to the node have Dynamic shapes
4862
assert isinstance(
4963
node, torch.fx.Node
5064
), "Inputs to validator functions must be FX Nodes"
5165

5266
# Check node value itself
53-
if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False):
67+
if ("val" in node.meta) and getattr(
68+
node.meta["val"], "_has_symbolic_sizes_strides", False
69+
):
5470
return False
5571

5672
# Check node arguments individually
5773
if any(
58-
getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
74+
(
75+
("val" in arg.meta)
76+
and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
77+
)
5978
for arg in node.args
6079
if isinstance(arg, torch.fx.Node)
6180
):
6281
return False
6382

6483
# Check node keyword arguments individually
6584
if any(
66-
getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
85+
(
86+
("val" in kwarg.meta)
87+
and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
88+
)
6789
for kwarg in node.kwargs.values()
6890
if isinstance(kwarg, torch.fx.Node)
6991
):
@@ -80,9 +102,12 @@ def cast_trt_tensor(
80102
target: Target = "",
81103
source_ir: Optional[SourceIR] = None,
82104
) -> TRTTensor:
83-
"""
84-
Given a TRT Tensor, convert that Tensor to the specified dtype
105+
"""Given a TRT Tensor, convert that Tensor to the specified dtype
106+
85107
Adds an Identity layer to the network which performs the conversion
108+
if the input's dtype is different from the cast type. Otherwise returns
109+
input unchanged
110+
86111
Args:
87112
network (TRTNetwork): A TensorRT network
88113
input_val (TRTTensor): A TRT Tensor to cast to a new data type

py/torch_tensorrt/dynamo/conversion/impl/cast.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
67
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
8+
from torch_tensorrt.fx.converters.converter_utils import (
9+
Frameworks,
10+
unified_dtype_converter,
11+
)
712
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
813

914
LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -16,28 +21,25 @@ def to_copy(
1621
name: str,
1722
input: TRTTensor,
1823
dtype: TRTDataType,
24+
force_cast: bool = False,
1925
) -> TRTTensor:
2026
if not isinstance(input, TRTTensor):
2127
raise RuntimeError(
2228
f"to_copy received input {input} that is not a TensorRT ITensor"
2329
)
2430

25-
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
26-
return casted_tensor
27-
28-
29-
def clone(
30-
network: TRTNetwork,
31-
target: Target,
32-
source_ir: Optional[SourceIR],
33-
name: str,
34-
input: TRTTensor,
35-
) -> TRTTensor:
36-
if not isinstance(input, TRTTensor):
37-
raise RuntimeError(
38-
f"clone received input {input} that is not a TensorRT ITensor"
39-
)
40-
41-
LOGGER.debug(f"Evaluating clone on object with name: {name}")
42-
43-
return input
31+
# If cast is forced, insert identity layer regardless of whether the dtype
32+
# doesn't change
33+
if force_cast:
34+
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)
35+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
36+
target_str = ConverterRegistry.qualified_name_or_str(target)
37+
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"
38+
39+
identity_layer = network.add_identity(input)
40+
identity_layer.set_output_type(0, trt_dtype)
41+
identity_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
42+
return identity_layer.get_output(0)
43+
else:
44+
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
45+
return casted_tensor

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Any, Callable, Optional, Union
44

5+
import numpy as np
56
import tensorrt as trt
67
import torch
78
from torch.fx.node import Target
@@ -75,10 +76,10 @@ def convert_binary_elementwise(
7576
is_rhs_trt_tensor = False
7677

7778
if isinstance(lhs_val, TRTTensor):
78-
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH)
79+
lhs_dtype = lhs_val.dtype
7980
is_lhs_trt_tensor = True
8081
if isinstance(rhs_val, TRTTensor):
81-
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH)
82+
rhs_dtype = rhs_val.dtype
8283
is_rhs_trt_tensor = True
8384

8485
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
@@ -103,9 +104,13 @@ def convert_binary_elementwise(
103104
# dtype but we don't have a way to detect whether it makes sense for the
104105
# scalar to be float or half. Hence we go with the lhs dtype.
105106
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
106-
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
107+
rhs_val = np.array(
108+
[rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY)
109+
)
107110
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
108-
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)
111+
lhs_val = np.array(
112+
[lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY)
113+
)
109114

110115
# When lhs is scalar, and rhs has shape [1,], then currently the assert
111116
# will fail because lhs shape has fewer dimensions than rhs shape. This
@@ -116,9 +121,9 @@ def convert_binary_elementwise(
116121
# tensor. This is safe because broadcast will pad dimensions on the left
117122
# (prepend) to make lhs and rhs shape compatible.
118123
if network.has_implicit_batch_dimension:
119-
if isinstance(lhs_val, torch.Tensor):
124+
if isinstance(lhs_val, (torch.Tensor, np.ndarray)):
120125
lhs_val = squeeze_left(lhs_val)
121-
if isinstance(rhs_val, torch.Tensor):
126+
if isinstance(rhs_val, (torch.Tensor, np.ndarray)):
122127
rhs_val = squeeze_left(rhs_val)
123128

124129
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)

tests/py/dynamo/conversion/test_casts.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def forward(self, x):
3535
disable_passes=True,
3636
)
3737

38+
def test_clone_direct(self):
39+
class Clone(nn.Module):
40+
def forward(self, x):
41+
return x.clone()
42+
43+
inputs = [torch.randn((8, 2, 10))]
44+
self.run_test(
45+
Clone(),
46+
inputs,
47+
expected_ops={torch.ops.aten.clone.default},
48+
disable_passes=True,
49+
)
50+
3851

3952
class TestToCopyConverter(DispatchTestCase):
4053
def test_to_copy_half(self):
@@ -83,6 +96,20 @@ def forward(self, x):
8396
disable_passes=True,
8497
)
8598

99+
def test_to_copy_direct(self):
100+
class ToCopyFloat(nn.Module):
101+
def forward(self, x):
102+
return x.to(dtype=torch.float, copy=True)
103+
104+
inputs = [torch.rand((1, 3, 10)).float()]
105+
self.run_test(
106+
ToCopyFloat(),
107+
inputs,
108+
expected_ops={torch.ops.aten._to_copy.default},
109+
precision=torch.float,
110+
disable_passes=True,
111+
)
112+
86113

87114
if __name__ == "__main__":
88115
run_tests()

0 commit comments

Comments
 (0)