16
16
DynamoConverterImplSignature ,
17
17
)
18
18
19
- from ..types import Shape , TRTDataType , TRTLayer , TRTNetwork , TRTTensor
19
+ from ..types import Shape , TRTDataType , TRTLayer , TRTTensor
20
20
21
21
_LOGGER : logging .Logger = logging .getLogger (__name__ )
22
22
@@ -174,9 +174,7 @@ def broadcast_to_same_shape(
174
174
Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape
175
175
176
176
"""
177
- lhs_val , rhs_val = broadcast (
178
- ctx .net , lhs_val , rhs_val , f"{ name } _lhs" , f"{ name } _rhs"
179
- )
177
+ lhs_val , rhs_val = broadcast (ctx , lhs_val , rhs_val , f"{ name } _lhs" , f"{ name } _rhs" )
180
178
181
179
lhs_val_shape = lhs_val .shape
182
180
rhs_val_shape = rhs_val .shape
@@ -679,7 +677,7 @@ def calculate_strides(shape: Sequence[int]) -> Sequence[int]:
679
677
680
678
681
679
def broadcast (
682
- network : TRTNetwork ,
680
+ ctx : ConversionContext ,
683
681
a : TRTTensor ,
684
682
b : TRTTensor ,
685
683
a_name : str ,
@@ -691,7 +689,7 @@ def broadcast(
691
689
prepending 1s to the tensor with less number of dimensions.
692
690
693
691
Args:
694
- network (TRTNetwork ): TensorRT network object.
692
+ ctx (ConversionContext ): A ConversionContext containing the TensorRT network
695
693
a (TRTTensor): A TensorRT ITensor.
696
694
b (TRTTensor): A TensorRT ITensor.
697
695
a_name (str): Name of tensor a.
@@ -711,9 +709,9 @@ def broadcast(
711
709
712
710
diff = len (a_shape ) - len (b_shape ) - preset_diff
713
711
if diff > 0 :
714
- b = prepend_ones (network , b , f"{ b_name } _broadcast" , diff )
712
+ b = prepend_ones (ctx , b , f"{ b_name } _broadcast" , diff )
715
713
elif diff < 0 :
716
- a = prepend_ones (network , a , f"{ a_name } _broadcast" , - diff )
714
+ a = prepend_ones (ctx , a , f"{ a_name } _broadcast" , - diff )
717
715
718
716
return a , b
719
717
@@ -766,24 +764,8 @@ def has_dynamic_shape(shape: Shape) -> bool:
766
764
return count > 0
767
765
768
766
769
- def type_cast (
770
- network : TRTNetwork ,
771
- target : Target ,
772
- name : str ,
773
- input : TRTTensor ,
774
- cast_type : TRTDataType ,
775
- ) -> TRTTensor :
776
- """
777
- This function helps to cast the input type to cast_type
778
- """
779
- layer_i = network .add_identity (input )
780
- layer_i .set_output_type (0 , cast_type )
781
- set_layer_name (layer_i , target , f"{ name } _dtype_change" )
782
- return layer_i .get_output (0 )
783
-
784
-
785
767
def prepend_ones (
786
- network : TRTNetwork ,
768
+ ctx : ConversionContext ,
787
769
tensor : TRTTensor ,
788
770
name : str ,
789
771
num_prepend_ones : int ,
@@ -792,8 +774,7 @@ def prepend_ones(
792
774
Prepend 1s to the shape of TensorRT ITensor `tensor`.
793
775
794
776
Args:
795
- network (TRTNetwork): The TensorRT network that `tensor`
796
- belongs to.
777
+ ctx (ConversionContext): A ConversionContext containing the TensorRT network
797
778
tensor (TRTTensor): A TensorRT tensor.
798
779
name (str): Name of the TensorRT Shuffle layer which is used to prepend
799
780
1s.
@@ -803,22 +784,22 @@ def prepend_ones(
803
784
A Tensorrt ITensor which contains the same value as `tensor` but with
804
785
more 1s prepended to the beginning of `tensor` shape.
805
786
"""
806
- layer = network .add_shuffle (tensor )
787
+ layer = ctx . net .add_shuffle (tensor )
807
788
808
789
# If there're dynamic dim in tensor's shape, we need to use shape layer to
809
790
# compute the final shape.
810
791
if has_dynamic_shape (tensor .shape ):
811
- tensor_shape_layer = network .add_shape (tensor )
792
+ tensor_shape_layer = ctx . net .add_shape (tensor )
812
793
tensor_shape = tensor_shape_layer .get_output (0 )
813
- tensor_shape = type_cast (
814
- network , "shape" , name + "shape_casted" , tensor_shape , trt . int32
794
+ tensor_shape = cast_trt_tensor (
795
+ ctx , tensor_shape , trt . int32 , name + "shape_casted" , "shape"
815
796
)
816
797
tensor_shape_layer .name = f"{ name } _broadcast_orig_shape"
817
- prepend_shape_layer = network .add_constant (
798
+ prepend_shape_layer = ctx . net .add_constant (
818
799
(num_prepend_ones ,), np .ones ((num_prepend_ones ,), dtype = np .int32 )
819
800
)
820
801
prepend_shape_layer .name = f"{ name } _broadcast_prepend_ones"
821
- reshape_dim_layer = network .add_concatenation (
802
+ reshape_dim_layer = ctx . net .add_concatenation (
822
803
[prepend_shape_layer .get_output (0 ), tensor_shape ]
823
804
)
824
805
reshape_dim_layer .axis = 0
0 commit comments