File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
py/torch_tensorrt/dynamo/lowering Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -333,6 +333,7 @@ def scatter_reduce_decomposition(
333
333
reduce : str ,
334
334
) -> torch .Tensor :
335
335
scatter_loop_tensor = input_tensor
336
+ device_input_tensor = input_tensor .device
336
337
# required for mean reduce operation
337
338
scatter_count_tensor = torch .zeros_like (input_tensor )
338
339
src_shape = list (src_tensor .shape )
@@ -344,12 +345,11 @@ def scatter_reduce_decomposition(
344
345
# unsqueeze src and index in dim
345
346
src_slice = torch .unsqueeze (src_slice , dim )
346
347
index_slice = torch .unsqueeze (index_slice , dim )
347
- device = to_torch_device (default_device ())
348
348
349
349
# moving tensor to default device
350
- scatter_loop_tensor = scatter_loop_tensor .to (device )
351
- index_slice = index_slice .to (device )
352
- src_slice = src_slice .to (device )
350
+ scatter_loop_tensor = scatter_loop_tensor .to (device_input_tensor )
351
+ index_slice = index_slice .to (device_input_tensor )
352
+ src_slice = src_slice .to (device_input_tensor )
353
353
if reduce == "sum" :
354
354
reduceOp = ReduceOperation .SUM
355
355
elif reduce == "prod" :
You can’t perform that action at this time.
0 commit comments