12
12
ndindex , r_ , s_ , ix_
13
13
)
14
14
15
- # from torch_np import diag_indices, diag_indices_from
15
+ from torch_np import diag_indices , diag_indices_from
16
16
from torch_np ._detail ._index_tricks import index_exp
17
17
18
18
@@ -494,7 +494,6 @@ def test_hetero_shape_handling(self):
494
494
fill_diagonal (a , 2 )
495
495
496
496
497
- @pytest .mark .xfail (reason = 'diag_indices not implemented' )
498
497
def test_diag_indices ():
499
498
di = diag_indices (4 )
500
499
a = np .array ([[1 , 2 , 3 , 4 ],
@@ -513,7 +512,7 @@ def test_diag_indices():
513
512
d3 = diag_indices (2 , 3 )
514
513
515
514
# And use it to set the diagonal of a zeros array to 1:
516
- a = np .zeros ((2 , 2 , 2 ), int )
515
+ a = np .zeros ((2 , 2 , 2 ), dtype = int )
517
516
a [d3 ] = 1
518
517
assert_array_equal (
519
518
a , np .array ([[[1 , 0 ],
@@ -523,7 +522,6 @@ def test_diag_indices():
523
522
)
524
523
525
524
526
- @pytest .mark .xfail (reason = 'diag_indices_from not implemented' )
527
525
class TestDiagIndicesFrom :
528
526
529
527
def test_diag_indices_from (self ):
@@ -534,12 +532,12 @@ def test_diag_indices_from(self):
534
532
535
533
def test_error_small_input (self ):
536
534
x = np .ones (7 )
537
- with assert_raises_regex (ValueError , "at least 2-d" ):
535
+ with assert_raises (ValueError ):
538
536
diag_indices_from (x )
539
537
540
538
def test_error_shape_mismatch (self ):
541
- x = np .zeros ((3 , 3 , 2 , 3 ), int )
542
- with assert_raises_regex (ValueError , "equal length" ):
539
+ x = np .zeros ((3 , 3 , 2 , 3 ), dtype = int )
540
+ with assert_raises (ValueError ):
543
541
diag_indices_from (x )
544
542
545
543
0 commit comments