Skip to content

fix: Handle dynamic shapes in where ops #2853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 12, 2024
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def has_static_shapes(node: torch.fx.Node) -> bool:
return not _has_dynamic_shapes(node=node)


def has_dynamic_shapes(node: torch.fx.Node) -> bool:
def node_has_dynamic_shapes(node: torch.fx.Node) -> bool:
"""Returns True if a node has dynamic args, kwargs, or outputs"""
return _has_dynamic_shapes(node=node)

Expand Down Expand Up @@ -438,7 +438,7 @@ def __getitem__(
# 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True
if candidate.capability_validator(node) and (
self.assume_dynamic_shape_support
or not has_dynamic_shapes(node)
or not node_has_dynamic_shapes(node)
or candidate.supports_dynamic_shapes
):
return (
Expand All @@ -447,7 +447,7 @@ def __getitem__(
)
else:
# Assuming FX converters don't have dynamic shapes supported
if not has_dynamic_shapes(node):
if not node_has_dynamic_shapes(node):
return converters, calling_convention

raise KeyError(
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def aten_ops_split(
)


@dynamo_tensorrt_converter(torch.ops.aten.where.self)
@dynamo_tensorrt_converter(torch.ops.aten.where.self, supports_dynamic_shapes=True)
def aten_ops_where(
ctx: ConversionContext,
target: Target,
Expand Down
173 changes: 160 additions & 13 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
ConverterRegistry,
DynamoConverterImplSignature,
)
from torch_tensorrt.fx.converters.converter_utils import (
broadcast,
get_axes_for_reduce_op,
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor

from ..types import Shape, TRTDataType, TRTLayer, TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -177,9 +174,7 @@ def broadcast_to_same_shape(
Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape

"""
lhs_val, rhs_val = broadcast(
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
)
lhs_val, rhs_val = broadcast(ctx, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs")

lhs_val_shape = lhs_val.shape
rhs_val_shape = rhs_val.shape
Expand Down Expand Up @@ -216,11 +211,6 @@ def broadcast_to_same_shape(
return lhs_val, rhs_val


get_axes_for_reduce_op = functools.partial(
get_axes_for_reduce_op, has_implicit_batch_dimension=False
)


def extend_attr_to_tuple(
val: Any,
num_elem: int,
Expand Down Expand Up @@ -692,3 +682,160 @@ def calculate_strides(shape: Sequence[int]) -> Sequence[int]:
for i in range(len(shape) - 2, -1, -1):
strides[i] = strides[i + 1] * shape[i + 1]
return strides


def broadcast(
ctx: ConversionContext,
a: TRTTensor,
b: TRTTensor,
a_name: str,
b_name: str,
preset_diff: int = 0,
) -> Tuple[TRTTensor, TRTTensor]:
"""
Broadcast two TensorRT tensors to the same number of dimensions by
prepending 1s to the tensor with less number of dimensions.

Args:
ctx (ConversionContext): A ConversionContext containing the TensorRT network
a (TRTTensor): A TensorRT ITensor.
b (TRTTensor): A TensorRT ITensor.
a_name (str): Name of tensor a.
b_name (str): Name of tensor b.
preset_diff (int): The difference of number of dimensions after broadcast.
A positive number means after broadcast, tensor `a` would have `preset_diff`
more dimensions than `b`. This is used in matmul, since we need to broadcast
tensors but not always to the same number of dimension. The reason is that
matmul supports Matrix x Vector and in this case broadcasted vector should
have 1 less number of dimensions than the matrix tensor.

Returns:
Two TensorRT ITensors that are broadcasted to the same number of dimensions.
"""
a_shape = tuple(a.shape)
b_shape = tuple(b.shape)

diff = len(a_shape) - len(b_shape) - preset_diff
if diff > 0:
b = prepend_ones(ctx, b, f"{b_name}_broadcast", diff)
elif diff < 0:
a = prepend_ones(ctx, a, f"{a_name}_broadcast", -diff)

return a, b


def get_axes_for_reduce_op(
dim: Union[int, Sequence[int]],
) -> int:
"""
TensorRT reduce layer relies on the binary representation of axes to
determine which dims to reduce. For example, if we want to reduce on
dim 1 and 2 then axes should be 6(110).

Args:
dim (Union[int, Sequence[int]]): An integer or a sequence of integers
that will be used to generate axes for TensorRT.

Returns:
An integer which binary form can be used as axes for TensorRT reduce
layer.
"""
if isinstance(dim, int):
dim = (dim,)

axes = 0
for d in dim:
axes |= 1 << d

return axes


def has_dynamic_shape(shape: Shape) -> bool:
"""
Determine if the given shape has dynamic dim. i.e. if there're -1 in shape.

Args:
shape (Shape): Shape of a tensor. Essentially is a sequence of integers.

Returns:
A boolean value indicates whether there's dynamic dim in the shape.
"""
count = 0
for s in shape:
count += 1 if s == -1 else 0
return count > 0


def prepend_ones(
ctx: ConversionContext,
tensor: TRTTensor,
name: str,
num_prepend_ones: int,
) -> TRTTensor:
"""
Prepend 1s to the shape of TensorRT ITensor `tensor`.

Args:
ctx (ConversionContext): A ConversionContext containing the TensorRT network
tensor (TRTTensor): A TensorRT tensor.
name (str): Name of the TensorRT Shuffle layer which is used to prepend
1s.
num_prepend_ones (int): Number of 1s that will be prepend.

Returns:
A Tensorrt ITensor which contains the same value as `tensor` but with
more 1s prepended to the beginning of `tensor` shape.
"""
layer = ctx.net.add_shuffle(tensor)

# If there're dynamic dim in tensor's shape, we need to use shape layer to
# compute the final shape.
if has_dynamic_shape(tensor.shape):
tensor_shape_layer = ctx.net.add_shape(tensor)
tensor_shape = tensor_shape_layer.get_output(0)
tensor_shape = cast_trt_tensor(
ctx, tensor_shape, trt.int32, name + "shape_casted", "shape"
)
tensor_shape_layer.name = f"{name}_broadcast_orig_shape"
prepend_shape_layer = ctx.net.add_constant(
(num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32)
)
prepend_shape_layer.name = f"{name}_broadcast_prepend_ones"
reshape_dim_layer = ctx.net.add_concatenation(
[prepend_shape_layer.get_output(0), tensor_shape]
)
reshape_dim_layer.axis = 0
reshape_dim_layer.name = f"{name}_broadcast_final_shape"
layer.set_input(1, reshape_dim_layer.get_output(0))
else:
layer.reshape_dims = (1,) * num_prepend_ones + tuple(tensor.shape)

layer.name = name
return layer.get_output(0)


def set_layer_name(
layer: TRTLayer,
target: Union[Target, torch.nn.Module, str],
name: str,
source_ir: Optional[SourceIR] = None,
) -> None:
"""
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"

Args:
layer (TRTLayer): A TensorRT layer of which we want to set the name.
target (Target): A fx node.target or submodule. For call_function node, it's the function that
the node represents.
name (str): Consists of fx node.name with optional suffix.
source_ir: (Optional[SourceIR]): The IR producing the op.
"""

source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN

target_name = (
f"{source_ir}_ops.{target}"
if isinstance(target, str)
else f"{source_ir}_ops.{target.__name__}"
)
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
85 changes: 23 additions & 62 deletions py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
broadcastable,
cast_trt_tensor,
get_trt_tensor,
prepend_ones,
set_layer_name,
)
from torch_tensorrt.dynamo.conversion.impl.slice import expand
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.dynamo.conversion.impl.elementwise import ne
from torch_tensorrt.fx.types import TRTTensor


Expand All @@ -30,73 +32,32 @@ def where(
x_shape = list(input.shape)
y_shape = list(other.shape)
condition_shape = list(condition.shape)
max_shape_len = max(len(x_shape), len(y_shape), len(condition_shape))

output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))

# expand shape
if not isinstance(condition, TRTTensor):
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
if condition_shape != output_shape:
condition = (
condition.expand(output_shape)
if isinstance(condition, torch.Tensor)
else np.broadcast_to(condition, output_shape)
)
condition_val = get_trt_tensor(ctx, condition, f"{name}_condition")
else:
assert condition.dtype == trt.bool, "mask dtype is not bool!"
if condition_shape != output_shape:
condition_val = expand(
ctx, target, source_ir, f"{name}_expand", condition, output_shape
)
else:
condition_val = condition
condition = get_trt_tensor(ctx, condition, f"{name}_condition")

if condition.dtype != trt.bool:
condition = cast_trt_tensor(ctx, condition, trt.float32, f"{name}_cast")
condition = ne(ctx, target, source_ir, f"{name}_cond_zero", condition, 0)

diff = max_shape_len - len(condition_shape)
if diff > 0:
condition = prepend_ones(ctx, condition, f"{name}_condition_broadcast", diff)

if not isinstance(input, TRTTensor):
if x_shape != output_shape:
# special case where 1 element in input
if len(input.shape) == 0:
input = (
input.unsqueeze(0)
if isinstance(input, torch.Tensor)
else np.expand_dims(input, axis=0)
)
input = (
input.expand(output_shape)
if isinstance(input, torch.Tensor)
else np.broadcast_to(input, output_shape)
)
x_val = get_trt_tensor(ctx, input, f"{name}_x")
else:
x_val = input
if x_shape != output_shape:
x_val = expand(
ctx, target, source_ir, f"{name}_x_expand", input, output_shape
)
input = get_trt_tensor(ctx, input, f"{name}_x")
diff = max_shape_len - len(x_shape)
if diff > 0:
input = prepend_ones(ctx, input, f"{name}_input_broadcast", diff)

if not isinstance(other, TRTTensor):
if y_shape != output_shape:
# special case where 1 element in other
if len(other.shape) == 0:
other = (
other.unsqueeze(0)
if isinstance(other, torch.Tensor)
else np.expand_dims(other, axis=0)
)
other = (
other.expand(output_shape)
if isinstance(other, torch.Tensor)
else np.broadcast_to(other, output_shape)
)
y_val = get_trt_tensor(ctx, other, f"{name}_y")
else:
y_val = other
if y_shape != output_shape:
y_val = expand(
ctx, target, source_ir, f"{name}_y_expand", y_val, output_shape
)
other = get_trt_tensor(ctx, other, f"{name}_y")
diff = max_shape_len - len(y_shape)
if diff > 0:
other = prepend_ones(ctx, other, f"{name}_other_broadcast", diff)

return select(ctx, target, source_ir, name, x_val, y_val, condition_val)
return select(ctx, target, source_ir, name, input, other, condition)


def select(
Expand Down
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Sequence, Tuple

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt

if hasattr(trt, "__version__"):
TRTNetwork = trt.INetworkDefinition
TRTTensor = trt.tensorrt.ITensor
TRTLayer = trt.ILayer
TRTPluginFieldCollection = trt.PluginFieldCollection
TRTPlugin = trt.IPluginV2
TRTDataType = trt.DataType
TRTElementWiseOp = trt.ElementWiseOperation
else:
TRTNetwork = "trt.INetworkDefinition"
TRTTensor = "trt.tensorrt.ITensor"
TRTLayer = "trt.ILayer"
TRTPluginFieldCollection = "trt.PluginFieldCollection"
TRTPlugin = "trt.IPluginV2"
TRTDataType = "trt.DataType"
TRTElementWiseOp = "trt.ElementWiseOperation"

Shape = Sequence[int]
Loading
Loading