Skip to content

Commit 9d72887

Browse files
Add w/a for dpnp.linalg.qr on CUDA
1 parent 952a798 commit 9d72887

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,12 @@ def _batched_qr(a, mode="reduced"):
399399
)
400400
_manager.add_event_pair(ht_ev, geqrf_ev)
401401

402+
# w/a to avoid raice conditional on CUDA during multiple runs
403+
# TODO: Remove it ones the OneMath issue is resolved
404+
# https://github.com/uxlfoundation/oneMath/issues/626
405+
if dpnp.is_cuda_backend(a_sycl_queue):
406+
ht_ev.wait()
407+
402408
if mode in ["r", "raw"]:
403409
if mode == "r":
404410
r = a_t[..., :k].swapaxes(-2, -1)
@@ -2470,6 +2476,12 @@ def dpnp_qr(a, mode="reduced"):
24702476
)
24712477
_manager.add_event_pair(ht_ev, geqrf_ev)
24722478

2479+
# w/a to avoid raice conditional on CUDA during multiple runs
2480+
# TODO: Remove it ones the OneMath issue is resolved
2481+
# https://github.com/uxlfoundation/oneMath/issues/626
2482+
if dpnp.is_cuda_backend(a_sycl_queue):
2483+
ht_ev.wait()
2484+
24732485
if mode in ["r", "raw"]:
24742486
if mode == "r":
24752487
r = a_t[:, :k].transpose()

0 commit comments

Comments
 (0)