@@ -506,8 +506,17 @@ def acc_ops_size(
506
506
kwargs : Dict [str , Argument ],
507
507
name : str ,
508
508
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
509
- input_val = kwargs ["input" ]
510
-
509
+ input_t = kwargs ["input" ]
510
+ if type (input_t ) == torch .nn .Parameter or type (input_t ) == torch .Tensor :
511
+ if (
512
+ not has_dynamic_shape (input_t .shape )
513
+ and network .has_implicit_batch_dimension
514
+ ):
515
+ return torch .Size ((IMPLICIT_BATCH_DIM ,) + tuple (input_t .shape ))
516
+ return input_t .shape
517
+
518
+ # input_val = get_trt_tensor(network, input_t, f"{name}_input_t")
519
+ input_val = input_t
511
520
if not isinstance (input_val , TRTTensor ):
512
521
raise RuntimeError (
513
522
f"size received input { input_val } that is not part "
@@ -779,13 +788,8 @@ def acc_ops_tile(
779
788
kwargs : Dict [str , Argument ],
780
789
name : str ,
781
790
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
782
- input_val = kwargs ["input" ]
783
-
784
- if not isinstance (input_val , TRTTensor ):
785
- raise RuntimeError (
786
- f"tile received input { input_val } that is not part "
787
- "of the TensorRT region!"
788
- )
791
+ input_t = kwargs ["input" ]
792
+ input_val = get_trt_tensor (network , input_t , f"{ name } _input" )
789
793
790
794
dims = tuple (cast (Sequence [int ], kwargs ["dims" ]))
791
795
n_input_dims = len (input_val .shape ) + (
@@ -822,9 +826,28 @@ def acc_ops_tile(
822
826
if network .has_implicit_batch_dimension :
823
827
assert dims [0 ] == 1 , "Can't tile the batch dim when it's implicit."
824
828
dims = dims [1 :]
825
-
826
829
starts = [0 ] * len (dims )
827
- shapes = [i * j for i , j in zip (input_val .shape , dims )] # type: ignore[union-attr]
830
+ shapes = []
831
+ if all (isinstance (d , int ) for d in dims ):
832
+ shapes = [i * j for i , j in zip (input_val .shape , dims )] # type: ignore[union-attr]
833
+ else :
834
+ shape = []
835
+ for i , (s , d ) in enumerate (zip (input_val .shape , dims )):
836
+ if isinstance (d , TRTTensor ) and len (d .shape ) == 0 :
837
+ d = prepend_ones (network , d , f"{ name } _{ i } " , 1 )
838
+ else :
839
+ d = get_trt_tensor (network , d , f"{ name } _{ i } " )
840
+ shape .append (d )
841
+ mul = add_binary_elementwise_layer (
842
+ network ,
843
+ s ,
844
+ d ,
845
+ trt .ElementWiseOperation .PROD ,
846
+ target ,
847
+ f"{ name } _mul_{ i } " ,
848
+ )
849
+ shapes .append (mul )
850
+ dims = shape
828
851
# If there's dynmaic dim then there would be negative dims in shapes which is not allowed.
829
852
# Here we build a dummy shapes array.
830
853
if has_dynamic_shape (input_val .shape ): # type: ignore[union-attr]
@@ -838,9 +861,16 @@ def acc_ops_tile(
838
861
starts_tensor = network .add_constant (
839
862
(len (dims ),), np .ascontiguousarray ([0 ] * len (dims ), np .int32 )
840
863
).get_output (0 )
841
- dims_tensor = network .add_constant (
842
- (len (dims ),), np .ascontiguousarray (dims , np .int32 )
843
- ).get_output (0 )
864
+ if all (isinstance (d , int ) for d in dims ):
865
+ dims_tensor = network .add_constant (
866
+ (len (dims ),), np .ascontiguousarray (dims , np .int32 )
867
+ ).get_output (0 )
868
+ else :
869
+ assert all (isinstance (d , TRTTensor ) for d in dims )
870
+ concat_dims_layer = network .add_concatenation (inputs = dims )
871
+ concat_dims_layer .axis = 0
872
+ concat_dims_layer .name = f"{ name } _tile_dim"
873
+ dims_tensor = concat_dims_layer .get_output (0 )
844
874
input_shape_layer = network .add_shape (input_val )
845
875
input_shape_layer .name = f"{ name } _slice_input_shape"
846
876
slice_shapes_tensor = add_binary_elementwise_layer (
@@ -1880,7 +1910,8 @@ def acc_ops_max_pool1d(
1880
1910
1881
1911
1882
1912
@tensorrt_converter (acc_ops .max_pool2d )
1883
- def acc_ops_max_pool2d (
1913
+ @tensorrt_converter (acc_ops .max_pool3d )
1914
+ def acc_ops_max_poolnd (
1884
1915
network : TRTNetwork ,
1885
1916
target : Target ,
1886
1917
args : Tuple [Argument , ...],
@@ -1894,26 +1925,27 @@ def acc_ops_max_pool2d(
1894
1925
f"MaxPool2d received input { input_val } that is not part "
1895
1926
"of the TensorRT region!"
1896
1927
)
1897
-
1898
- kernel_size = extend_attr_to_tuple (kwargs ["kernel_size" ], 2 )
1899
- stride = extend_attr_to_tuple (kwargs ["stride" ], 2 )
1900
- padding = extend_attr_to_tuple (kwargs ["padding" ], 2 )
1901
- dilation = extend_attr_to_tuple (kwargs ["dilation" ], 2 )
1928
+ extend_len = 2 if target == acc_ops . max_pool2d else 3
1929
+ kernel_size = extend_attr_to_tuple (kwargs ["kernel_size" ], extend_len )
1930
+ stride = extend_attr_to_tuple (kwargs ["stride" ], extend_len )
1931
+ padding = extend_attr_to_tuple (kwargs ["padding" ], extend_len )
1932
+ dilation = extend_attr_to_tuple (kwargs ["dilation" ], extend_len )
1902
1933
ceil_mode = kwargs ["ceil_mode" ]
1903
1934
1904
1935
if len (stride ) == 0 or stride [0 ] == None :
1905
1936
stride = kernel_size
1906
1937
1907
- if dilation != (1 , 1 ):
1938
+ ones = (1 ,) * extend_len
1939
+ if dilation != ones :
1908
1940
raise RuntimeError (
1909
1941
f"Only support dilation=(1, 1) for maxpool, but got { dilation } "
1910
1942
)
1911
1943
1912
- layer = network .add_pooling (
1944
+ layer = network .add_pooling_nd (
1913
1945
input = input_val , type = trt .PoolingType .MAX , window_size = kernel_size
1914
1946
)
1915
- layer .stride = stride
1916
- layer .padding = padding
1947
+ layer .stride_nd = stride
1948
+ layer .padding_nd = padding
1917
1949
set_layer_name (layer , target , name )
1918
1950
1919
1951
if ceil_mode :
@@ -2093,8 +2125,8 @@ def acc_ops_unsqueeze(
2093
2125
kwargs : Dict [str , Argument ],
2094
2126
name : str ,
2095
2127
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2096
- input_val = kwargs ["input" ]
2097
-
2128
+ input_t = kwargs ["input" ]
2129
+ input_val = get_trt_tensor ( network , input_t , f" { name } _input_t" )
2098
2130
if not isinstance (input_val , TRTTensor ):
2099
2131
raise RuntimeError (
2100
2132
f"unsqueeze received input { input_val } that is not part "
@@ -2161,8 +2193,9 @@ def acc_ops_topk(
2161
2193
return layer .get_output (0 ), layer .get_output (1 )
2162
2194
2163
2195
2196
+ @tensorrt_converter (acc_ops .adaptive_avg_pool3d )
2164
2197
@tensorrt_converter (acc_ops .adaptive_avg_pool2d )
2165
- def acc_ops_adaptive_avg_pool2d (
2198
+ def acc_ops_adaptive_avg_poolnd (
2166
2199
network : TRTNetwork ,
2167
2200
target : Target ,
2168
2201
args : Tuple [Argument , ...],
@@ -2177,30 +2210,32 @@ def acc_ops_adaptive_avg_pool2d(
2177
2210
"of the TensorRT region!"
2178
2211
)
2179
2212
2180
- assert (
2181
- input_val .shape [- 1 ] != - 1 and input_val .shape [- 1 ] != - 1
2213
+ extend_len = 2 if target == acc_ops .adaptive_avg_pool2d else 3
2214
+ assert all (
2215
+ input_val .shape [- (i + 1 )] != - 1 for i in range (extend_len )
2182
2216
), "AdaptiveAvgPool2d currently doesn't support dynamic shapes for last two dims."
2183
2217
2184
- output_size = cast (Sequence [int ], extend_attr_to_tuple (kwargs ["output_size" ], 2 ))
2185
- for input_dim , output_dim in zip (input_val .shape [- 2 :], output_size ):
2218
+ output_size = cast (
2219
+ Sequence [int ], extend_attr_to_tuple (kwargs ["output_size" ], extend_len )
2220
+ )
2221
+ for input_dim , output_dim in zip (input_val .shape [- extend_len :], output_size ):
2186
2222
if input_dim % output_dim != 0 :
2187
2223
raise RuntimeError (
2188
2224
"For AdaptiveAvgPool, input dim has to be integer multiple of output dim."
2189
2225
f"Got input dim { input_dim } , output dim { output_dim } "
2190
2226
)
2191
2227
2192
- stride = (
2193
- input_val .shape [- 2 ] // output_size [0 ],
2194
- input_val .shape [- 1 ] // output_size [1 ],
2228
+ stride = tuple (
2229
+ input_val .shape [- extend_len + i ] // output_size [i ] for i in range (extend_len )
2195
2230
)
2196
- kernel_size = (
2197
- input_val .shape [- 2 ] - (output_size [0 ] - 1 ) * stride [0 ],
2198
- input_val . shape [ - 1 ] - ( output_size [ 1 ] - 1 ) * stride [ 1 ],
2231
+ kernel_size = tuple (
2232
+ input_val .shape [- extend_len + i ] - (output_size [i ] - 1 ) * stride [i ]
2233
+ for i in range ( extend_len )
2199
2234
)
2200
- layer = network .add_pooling (
2235
+ layer = network .add_pooling_nd (
2201
2236
input = input_val , type = trt .PoolingType .AVERAGE , window_size = kernel_size
2202
2237
)
2203
- layer .stride = stride
2238
+ layer .stride_nd = stride
2204
2239
set_layer_name (layer , target , name )
2205
2240
2206
2241
return layer .get_output (0 )
@@ -2781,7 +2816,6 @@ def acc_ops_getitem(
2781
2816
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2782
2817
input_val = kwargs ["input" ]
2783
2818
slices = kwargs ["idx" ]
2784
-
2785
2819
if not isinstance (input_val , TRTTensor ):
2786
2820
return operator .getitem (input_val , slices ) # type: ignore[arg-type]
2787
2821
0 commit comments