Skip to content

Commit be401a6

Browse files
Fixed missing dependency events
1 parent 9116e73 commit be401a6

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

dpctl/tensor/_ctors.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,8 +1631,9 @@ def tril(x, /, *, k=0):
16311631
sycl_queue=q,
16321632
)
16331633
_manager = dpctl.utils.SequentialOrderManager
1634+
dep_evs = _manager.submitted_events
16341635
hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1635-
src=x, dst=res, sycl_queue=q
1636+
src=x, dst=res, sycl_queue=q, depends=dep_evs
16361637
)
16371638
_manager.add_event_pair(hev, cpy_ev)
16381639
elif k < -shape[nd - 2]:
@@ -1652,7 +1653,10 @@ def tril(x, /, *, k=0):
16521653
sycl_queue=q,
16531654
)
16541655
_manager = dpctl.utils.SequentialOrderManager
1655-
hev, tril_ev = ti._tril(src=x, dst=res, k=k, sycl_queue=q)
1656+
dep_evs = _manager.submitted_events
1657+
hev, tril_ev = ti._tril(
1658+
src=x, dst=res, k=k, sycl_queue=q, depends=dep_evs
1659+
)
16561660
_manager.add_event_pair(hev, tril_ev)
16571661

16581662
return res
@@ -1713,8 +1717,9 @@ def triu(x, /, *, k=0):
17131717
sycl_queue=q,
17141718
)
17151719
_manager = dpctl.utils.SequentialOrderManager
1720+
dep_evs = _manager.submitted_events
17161721
hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1717-
src=x, dst=res, sycl_queue=q
1722+
src=x, dst=res, sycl_queue=q, depends=dep_evs
17181723
)
17191724
_manager.add_event_pair(hev, cpy_ev)
17201725
else:
@@ -1726,7 +1731,10 @@ def triu(x, /, *, k=0):
17261731
sycl_queue=q,
17271732
)
17281733
_manager = dpctl.utils.SequentialOrderManager
1729-
hev, triu_ev = ti._triu(src=x, dst=res, k=k, sycl_queue=q)
1734+
dep_evs = _manager.submitted_events
1735+
hev, triu_ev = ti._triu(
1736+
src=x, dst=res, k=k, sycl_queue=q, depends=dep_evs
1737+
)
17301738
_manager.add_event_pair(hev, triu_ev)
17311739

17321740
return res

0 commit comments

Comments
 (0)