Skip to content

Commit 64442d6

Browse files
committed
addressing review comments-move all tensors to input tensor device
1 parent f640306 commit 64442d6

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def scatter_reduce_decomposition(
329329
reduce: str,
330330
) -> torch.Tensor:
331331
scatter_loop_tensor = input_tensor
332+
device_input_tensor = input_tensor.device
332333
# required for mean reduce operation
333334
scatter_count_tensor = torch.zeros_like(input_tensor)
334335
src_shape = list(src_tensor.shape)
@@ -340,12 +341,11 @@ def scatter_reduce_decomposition(
340341
# unsqueeze src and index in dim
341342
src_slice = torch.unsqueeze(src_slice, dim)
342343
index_slice = torch.unsqueeze(index_slice, dim)
343-
device = to_torch_device(default_device())
344344

345345
# moving tensor to default device
346-
scatter_loop_tensor = scatter_loop_tensor.to(device)
347-
index_slice = index_slice.to(device)
348-
src_slice = src_slice.to(device)
346+
scatter_loop_tensor = scatter_loop_tensor.to(device_input_tensor)
347+
index_slice = index_slice.to(device_input_tensor)
348+
src_slice = src_slice.to(device_input_tensor)
349349
if reduce == "sum":
350350
reduceOp = ReduceOperation.SUM
351351
elif reduce == "prod":

0 commit comments

Comments
 (0)