Skip to content

Commit 807179a

Browse files
Workaround for dpnp.linalg.qr() to run on CUDA (#2265)
This PR suggests adding a workaround like waiting for host task after calling `geqrf` to avoid a race condition due to an issue in oneMath uxlfoundation/oneMath#626 Also updates tests by removing old skips and adds `test_qr_large` in `TestQr`
1 parent 9ad1bb5 commit 807179a

File tree

3 files changed

+52
-27
lines changed

3 files changed

+52
-27
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,14 @@ def _batched_qr(a, mode="reduced"):
397397
batch_size,
398398
depends=[copy_ev],
399399
)
400-
_manager.add_event_pair(ht_ev, geqrf_ev)
400+
401+
# w/a to avoid raice conditional on CUDA during multiple runs
402+
# TODO: Remove it ones the OneMath issue is resolved
403+
# https://github.com/uxlfoundation/oneMath/issues/626
404+
if dpnp.is_cuda_backend(a_sycl_queue):
405+
ht_ev.wait()
406+
else:
407+
_manager.add_event_pair(ht_ev, geqrf_ev)
401408

402409
if mode in ["r", "raw"]:
403410
if mode == "r":
@@ -2468,7 +2475,14 @@ def dpnp_qr(a, mode="reduced"):
24682475
ht_ev, geqrf_ev = li._geqrf(
24692476
a_sycl_queue, a_t.get_array(), tau_h.get_array(), depends=[copy_ev]
24702477
)
2471-
_manager.add_event_pair(ht_ev, geqrf_ev)
2478+
2479+
# w/a to avoid raice conditional on CUDA during multiple runs
2480+
# TODO: Remove it ones the OneMath issue is resolved
2481+
# https://github.com/uxlfoundation/oneMath/issues/626
2482+
if dpnp.is_cuda_backend(a_sycl_queue):
2483+
ht_ev.wait()
2484+
else:
2485+
_manager.add_event_pair(ht_ev, geqrf_ev)
24722486

24732487
if mode in ["r", "raw"]:
24742488
if mode == "r":

dpnp/tests/test_linalg.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,12 +2380,6 @@ class TestQr:
23802380
)
23812381
@pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"])
23822382
def test_qr(self, dtype, shape, mode):
2383-
if (
2384-
is_cuda_device()
2385-
and mode in ["complete", "reduced"]
2386-
and shape in [(16, 16), (2, 2, 4)]
2387-
):
2388-
pytest.skip("SAT-7589")
23892383
a = generate_random_numpy_array(shape, dtype, seed_value=81)
23902384
ia = dpnp.array(a)
23912385

@@ -2398,24 +2392,48 @@ def test_qr(self, dtype, shape, mode):
23982392

23992393
# check decomposition
24002394
if mode in ("complete", "reduced"):
2401-
if a.ndim == 2:
2402-
assert_almost_equal(
2403-
dpnp.dot(dpnp_q, dpnp_r),
2404-
a,
2405-
decimal=5,
2406-
)
2407-
else: # a.ndim > 2
2408-
assert_almost_equal(
2409-
dpnp.matmul(dpnp_q, dpnp_r),
2410-
a,
2411-
decimal=5,
2412-
)
2395+
assert_almost_equal(
2396+
dpnp.matmul(dpnp_q, dpnp_r),
2397+
a,
2398+
decimal=5,
2399+
)
24132400
else: # mode=="raw"
24142401
assert_dtype_allclose(dpnp_q, np_q)
24152402

24162403
if mode in ("raw", "r"):
24172404
assert_dtype_allclose(dpnp_r, np_r)
24182405

2406+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
2407+
@pytest.mark.parametrize(
2408+
"shape",
2409+
[(32, 32), (8, 16, 16)],
2410+
ids=[
2411+
"(32, 32)",
2412+
"(8, 16, 16)",
2413+
],
2414+
)
2415+
@pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"])
2416+
def test_qr_large(self, dtype, shape, mode):
2417+
a = generate_random_numpy_array(shape, dtype, seed_value=81)
2418+
ia = dpnp.array(a)
2419+
if mode == "r":
2420+
np_r = numpy.linalg.qr(a, mode)
2421+
dpnp_r = dpnp.linalg.qr(ia, mode)
2422+
else:
2423+
np_q, np_r = numpy.linalg.qr(a, mode)
2424+
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
2425+
# check decomposition
2426+
if mode in ("complete", "reduced"):
2427+
assert_almost_equal(
2428+
dpnp.matmul(dpnp_q, dpnp_r),
2429+
a,
2430+
decimal=5,
2431+
)
2432+
else: # mode=="raw"
2433+
assert_allclose(np_q, dpnp_q, atol=1e-4)
2434+
if mode in ("raw", "r"):
2435+
assert_allclose(np_r, dpnp_r, atol=1e-4)
2436+
24192437
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
24202438
@pytest.mark.parametrize(
24212439
"shape",

dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,7 @@ def test_decomposition(self, dtype):
163163
class TestQRDecomposition(unittest.TestCase):
164164

165165
@testing.for_dtypes("fdFD")
166-
# skip cases with 'complete' and 'reduce' modes on CUDA (SAT-7611)
167166
def check_mode(self, array, mode, dtype):
168-
if (
169-
is_cuda_device()
170-
and array.size > 0
171-
and mode in ["complete", "reduced"]
172-
):
173-
return
174167
a_cpu = numpy.asarray(array, dtype=dtype)
175168
a_gpu = cupy.asarray(array, dtype=dtype)
176169
result_gpu = cupy.linalg.qr(a_gpu, mode=mode)

0 commit comments

Comments
 (0)