Skip to content

Commit d86a957

Browse files
Update test_eigenvalues in test_linalg.py
1 parent ed8e307 commit d86a957

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,14 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
932932
sycl_queue=a_sycl_queue,
933933
)
934934

935+
# TODO: Remove this w/a when MKLD-17201 is solved.
936+
# Waiting for a host task executing an OneMKL LAPACK syevd call
937+
# on CPU causes deadlock due to serialization of all host tasks
938+
# in the queue.
939+
# We need to wait for each host tasks before calling _seyvd to avoid deadlock.
940+
if lapack_func == "_syevd" and is_cpu_device:
941+
ht_list_ev[2 * i].wait()
942+
935943
# call LAPACK extension function to get eigenvalues and eigenvectors of a portion of matrix A
936944
ht_list_ev[2 * i + 1], _ = getattr(li, lapack_func)(
937945
a_sycl_queue,
@@ -942,15 +950,6 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
942950
depends=[copy_ev],
943951
)
944952

945-
# TODO: Remove this w/a when MKLD-17201 is solved.
946-
# Waiting for a host task executing an OneMKL LAPACK syevd call
947-
# on CPU causes deadlock due to serialization of all host tasks
948-
# in the queue.
949-
# We need to wait for each host tasks before calling _seyvd again
950-
# to avoid deadlock.
951-
if lapack_func == "_syevd" and is_cpu_device:
952-
ht_list_ev[2 * i + 1].wait()
953-
954953
dpctl.SyclEvent.wait_for(ht_list_ev)
955954

956955
w = w.reshape(w_orig_shape)

tests/test_linalg.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,23 @@ def test_eig_arange(type, size):
386386

387387

388388
class TestEigenvalue:
389+
# Eigenvalue decomposition of a matrix or a batch of matrices
390+
# by checking if the eigen equation A*v=w*v holds for given eigenvalues(w)
391+
# and eigenvectors(v).
392+
def assert_eigen_decomposition(self, a, w, v, rtol=1e-5, atol=1e-5):
393+
a_ndim = a.ndim
394+
if a_ndim == 2:
395+
assert_allclose(a @ v, v @ inp.diag(w), rtol=rtol, atol=atol)
396+
else: # a_ndim > 2
397+
if a_ndim > 3:
398+
a = a.reshape(-1, *a.shape[-2:])
399+
w = w.reshape(-1, w.shape[-1])
400+
v = v.reshape(-1, *v.shape[-2:])
401+
for i in range(a.shape[0]):
402+
assert_allclose(
403+
a[i].dot(v[i]), w[i] * v[i], rtol=rtol, atol=atol
404+
)
405+
389406
@pytest.mark.parametrize(
390407
"func",
391408
[
@@ -413,11 +430,16 @@ def test_eigenvalues(self, func, shape, dtype, order):
413430
a_order = numpy.array(a, order=order)
414431
a_dp = inp.array(a, order=order)
415432

433+
# NumPy with OneMKL and with rocSOLVER sorts in ascending order,
434+
# so w's should be directly comparable.
435+
# However, both OneMKL and rocSOLVER pick a different convention for
436+
# constructing eigenvectors, so v's are not directly comparible and
437+
# we verify them through the eigen equation A*v=w*v.
416438
if func == "eigh":
417-
w, v = numpy.linalg.eigh(a_order)
439+
w, _ = numpy.linalg.eigh(a_order)
418440
w_dp, v_dp = inp.linalg.eigh(a_dp)
419441

420-
assert_dtype_allclose(v_dp, v)
442+
self.assert_eigen_decomposition(a_dp, w_dp, v_dp)
421443

422444
else: # eighvalsh
423445
w = numpy.linalg.eigvalsh(a_order)

0 commit comments

Comments
 (0)