Skip to content

Commit 2ce4933

Browse files
committed
addressing review comments-move all tensors to input tensor device
1 parent 3d05a26 commit 2ce4933

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def scatter_reduce_decomposition(
333333
reduce: str,
334334
) -> torch.Tensor:
335335
scatter_loop_tensor = input_tensor
336+
device_input_tensor = input_tensor.device
336337
# required for mean reduce operation
337338
scatter_count_tensor = torch.zeros_like(input_tensor)
338339
src_shape = list(src_tensor.shape)
@@ -344,12 +345,11 @@ def scatter_reduce_decomposition(
344345
# unsqueeze src and index in dim
345346
src_slice = torch.unsqueeze(src_slice, dim)
346347
index_slice = torch.unsqueeze(index_slice, dim)
347-
device = to_torch_device(default_device())
348348

349349
# moving tensor to default device
350-
scatter_loop_tensor = scatter_loop_tensor.to(device)
351-
index_slice = index_slice.to(device)
352-
src_slice = src_slice.to(device)
350+
scatter_loop_tensor = scatter_loop_tensor.to(device_input_tensor)
351+
index_slice = index_slice.to(device_input_tensor)
352+
src_slice = src_slice.to(device_input_tensor)
353353
if reduce == "sum":
354354
reduceOp = ReduceOperation.SUM
355355
elif reduce == "prod":

0 commit comments

Comments
 (0)