Skip to content

Commit 3323156

Browse files
authored
empty tensor moving to default device (#2948)
1 parent e3363df commit 3323156

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import torch
55
from torch._decomp import register_decomposition
66
from torch._ops import OpOverload
7+
from torch_tensorrt.dynamo._defaults import default_device
78
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
9+
from torch_tensorrt.dynamo.utils import to_torch_device
810

911
from ._decomposition_groups import (
1012
ENABLED_TORCH_DECOMPOSITIONS,
@@ -172,6 +174,7 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
172174
perm = [0] * len(empty_size)
173175
for permute_index, permute_element in enumerate(empty_permute):
174176
perm[permute_element] = permute_index
177+
kwargs["device"] = to_torch_device(default_device())
175178
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)
176179

177180

@@ -233,7 +236,11 @@ def select_scatter_decomposition(
233236
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
234237
empty_size = args[0]
235238
empty_stride = args[1]
236-
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)
239+
return torch.as_strided(
240+
torch.empty(empty_size, device=to_torch_device(default_device())),
241+
empty_size,
242+
empty_stride,
243+
)
237244

238245

239246
def get_decompositions(

0 commit comments

Comments
 (0)