Skip to content

Commit 68f08ea

Browse files
committed
chore: Remove dependency from fx.converters.converter_utils
1 parent 7827fa4 commit 68f08ea

File tree

2 files changed

+206
-12
lines changed

2 files changed

+206
-12
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 182 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,8 @@
1515
ConverterRegistry,
1616
DynamoConverterImplSignature,
1717
)
18-
from torch_tensorrt.fx.converters.converter_utils import ( # noqa: F401
19-
broadcast,
20-
get_axes_for_reduce_op,
21-
prepend_ones,
22-
set_layer_name,
23-
)
24-
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
18+
19+
from ..types import Shape, TRTDataType, TRTLayer, TRTNetwork, TRTTensor
2520

2621
_LOGGER: logging.Logger = logging.getLogger(__name__)
2722

@@ -218,11 +213,6 @@ def broadcast_to_same_shape(
218213
return lhs_val, rhs_val
219214

220215

221-
get_axes_for_reduce_op = functools.partial(
222-
get_axes_for_reduce_op, has_implicit_batch_dimension=False
223-
)
224-
225-
226216
def extend_attr_to_tuple(
227217
val: Any,
228218
num_elem: int,
@@ -686,3 +676,183 @@ def calculate_strides(shape: Sequence[int]) -> Sequence[int]:
686676
for i in range(len(shape) - 2, -1, -1):
687677
strides[i] = strides[i + 1] * shape[i + 1]
688678
return strides
679+
680+
681+
def broadcast(
682+
network: TRTNetwork,
683+
a: TRTTensor,
684+
b: TRTTensor,
685+
a_name: str,
686+
b_name: str,
687+
preset_diff: int = 0,
688+
) -> Tuple[TRTTensor, TRTTensor]:
689+
"""
690+
Broadcast two TensorRT tensors to the same number of dimensions by
691+
prepending 1s to the tensor with less number of dimensions.
692+
693+
Args:
694+
network (TRTNetwork): TensorRT network object.
695+
a (TRTTensor): A TensorRT ITensor.
696+
b (TRTTensor): A TensorRT ITensor.
697+
a_name (str): Name of tensor a.
698+
b_name (str): Name of tensor b.
699+
preset_diff (int): The difference of number of dimensions after broadcast.
700+
A positive number means after broadcast, tensor `a` would have `preset_diff`
701+
more dimensions than `b`. This is used in matmul, since we need to broadcast
702+
tensors but not always to the same number of dimension. The reason is that
703+
matmul supports Matrix x Vector and in this case broadcasted vector should
704+
have 1 less number of dimensions than the matrix tensor.
705+
706+
Returns:
707+
Two TensorRT ITensors that are broadcasted to the same number of dimensions.
708+
"""
709+
a_shape = tuple(a.shape)
710+
b_shape = tuple(b.shape)
711+
712+
diff = len(a_shape) - len(b_shape) - preset_diff
713+
if diff > 0:
714+
b = prepend_ones(network, b, f"{b_name}_broadcast", diff)
715+
elif diff < 0:
716+
a = prepend_ones(network, a, f"{a_name}_broadcast", -diff)
717+
718+
return a, b
719+
720+
721+
def get_axes_for_reduce_op(
722+
dim: Union[int, Sequence[int]],
723+
has_implicit_batch_dimension: bool = False,
724+
) -> int:
725+
"""
726+
TensorRT reduce layer relies on the binary representation of axes to
727+
determine which dims to reduce. For example, if we want to reduce on
728+
dim 1 and 2 then axes should be 6(110).
729+
730+
Args:
731+
dim (Union[int, Sequence[int]]): An integer or a sequence of integers
732+
that will be used to generate axes for TensorRT.
733+
has_implicit_batch_dimension (bool): Whether the TensorRT network is
734+
using implicit batch dimension.
735+
736+
Returns:
737+
An integer which binary form can be used as axes for TensorRT reduce
738+
layer.
739+
"""
740+
if isinstance(dim, int):
741+
dim = (dim,)
742+
743+
if has_implicit_batch_dimension:
744+
assert 0 not in dim, "Can't reduce over batch dimension when it's implicit."
745+
746+
axes = 0
747+
for d in dim:
748+
axes |= 1 << (d - (1 if has_implicit_batch_dimension else 0))
749+
750+
return axes
751+
752+
753+
def has_dynamic_shape(shape: Shape) -> bool:
754+
"""
755+
Determine if the given shape has dynamic dim. i.e. if there're -1 in shape.
756+
757+
Args:
758+
shape (Shape): Shape of a tensor. Essentially is a sequence of integers.
759+
760+
Returns:
761+
A boolean value indicates whether there's dynamic dim in the shape.
762+
"""
763+
count = 0
764+
for s in shape:
765+
count += 1 if s == -1 else 0
766+
return count > 0
767+
768+
769+
def type_cast(
770+
network: TRTNetwork,
771+
target: Target,
772+
name: str,
773+
input: TRTTensor,
774+
cast_type: TRTDataType,
775+
) -> TRTTensor:
776+
"""
777+
This function helps to cast the input type to cast_type
778+
"""
779+
layer_i = network.add_identity(input)
780+
layer_i.set_output_type(0, cast_type)
781+
set_layer_name(layer_i, target, f"{name}_dtype_change")
782+
return layer_i.get_output(0)
783+
784+
785+
def prepend_ones(
786+
network: TRTNetwork,
787+
tensor: TRTTensor,
788+
name: str,
789+
num_prepend_ones: int,
790+
) -> TRTTensor:
791+
"""
792+
Prepend 1s to the shape of TensorRT ITensor `tensor`.
793+
794+
Args:
795+
network (TRTNetwork): The TensorRT network that `tensor`
796+
belongs to.
797+
tensor (TRTTensor): A TensorRT tensor.
798+
name (str): Name of the TensorRT Shuffle layer which is used to prepend
799+
1s.
800+
num_prepend_ones (int): Number of 1s that will be prepend.
801+
802+
Returns:
803+
A Tensorrt ITensor which contains the same value as `tensor` but with
804+
more 1s prepended to the beginning of `tensor` shape.
805+
"""
806+
layer = network.add_shuffle(tensor)
807+
808+
# If there're dynamic dim in tensor's shape, we need to use shape layer to
809+
# compute the final shape.
810+
if has_dynamic_shape(tensor.shape):
811+
tensor_shape_layer = network.add_shape(tensor)
812+
tensor_shape = tensor_shape_layer.get_output(0)
813+
tensor_shape = type_cast(
814+
network, "shape", name + "shape_casted", tensor_shape, trt.int32
815+
)
816+
tensor_shape_layer.name = f"{name}_broadcast_orig_shape"
817+
prepend_shape_layer = network.add_constant(
818+
(num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32)
819+
)
820+
prepend_shape_layer.name = f"{name}_broadcast_prepend_ones"
821+
reshape_dim_layer = network.add_concatenation(
822+
[prepend_shape_layer.get_output(0), tensor_shape]
823+
)
824+
reshape_dim_layer.axis = 0
825+
reshape_dim_layer.name = f"{name}_broadcast_final_shape"
826+
layer.set_input(1, reshape_dim_layer.get_output(0))
827+
else:
828+
layer.reshape_dims = (1,) * num_prepend_ones + tuple(tensor.shape)
829+
830+
layer.name = name
831+
return layer.get_output(0)
832+
833+
834+
def set_layer_name(
835+
layer: TRTLayer,
836+
target: Union[Target, torch.nn.Module, str],
837+
name: str,
838+
source_ir: Optional[SourceIR] = None,
839+
) -> None:
840+
"""
841+
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
842+
843+
Args:
844+
layer (TRTLayer): A TensorRT layer of which we want to set the name.
845+
target (Target): A fx node.target or submodule. For call_function node, it's the function that
846+
the node represents.
847+
name (str): Consists of fx node.name with optional suffix.
848+
source_ir: (Optional[SourceIR]): The IR producing the op.
849+
"""
850+
851+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
852+
853+
target_name = (
854+
f"{source_ir}_ops.{target}"
855+
if isinstance(target, str)
856+
else f"{source_ir}_ops.{target.__name__}"
857+
)
858+
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"

py/torch_tensorrt/dynamo/types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Sequence, Tuple
2+
3+
# @manual=//deeplearning/trt/python:py_tensorrt
4+
import tensorrt as trt
5+
6+
if hasattr(trt, "__version__"):
7+
TRTNetwork = trt.INetworkDefinition
8+
TRTTensor = trt.tensorrt.ITensor
9+
TRTLayer = trt.ILayer
10+
TRTPluginFieldCollection = trt.PluginFieldCollection
11+
TRTPlugin = trt.IPluginV2
12+
TRTDataType = trt.DataType
13+
TRTElementWiseOp = trt.ElementWiseOperation
14+
else:
15+
TRTNetwork = "trt.INetworkDefinition"
16+
TRTTensor = "trt.tensorrt.ITensor"
17+
TRTLayer = "trt.ILayer"
18+
TRTPluginFieldCollection = "trt.PluginFieldCollection"
19+
TRTPlugin = "trt.IPluginV2"
20+
TRTDataType = "trt.DataType"
21+
TRTElementWiseOp = "trt.ElementWiseOperation"
22+
23+
Shape = Sequence[int]
24+
ShapeRange = Tuple[Shape, Shape, Shape]

0 commit comments

Comments
 (0)