Skip to content

Commit b716d34

Browse files
committed
Upcast int tensor indices
1 parent 5081a98 commit b716d34

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

torch_np/_ndarray.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,19 @@ def clip(self, min, max, out=None):
410410
)
411411

412412
### indexing ###
413+
@staticmethod
414+
def _upcast_int_indices(index):
415+
if isinstance(index, torch.Tensor):
416+
if index.dtype in [torch.int8, torch.int16, torch.int32]:
417+
return index.type(torch.int64)
418+
elif isinstance(index, tuple):
419+
return tuple(ndarray._upcast_int_indices(i) for i in index)
420+
return index
421+
413422
def __getitem__(self, index):
414-
t_index = _helpers.ndarrays_to_tensors(index)
415-
return ndarray._from_tensor_and_base(self._tensor.__getitem__(t_index), self)
423+
index = _helpers.ndarrays_to_tensors(index)
424+
index = ndarray._upcast_int_indices(index)
425+
return ndarray._from_tensor_and_base(self._tensor.__getitem__(index), self)
416426

417427
def __setitem__(self, index, value):
418428
value = asarray(value).get()

0 commit comments

Comments
 (0)