Skip to content

Commit ac702b7

Browse files
authored
fix: Handle dynamic shapes in where ops (#2853)
1 parent 6042cea commit ac702b7

File tree

6 files changed

+279
-79
lines changed

6 files changed

+279
-79
lines changed

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def has_static_shapes(node: torch.fx.Node) -> bool:
9696
return not _has_dynamic_shapes(node=node)
9797

9898

99-
def has_dynamic_shapes(node: torch.fx.Node) -> bool:
99+
def node_has_dynamic_shapes(node: torch.fx.Node) -> bool:
100100
"""Returns True if a node has dynamic args, kwargs, or outputs"""
101101
return _has_dynamic_shapes(node=node)
102102

@@ -438,7 +438,7 @@ def __getitem__(
438438
# 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True
439439
if candidate.capability_validator(node) and (
440440
self.assume_dynamic_shape_support
441-
or not has_dynamic_shapes(node)
441+
or not node_has_dynamic_shapes(node)
442442
or candidate.supports_dynamic_shapes
443443
):
444444
return (
@@ -447,7 +447,7 @@ def __getitem__(
447447
)
448448
else:
449449
# Assuming FX converters don't have dynamic shapes supported
450-
if not has_dynamic_shapes(node):
450+
if not node_has_dynamic_shapes(node):
451451
return converters, calling_convention
452452

453453
raise KeyError(

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def aten_ops_split(
695695
)
696696

697697

698-
@dynamo_tensorrt_converter(torch.ops.aten.where.self)
698+
@dynamo_tensorrt_converter(torch.ops.aten.where.self, supports_dynamic_shapes=True)
699699
def aten_ops_where(
700700
ctx: ConversionContext,
701701
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 160 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
ConverterRegistry,
1616
DynamoConverterImplSignature,
1717
)
18-
from torch_tensorrt.fx.converters.converter_utils import (
19-
broadcast,
20-
get_axes_for_reduce_op,
21-
)
22-
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
18+
19+
from ..types import Shape, TRTDataType, TRTLayer, TRTTensor
2320

2421
_LOGGER: logging.Logger = logging.getLogger(__name__)
2522

@@ -177,9 +174,7 @@ def broadcast_to_same_shape(
177174
Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape
178175
179176
"""
180-
lhs_val, rhs_val = broadcast(
181-
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
182-
)
177+
lhs_val, rhs_val = broadcast(ctx, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs")
183178

184179
lhs_val_shape = lhs_val.shape
185180
rhs_val_shape = rhs_val.shape
@@ -216,11 +211,6 @@ def broadcast_to_same_shape(
216211
return lhs_val, rhs_val
217212

218213

219-
get_axes_for_reduce_op = functools.partial(
220-
get_axes_for_reduce_op, has_implicit_batch_dimension=False
221-
)
222-
223-
224214
def extend_attr_to_tuple(
225215
val: Any,
226216
num_elem: int,
@@ -692,3 +682,160 @@ def calculate_strides(shape: Sequence[int]) -> Sequence[int]:
692682
for i in range(len(shape) - 2, -1, -1):
693683
strides[i] = strides[i + 1] * shape[i + 1]
694684
return strides
685+
686+
687+
def broadcast(
688+
ctx: ConversionContext,
689+
a: TRTTensor,
690+
b: TRTTensor,
691+
a_name: str,
692+
b_name: str,
693+
preset_diff: int = 0,
694+
) -> Tuple[TRTTensor, TRTTensor]:
695+
"""
696+
Broadcast two TensorRT tensors to the same number of dimensions by
697+
prepending 1s to the tensor with less number of dimensions.
698+
699+
Args:
700+
ctx (ConversionContext): A ConversionContext containing the TensorRT network
701+
a (TRTTensor): A TensorRT ITensor.
702+
b (TRTTensor): A TensorRT ITensor.
703+
a_name (str): Name of tensor a.
704+
b_name (str): Name of tensor b.
705+
preset_diff (int): The difference of number of dimensions after broadcast.
706+
A positive number means after broadcast, tensor `a` would have `preset_diff`
707+
more dimensions than `b`. This is used in matmul, since we need to broadcast
708+
tensors but not always to the same number of dimension. The reason is that
709+
matmul supports Matrix x Vector and in this case broadcasted vector should
710+
have 1 less number of dimensions than the matrix tensor.
711+
712+
Returns:
713+
Two TensorRT ITensors that are broadcasted to the same number of dimensions.
714+
"""
715+
a_shape = tuple(a.shape)
716+
b_shape = tuple(b.shape)
717+
718+
diff = len(a_shape) - len(b_shape) - preset_diff
719+
if diff > 0:
720+
b = prepend_ones(ctx, b, f"{b_name}_broadcast", diff)
721+
elif diff < 0:
722+
a = prepend_ones(ctx, a, f"{a_name}_broadcast", -diff)
723+
724+
return a, b
725+
726+
727+
def get_axes_for_reduce_op(
728+
dim: Union[int, Sequence[int]],
729+
) -> int:
730+
"""
731+
TensorRT reduce layer relies on the binary representation of axes to
732+
determine which dims to reduce. For example, if we want to reduce on
733+
dim 1 and 2 then axes should be 6(110).
734+
735+
Args:
736+
dim (Union[int, Sequence[int]]): An integer or a sequence of integers
737+
that will be used to generate axes for TensorRT.
738+
739+
Returns:
740+
An integer which binary form can be used as axes for TensorRT reduce
741+
layer.
742+
"""
743+
if isinstance(dim, int):
744+
dim = (dim,)
745+
746+
axes = 0
747+
for d in dim:
748+
axes |= 1 << d
749+
750+
return axes
751+
752+
753+
def has_dynamic_shape(shape: Shape) -> bool:
754+
"""
755+
Determine if the given shape has dynamic dim. i.e. if there're -1 in shape.
756+
757+
Args:
758+
shape (Shape): Shape of a tensor. Essentially is a sequence of integers.
759+
760+
Returns:
761+
A boolean value indicates whether there's dynamic dim in the shape.
762+
"""
763+
count = 0
764+
for s in shape:
765+
count += 1 if s == -1 else 0
766+
return count > 0
767+
768+
769+
def prepend_ones(
770+
ctx: ConversionContext,
771+
tensor: TRTTensor,
772+
name: str,
773+
num_prepend_ones: int,
774+
) -> TRTTensor:
775+
"""
776+
Prepend 1s to the shape of TensorRT ITensor `tensor`.
777+
778+
Args:
779+
ctx (ConversionContext): A ConversionContext containing the TensorRT network
780+
tensor (TRTTensor): A TensorRT tensor.
781+
name (str): Name of the TensorRT Shuffle layer which is used to prepend
782+
1s.
783+
num_prepend_ones (int): Number of 1s that will be prepend.
784+
785+
Returns:
786+
A Tensorrt ITensor which contains the same value as `tensor` but with
787+
more 1s prepended to the beginning of `tensor` shape.
788+
"""
789+
layer = ctx.net.add_shuffle(tensor)
790+
791+
# If there're dynamic dim in tensor's shape, we need to use shape layer to
792+
# compute the final shape.
793+
if has_dynamic_shape(tensor.shape):
794+
tensor_shape_layer = ctx.net.add_shape(tensor)
795+
tensor_shape = tensor_shape_layer.get_output(0)
796+
tensor_shape = cast_trt_tensor(
797+
ctx, tensor_shape, trt.int32, name + "shape_casted", "shape"
798+
)
799+
tensor_shape_layer.name = f"{name}_broadcast_orig_shape"
800+
prepend_shape_layer = ctx.net.add_constant(
801+
(num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32)
802+
)
803+
prepend_shape_layer.name = f"{name}_broadcast_prepend_ones"
804+
reshape_dim_layer = ctx.net.add_concatenation(
805+
[prepend_shape_layer.get_output(0), tensor_shape]
806+
)
807+
reshape_dim_layer.axis = 0
808+
reshape_dim_layer.name = f"{name}_broadcast_final_shape"
809+
layer.set_input(1, reshape_dim_layer.get_output(0))
810+
else:
811+
layer.reshape_dims = (1,) * num_prepend_ones + tuple(tensor.shape)
812+
813+
layer.name = name
814+
return layer.get_output(0)
815+
816+
817+
def set_layer_name(
818+
layer: TRTLayer,
819+
target: Union[Target, torch.nn.Module, str],
820+
name: str,
821+
source_ir: Optional[SourceIR] = None,
822+
) -> None:
823+
"""
824+
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
825+
826+
Args:
827+
layer (TRTLayer): A TensorRT layer of which we want to set the name.
828+
target (Target): A fx node.target or submodule. For call_function node, it's the function that
829+
the node represents.
830+
name (str): Consists of fx node.name with optional suffix.
831+
source_ir: (Optional[SourceIR]): The IR producing the op.
832+
"""
833+
834+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
835+
836+
target_name = (
837+
f"{source_ir}_ops.{target}"
838+
if isinstance(target, str)
839+
else f"{source_ir}_ops.{target.__name__}"
840+
)
841+
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 23 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
99
from torch_tensorrt.dynamo.conversion.converter_utils import (
1010
broadcastable,
11+
cast_trt_tensor,
1112
get_trt_tensor,
13+
prepend_ones,
14+
set_layer_name,
1215
)
13-
from torch_tensorrt.dynamo.conversion.impl.slice import expand
14-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
16+
from torch_tensorrt.dynamo.conversion.impl.elementwise import ne
1517
from torch_tensorrt.fx.types import TRTTensor
1618

1719

@@ -30,73 +32,32 @@ def where(
3032
x_shape = list(input.shape)
3133
y_shape = list(other.shape)
3234
condition_shape = list(condition.shape)
35+
max_shape_len = max(len(x_shape), len(y_shape), len(condition_shape))
3336

34-
output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
35-
36-
# expand shape
3737
if not isinstance(condition, TRTTensor):
38-
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
39-
if condition_shape != output_shape:
40-
condition = (
41-
condition.expand(output_shape)
42-
if isinstance(condition, torch.Tensor)
43-
else np.broadcast_to(condition, output_shape)
44-
)
45-
condition_val = get_trt_tensor(ctx, condition, f"{name}_condition")
46-
else:
47-
assert condition.dtype == trt.bool, "mask dtype is not bool!"
48-
if condition_shape != output_shape:
49-
condition_val = expand(
50-
ctx, target, source_ir, f"{name}_expand", condition, output_shape
51-
)
52-
else:
53-
condition_val = condition
38+
condition = get_trt_tensor(ctx, condition, f"{name}_condition")
39+
40+
if condition.dtype != trt.bool:
41+
condition = cast_trt_tensor(ctx, condition, trt.float32, f"{name}_cast")
42+
condition = ne(ctx, target, source_ir, f"{name}_cond_zero", condition, 0)
43+
44+
diff = max_shape_len - len(condition_shape)
45+
if diff > 0:
46+
condition = prepend_ones(ctx, condition, f"{name}_condition_broadcast", diff)
5447

5548
if not isinstance(input, TRTTensor):
56-
if x_shape != output_shape:
57-
# special case where 1 element in input
58-
if len(input.shape) == 0:
59-
input = (
60-
input.unsqueeze(0)
61-
if isinstance(input, torch.Tensor)
62-
else np.expand_dims(input, axis=0)
63-
)
64-
input = (
65-
input.expand(output_shape)
66-
if isinstance(input, torch.Tensor)
67-
else np.broadcast_to(input, output_shape)
68-
)
69-
x_val = get_trt_tensor(ctx, input, f"{name}_x")
70-
else:
71-
x_val = input
72-
if x_shape != output_shape:
73-
x_val = expand(
74-
ctx, target, source_ir, f"{name}_x_expand", input, output_shape
75-
)
49+
input = get_trt_tensor(ctx, input, f"{name}_x")
50+
diff = max_shape_len - len(x_shape)
51+
if diff > 0:
52+
input = prepend_ones(ctx, input, f"{name}_input_broadcast", diff)
7653

7754
if not isinstance(other, TRTTensor):
78-
if y_shape != output_shape:
79-
# special case where 1 element in other
80-
if len(other.shape) == 0:
81-
other = (
82-
other.unsqueeze(0)
83-
if isinstance(other, torch.Tensor)
84-
else np.expand_dims(other, axis=0)
85-
)
86-
other = (
87-
other.expand(output_shape)
88-
if isinstance(other, torch.Tensor)
89-
else np.broadcast_to(other, output_shape)
90-
)
91-
y_val = get_trt_tensor(ctx, other, f"{name}_y")
92-
else:
93-
y_val = other
94-
if y_shape != output_shape:
95-
y_val = expand(
96-
ctx, target, source_ir, f"{name}_y_expand", y_val, output_shape
97-
)
55+
other = get_trt_tensor(ctx, other, f"{name}_y")
56+
diff = max_shape_len - len(y_shape)
57+
if diff > 0:
58+
other = prepend_ones(ctx, other, f"{name}_other_broadcast", diff)
9859

99-
return select(ctx, target, source_ir, name, x_val, y_val, condition_val)
60+
return select(ctx, target, source_ir, name, input, other, condition)
10061

10162

10263
def select(

py/torch_tensorrt/dynamo/types.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Sequence, Tuple
2+
3+
# @manual=//deeplearning/trt/python:py_tensorrt
4+
import tensorrt as trt
5+
6+
if hasattr(trt, "__version__"):
7+
TRTNetwork = trt.INetworkDefinition
8+
TRTTensor = trt.tensorrt.ITensor
9+
TRTLayer = trt.ILayer
10+
TRTPluginFieldCollection = trt.PluginFieldCollection
11+
TRTPlugin = trt.IPluginV2
12+
TRTDataType = trt.DataType
13+
TRTElementWiseOp = trt.ElementWiseOperation
14+
else:
15+
TRTNetwork = "trt.INetworkDefinition"
16+
TRTTensor = "trt.tensorrt.ITensor"
17+
TRTLayer = "trt.ILayer"
18+
TRTPluginFieldCollection = "trt.PluginFieldCollection"
19+
TRTPlugin = "trt.IPluginV2"
20+
TRTDataType = "trt.DataType"
21+
TRTElementWiseOp = "trt.ElementWiseOperation"
22+
23+
Shape = Sequence[int]

0 commit comments

Comments
 (0)