Skip to content

Commit 63e0787

Browse files
committed
TST: un-xfail tests of diag_indices
1 parent a781856 commit 63e0787

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

torch_np/_detail/implementations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def diag_indices_from(tensor):
138138
# For more than d=2, the strided formula is only valid for arrays with
139139
# all dimensions equal, so we check first.
140140
s = tensor.shape
141-
if any(s[1:] != s[:-1]):
141+
if s[1:] != s[:-1]:
142142
raise ValueError("All dimensions of input must be of equal length")
143143
return diag_indices(s[0], tensor.ndim)
144144

torch_np/tests/numpy_tests/lib/test_index_tricks.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ndindex, r_, s_, ix_
1313
)
1414

15-
# from torch_np import diag_indices, diag_indices_from
15+
from torch_np import diag_indices, diag_indices_from
1616
from torch_np._detail._index_tricks import index_exp
1717

1818

@@ -494,7 +494,6 @@ def test_hetero_shape_handling(self):
494494
fill_diagonal(a, 2)
495495

496496

497-
@pytest.mark.xfail(reason='diag_indices not implemented')
498497
def test_diag_indices():
499498
di = diag_indices(4)
500499
a = np.array([[1, 2, 3, 4],
@@ -513,7 +512,7 @@ def test_diag_indices():
513512
d3 = diag_indices(2, 3)
514513

515514
# And use it to set the diagonal of a zeros array to 1:
516-
a = np.zeros((2, 2, 2), int)
515+
a = np.zeros((2, 2, 2), dtype=int)
517516
a[d3] = 1
518517
assert_array_equal(
519518
a, np.array([[[1, 0],
@@ -523,7 +522,6 @@ def test_diag_indices():
523522
)
524523

525524

526-
@pytest.mark.xfail(reason='diag_indices_from not implemented')
527525
class TestDiagIndicesFrom:
528526

529527
def test_diag_indices_from(self):
@@ -534,12 +532,12 @@ def test_diag_indices_from(self):
534532

535533
def test_error_small_input(self):
536534
x = np.ones(7)
537-
with assert_raises_regex(ValueError, "at least 2-d"):
535+
with assert_raises(ValueError):
538536
diag_indices_from(x)
539537

540538
def test_error_shape_mismatch(self):
541-
x = np.zeros((3, 3, 2, 3), int)
542-
with assert_raises_regex(ValueError, "equal length"):
539+
x = np.zeros((3, 3, 2, 3), dtype=int)
540+
with assert_raises(ValueError):
543541
diag_indices_from(x)
544542

545543

0 commit comments

Comments
 (0)