Skip to content

Commit 6addbc3

Browse files
committed
chore: Proper helper functions when it's copied from fx
1 parent 4da01bc commit 6addbc3

File tree

2 files changed

+17
-38
lines changed

2 files changed

+17
-38
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
DynamoConverterImplSignature,
1717
)
1818

19-
from ..types import Shape, TRTDataType, TRTLayer, TRTNetwork, TRTTensor
19+
from ..types import Shape, TRTDataType, TRTLayer, TRTTensor
2020

2121
_LOGGER: logging.Logger = logging.getLogger(__name__)
2222

@@ -174,9 +174,7 @@ def broadcast_to_same_shape(
174174
Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape
175175
176176
"""
177-
lhs_val, rhs_val = broadcast(
178-
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
179-
)
177+
lhs_val, rhs_val = broadcast(ctx, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs")
180178

181179
lhs_val_shape = lhs_val.shape
182180
rhs_val_shape = rhs_val.shape
@@ -687,7 +685,7 @@ def calculate_strides(shape: Sequence[int]) -> Sequence[int]:
687685

688686

689687
def broadcast(
690-
network: TRTNetwork,
688+
ctx: ConversionContext,
691689
a: TRTTensor,
692690
b: TRTTensor,
693691
a_name: str,
@@ -699,7 +697,7 @@ def broadcast(
699697
prepending 1s to the tensor with less number of dimensions.
700698
701699
Args:
702-
network (TRTNetwork): TensorRT network object.
700+
ctx (ConversionContext): A ConversionContext containing the TensorRT network
703701
a (TRTTensor): A TensorRT ITensor.
704702
b (TRTTensor): A TensorRT ITensor.
705703
a_name (str): Name of tensor a.
@@ -719,9 +717,9 @@ def broadcast(
719717

720718
diff = len(a_shape) - len(b_shape) - preset_diff
721719
if diff > 0:
722-
b = prepend_ones(network, b, f"{b_name}_broadcast", diff)
720+
b = prepend_ones(ctx, b, f"{b_name}_broadcast", diff)
723721
elif diff < 0:
724-
a = prepend_ones(network, a, f"{a_name}_broadcast", -diff)
722+
a = prepend_ones(ctx, a, f"{a_name}_broadcast", -diff)
725723

726724
return a, b
727725

@@ -774,24 +772,8 @@ def has_dynamic_shape(shape: Shape) -> bool:
774772
return count > 0
775773

776774

777-
def type_cast(
778-
network: TRTNetwork,
779-
target: Target,
780-
name: str,
781-
input: TRTTensor,
782-
cast_type: TRTDataType,
783-
) -> TRTTensor:
784-
"""
785-
This function helps to cast the input type to cast_type
786-
"""
787-
layer_i = network.add_identity(input)
788-
layer_i.set_output_type(0, cast_type)
789-
set_layer_name(layer_i, target, f"{name}_dtype_change")
790-
return layer_i.get_output(0)
791-
792-
793775
def prepend_ones(
794-
network: TRTNetwork,
776+
ctx: ConversionContext,
795777
tensor: TRTTensor,
796778
name: str,
797779
num_prepend_ones: int,
@@ -800,8 +782,7 @@ def prepend_ones(
800782
Prepend 1s to the shape of TensorRT ITensor `tensor`.
801783
802784
Args:
803-
network (TRTNetwork): The TensorRT network that `tensor`
804-
belongs to.
785+
ctx (ConversionContext): A ConversionContext containing the TensorRT network
805786
tensor (TRTTensor): A TensorRT tensor.
806787
name (str): Name of the TensorRT Shuffle layer which is used to prepend
807788
1s.
@@ -811,22 +792,22 @@ def prepend_ones(
811792
A Tensorrt ITensor which contains the same value as `tensor` but with
812793
more 1s prepended to the beginning of `tensor` shape.
813794
"""
814-
layer = network.add_shuffle(tensor)
795+
layer = ctx.net.add_shuffle(tensor)
815796

816797
# If there're dynamic dim in tensor's shape, we need to use shape layer to
817798
# compute the final shape.
818799
if has_dynamic_shape(tensor.shape):
819-
tensor_shape_layer = network.add_shape(tensor)
800+
tensor_shape_layer = ctx.net.add_shape(tensor)
820801
tensor_shape = tensor_shape_layer.get_output(0)
821-
tensor_shape = type_cast(
822-
network, "shape", name + "shape_casted", tensor_shape, trt.int32
802+
tensor_shape = cast_trt_tensor(
803+
ctx, tensor_shape, trt.int32, name + "shape_casted", "shape"
823804
)
824805
tensor_shape_layer.name = f"{name}_broadcast_orig_shape"
825-
prepend_shape_layer = network.add_constant(
806+
prepend_shape_layer = ctx.net.add_constant(
826807
(num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32)
827808
)
828809
prepend_shape_layer.name = f"{name}_broadcast_prepend_ones"
829-
reshape_dim_layer = network.add_concatenation(
810+
reshape_dim_layer = ctx.net.add_concatenation(
830811
[prepend_shape_layer.get_output(0), tensor_shape]
831812
)
832813
reshape_dim_layer.axis = 0

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,19 @@ def where(
3636
condition = get_trt_tensor(ctx, condition, f"{name}_condition")
3737
diff = max_shape_len - len(condition_shape)
3838
if diff > 0:
39-
condition = prepend_ones(
40-
ctx.net, condition, f"{name}_condition_broadcast", diff
41-
)
39+
condition = prepend_ones(ctx, condition, f"{name}_condition_broadcast", diff)
4240

4341
if not isinstance(input, TRTTensor):
4442
input = get_trt_tensor(ctx, input, f"{name}_x")
4543
diff = max_shape_len - len(x_shape)
4644
if diff > 0:
47-
input = prepend_ones(ctx.net, input, f"{name}_input_broadcast", diff)
45+
input = prepend_ones(ctx, input, f"{name}_input_broadcast", diff)
4846

4947
if not isinstance(other, TRTTensor):
5048
other = get_trt_tensor(ctx, other, f"{name}_y")
5149
diff = max_shape_len - len(y_shape)
5250
if diff > 0:
53-
other = prepend_ones(ctx.net, other, f"{name}_other_broadcast", diff)
51+
other = prepend_ones(ctx, other, f"{name}_other_broadcast", diff)
5452

5553
return select(ctx, target, source_ir, name, input, other, condition)
5654

0 commit comments

Comments
 (0)