Skip to content

Commit 0c74aca

Browse files
committed
MAINT: address further review comments
1 parent abc8de6 commit 0c74aca

File tree

2 files changed

+6
-14
lines changed

2 files changed

+6
-14
lines changed

torch_np/_funcs.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -888,20 +888,14 @@ def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False
888888
return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
889889

890890

891-
def _tensor_equal(a1_t, a2_t, equal_nan=False):
891+
def _tensor_equal(a1, a2, equal_nan=False):
892892
# Implementation of array_equal/array_equiv.
893-
if a1_t.shape != a2_t.shape:
894-
return False
895893
if equal_nan:
896-
nan_loc = (torch.isnan(a1_t) == torch.isnan(a2_t)).all()
897-
if nan_loc:
898-
# check the values
899-
result = a1_t[~torch.isnan(a1_t)] == a2_t[~torch.isnan(a2_t)]
900-
else:
901-
return False
894+
return (a1.shape == a2.shape) and (
895+
(a1 == a2) | (torch.isnan(a1) & torch.isnan(a2))
896+
).all().item()
902897
else:
903-
result = a1_t == a2_t
904-
return bool(result.all())
898+
return torch.equal(a1, a2)
905899

906900

907901
@normalizer
@@ -1822,8 +1816,6 @@ def i0(x: ArrayLike):
18221816
@normalizer(return_on_failure=False)
18231817
def isscalar(a: ArrayLike):
18241818
# XXX: this is a stub
1825-
if a is False:
1826-
return False
18271819
return a.numel() == 1
18281820

18291821

torch_np/_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def reshape(self, *shape, order="C"):
358358

359359
def sort(self, axis=-1, kind=None, order=None):
360360
# ndarray.sort works in-place
361-
self.tensor = _funcs._sort(self.tensor, axis, kind, order)
361+
self.tensor.copy_(_funcs._sort(self.tensor, axis, kind, order))
362362

363363
argsort = _funcs.argsort
364364
searchsorted = _funcs.searchsorted

0 commit comments

Comments
 (0)