Skip to content

Commit 648fa95

Browse files
committed
changing the device setting in conversion.py
1 parent 485adf9 commit 648fa95

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def reduce_operation_with_scatter(
318318
print("Invalid Operation for Reduce op!!")
319319

320320
operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
321-
device = to_torch_device(default_device())
321+
device = to_torch_device(scatter_tensor.device)
322322
operation_lhs = operation_lhs.to(device)
323323
operation_rhs = operation_rhs.to(device)
324324
return self.func(operation_lhs, operation_rhs)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch_tensorrt._enums import dtype
1313
from torch_tensorrt._Input import Input
1414
from torch_tensorrt.dynamo import _defaults
15+
from torch_tensorrt.dynamo._defaults import default_device
1516
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
1617
from torch_tensorrt.dynamo._settings import CompilationSettings
1718

@@ -190,7 +191,7 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device:
190191
break
191192

192193
if device is None:
193-
device = torch.device("cpu")
194+
device = to_torch_device(default_device())
194195
logger.warning(
195196
"Could not detect the device on which the model exists. Assuming the model is on CPU"
196197
)

0 commit comments

Comments
 (0)