31
31
)
32
32
from dpctl .tensor ._manipulation_functions import _broadcast_shape_impl
33
33
from dpctl .tensor ._type_utils import _can_cast , _to_device_supported_dtype
34
- from dpctl .utils import ExecutionPlacementError
34
+ from dpctl .utils import ExecutionPlacementError , SequentialOrderManager
35
35
36
36
from ._type_utils import (
37
37
WeakComplexType ,
@@ -299,18 +299,21 @@ def _clip_none(x, val, out, order, _binary_fn):
299
299
x = dpt .broadcast_to (x , res_shape )
300
300
if val_ary .shape != res_shape :
301
301
val_ary = dpt .broadcast_to (val_ary , res_shape )
302
+ _manager = SequentialOrderManager
303
+ dep_evs = _manager .submitted_events
302
304
ht_binary_ev , binary_ev = _binary_fn (
303
- src1 = x , src2 = val_ary , dst = out , sycl_queue = exec_q
305
+ src1 = x , src2 = val_ary , dst = out , sycl_queue = exec_q , depends = dep_evs
304
306
)
307
+ _manager .add_event_pair (ht_binary_ev , binary_ev )
305
308
if not (orig_out is None or orig_out is out ):
306
309
# Copy the out data from temporary buffer to original memory
307
- ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
310
+ ht_copy_out_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
308
311
src = out ,
309
312
dst = orig_out ,
310
313
sycl_queue = exec_q ,
311
314
depends = [binary_ev ],
312
315
)
313
- ht_copy_out_ev . wait ( )
316
+ _manager . add_event_pair ( ht_copy_out_ev , copy_ev )
314
317
out = orig_out
315
318
ht_binary_ev .wait ()
316
319
return out
@@ -319,9 +322,12 @@ def _clip_none(x, val, out, order, _binary_fn):
319
322
buf = _empty_like_orderK (val_ary , res_dt )
320
323
else :
321
324
buf = dpt .empty_like (val_ary , dtype = res_dt , order = order )
325
+ _manager = SequentialOrderManager
326
+ dep_evs = _manager .submitted_events
322
327
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
323
- src = val_ary , dst = buf , sycl_queue = exec_q
328
+ src = val_ary , dst = buf , sycl_queue = exec_q , depends = dep_evs
324
329
)
330
+ _manager .add_event_pair (ht_copy_ev , copy_ev )
325
331
if out is None :
326
332
if order == "K" :
327
333
out = _empty_like_pair_orderK (
@@ -346,18 +352,17 @@ def _clip_none(x, val, out, order, _binary_fn):
346
352
sycl_queue = exec_q ,
347
353
depends = [copy_ev ],
348
354
)
355
+ _manager .add_event_pair (ht_binary_ev , binary_ev )
349
356
if not (orig_out is None or orig_out is out ):
350
357
# Copy the out data from temporary buffer to original memory
351
- ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
358
+ ht_copy_out_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
352
359
src = out ,
353
360
dst = orig_out ,
354
361
sycl_queue = exec_q ,
355
362
depends = [binary_ev ],
356
363
)
357
- ht_copy_out_ev . wait ( )
364
+ _manager . add_event_pair ( ht_copy_out_ev , cpy_ev )
358
365
out = orig_out
359
- ht_copy_ev .wait ()
360
- ht_binary_ev .wait ()
361
366
return out
362
367
363
368
@@ -444,20 +449,22 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
444
449
else :
445
450
out = dpt .empty_like (x , order = order )
446
451
452
+ _manager = SequentialOrderManager
453
+ dep_evs = _manager .submitted_events
447
454
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
448
- src = x , dst = out , sycl_queue = exec_q
455
+ src = x , dst = out , sycl_queue = exec_q , depends = dep_evs
449
456
)
457
+ _manager .add_event_pair (ht_copy_ev , copy_ev )
450
458
if not (orig_out is None or orig_out is out ):
451
459
# Copy the out data from temporary buffer to original memory
452
- ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
460
+ ht_copy_out_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
453
461
src = out ,
454
462
dst = orig_out ,
455
463
sycl_queue = exec_q ,
456
464
depends = [copy_ev ],
457
465
)
458
- ht_copy_out_ev . wait ( )
466
+ _manager . add_event_pair ( ht_copy_ev , cpy_ev )
459
467
out = orig_out
460
- ht_copy_ev .wait ()
461
468
return out
462
469
elif max is None :
463
470
return _clip_none (x , min , out , order , tei ._maximum )
@@ -665,30 +672,40 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
665
672
a_min = dpt .broadcast_to (a_min , res_shape )
666
673
if a_max .shape != res_shape :
667
674
a_max = dpt .broadcast_to (a_max , res_shape )
675
+ _manager = SequentialOrderManager
676
+ dep_ev = _manager .submitted_events
668
677
ht_binary_ev , binary_ev = ti ._clip (
669
- src = x , min = a_min , max = a_max , dst = out , sycl_queue = exec_q
678
+ src = x ,
679
+ min = a_min ,
680
+ max = a_max ,
681
+ dst = out ,
682
+ sycl_queue = exec_q ,
683
+ depends = dep_ev ,
670
684
)
685
+ _manager .add_event_pair (ht_binary_ev , binary_ev )
671
686
if not (orig_out is None or orig_out is out ):
672
687
# Copy the out data from temporary buffer to original memory
673
- ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
688
+ ht_copy_out_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
674
689
src = out ,
675
690
dst = orig_out ,
676
691
sycl_queue = exec_q ,
677
692
depends = [binary_ev ],
678
693
)
679
- ht_copy_out_ev . wait ( )
694
+ _manager . add_event_pair ( ht_copy_out_ev , cpy_ev )
680
695
out = orig_out
681
- ht_binary_ev .wait ()
682
696
return out
683
697
684
698
elif buf1_dt is None :
685
699
if order == "K" :
686
700
buf2 = _empty_like_orderK (a_max , buf2_dt )
687
701
else :
688
702
buf2 = dpt .empty_like (a_max , dtype = buf2_dt , order = order )
703
+ _manager = SequentialOrderManager
704
+ dep_ev = _manager .submitted_events
689
705
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
690
- src = a_max , dst = buf2 , sycl_queue = exec_q
706
+ src = a_max , dst = buf2 , sycl_queue = exec_q , depends = dep_ev
691
707
)
708
+ _manager .add_event_pair (ht_copy_ev , copy_ev )
692
709
if out is None :
693
710
if order == "K" :
694
711
out = _empty_like_triple_orderK (
@@ -721,28 +738,30 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
721
738
sycl_queue = exec_q ,
722
739
depends = [copy_ev ],
723
740
)
741
+ _manager .add_event_pair (ht_binary_ev , binary_ev )
724
742
if not (orig_out is None or orig_out is out ):
725
743
# Copy the out data from temporary buffer to original memory
726
- ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
744
+ ht_copy_out_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
727
745
src = out ,
728
746
dst = orig_out ,
729
747
sycl_queue = exec_q ,
730
748
depends = [binary_ev ],
731
749
)
732
- ht_copy_out_ev . wait ( )
750
+ _manager . add_event_pair ( ht_copy_out_ev , cpy_ev )
733
751
out = orig_out
734
- ht_copy_ev .wait ()
735
- ht_binary_ev .wait ()
736
752
return out
737
753
738
754
elif buf2_dt is None :
739
755
if order == "K" :
740
756
buf1 = _empty_like_orderK (a_min , buf1_dt )
741
757
else :
742
758
buf1 = dpt .empty_like (a_min , dtype = buf1_dt , order = order )
759
+ _manager = SequentialOrderManager
760
+ dep_ev = _manager .submitted_events
743
761
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
744
- src = a_min , dst = buf1 , sycl_queue = exec_q
762
+ src = a_min , dst = buf1 , sycl_queue = exec_q , depends = dep_ev
745
763
)
764
+ _manager .add_event_pair (ht_copy_ev , copy_ev )
746
765
if out is None :
747
766
if order == "K" :
748
767
out = _empty_like_triple_orderK (
@@ -775,18 +794,17 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
775
794
sycl_queue = exec_q ,
776
795
depends = [copy_ev ],
777
796
)
797
+ _manager .add_event_pair (ht_binary_ev , binary_ev )
778
798
if not (orig_out is None or orig_out is out ):
779
799
# Copy the out data from temporary buffer to original memory
780
- ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
800
+ ht_copy_out_ev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
781
801
src = out ,
782
802
dst = orig_out ,
783
803
sycl_queue = exec_q ,
784
804
depends = [binary_ev ],
785
805
)
786
- ht_copy_out_ev . wait ( )
806
+ _manager . add_event_pair ( ht_copy_out_ev , cpy_ev )
787
807
out = orig_out
788
- ht_copy_ev .wait ()
789
- ht_binary_ev .wait ()
790
808
return out
791
809
792
810
if order == "K" :
@@ -806,16 +824,21 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
806
824
buf1 = _empty_like_orderK (a_min , buf1_dt )
807
825
else :
808
826
buf1 = dpt .empty_like (a_min , dtype = buf1_dt , order = order )
827
+
828
+ _manager = SequentialOrderManager
829
+ dep_evs = _manager .submitted_events
809
830
ht_copy1_ev , copy1_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
810
- src = a_min , dst = buf1 , sycl_queue = exec_q
831
+ src = a_min , dst = buf1 , sycl_queue = exec_q , depends = dep_evs
811
832
)
833
+ _manager .add_event_pair (ht_copy1_ev , copy1_ev )
812
834
if order == "K" :
813
835
buf2 = _empty_like_orderK (a_max , buf2_dt )
814
836
else :
815
837
buf2 = dpt .empty_like (a_max , dtype = buf2_dt , order = order )
816
838
ht_copy2_ev , copy2_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
817
- src = a_max , dst = buf2 , sycl_queue = exec_q
839
+ src = a_max , dst = buf2 , sycl_queue = exec_q , depends = dep_evs
818
840
)
841
+ _manager .add_event_pair (ht_copy2_ev , copy2_ev )
819
842
if out is None :
820
843
if order == "K" :
821
844
out = _empty_like_triple_orderK (
@@ -833,13 +856,13 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
833
856
x = dpt .broadcast_to (x , res_shape )
834
857
buf1 = dpt .broadcast_to (buf1 , res_shape )
835
858
buf2 = dpt .broadcast_to (buf2 , res_shape )
836
- ht_ , _ = ti ._clip (
859
+ ht_ , clip_ev = ti ._clip (
837
860
src = x ,
838
861
min = buf1 ,
839
862
max = buf2 ,
840
863
dst = out ,
841
864
sycl_queue = exec_q ,
842
865
depends = [copy1_ev , copy2_ev ],
843
866
)
844
- dpctl . SyclEvent . wait_for ([ ht_copy1_ev , ht_copy2_ev , ht_ ] )
867
+ _manager . add_event_pair ( ht_ , clip_ev )
845
868
return out
0 commit comments