File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
py/torch_tensorrt/dynamo/lowering Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -329,6 +329,7 @@ def scatter_reduce_decomposition(
329
329
reduce : str ,
330
330
) -> torch .Tensor :
331
331
scatter_loop_tensor = input_tensor
332
+ device_input_tensor = input_tensor .device
332
333
# required for mean reduce operation
333
334
scatter_count_tensor = torch .zeros_like (input_tensor )
334
335
src_shape = list (src_tensor .shape )
@@ -340,12 +341,11 @@ def scatter_reduce_decomposition(
340
341
# unsqueeze src and index in dim
341
342
src_slice = torch .unsqueeze (src_slice , dim )
342
343
index_slice = torch .unsqueeze (index_slice , dim )
343
- device = to_torch_device (default_device ())
344
344
345
345
# moving tensor to default device
346
- scatter_loop_tensor = scatter_loop_tensor .to (device )
347
- index_slice = index_slice .to (device )
348
- src_slice = src_slice .to (device )
346
+ scatter_loop_tensor = scatter_loop_tensor .to (device_input_tensor )
347
+ index_slice = index_slice .to (device_input_tensor )
348
+ src_slice = src_slice .to (device_input_tensor )
349
349
if reduce == "sum" :
350
350
reduceOp = ReduceOperation .SUM
351
351
elif reduce == "prod" :
You can’t perform that action at this time.
0 commit comments