Skip to content

fix: Add special cases for clone and to_copy where input of graph is output #2265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 65 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion.converter_utils import (
is_only_operator_on_placeholder,
)
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from .converter_registry import dynamo_tensorrt_converter
Expand Down Expand Up @@ -441,29 +444,59 @@ def aten_ops_permute(
)


def to_copy_dtype_validator(to_copy_node: Node) -> bool:
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}

# Validate input node has convertible kwargs
if "dtype" in to_copy_node.kwargs:
if to_copy_node.kwargs["dtype"] in allowed_casts:
return True
def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]:
"""Return validator for to_copy node with placeholder restrictions"""

def validate_dtype(to_copy_node: Node) -> bool:
"""Returns true if the to_copy node can be converted to TRT

Based on data type being casted to
"""
allowed_casts = {
torch.float,
torch.int32,
torch.bool,
torch.int8,
torch.float16,
}

# Validate input node has convertible kwargs
if "dtype" in to_copy_node.kwargs:
if to_copy_node.kwargs["dtype"] in allowed_casts:
return True
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
)
return False
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
)
return False
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"

def validator(to_copy_node: Node) -> bool:
"""Returns true if the to_copy node can be converted to TRT
and the placeholder restriction is satisfied
"""
# The placeholder restriction is satsfied if placeholder_only is the same
# truth value as is_only_operator_on_placeholder(to_copy_node)
return validate_dtype(to_copy_node) and (
(not placeholder_only) ^ is_only_operator_on_placeholder(to_copy_node)
)
return False

return validator


@dynamo_tensorrt_converter(
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
torch.ops.aten.clone.default,
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
) # type: ignore[misc]
def aten_ops_to_copy_dtype(
@dynamo_tensorrt_converter(
torch.ops.aten._to_copy.default,
capability_validator=to_copy_dtype_validator(placeholder_only=False),
) # type: ignore[misc]
def aten_ops_clone_copy_dtype(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -476,24 +509,37 @@ def aten_ops_to_copy_dtype(
SourceIR.ATEN,
name,
args[0],
kwargs["dtype"],
kwargs.get("dtype", args[0].dtype),
force_layer=False,
)


@dynamo_tensorrt_converter(torch.ops.aten.clone.default) # type: ignore[misc]
def aten_ops_clone(
@dynamo_tensorrt_converter(
torch.ops.aten.clone.default,
capability_validator=is_only_operator_on_placeholder,
) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten._to_copy.default,
capability_validator=to_copy_dtype_validator(placeholder_only=True),
) # type: ignore[misc]
def aten_ops_clone_copy_placeholder(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.cast.clone(
# For clone or copy nodes where the input is also the output,
# we need to force cast to ensure a layer is added to the TRT engine
# since TRT engine inputs cannot also be TRT engine outputs
return impl.cast.to_copy(
network,
target,
SourceIR.ATEN,
name,
args[0],
kwargs.get("dtype", args[0].dtype),
force_layer=True,
)


Expand Down
37 changes: 31 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,49 @@ def get_node_name(node: torch.fx.Node) -> str:
return node_name


def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
"""Detects whether a call_function node is the only operator on a placeholder"""
# Returns true if the node operates on a placeholder and is a direct output
return (
node.op == "call_function"
and any(
arg.op == "placeholder"
for arg in node.args
if isinstance(arg, torch.fx.Node)
)
and any(user.op == "output" for user in list(node.users.keys()))
)


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there any used here? I believe that the above is for to_copy and clone cases having one node arg only? Or are there different cases for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any is used here because we are checking if any of the users of the node are outputs, meaning that the node is the only function between a placeholder and an output. We are effectively search for subgraphs where an input is followed by a function is followed by an output.

def dynamic_unsupported(node: torch.fx.Node) -> bool:
# Validate that none of the inputs to the node have Dynamic shapes
assert isinstance(
node, torch.fx.Node
), "Inputs to validator functions must be FX Nodes"

# Check node value itself
if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False):
if ("val" in node.meta) and getattr(
node.meta["val"], "_has_symbolic_sizes_strides", False
):
return False

# Check node arguments individually
if any(
getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
(
("val" in arg.meta)
and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
)
for arg in node.args
if isinstance(arg, torch.fx.Node)
):
return False

# Check node keyword arguments individually
if any(
getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
(
("val" in kwarg.meta)
and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
)
for kwarg in node.kwargs.values()
if isinstance(kwarg, torch.fx.Node)
):
Expand All @@ -82,9 +104,12 @@ def cast_trt_tensor(
target: Target = "",
source_ir: Optional[SourceIR] = None,
) -> TRTTensor:
"""
Given a TRT Tensor, convert that Tensor to the specified dtype
"""Given a TRT Tensor, convert that Tensor to the specified dtype

Adds an Identity layer to the network which performs the conversion
if the input's dtype is different from the cast type. Otherwise returns
input unchanged

Args:
network (TRTNetwork): A TensorRT network
input_val (TRTTensor): A TRT Tensor to cast to a new data type
Expand Down Expand Up @@ -191,7 +216,7 @@ def extend_attr_to_tuple(
if isinstance(val, tuple):
return val
else:
raise AssertionError(f"Could not extend attribute {val}")
raise AssertionError(f"Object {val} could not be extended to tuple")


def cast_int_or_float_to_bool(
Expand Down
40 changes: 21 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/impl/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import (
Frameworks,
unified_dtype_converter,
)
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor

LOGGER: logging.Logger = logging.getLogger(__name__)
Expand All @@ -16,28 +21,25 @@ def to_copy(
name: str,
input: TRTTensor,
dtype: TRTDataType,
force_layer: bool = False,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"to_copy received input {input} that is not a TensorRT ITensor"
)

casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
return casted_tensor


def clone(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"clone received input {input} that is not a TensorRT ITensor"
)

LOGGER.debug(f"Evaluating clone on object with name: {name}")

return input
# If cast is forced, insert identity layer regardless of whether the dtype
# doesn't change
if force_layer:
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
target_str = ConverterRegistry.qualified_name_or_str(target)
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"

identity_layer = network.add_identity(input)
identity_layer.set_output_type(0, trt_dtype)
identity_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
return identity_layer.get_output(0)
else:
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
return casted_tensor
32 changes: 9 additions & 23 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
cast_trt_tensor,
get_trt_tensor,
)
from torch_tensorrt.fx.converters.converter_utils import (
broadcast,
set_layer_name,
squeeze_left,
)
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

Expand Down Expand Up @@ -96,10 +92,10 @@ def convert_binary_elementwise(
is_rhs_trt_tensor = False

if isinstance(lhs_val, TRTTensor):
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY)
lhs_dtype = lhs_val.dtype
is_lhs_trt_tensor = True
if isinstance(rhs_val, TRTTensor):
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY)
rhs_dtype = rhs_val.dtype
is_rhs_trt_tensor = True

if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
Expand All @@ -124,23 +120,13 @@ def convert_binary_elementwise(
# dtype but we don't have a way to detect whether it makes sense for the
# scalar to be float or half. Hence we go with the lhs dtype.
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
rhs_val = np.array([rhs_val], dtype=lhs_dtype)
rhs_val = np.array(
[rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY)
)
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
lhs_val = np.array([lhs_val], dtype=rhs_dtype)

# When lhs is scalar, and rhs has shape [1,], then currently the assert
# will fail because lhs shape has fewer dimensions than rhs shape. This
# happens when using implicit batch dimension, when we removed the 1st
# dimension from input tensor, causing it to have shape [] - a scalar. We
# fix it by reducing the rhs constant with a squeeze_left, so it becomes a
# scalar too. More generally, we squeeze_left on input if it's a constant
# tensor. This is safe because broadcast will pad dimensions on the left
# (prepend) to make lhs and rhs shape compatible.
if network.has_implicit_batch_dimension:
if isinstance(lhs_val, torch.Tensor):
lhs_val = squeeze_left(lhs_val)
if isinstance(rhs_val, torch.Tensor):
rhs_val = squeeze_left(rhs_val)
lhs_val = np.array(
[lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY)
)

lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)
Expand Down
27 changes: 27 additions & 0 deletions tests/py/dynamo/conversion/test_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def forward(self, x):
disable_passes=True,
)

def test_clone_direct(self):
class Clone(nn.Module):
def forward(self, x):
return x.clone()

inputs = [torch.randn((8, 2, 10))]
self.run_test(
Clone(),
inputs,
expected_ops={torch.ops.aten.clone.default},
disable_passes=True,
)


class TestToCopyConverter(DispatchTestCase):
def test_to_copy_half(self):
Expand Down Expand Up @@ -83,6 +96,20 @@ def forward(self, x):
disable_passes=True,
)

def test_to_copy_direct(self):
class ToCopyFloat(nn.Module):
def forward(self, x):
return x.to(dtype=torch.float, copy=True)

inputs = [torch.rand((1, 3, 10)).float()]
self.run_test(
ToCopyFloat(),
inputs,
expected_ops={torch.ops.aten._to_copy.default},
precision=torch.float,
disable_passes=True,
)


if __name__ == "__main__":
run_tests()