Skip to content

Commit f124297

Browse files
committed
changing the device setting in conversion.py
1 parent 485adf9 commit f124297

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def reduce_operation_with_scatter(
318318
print("Invalid Operation for Reduce op!!")
319319

320320
operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
321-
device = to_torch_device(default_device())
321+
device = to_torch_device(scatter_tensor.device)
322322
operation_lhs = operation_lhs.to(device)
323323
operation_rhs = operation_rhs.to(device)
324324
return self.func(operation_lhs, operation_rhs)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch_tensorrt._enums import dtype
1313
from torch_tensorrt._Input import Input
1414
from torch_tensorrt.dynamo import _defaults
15+
from torch_tensorrt.dynamo._defaults import default_device
1516
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
1617
from torch_tensorrt.dynamo._settings import CompilationSettings
1718

@@ -186,11 +187,14 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device:
186187
device = None
187188
for parameter in list(module.parameters()):
188189
if isinstance(parameter, (torch.nn.parameter.Parameter, torch.Tensor)):
189-
device = parameter.device
190-
break
190+
return parameter.device
191+
192+
for buffer in list(module.buffers()):
193+
if isinstance(buffer, (torch.Tensor)):
194+
return buffer.device
191195

192196
if device is None:
193-
device = torch.device("cpu")
197+
device = to_torch_device(default_device())
194198
logger.warning(
195199
"Could not detect the device on which the model exists. Assuming the model is on CPU"
196200
)

0 commit comments

Comments
 (0)