Skip to content

Commit cf96dec

Browse files
committed
Adding add_slice function in operator.py
1 parent a1d94c1 commit cf96dec

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,3 +1128,51 @@ def add_expand(network, target, kwargs, name):
11281128
layer = network.add_slice(input_val, start=start, shape=shape, stride=stride)
11291129
set_layer_name(layer, target, name)
11301130
return layer.get_output(0)
1131+
1132+
1133+
def add_slice(network, target, kwargs, name):
1134+
input_val = kwargs["input"]
1135+
if not isinstance(input_val, TRTTensor):
1136+
raise RuntimeError(
1137+
f"slice_tensor received input {input_val} that is not part "
1138+
"of the TensorRT region!"
1139+
)
1140+
1141+
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
1142+
dim = get_positive_dim(cast(int, kwargs["dim"]), ranks)
1143+
dynamic_shape = has_dynamic_shape(input_val.shape)
1144+
if network.has_implicit_batch_dimension:
1145+
if dim == 0:
1146+
raise RuntimeError(
1147+
f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!"
1148+
)
1149+
dim = dim - 1
1150+
else:
1151+
if dynamic_shape:
1152+
# Check whether slice target dim is dynamic shape dim
1153+
assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
1154+
1155+
start_int = cast(int, kwargs["start"])
1156+
stop_int = cast(int, kwargs["stop"])
1157+
step_int = cast(int, kwargs["step"])
1158+
start = [0] * len(input_val.shape)
1159+
start[dim] = start_int
1160+
stride = [1] * len(start)
1161+
stride[dim] = step_int
1162+
output_shape = list(input_val.shape)
1163+
output_shape[dim] = (stop_int - start_int) // step_int
1164+
1165+
if dynamic_shape > 0:
1166+
output_shape = get_shape_with_dynamic_shape(
1167+
network, output_shape, input_val, target, name
1168+
)
1169+
layer = network.add_slice(
1170+
input_val,
1171+
start=start,
1172+
shape=[] if dynamic_shape else output_shape,
1173+
stride=stride,
1174+
)
1175+
if dynamic_shape:
1176+
layer.set_input(2, output_shape)
1177+
set_layer_name(layer, target, name)
1178+
return layer.get_output(0)

0 commit comments

Comments
 (0)