Skip to content

Commit 7f8bb4f

Browse files
committed
empty tensor moving to default device
1 parent 6aa439b commit 7f8bb4f

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
172172
perm = [0] * len(empty_size)
173173
for permute_index, permute_element in enumerate(empty_permute):
174174
perm[permute_element] = permute_index
175+
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
176+
kwargs[device] = default_device
175177
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)
176178

177179

@@ -233,7 +235,10 @@ def select_scatter_decomposition(
233235
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
234236
empty_size = args[0]
235237
empty_stride = args[1]
236-
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)
238+
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
239+
return torch.as_strided(
240+
torch.empty(empty_size, device=default_device), empty_size, empty_stride
241+
)
237242

238243

239244
def get_decompositions(

0 commit comments

Comments
 (0)