Skip to content

Commit 0814b4b

Browse files
committed
chore: Proper helper functions when it's copied from fx
1 parent 68f08ea commit 0814b4b

File tree

2 files changed

+17
-38
lines changed

2 files changed

+17
-38
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
DynamoConverterImplSignature,
1717
)
1818

19-
from ..types import Shape, TRTDataType, TRTLayer, TRTNetwork, TRTTensor
19+
from ..types import Shape, TRTDataType, TRTLayer, TRTTensor
2020

2121
_LOGGER: logging.Logger = logging.getLogger(__name__)
2222

@@ -174,9 +174,7 @@ def broadcast_to_same_shape(
174174
Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape
175175
176176
"""
177-
lhs_val, rhs_val = broadcast(
178-
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
179-
)
177+
lhs_val, rhs_val = broadcast(ctx, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs")
180178

181179
lhs_val_shape = lhs_val.shape
182180
rhs_val_shape = rhs_val.shape
@@ -679,7 +677,7 @@ def calculate_strides(shape: Sequence[int]) -> Sequence[int]:
679677

680678

681679
def broadcast(
682-
network: TRTNetwork,
680+
ctx: ConversionContext,
683681
a: TRTTensor,
684682
b: TRTTensor,
685683
a_name: str,
@@ -691,7 +689,7 @@ def broadcast(
691689
prepending 1s to the tensor with less number of dimensions.
692690
693691
Args:
694-
network (TRTNetwork): TensorRT network object.
692+
ctx (ConversionContext): A ConversionContext containing the TensorRT network
695693
a (TRTTensor): A TensorRT ITensor.
696694
b (TRTTensor): A TensorRT ITensor.
697695
a_name (str): Name of tensor a.
@@ -711,9 +709,9 @@ def broadcast(
711709

712710
diff = len(a_shape) - len(b_shape) - preset_diff
713711
if diff > 0:
714-
b = prepend_ones(network, b, f"{b_name}_broadcast", diff)
712+
b = prepend_ones(ctx, b, f"{b_name}_broadcast", diff)
715713
elif diff < 0:
716-
a = prepend_ones(network, a, f"{a_name}_broadcast", -diff)
714+
a = prepend_ones(ctx, a, f"{a_name}_broadcast", -diff)
717715

718716
return a, b
719717

@@ -766,24 +764,8 @@ def has_dynamic_shape(shape: Shape) -> bool:
766764
return count > 0
767765

768766

769-
def type_cast(
770-
network: TRTNetwork,
771-
target: Target,
772-
name: str,
773-
input: TRTTensor,
774-
cast_type: TRTDataType,
775-
) -> TRTTensor:
776-
"""
777-
This function helps to cast the input type to cast_type
778-
"""
779-
layer_i = network.add_identity(input)
780-
layer_i.set_output_type(0, cast_type)
781-
set_layer_name(layer_i, target, f"{name}_dtype_change")
782-
return layer_i.get_output(0)
783-
784-
785767
def prepend_ones(
786-
network: TRTNetwork,
768+
ctx: ConversionContext,
787769
tensor: TRTTensor,
788770
name: str,
789771
num_prepend_ones: int,
@@ -792,8 +774,7 @@ def prepend_ones(
792774
Prepend 1s to the shape of TensorRT ITensor `tensor`.
793775
794776
Args:
795-
network (TRTNetwork): The TensorRT network that `tensor`
796-
belongs to.
777+
ctx (ConversionContext): A ConversionContext containing the TensorRT network
797778
tensor (TRTTensor): A TensorRT tensor.
798779
name (str): Name of the TensorRT Shuffle layer which is used to prepend
799780
1s.
@@ -803,22 +784,22 @@ def prepend_ones(
803784
A Tensorrt ITensor which contains the same value as `tensor` but with
804785
more 1s prepended to the beginning of `tensor` shape.
805786
"""
806-
layer = network.add_shuffle(tensor)
787+
layer = ctx.net.add_shuffle(tensor)
807788

808789
# If there're dynamic dim in tensor's shape, we need to use shape layer to
809790
# compute the final shape.
810791
if has_dynamic_shape(tensor.shape):
811-
tensor_shape_layer = network.add_shape(tensor)
792+
tensor_shape_layer = ctx.net.add_shape(tensor)
812793
tensor_shape = tensor_shape_layer.get_output(0)
813-
tensor_shape = type_cast(
814-
network, "shape", name + "shape_casted", tensor_shape, trt.int32
794+
tensor_shape = cast_trt_tensor(
795+
ctx, tensor_shape, trt.int32, name + "shape_casted", "shape"
815796
)
816797
tensor_shape_layer.name = f"{name}_broadcast_orig_shape"
817-
prepend_shape_layer = network.add_constant(
798+
prepend_shape_layer = ctx.net.add_constant(
818799
(num_prepend_ones,), np.ones((num_prepend_ones,), dtype=np.int32)
819800
)
820801
prepend_shape_layer.name = f"{name}_broadcast_prepend_ones"
821-
reshape_dim_layer = network.add_concatenation(
802+
reshape_dim_layer = ctx.net.add_concatenation(
822803
[prepend_shape_layer.get_output(0), tensor_shape]
823804
)
824805
reshape_dim_layer.axis = 0

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,19 @@ def where(
3636
condition = get_trt_tensor(ctx, condition, f"{name}_condition")
3737
diff = max_shape_len - len(condition_shape)
3838
if diff > 0:
39-
condition = prepend_ones(
40-
ctx.net, condition, f"{name}_condition_broadcast", diff
41-
)
39+
condition = prepend_ones(ctx, condition, f"{name}_condition_broadcast", diff)
4240

4341
if not isinstance(input, TRTTensor):
4442
input = get_trt_tensor(ctx, input, f"{name}_x")
4543
diff = max_shape_len - len(x_shape)
4644
if diff > 0:
47-
input = prepend_ones(ctx.net, input, f"{name}_input_broadcast", diff)
45+
input = prepend_ones(ctx, input, f"{name}_input_broadcast", diff)
4846

4947
if not isinstance(other, TRTTensor):
5048
other = get_trt_tensor(ctx, other, f"{name}_y")
5149
diff = max_shape_len - len(y_shape)
5250
if diff > 0:
53-
other = prepend_ones(ctx.net, other, f"{name}_other_broadcast", diff)
51+
other = prepend_ones(ctx, other, f"{name}_other_broadcast", diff)
5452

5553
return select(ctx, target, source_ir, name, input, other, condition)
5654

0 commit comments

Comments
 (0)