@@ -1631,8 +1631,9 @@ def tril(x, /, *, k=0):
1631
1631
sycl_queue = q ,
1632
1632
)
1633
1633
_manager = dpctl .utils .SequentialOrderManager
1634
+ dep_evs = _manager .submitted_events
1634
1635
hev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
1635
- src = x , dst = res , sycl_queue = q
1636
+ src = x , dst = res , sycl_queue = q , depends = dep_evs
1636
1637
)
1637
1638
_manager .add_event_pair (hev , cpy_ev )
1638
1639
elif k < - shape [nd - 2 ]:
@@ -1652,7 +1653,10 @@ def tril(x, /, *, k=0):
1652
1653
sycl_queue = q ,
1653
1654
)
1654
1655
_manager = dpctl .utils .SequentialOrderManager
1655
- hev , tril_ev = ti ._tril (src = x , dst = res , k = k , sycl_queue = q )
1656
+ dep_evs = _manager .submitted_events
1657
+ hev , tril_ev = ti ._tril (
1658
+ src = x , dst = res , k = k , sycl_queue = q , depends = dep_evs
1659
+ )
1656
1660
_manager .add_event_pair (hev , tril_ev )
1657
1661
1658
1662
return res
@@ -1713,8 +1717,9 @@ def triu(x, /, *, k=0):
1713
1717
sycl_queue = q ,
1714
1718
)
1715
1719
_manager = dpctl .utils .SequentialOrderManager
1720
+ dep_evs = _manager .submitted_events
1716
1721
hev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
1717
- src = x , dst = res , sycl_queue = q
1722
+ src = x , dst = res , sycl_queue = q , depends = dep_evs
1718
1723
)
1719
1724
_manager .add_event_pair (hev , cpy_ev )
1720
1725
else :
@@ -1726,7 +1731,10 @@ def triu(x, /, *, k=0):
1726
1731
sycl_queue = q ,
1727
1732
)
1728
1733
_manager = dpctl .utils .SequentialOrderManager
1729
- hev , triu_ev = ti ._triu (src = x , dst = res , k = k , sycl_queue = q )
1734
+ dep_evs = _manager .submitted_events
1735
+ hev , triu_ev = ti ._triu (
1736
+ src = x , dst = res , k = k , sycl_queue = q , depends = dep_evs
1737
+ )
1730
1738
_manager .add_event_pair (hev , triu_ev )
1731
1739
1732
1740
return res
0 commit comments