Skip to content

Commit 5431f29

Browse files
committed
empty tensor moving to default device
1 parent 6aa439b commit 5431f29

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
172172
perm = [0] * len(empty_size)
173173
for permute_index, permute_element in enumerate(empty_permute):
174174
perm[permute_element] = permute_index
175+
default_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
176+
kwargs[device] = default_device
175177
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)
176178

177179

@@ -233,7 +235,8 @@ def select_scatter_decomposition(
233235
def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
234236
empty_size = args[0]
235237
empty_stride = args[1]
236-
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)
238+
default_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
239+
return torch.as_strided(torch.empty(empty_size, device = default_device), empty_size, empty_stride)
237240

238241

239242
def get_decompositions(

0 commit comments

Comments
 (0)