Skip to content

Commit c8811cd

Browse files
committed
select test implementation
1 parent 85c755b commit c8811cd

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def aten_ops_sigmoid(
572572
return add_sigmoid(network, target, kwargs_new, name)
573573

574574

575-
@tensorrt_converter(torch.ops.aten.select)
575+
@tensorrt_converter(torch.ops.aten.select.int)
576576
def aten_ops_select(
577577
network: TRTNetwork,
578578
target: Target,
@@ -585,4 +585,4 @@ def aten_ops_select(
585585
"dim": args[1],
586586
"index": args[2],
587587
}
588-
return add_select(network, target.kwargs_new, name)
588+
return add_select(network, target, kwargs_new, name)

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,8 @@ def add_select(network, target, kwargs, name):
11531153
assert (
11541154
input_val.shape[dim] != -1
11551155
), "Can't select on negative shape dimension!"
1156-
index = kwargs[2]
1156+
index = kwargs["index"]
1157+
11571158
if index >= input_val.shape[dim]:
11581159
raise RuntimeError(
11591160
f"cannot have index greater than the dimension length! {input_val.shape[dim]}"
@@ -1164,8 +1165,11 @@ def add_select(network, target, kwargs, name):
11641165
output_shape = get_shape_with_dynamic_shape(
11651166
network, output_shape, input_val, target, name
11661167
)
1167-
layer = network.add_gather(input_val, dim, index)
1168-
out = layer.getOutput(0)
1168+
input_shape = network.add_shape(input_val).get_output(0)
1169+
dim_value = torch.tensor(dim, dtype=torch.int32)
1170+
axis = network.add_constant(dim_value.shape, to_numpy(dim_value)).get_output(0)
1171+
layer = network.add_gather(input_shape, axis, index)
1172+
out = layer.get_output(0)
11691173
if len(out.shape) != 1:
11701174
layer = network.add_shuffle(out)
1171-
return layer.getOutput(0)
1175+
return layer.get_output(0)

0 commit comments

Comments
 (0)