@@ -1128,3 +1128,51 @@ def add_expand(network, target, kwargs, name):
1128
1128
layer = network .add_slice (input_val , start = start , shape = shape , stride = stride )
1129
1129
set_layer_name (layer , target , name )
1130
1130
return layer .get_output (0 )
1131
+
1132
+
1133
+ def add_slice (network , target , kwargs , name ):
1134
+ input_val = kwargs ["input" ]
1135
+ if not isinstance (input_val , TRTTensor ):
1136
+ raise RuntimeError (
1137
+ f"slice_tensor received input { input_val } that is not part "
1138
+ "of the TensorRT region!"
1139
+ )
1140
+
1141
+ ranks = len (input_val .shape ) + (1 if network .has_implicit_batch_dimension else 0 )
1142
+ dim = get_positive_dim (cast (int , kwargs ["dim" ]), ranks )
1143
+ dynamic_shape = has_dynamic_shape (input_val .shape )
1144
+ if network .has_implicit_batch_dimension :
1145
+ if dim == 0 :
1146
+ raise RuntimeError (
1147
+ f"We do not support slice_tensor at batch dim when it's implicit, got { dim } !"
1148
+ )
1149
+ dim = dim - 1
1150
+ else :
1151
+ if dynamic_shape :
1152
+ # Check whether slice target dim is dynamic shape dim
1153
+ assert input_val .shape [dim ] != - 1 , "Can't chunk on dynamic shape dimension!"
1154
+
1155
+ start_int = cast (int , kwargs ["start" ])
1156
+ stop_int = cast (int , kwargs ["stop" ])
1157
+ step_int = cast (int , kwargs ["step" ])
1158
+ start = [0 ] * len (input_val .shape )
1159
+ start [dim ] = start_int
1160
+ stride = [1 ] * len (start )
1161
+ stride [dim ] = step_int
1162
+ output_shape = list (input_val .shape )
1163
+ output_shape [dim ] = (stop_int - start_int ) // step_int
1164
+
1165
+ if dynamic_shape > 0 :
1166
+ output_shape = get_shape_with_dynamic_shape (
1167
+ network , output_shape , input_val , target , name
1168
+ )
1169
+ layer = network .add_slice (
1170
+ input_val ,
1171
+ start = start ,
1172
+ shape = [] if dynamic_shape else output_shape ,
1173
+ stride = stride ,
1174
+ )
1175
+ if dynamic_shape :
1176
+ layer .set_input (2 , output_shape )
1177
+ set_layer_name (layer , target , name )
1178
+ return layer .get_output (0 )
0 commit comments