@@ -386,6 +386,23 @@ def test_eig_arange(type, size):
386
386
387
387
388
388
class TestEigenvalue :
389
+ # Eigenvalue decomposition of a matrix or a batch of matrices
390
+ # by checking if the eigen equation A*v=w*v holds for given eigenvalues(w)
391
+ # and eigenvectors(v).
392
+ def assert_eigen_decomposition (self , a , w , v , rtol = 1e-5 , atol = 1e-5 ):
393
+ a_ndim = a .ndim
394
+ if a_ndim == 2 :
395
+ assert_allclose (a @ v , v @ inp .diag (w ), rtol = rtol , atol = atol )
396
+ else : # a_ndim > 2
397
+ if a_ndim > 3 :
398
+ a = a .reshape (- 1 , * a .shape [- 2 :])
399
+ w = w .reshape (- 1 , w .shape [- 1 ])
400
+ v = v .reshape (- 1 , * v .shape [- 2 :])
401
+ for i in range (a .shape [0 ]):
402
+ assert_allclose (
403
+ a [i ].dot (v [i ]), w [i ] * v [i ], rtol = rtol , atol = atol
404
+ )
405
+
389
406
@pytest .mark .parametrize (
390
407
"func" ,
391
408
[
@@ -413,11 +430,16 @@ def test_eigenvalues(self, func, shape, dtype, order):
413
430
a_order = numpy .array (a , order = order )
414
431
a_dp = inp .array (a , order = order )
415
432
433
+ # NumPy with OneMKL and with rocSOLVER sorts in ascending order,
434
+ # so w's should be directly comparable.
435
+ # However, both OneMKL and rocSOLVER pick a different convention for
436
+ # constructing eigenvectors, so v's are not directly comparible and
437
+ # we verify them through the eigen equation A*v=w*v.
416
438
if func == "eigh" :
417
- w , v = numpy .linalg .eigh (a_order )
439
+ w , _ = numpy .linalg .eigh (a_order )
418
440
w_dp , v_dp = inp .linalg .eigh (a_dp )
419
441
420
- assert_dtype_allclose ( v_dp , v )
442
+ self . assert_eigen_decomposition ( a_dp , w_dp , v_dp )
421
443
422
444
else : # eighvalsh
423
445
w = numpy .linalg .eigvalsh (a_order )
0 commit comments