Skip to content

Commit ffc5429

Browse files
Fix dpnp.linalg.solve() hang on CPU (#1778)
* Add wait() for each ht_lapack_ev(gesv) in loop on CPU * Add wait() for each b_ht_copy_ev in the loop
1 parent ceff0c7 commit ffc5429

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,6 +1824,7 @@ def dpnp_solve(a, b):
18241824
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)
18251825

18261826
if a.ndim > 2:
1827+
is_cpu_device = exec_q.sycl_device.has_aspect_cpu
18271828
reshape = False
18281829
orig_shape_b = b_shape
18291830
if a.ndim > 3:
@@ -1850,22 +1851,27 @@ def dpnp_solve(a, b):
18501851
for i in range(batch_size):
18511852
# oneMKL LAPACK assumes fortran-like array as input, so
18521853
# allocate a memory with 'F' order for dpnp array of coefficient matrix
1853-
# and multiple dependent variables array
18541854
coeff_vecs[i] = dpnp.empty_like(
18551855
a[i], order="F", dtype=res_type, usm_type=res_usm_type
18561856
)
1857-
val_vecs[i] = dpnp.empty_like(
1858-
b[i], order="F", dtype=res_type, usm_type=res_usm_type
1859-
)
18601857

18611858
# use DPCTL tensor function to fill the coefficient matrix array
1862-
# and the array of multiple dependent variables with content
1863-
# from the input arrays
1859+
# with content from the input array
18641860
a_ht_copy_ev[i], a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
18651861
src=a_usm_arr[i],
18661862
dst=coeff_vecs[i].get_array(),
18671863
sycl_queue=a.sycl_queue,
18681864
)
1865+
1866+
# oneMKL LAPACK assumes fortran-like array as input, so
1867+
# allocate a memory with 'F' order for dpnp array of multiple
1868+
# dependent variables array
1869+
val_vecs[i] = dpnp.empty_like(
1870+
b[i], order="F", dtype=res_type, usm_type=res_usm_type
1871+
)
1872+
1873+
# use DPCTL tensor function to fill the array of multiple dependent
1874+
# variables with content from the input arrays
18691875
b_ht_copy_ev[i], b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
18701876
src=b_usm_arr[i],
18711877
dst=val_vecs[i].get_array(),
@@ -1882,6 +1888,15 @@ def dpnp_solve(a, b):
18821888
depends=[a_copy_ev, b_copy_ev],
18831889
)
18841890

1891+
# TODO: Remove this w/a when MKLD-17201 is solved.
1892+
# Waiting for a host task executing an OneMKL LAPACK gesv call
1893+
# on CPU causes deadlock due to serialization of all host tasks
1894+
# in the queue.
1895+
# We need to wait for each host tasks before calling _gesv to avoid deadlock.
1896+
if is_cpu_device:
1897+
ht_lapack_ev[i].wait()
1898+
b_ht_copy_ev[i].wait()
1899+
18851900
for i in range(batch_size):
18861901
ht_lapack_ev[i].wait()
18871902
b_ht_copy_ev[i].wait()

tests/test_usm_type.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -899,9 +899,6 @@ def test_eigenvalue(func, shape, usm_type):
899899
)
900900
def test_solve(matrix, vector, usm_type_matrix, usm_type_vector):
901901
x = dp.array(matrix, usm_type=usm_type_matrix)
902-
if x.ndim > 2 and x.device.sycl_device.is_cpu:
903-
pytest.skip("SAT-6842: reported hanging in public CI")
904-
905902
y = dp.array(vector, usm_type=usm_type_vector)
906903
z = dp.linalg.solve(x, y)
907904

tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def check_x(self, a_shape, b_shape, xp, dtype):
4747
testing.assert_array_equal(b_copy, b)
4848
return result
4949

50-
@pytest.mark.skipif(is_cpu_device(), reason="SAT-6842")
5150
def test_solve(self):
5251
self.check_x((4, 4), (4,))
5352
self.check_x((5, 5), (5, 2))

0 commit comments

Comments
 (0)