Skip to content

Commit 4da01bc

Browse files
committed
chore: Remove dependency from fx.converters.converter_utils
1 parent f29596c commit 4da01bc

File tree

2 files changed

+206
-12
lines changed

2 files changed

+206
-12
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 182 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,8 @@
1515
ConverterRegistry,
1616
DynamoConverterImplSignature,
1717
)
18-
from torch_tensorrt.fx.converters.converter_utils import ( # noqa: F401
19-
broadcast,
20-
get_axes_for_reduce_op,
21-
prepend_ones,
22-
set_layer_name,
23-
)
24-
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
18+
19+
from ..types import Shape, TRTDataType, TRTLayer, TRTNetwork, TRTTensor
2520

2621
_LOGGER: logging.Logger = logging.getLogger(__name__)
2722

@@ -218,11 +213,6 @@ def broadcast_to_same_shape(
218213
return lhs_val, rhs_val
219214

220215

221-
get_axes_for_reduce_op = functools.partial(
222-
get_axes_for_reduce_op, has_implicit_batch_dimension=False
223-
)
224-
225-
226216
def extend_attr_to_tuple(
227217
val: Any,
228218
num_elem: int,
@@ -694,3 +684,183 @@ def calculate_strides(shape: Sequence[int]) -> Sequence[int]:
694684
for i in range(len(shape) - 2, -1, -1):
695685
strides[i] = strides[i + 1] * shape[i + 1]
696686
return strides
687+
688+
689+
def broadcast(
690+
network: TRTNetwork,
691+
a: TRTTensor,
692+
b: TRTTensor,
693+
a_name: str,
694+
b_name: str,
695+
preset_diff: int = 0,
696+
) -> Tuple[TRTTensor, TRTTensor]:
697+
"""
698+
Broadcast two TensorRT tensors to the same number of dimensions by
699+
prepending 1s to the tensor with less number of dimensions.
700+
701+
Args:
702+
network (TRTNetwork): TensorRT network object.
703+
a (TRTTensor): A TensorRT ITensor.
704+
b (TRTTensor): A TensorRT ITensor.
705+
a_name (str): Name of tensor a.
706+
b_name (str): Name of tensor b.
707+
preset_diff (int): The difference of number of dimensions after broadcast.
708+
A positive number means after broadcast, tensor `a` would have `preset_diff`
709+
more dimensions than `b`. This is used in matmul, since we need to broadcast
710+
tensors but not always to the same number of dimension. The reason is that
711+
matmul supports Matrix x Vector and in this case broadcasted vector should
712+
have 1 less number of dimensions than the matrix tensor.
713+
714+
Returns:
715+
Two TensorRT ITensors that are broadcasted to the same number of dimensions.
716+
"""
717+
a_shape = tuple(a.shape)
718+
b_shape = tuple(b.shape)
719+
720+
diff = len(a_shape) - len(b_shape) - preset_diff
721+
if diff > 0:
722+
b = prepend_ones(network, b, f"{b_name}_broadcast", diff)
723+
elif diff < 0:
724+
a = prepend_ones(network, a, f"{a_name}_broadcast", -diff)
725+
726+
return a, b
727+
728+
729+
def get_axes_for_reduce_op(
730+
dim: Union[int, Sequence[int]],
731+
has_implicit_batch_dimension: bool = False,
732+
) -> int:
733+
"""
734+
TensorRT reduce layer relies on the binary representation of axes to
735+
determine which dims to reduce. For example, if we want to reduce on
736+
dim 1 and 2 then axes should be 6(110).
737+
738+
Args:
739+
dim (Union[int, Sequence[int]]): An integer or a sequence of integers
740+
that will be used to generate axes for TensorRT.
741+
has_implicit_batch_dimension (bool): Whether the TensorRT network is
742+
using implicit batch dimension.
743+
744+
Returns:
745+
An integer which binary form can be used as axes for TensorRT reduce
746+
layer.
747+
"""
748+
if isinstance(dim, int):
749+
dim = (dim,)
750+
751+
if has_implicit_batch_dimension:
752+
assert 0 not in dim, "Can't reduce over batch dimension when it's implicit."
753+
754+
axes = 0
755+
for d in dim:
756+
axes |= 1 << (d - (1 if has_implicit_batch_dimension else 0))
757+
758+
return axes
759+
760+
761+
def has_dynamic_shape(shape: Shape) -> bool:
762+
"""
763+
Determine if the given shape has dynamic dim. i.e. if there're -1 in shape.
764+
765+
Args:
766+
shape (Shape): Shape of a tensor. Essentially is a sequence of integers.
767+
768+
Returns:
769+
A boolean value indicates whether there's dynamic dim in the shape.
770+
"""
771+
count = 0
772+
for s in shape:
773+
count += 1 if s == -1 else 0
774+
return count > 0
775+
776+
777+
def type_cast(
778+
network: TRTNetwork,
779+
target: Target,
780+
name: str,
781+
input: TRTTensor,
782+
cast_type: TRTDataType,
783+
) -> TRTTensor:
784+
"""
785+
This function helps to cast the input type to cast_type
786+
"""
787+
layer_i = network.add_identity(input)
788+
layer_i.set_output_type(0, cast_type)
789+
set_layer_name(layer_i, target, f"{name}_dtype_change")
790+
return layer_i.get_output(0)
791+
792+
793+
def prepend_ones(
794+
network: TRTNetwork,
795+
tensor: TRTTensor,
796+
name: str,
797+
num_prepend_ones: int,
798+
) -> TRTTensor:
799+
"""
800+
Prepend 1s to the shape of TensorRT ITensor `tensor`.
801+
802+
Args:
803+
network (TRTNetwork): The TensorRT network that `tensor`
804+
belongs to.
805+
tensor (TRTTensor): A TensorRT tensor.
806+
name (str): Name of the TensorRT Shuffle layer which is used to prepend
807+
1s.
808+
num_prepend_ones (int): Number of 1s that will be prepend.
809+
810+
Returns:
811+
A Tensorrt ITensor which contains the same value as `tensor` but with
812+
more 1s prepended to the beginning of `tensor` shape.
813+
"""
814+
layer = network.add_shuffle(tensor)
815+
816+
# If there're dynamic dim in tensor's shape, we need to use shape layer to
817+
# compute the final shape.
818+
if has_dynamic_shape(tensor.shape):
819+
tensor_shape_layer = network.add_shape(tensor)
820+
tensor_shape = tensor_shape_layer.get_output(0)
821+
tensor_shape = type_cast(
822+
network, "shape", name + "shape_casted", tensor_shape, trt.int32
823+
)
824+
tensor_shape_layer.name = f"{name}_broadcast_orig_shape"
825+
prepend_shape_layer = network.add_constant(
826+
(num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32)
827+
)
828+
prepend_shape_layer.name = f"{name}_broadcast_prepend_ones"
829+
reshape_dim_layer = network.add_concatenation(
830+
[prepend_shape_layer.get_output(0), tensor_shape]
831+
)
832+
reshape_dim_layer.axis = 0
833+
reshape_dim_layer.name = f"{name}_broadcast_final_shape"
834+
layer.set_input(1, reshape_dim_layer.get_output(0))
835+
else:
836+
layer.reshape_dims = (1,) * num_prepend_ones + tuple(tensor.shape)
837+
838+
layer.name = name
839+
return layer.get_output(0)
840+
841+
842+
def set_layer_name(
843+
layer: TRTLayer,
844+
target: Union[Target, torch.nn.Module, str],
845+
name: str,
846+
source_ir: Optional[SourceIR] = None,
847+
) -> None:
848+
"""
849+
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
850+
851+
Args:
852+
layer (TRTLayer): A TensorRT layer of which we want to set the name.
853+
target (Target): A fx node.target or submodule. For call_function node, it's the function that
854+
the node represents.
855+
name (str): Consists of fx node.name with optional suffix.
856+
source_ir: (Optional[SourceIR]): The IR producing the op.
857+
"""
858+
859+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
860+
861+
target_name = (
862+
f"{source_ir}_ops.{target}"
863+
if isinstance(target, str)
864+
else f"{source_ir}_ops.{target.__name__}"
865+
)
866+
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"

py/torch_tensorrt/dynamo/types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Sequence, Tuple
2+
3+
# @manual=//deeplearning/trt/python:py_tensorrt
4+
import tensorrt as trt
5+
6+
if hasattr(trt, "__version__"):
7+
TRTNetwork = trt.INetworkDefinition
8+
TRTTensor = trt.tensorrt.ITensor
9+
TRTLayer = trt.ILayer
10+
TRTPluginFieldCollection = trt.PluginFieldCollection
11+
TRTPlugin = trt.IPluginV2
12+
TRTDataType = trt.DataType
13+
TRTElementWiseOp = trt.ElementWiseOperation
14+
else:
15+
TRTNetwork = "trt.INetworkDefinition"
16+
TRTTensor = "trt.tensorrt.ITensor"
17+
TRTLayer = "trt.ILayer"
18+
TRTPluginFieldCollection = "trt.PluginFieldCollection"
19+
TRTPlugin = "trt.IPluginV2"
20+
TRTDataType = "trt.DataType"
21+
TRTElementWiseOp = "trt.ElementWiseOperation"
22+
23+
Shape = Sequence[int]
24+
ShapeRange = Tuple[Shape, Shape, Shape]

0 commit comments

Comments
 (0)