8
8
from .helper import (
9
9
assert_dtype_allclose ,
10
10
get_all_dtypes ,
11
+ get_complex_dtypes ,
11
12
has_support_aspect64 ,
12
13
is_cpu_device ,
13
14
)
@@ -563,6 +564,76 @@ def test_solve_errors(self):
563
564
564
565
565
566
class TestSvd :
567
+ def set_tol (self , dtype ):
568
+ tol = 1e-06
569
+ if dtype in (inp .float32 , inp .complex64 ):
570
+ tol = 1e-05
571
+ elif not has_support_aspect64 () and dtype in (
572
+ inp .int32 ,
573
+ inp .int64 ,
574
+ None ,
575
+ ):
576
+ tol = 1e-05
577
+ return tol
578
+
579
+ def check_types_shapes (
580
+ self , dp_u , dp_s , dp_vt , np_u , np_s , np_vt , compute_vt = True
581
+ ):
582
+ if has_support_aspect64 ():
583
+ if compute_vt :
584
+ assert dp_u .dtype == np_u .dtype
585
+ assert dp_vt .dtype == np_vt .dtype
586
+ assert dp_s .dtype == np_s .dtype
587
+ else :
588
+ if compute_vt :
589
+ assert dp_u .dtype .kind == np_u .dtype .kind
590
+ assert dp_vt .dtype .kind == np_vt .dtype .kind
591
+ assert dp_s .dtype .kind == np_s .dtype .kind
592
+
593
+ if compute_vt :
594
+ assert dp_u .shape == np_u .shape
595
+ assert dp_vt .shape == np_vt .shape
596
+ assert dp_s .shape == np_s .shape
597
+
598
+ def check_decomposition (
599
+ self , dp_a , dp_u , dp_s , dp_vt , np_u , np_s , np_vt , compute_vt , tol
600
+ ):
601
+ if compute_vt :
602
+ dpnp_diag_s = inp .zeros_like (dp_a , dtype = dp_s .dtype )
603
+ for i in range (min (dp_a .shape [- 2 ], dp_a .shape [- 1 ])):
604
+ dpnp_diag_s [..., i , i ] = dp_s [..., i ]
605
+ # TODO: remove it when dpnp.dot is updated
606
+ # dpnp.dot does not support complex type
607
+ if inp .issubdtype (dp_a .dtype , inp .complexfloating ):
608
+ reconstructed = numpy .dot (
609
+ inp .asnumpy (dp_u ),
610
+ numpy .dot (inp .asnumpy (dpnp_diag_s ), inp .asnumpy (dp_vt )),
611
+ )
612
+ else :
613
+ reconstructed = inp .dot (dp_u , inp .dot (dpnp_diag_s , dp_vt ))
614
+ assert_allclose (dp_a , reconstructed , rtol = tol , atol = tol )
615
+
616
+ assert_allclose (dp_s , np_s , rtol = tol , atol = 1e-03 )
617
+
618
+ if compute_vt :
619
+ for i in range (min (dp_a .shape [- 2 ], dp_a .shape [- 1 ])):
620
+ if np_u [..., 0 , i ] * dp_u [..., 0 , i ] < 0 :
621
+ np_u [..., :, i ] = - np_u [..., :, i ]
622
+ np_vt [..., i , :] = - np_vt [..., i , :]
623
+ for i in range (numpy .count_nonzero (np_s > tol )):
624
+ assert_allclose (
625
+ inp .asnumpy (dp_u [..., :, i ]),
626
+ np_u [..., :, i ],
627
+ rtol = tol ,
628
+ atol = tol ,
629
+ )
630
+ assert_allclose (
631
+ inp .asnumpy (dp_vt [..., i , :]),
632
+ np_vt [..., i , :],
633
+ rtol = tol ,
634
+ atol = tol ,
635
+ )
636
+
566
637
@pytest .mark .parametrize ("dtype" , get_all_dtypes (no_bool = True ))
567
638
@pytest .mark .parametrize (
568
639
"shape" ,
@@ -571,71 +642,50 @@ class TestSvd:
571
642
)
572
643
def test_svd (self , dtype , shape ):
573
644
a = numpy .arange (shape [0 ] * shape [1 ], dtype = dtype ).reshape (shape )
574
- ia = inp .array (a )
645
+ dp_a = inp .array (a )
575
646
576
647
np_u , np_s , np_vt = numpy .linalg .svd (a )
577
- dpnp_u , dpnp_s , dpnp_vt = inp .linalg .svd (ia )
578
-
579
- support_aspect64 = has_support_aspect64 ()
648
+ dp_u , dp_s , dp_vt = inp .linalg .svd (dp_a )
580
649
581
- if support_aspect64 :
582
- assert dpnp_u .dtype == np_u .dtype
583
- assert dpnp_s .dtype == np_s .dtype
584
- assert dpnp_vt .dtype == np_vt .dtype
585
-
586
- assert dpnp_u .shape == np_u .shape
587
- assert dpnp_s .shape == np_s .shape
588
- assert dpnp_vt .shape == np_vt .shape
650
+ self .check_types_shapes (dp_u , dp_s , dp_vt , np_u , np_s , np_vt )
651
+ tol = self .set_tol (dtype )
652
+ self .check_decomposition (
653
+ dp_a , dp_u , dp_s , dp_vt , np_u , np_s , np_vt , True , tol
654
+ )
589
655
590
- tol = 1e-06
591
- if dtype in (inp .float32 , inp .complex64 ):
592
- tol = 1e-05
593
- elif not support_aspect64 and dtype in (inp .int32 , inp .int64 , None ):
594
- tol = 1e-05
656
+ @pytest .mark .parametrize ("dtype" , get_complex_dtypes ())
657
+ @pytest .mark .parametrize ("compute_vt" , [True , False ], ids = ["True" , "False" ])
658
+ @pytest .mark .parametrize (
659
+ "shape" ,
660
+ [(2 , 2 ), (16 , 16 )],
661
+ ids = ["(2,2)" , "(16, 16)" ],
662
+ )
663
+ def test_svd_hermitian (self , dtype , compute_vt , shape ):
664
+ a = numpy .random .randn (* shape ) + 1j * numpy .random .randn (* shape )
665
+ a = numpy .conj (a .T ) @ a
595
666
596
- # check decomposition
597
- dpnp_diag_s = inp .zeros (shape , dtype = dpnp_s .dtype )
598
- for i in range (dpnp_s .size ):
599
- dpnp_diag_s [i , i ] = dpnp_s [i ]
667
+ a = a .astype (dtype )
668
+ dp_a = inp .array (a )
600
669
601
- # check decomposition
602
- # TODO: remove it when dpnp.dot is updated
603
- # dpnp.dot does not support complex type
604
- if inp .issubdtype (dtype , inp .complexfloating ):
605
- assert_allclose (
606
- inp .asnumpy (ia ),
607
- numpy .dot (
608
- inp .asnumpy (dpnp_u ),
609
- numpy .dot (inp .asnumpy (dpnp_diag_s ), inp .asnumpy (dpnp_vt )),
610
- ),
611
- rtol = tol ,
612
- atol = tol ,
670
+ if compute_vt :
671
+ np_u , np_s , np_vt = numpy .linalg .svd (
672
+ a , compute_uv = compute_vt , hermitian = True
613
673
)
614
- else :
615
- assert_allclose (
616
- ia ,
617
- inp .dot (dpnp_u , inp .dot (dpnp_diag_s , dpnp_vt )),
618
- rtol = tol ,
619
- atol = tol ,
674
+ dp_u , dp_s , dp_vt = inp .linalg .svd (
675
+ dp_a , compute_uv = compute_vt , hermitian = True
620
676
)
677
+ else :
678
+ np_s = numpy .linalg .svd (a , compute_uv = compute_vt , hermitian = True )
679
+ dp_s = inp .linalg .svd (dp_a , compute_uv = compute_vt , hermitian = True )
680
+ np_u = np_vt = dp_u = dp_vt = None
621
681
622
- # compare singular values
623
- assert_allclose (dpnp_s , np_s , rtol = tol , atol = 1e-03 )
624
-
625
- # change sign of vectors
626
- for i in range (min (shape [0 ], shape [1 ])):
627
- if np_u [0 , i ] * dpnp_u [0 , i ] < 0 :
628
- np_u [:, i ] = - np_u [:, i ]
629
- np_vt [i , :] = - np_vt [i , :]
630
-
631
- # compare vectors for non-zero values
632
- for i in range (numpy .count_nonzero (np_s > tol )):
633
- assert_allclose (
634
- inp .asnumpy (dpnp_u )[:, i ], np_u [:, i ], rtol = tol , atol = tol
635
- )
636
- assert_allclose (
637
- inp .asnumpy (dpnp_vt )[i , :], np_vt [i , :], rtol = tol , atol = tol
638
- )
682
+ self .check_types_shapes (
683
+ dp_u , dp_s , dp_vt , np_u , np_s , np_vt , compute_vt
684
+ )
685
+ tol = self .set_tol (dtype )
686
+ self .check_decomposition (
687
+ dp_a , dp_u , dp_s , dp_vt , np_u , np_s , np_vt , compute_vt , tol
688
+ )
639
689
640
690
def test_svd_errors (self ):
641
691
a_dp = inp .array ([[1 , 2 ], [3 , 4 ]], dtype = "float32" )
0 commit comments