Skip to content

Commit 3d2ada6

Browse files
committed
converting indices to fp32 tensors
1 parent b253bd1 commit 3d2ada6

File tree

1 file changed

+7
-1
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+7
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,13 @@ def gather(
8484
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
8585
sparse_grad: bool = False,
8686
) -> TRTTensor:
87-
gather_layer = ctx.net.add_gather(input, index, dim)
87+
indices_tensor = []
88+
89+
for i, ind in enumerate(index):
90+
indices_tensor.append(get_trt_tensor(
91+
ctx, ind, name + f"_parameter_to_fp32_tensor_{i}"
92+
))
93+
gather_layer = ctx.net.add_gather(input, indices_tensor, dim)
8894
set_layer_name(gather_layer, target, name + "_gather", source_ir)
8995
return gather_layer.get_output(0)
9096

0 commit comments

Comments
 (0)