We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b253bd1 commit 3d2ada6Copy full SHA for 3d2ada6
py/torch_tensorrt/dynamo/conversion/impl/select.py
@@ -84,7 +84,13 @@ def gather(
84
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
85
sparse_grad: bool = False,
86
) -> TRTTensor:
87
- gather_layer = ctx.net.add_gather(input, index, dim)
+ indices_tensor = []
88
+
89
+ for i, ind in enumerate(index):
90
+ indices_tensor.append(get_trt_tensor(
91
+ ctx, ind, name + f"_parameter_to_fp32_tensor_{i}"
92
+ ))
93
+ gather_layer = ctx.net.add_gather(input, indices_tensor, dim)
94
set_layer_name(gather_layer, target, name + "_gather", source_ir)
95
return gather_layer.get_output(0)
96
0 commit comments