Skip to content

Commit ac007ce

Browse files
authored
fix: Add special cases for clone and to_copy where input of graph is output (#2265)
1 parent 8ebb599 commit ac007ce

File tree

5 files changed

+153
-67
lines changed

5 files changed

+153
-67
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import logging
2-
from typing import Any, Dict, Optional, Sequence, Tuple, Union
2+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
33

44
import torch
55
from torch.fx.node import Argument, Node, Target
66
from torch_tensorrt.dynamo._SourceIR import SourceIR
77
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
is_only_operator_on_placeholder,
10+
)
811
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
912

1013
from .converter_registry import dynamo_tensorrt_converter
@@ -441,29 +444,59 @@ def aten_ops_permute(
441444
)
442445

443446

444-
def to_copy_dtype_validator(to_copy_node: Node) -> bool:
445-
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}
446-
447-
# Validate input node has convertible kwargs
448-
if "dtype" in to_copy_node.kwargs:
449-
if to_copy_node.kwargs["dtype"] in allowed_casts:
450-
return True
447+
def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]:
448+
"""Return validator for to_copy node with placeholder restrictions"""
449+
450+
def validate_dtype(to_copy_node: Node) -> bool:
451+
"""Returns true if the to_copy node can be converted to TRT
452+
453+
Based on data type being casted to
454+
"""
455+
allowed_casts = {
456+
torch.float,
457+
torch.int32,
458+
torch.bool,
459+
torch.int8,
460+
torch.float16,
461+
}
462+
463+
# Validate input node has convertible kwargs
464+
if "dtype" in to_copy_node.kwargs:
465+
if to_copy_node.kwargs["dtype"] in allowed_casts:
466+
return True
467+
else:
468+
_LOGGER.debug(
469+
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
470+
)
471+
return False
451472
else:
452473
_LOGGER.debug(
453-
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
474+
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
454475
)
455476
return False
456-
else:
457-
_LOGGER.debug(
458-
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
477+
478+
def validator(to_copy_node: Node) -> bool:
479+
"""Returns true if the to_copy node can be converted to TRT
480+
and the placeholder restriction is satisfied
481+
"""
482+
# The placeholder restriction is satsfied if placeholder_only is the same
483+
# truth value as is_only_operator_on_placeholder(to_copy_node)
484+
return validate_dtype(to_copy_node) and (
485+
(not placeholder_only) ^ is_only_operator_on_placeholder(to_copy_node)
459486
)
460-
return False
487+
488+
return validator
461489

462490

463491
@dynamo_tensorrt_converter(
464-
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
492+
torch.ops.aten.clone.default,
493+
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
465494
) # type: ignore[misc]
466-
def aten_ops_to_copy_dtype(
495+
@dynamo_tensorrt_converter(
496+
torch.ops.aten._to_copy.default,
497+
capability_validator=to_copy_dtype_validator(placeholder_only=False),
498+
) # type: ignore[misc]
499+
def aten_ops_clone_copy_dtype(
467500
network: TRTNetwork,
468501
target: Target,
469502
args: Tuple[Argument, ...],
@@ -476,24 +509,37 @@ def aten_ops_to_copy_dtype(
476509
SourceIR.ATEN,
477510
name,
478511
args[0],
479-
kwargs["dtype"],
512+
kwargs.get("dtype", args[0].dtype),
513+
force_layer=False,
480514
)
481515

482516

483-
@dynamo_tensorrt_converter(torch.ops.aten.clone.default) # type: ignore[misc]
484-
def aten_ops_clone(
517+
@dynamo_tensorrt_converter(
518+
torch.ops.aten.clone.default,
519+
capability_validator=is_only_operator_on_placeholder,
520+
) # type: ignore[misc]
521+
@dynamo_tensorrt_converter(
522+
torch.ops.aten._to_copy.default,
523+
capability_validator=to_copy_dtype_validator(placeholder_only=True),
524+
) # type: ignore[misc]
525+
def aten_ops_clone_copy_placeholder(
485526
network: TRTNetwork,
486527
target: Target,
487528
args: Tuple[Argument, ...],
488529
kwargs: Dict[str, Argument],
489530
name: str,
490531
) -> Union[TRTTensor, Sequence[TRTTensor]]:
491-
return impl.cast.clone(
532+
# For clone or copy nodes where the input is also the output,
533+
# we need to force cast to ensure a layer is added to the TRT engine
534+
# since TRT engine inputs cannot also be TRT engine outputs
535+
return impl.cast.to_copy(
492536
network,
493537
target,
494538
SourceIR.ATEN,
495539
name,
496540
args[0],
541+
kwargs.get("dtype", args[0].dtype),
542+
force_layer=True,
497543
)
498544

499545

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,49 @@ def get_node_name(node: torch.fx.Node) -> str:
4545
return node_name
4646

4747

48+
def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
49+
"""Detects whether a call_function node is the only operator on a placeholder"""
50+
# Returns true if the node operates on a placeholder and is a direct output
51+
return (
52+
node.op == "call_function"
53+
and any(
54+
arg.op == "placeholder"
55+
for arg in node.args
56+
if isinstance(arg, torch.fx.Node)
57+
)
58+
and any(user.op == "output" for user in list(node.users.keys()))
59+
)
60+
61+
4862
def dynamic_unsupported(node: torch.fx.Node) -> bool:
4963
# Validate that none of the inputs to the node have Dynamic shapes
5064
assert isinstance(
5165
node, torch.fx.Node
5266
), "Inputs to validator functions must be FX Nodes"
5367

5468
# Check node value itself
55-
if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False):
69+
if ("val" in node.meta) and getattr(
70+
node.meta["val"], "_has_symbolic_sizes_strides", False
71+
):
5672
return False
5773

5874
# Check node arguments individually
5975
if any(
60-
getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
76+
(
77+
("val" in arg.meta)
78+
and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
79+
)
6180
for arg in node.args
6281
if isinstance(arg, torch.fx.Node)
6382
):
6483
return False
6584

6685
# Check node keyword arguments individually
6786
if any(
68-
getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
87+
(
88+
("val" in kwarg.meta)
89+
and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
90+
)
6991
for kwarg in node.kwargs.values()
7092
if isinstance(kwarg, torch.fx.Node)
7193
):
@@ -82,9 +104,12 @@ def cast_trt_tensor(
82104
target: Target = "",
83105
source_ir: Optional[SourceIR] = None,
84106
) -> TRTTensor:
85-
"""
86-
Given a TRT Tensor, convert that Tensor to the specified dtype
107+
"""Given a TRT Tensor, convert that Tensor to the specified dtype
108+
87109
Adds an Identity layer to the network which performs the conversion
110+
if the input's dtype is different from the cast type. Otherwise returns
111+
input unchanged
112+
88113
Args:
89114
network (TRTNetwork): A TensorRT network
90115
input_val (TRTTensor): A TRT Tensor to cast to a new data type
@@ -191,7 +216,7 @@ def extend_attr_to_tuple(
191216
if isinstance(val, tuple):
192217
return val
193218
else:
194-
raise AssertionError(f"Could not extend attribute {val}")
219+
raise AssertionError(f"Object {val} could not be extended to tuple")
195220

196221

197222
def cast_int_or_float_to_bool(

py/torch_tensorrt/dynamo/conversion/impl/cast.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
67
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
8+
from torch_tensorrt.fx.converters.converter_utils import (
9+
Frameworks,
10+
unified_dtype_converter,
11+
)
712
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
813

914
LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -16,28 +21,25 @@ def to_copy(
1621
name: str,
1722
input: TRTTensor,
1823
dtype: TRTDataType,
24+
force_layer: bool = False,
1925
) -> TRTTensor:
2026
if not isinstance(input, TRTTensor):
2127
raise RuntimeError(
2228
f"to_copy received input {input} that is not a TensorRT ITensor"
2329
)
2430

25-
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
26-
return casted_tensor
27-
28-
29-
def clone(
30-
network: TRTNetwork,
31-
target: Target,
32-
source_ir: Optional[SourceIR],
33-
name: str,
34-
input: TRTTensor,
35-
) -> TRTTensor:
36-
if not isinstance(input, TRTTensor):
37-
raise RuntimeError(
38-
f"clone received input {input} that is not a TensorRT ITensor"
39-
)
40-
41-
LOGGER.debug(f"Evaluating clone on object with name: {name}")
42-
43-
return input
31+
# If cast is forced, insert identity layer regardless of whether the dtype
32+
# doesn't change
33+
if force_layer:
34+
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)
35+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
36+
target_str = ConverterRegistry.qualified_name_or_str(target)
37+
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"
38+
39+
identity_layer = network.add_identity(input)
40+
identity_layer.set_output_type(0, trt_dtype)
41+
identity_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
42+
return identity_layer.get_output(0)
43+
else:
44+
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
45+
return casted_tensor

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
cast_trt_tensor,
1212
get_trt_tensor,
1313
)
14-
from torch_tensorrt.fx.converters.converter_utils import (
15-
broadcast,
16-
set_layer_name,
17-
squeeze_left,
18-
)
14+
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
1915
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor
2016
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2117

@@ -96,10 +92,10 @@ def convert_binary_elementwise(
9692
is_rhs_trt_tensor = False
9793

9894
if isinstance(lhs_val, TRTTensor):
99-
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY)
95+
lhs_dtype = lhs_val.dtype
10096
is_lhs_trt_tensor = True
10197
if isinstance(rhs_val, TRTTensor):
102-
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY)
98+
rhs_dtype = rhs_val.dtype
10399
is_rhs_trt_tensor = True
104100

105101
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
@@ -124,23 +120,13 @@ def convert_binary_elementwise(
124120
# dtype but we don't have a way to detect whether it makes sense for the
125121
# scalar to be float or half. Hence we go with the lhs dtype.
126122
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
127-
rhs_val = np.array([rhs_val], dtype=lhs_dtype)
123+
rhs_val = np.array(
124+
[rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY)
125+
)
128126
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
129-
lhs_val = np.array([lhs_val], dtype=rhs_dtype)
130-
131-
# When lhs is scalar, and rhs has shape [1,], then currently the assert
132-
# will fail because lhs shape has fewer dimensions than rhs shape. This
133-
# happens when using implicit batch dimension, when we removed the 1st
134-
# dimension from input tensor, causing it to have shape [] - a scalar. We
135-
# fix it by reducing the rhs constant with a squeeze_left, so it becomes a
136-
# scalar too. More generally, we squeeze_left on input if it's a constant
137-
# tensor. This is safe because broadcast will pad dimensions on the left
138-
# (prepend) to make lhs and rhs shape compatible.
139-
if network.has_implicit_batch_dimension:
140-
if isinstance(lhs_val, torch.Tensor):
141-
lhs_val = squeeze_left(lhs_val)
142-
if isinstance(rhs_val, torch.Tensor):
143-
rhs_val = squeeze_left(rhs_val)
127+
lhs_val = np.array(
128+
[lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY)
129+
)
144130

145131
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
146132
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)

tests/py/dynamo/conversion/test_casts.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def forward(self, x):
3535
disable_passes=True,
3636
)
3737

38+
def test_clone_direct(self):
39+
class Clone(nn.Module):
40+
def forward(self, x):
41+
return x.clone()
42+
43+
inputs = [torch.randn((8, 2, 10))]
44+
self.run_test(
45+
Clone(),
46+
inputs,
47+
expected_ops={torch.ops.aten.clone.default},
48+
disable_passes=True,
49+
)
50+
3851

3952
class TestToCopyConverter(DispatchTestCase):
4053
def test_to_copy_half(self):
@@ -83,6 +96,20 @@ def forward(self, x):
8396
disable_passes=True,
8497
)
8598

99+
def test_to_copy_direct(self):
100+
class ToCopyFloat(nn.Module):
101+
def forward(self, x):
102+
return x.to(dtype=torch.float, copy=True)
103+
104+
inputs = [torch.rand((1, 3, 10)).float()]
105+
self.run_test(
106+
ToCopyFloat(),
107+
inputs,
108+
expected_ops={torch.ops.aten._to_copy.default},
109+
precision=torch.float,
110+
disable_passes=True,
111+
)
112+
86113

87114
if __name__ == "__main__":
88115
run_tests()

0 commit comments

Comments
 (0)