Skip to content

Commit 35f2b00

Browse files
committed
changing the device setting in conversion.py
1 parent 485adf9 commit 35f2b00

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def reduce_operation_with_scatter(
318318
print("Invalid Operation for Reduce op!!")
319319

320320
operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
321-
device = to_torch_device(default_device())
321+
device = to_torch_device(scatter_tensor.device)
322322
operation_lhs = operation_lhs.to(device)
323323
operation_rhs = operation_rhs.to(device)
324324
return self.func(operation_lhs, operation_rhs)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
77

88
import numpy as np
9+
import tensorrt as trt
910
import torch
1011
from torch._subclasses.fake_tensor import FakeTensor
1112
from torch_tensorrt._Device import Device
1213
from torch_tensorrt._enums import dtype
1314
from torch_tensorrt._Input import Input
1415
from torch_tensorrt.dynamo import _defaults
16+
from torch_tensorrt.dynamo._defaults import default_device
1517
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
1618
from torch_tensorrt.dynamo._settings import CompilationSettings
1719

18-
import tensorrt as trt
1920
from packaging import version
2021

2122
from .types import TRTDataType
@@ -186,11 +187,14 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device:
186187
device = None
187188
for parameter in list(module.parameters()):
188189
if isinstance(parameter, (torch.nn.parameter.Parameter, torch.Tensor)):
189-
device = parameter.device
190-
break
190+
return parameter.device
191+
192+
for buffer in list(module.buffers()):
193+
if isinstance(buffer, (torch.Tensor)):
194+
return buffer.device
191195

192196
if device is None:
193-
device = torch.device("cpu")
197+
device = to_torch_device(default_device())
194198
logger.warning(
195199
"Could not detect the device on which the model exists. Assuming the model is on CPU"
196200
)

0 commit comments

Comments
 (0)