16
16
DynamoConverterImplSignature ,
17
17
)
18
18
19
- from ..types import Shape , TRTDataType , TRTLayer , TRTNetwork , TRTTensor
19
+ from ..types import Shape , TRTDataType , TRTLayer , TRTTensor
20
20
21
21
_LOGGER : logging .Logger = logging .getLogger (__name__ )
22
22
@@ -174,9 +174,7 @@ def broadcast_to_same_shape(
174
174
Tuple[TRTTensor, TRTTensor]: Two TensorRT ITensors that are broadcasted to the same shape
175
175
176
176
"""
177
- lhs_val , rhs_val = broadcast (
178
- ctx .net , lhs_val , rhs_val , f"{ name } _lhs" , f"{ name } _rhs"
179
- )
177
+ lhs_val , rhs_val = broadcast (ctx , lhs_val , rhs_val , f"{ name } _lhs" , f"{ name } _rhs" )
180
178
181
179
lhs_val_shape = lhs_val .shape
182
180
rhs_val_shape = rhs_val .shape
@@ -687,7 +685,7 @@ def calculate_strides(shape: Sequence[int]) -> Sequence[int]:
687
685
688
686
689
687
def broadcast (
690
- network : TRTNetwork ,
688
+ ctx : ConversionContext ,
691
689
a : TRTTensor ,
692
690
b : TRTTensor ,
693
691
a_name : str ,
@@ -699,7 +697,7 @@ def broadcast(
699
697
prepending 1s to the tensor with less number of dimensions.
700
698
701
699
Args:
702
- network (TRTNetwork ): TensorRT network object.
700
+ ctx (ConversionContext ): A ConversionContext containing the TensorRT network
703
701
a (TRTTensor): A TensorRT ITensor.
704
702
b (TRTTensor): A TensorRT ITensor.
705
703
a_name (str): Name of tensor a.
@@ -719,9 +717,9 @@ def broadcast(
719
717
720
718
diff = len (a_shape ) - len (b_shape ) - preset_diff
721
719
if diff > 0 :
722
- b = prepend_ones (network , b , f"{ b_name } _broadcast" , diff )
720
+ b = prepend_ones (ctx , b , f"{ b_name } _broadcast" , diff )
723
721
elif diff < 0 :
724
- a = prepend_ones (network , a , f"{ a_name } _broadcast" , - diff )
722
+ a = prepend_ones (ctx , a , f"{ a_name } _broadcast" , - diff )
725
723
726
724
return a , b
727
725
@@ -774,24 +772,8 @@ def has_dynamic_shape(shape: Shape) -> bool:
774
772
return count > 0
775
773
776
774
777
- def type_cast (
778
- network : TRTNetwork ,
779
- target : Target ,
780
- name : str ,
781
- input : TRTTensor ,
782
- cast_type : TRTDataType ,
783
- ) -> TRTTensor :
784
- """
785
- This function helps to cast the input type to cast_type
786
- """
787
- layer_i = network .add_identity (input )
788
- layer_i .set_output_type (0 , cast_type )
789
- set_layer_name (layer_i , target , f"{ name } _dtype_change" )
790
- return layer_i .get_output (0 )
791
-
792
-
793
775
def prepend_ones (
794
- network : TRTNetwork ,
776
+ ctx : ConversionContext ,
795
777
tensor : TRTTensor ,
796
778
name : str ,
797
779
num_prepend_ones : int ,
@@ -800,8 +782,7 @@ def prepend_ones(
800
782
Prepend 1s to the shape of TensorRT ITensor `tensor`.
801
783
802
784
Args:
803
- network (TRTNetwork): The TensorRT network that `tensor`
804
- belongs to.
785
+ ctx (ConversionContext): A ConversionContext containing the TensorRT network
805
786
tensor (TRTTensor): A TensorRT tensor.
806
787
name (str): Name of the TensorRT Shuffle layer which is used to prepend
807
788
1s.
@@ -811,22 +792,22 @@ def prepend_ones(
811
792
A Tensorrt ITensor which contains the same value as `tensor` but with
812
793
more 1s prepended to the beginning of `tensor` shape.
813
794
"""
814
- layer = network .add_shuffle (tensor )
795
+ layer = ctx . net .add_shuffle (tensor )
815
796
816
797
# If there're dynamic dim in tensor's shape, we need to use shape layer to
817
798
# compute the final shape.
818
799
if has_dynamic_shape (tensor .shape ):
819
- tensor_shape_layer = network .add_shape (tensor )
800
+ tensor_shape_layer = ctx . net .add_shape (tensor )
820
801
tensor_shape = tensor_shape_layer .get_output (0 )
821
- tensor_shape = type_cast (
822
- network , "shape" , name + "shape_casted" , tensor_shape , trt . int32
802
+ tensor_shape = cast_trt_tensor (
803
+ ctx , tensor_shape , trt . int32 , name + "shape_casted" , "shape"
823
804
)
824
805
tensor_shape_layer .name = f"{ name } _broadcast_orig_shape"
825
- prepend_shape_layer = network .add_constant (
806
+ prepend_shape_layer = ctx . net .add_constant (
826
807
(num_prepend_ones ,), np .ones ((num_prepend_ones ,), dtype = np .int32 )
827
808
)
828
809
prepend_shape_layer .name = f"{ name } _broadcast_prepend_ones"
829
- reshape_dim_layer = network .add_concatenation (
810
+ reshape_dim_layer = ctx . net .add_concatenation (
830
811
[prepend_shape_layer .get_output (0 ), tensor_shape ]
831
812
)
832
813
reshape_dim_layer .axis = 0
0 commit comments