Skip to content

Commit b78698e

Browse files
Transition tensor to use SequentialOrderManager
Remove pervasive use of SyclEvent.wait in favor of using SequentialOrderManager to maintain sequential order semantics via ordering of submitted tasks using events.
1 parent 96cd26e commit b78698e

18 files changed

+683
-501
lines changed

dpctl/tensor/_accumulation.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
_default_accumulation_dtype_fp_types,
2626
_to_device_supported_dtype,
2727
)
28-
from dpctl.utils import ExecutionPlacementError
28+
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
2929

3030

3131
def _accumulate_common(
@@ -125,67 +125,70 @@ def _accumulate_common(
125125
if a1 != nd:
126126
out = dpt.permute_dims(out, perm)
127127

128-
host_tasks_list = []
128+
final_ev = dpctl.SyclEvent()
129+
_manager = SequentialOrderManager
130+
depends = _manager.submitted_events
129131
if implemented_types:
130132
if not include_initial:
131133
ht_e, acc_ev = _accumulate_fn(
132134
src=arr,
133135
trailing_dims_to_accumulate=1,
134136
dst=out,
135137
sycl_queue=q,
138+
depends=depends,
136139
)
137140
else:
138141
ht_e, acc_ev = _accumulate_include_initial_fn(
139-
src=arr,
140-
dst=out,
141-
sycl_queue=q,
142+
src=arr, dst=out, sycl_queue=q, depends=depends
142143
)
143-
host_tasks_list.append(ht_e)
144+
_manager.add_event_pair(ht_e, acc_ev)
144145
if not (orig_out is None or out is orig_out):
145146
# Copy the out data from temporary buffer to original memory
146-
ht_e_cpy, _ = ti._copy_usm_ndarray_into_usm_ndarray(
147+
ht_e_cpy, acc_ev = ti._copy_usm_ndarray_into_usm_ndarray(
147148
src=out, dst=orig_out, sycl_queue=q, depends=[acc_ev]
148149
)
149-
host_tasks_list.append(ht_e_cpy)
150+
_manager.add_event_pair(ht_e_cpy, acc_ev)
150151
out = orig_out
152+
final_ev = acc_ev
151153
else:
152154
if _dtype_supported(res_dt, res_dt):
153155
tmp = dpt.empty(
154156
arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
155157
)
156158
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
157-
src=arr, dst=tmp, sycl_queue=q
159+
src=arr, dst=tmp, sycl_queue=q, depends=depends
158160
)
159-
host_tasks_list.append(ht_e_cpy)
161+
_manager.add_event_pair(ht_e_cpy, cpy_e)
160162
if not include_initial:
161-
ht_e, acc_ev = _accumulate_fn(
163+
ht_e, final_ev = _accumulate_fn(
162164
src=tmp,
163165
trailing_dims_to_accumulate=1,
164166
dst=out,
165167
sycl_queue=q,
166168
depends=[cpy_e],
167169
)
168170
else:
169-
ht_e, acc_ev = _accumulate_include_initial_fn(
171+
ht_e, final_ev = _accumulate_include_initial_fn(
170172
src=tmp,
171173
dst=out,
172174
sycl_queue=q,
173175
depends=[cpy_e],
174176
)
177+
_manager.add_event_pair(ht_e, final_ev)
175178
else:
176179
buf_dt = _default_accumulation_type_fn(inp_dt, q)
177180
tmp = dpt.empty(
178181
arr.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
179182
)
180183
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
181-
src=arr, dst=tmp, sycl_queue=q
184+
src=arr, dst=tmp, sycl_queue=q, depends=depends
182185
)
186+
_manager.add_event_pair(ht_e_cpy, cpy_e)
183187
tmp_res = dpt.empty(
184188
res_sh, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
185189
)
186190
if a1 != nd:
187191
tmp_res = dpt.permute_dims(tmp_res, perm)
188-
host_tasks_list.append(ht_e_cpy)
189192
if not include_initial:
190193
ht_e, a_e = _accumulate_fn(
191194
src=tmp,
@@ -201,18 +204,17 @@ def _accumulate_common(
201204
sycl_queue=q,
202205
depends=[cpy_e],
203206
)
204-
host_tasks_list.append(ht_e)
205-
ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray(
207+
_manager.add_event_pair(ht_e, a_e)
208+
ht_e_cpy2, final_ev = ti._copy_usm_ndarray_into_usm_ndarray(
206209
src=tmp_res, dst=out, sycl_queue=q, depends=[a_e]
207210
)
208-
host_tasks_list.append(ht_e_cpy2)
211+
_manager.add_event_pair(ht_e_cpy2, final_ev)
209212

210213
if appended_axis:
211214
out = dpt.squeeze(out)
212215
if a1 != nd:
213216
inv_perm = sorted(range(nd), key=lambda d: perm[d])
214217
out = dpt.permute_dims(out, inv_perm)
215-
dpctl.SyclEvent.wait_for(host_tasks_list)
216218

217219
return out
218220

dpctl/tensor/_clip.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
)
3232
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
3333
from dpctl.tensor._type_utils import _can_cast, _to_device_supported_dtype
34-
from dpctl.utils import ExecutionPlacementError
34+
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
3535

3636
from ._type_utils import (
3737
WeakComplexType,
@@ -299,18 +299,21 @@ def _clip_none(x, val, out, order, _binary_fn):
299299
x = dpt.broadcast_to(x, res_shape)
300300
if val_ary.shape != res_shape:
301301
val_ary = dpt.broadcast_to(val_ary, res_shape)
302+
_manager = SequentialOrderManager
303+
dep_evs = _manager.submitted_events
302304
ht_binary_ev, binary_ev = _binary_fn(
303-
src1=x, src2=val_ary, dst=out, sycl_queue=exec_q
305+
src1=x, src2=val_ary, dst=out, sycl_queue=exec_q, depends=dep_evs
304306
)
307+
_manager.add_event_pair(ht_binary_ev, binary_ev)
305308
if not (orig_out is None or orig_out is out):
306309
# Copy the out data from temporary buffer to original memory
307-
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
310+
ht_copy_out_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
308311
src=out,
309312
dst=orig_out,
310313
sycl_queue=exec_q,
311314
depends=[binary_ev],
312315
)
313-
ht_copy_out_ev.wait()
316+
_manager.add_event_pair(ht_copy_out_ev, copy_ev)
314317
out = orig_out
315318
ht_binary_ev.wait()
316319
return out
@@ -319,9 +322,12 @@ def _clip_none(x, val, out, order, _binary_fn):
319322
buf = _empty_like_orderK(val_ary, res_dt)
320323
else:
321324
buf = dpt.empty_like(val_ary, dtype=res_dt, order=order)
325+
_manager = SequentialOrderManager
326+
dep_evs = _manager.submitted_events
322327
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
323-
src=val_ary, dst=buf, sycl_queue=exec_q
328+
src=val_ary, dst=buf, sycl_queue=exec_q, depends=dep_evs
324329
)
330+
_manager.add_event_pair(ht_copy_ev, copy_ev)
325331
if out is None:
326332
if order == "K":
327333
out = _empty_like_pair_orderK(
@@ -346,18 +352,17 @@ def _clip_none(x, val, out, order, _binary_fn):
346352
sycl_queue=exec_q,
347353
depends=[copy_ev],
348354
)
355+
_manager.add_event_pair(ht_binary_ev, binary_ev)
349356
if not (orig_out is None or orig_out is out):
350357
# Copy the out data from temporary buffer to original memory
351-
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
358+
ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
352359
src=out,
353360
dst=orig_out,
354361
sycl_queue=exec_q,
355362
depends=[binary_ev],
356363
)
357-
ht_copy_out_ev.wait()
364+
_manager.add_event_pair(ht_copy_out_ev, cpy_ev)
358365
out = orig_out
359-
ht_copy_ev.wait()
360-
ht_binary_ev.wait()
361366
return out
362367

363368

@@ -444,20 +449,22 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
444449
else:
445450
out = dpt.empty_like(x, order=order)
446451

452+
_manager = SequentialOrderManager
453+
dep_evs = _manager.submitted_events
447454
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
448-
src=x, dst=out, sycl_queue=exec_q
455+
src=x, dst=out, sycl_queue=exec_q, depends=dep_evs
449456
)
457+
_manager.add_event_pair(ht_copy_ev, copy_ev)
450458
if not (orig_out is None or orig_out is out):
451459
# Copy the out data from temporary buffer to original memory
452-
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
460+
ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
453461
src=out,
454462
dst=orig_out,
455463
sycl_queue=exec_q,
456464
depends=[copy_ev],
457465
)
458-
ht_copy_out_ev.wait()
466+
_manager.add_event_pair(ht_copy_ev, cpy_ev)
459467
out = orig_out
460-
ht_copy_ev.wait()
461468
return out
462469
elif max is None:
463470
return _clip_none(x, min, out, order, tei._maximum)
@@ -665,30 +672,40 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
665672
a_min = dpt.broadcast_to(a_min, res_shape)
666673
if a_max.shape != res_shape:
667674
a_max = dpt.broadcast_to(a_max, res_shape)
675+
_manager = SequentialOrderManager
676+
dep_ev = _manager.submitted_events
668677
ht_binary_ev, binary_ev = ti._clip(
669-
src=x, min=a_min, max=a_max, dst=out, sycl_queue=exec_q
678+
src=x,
679+
min=a_min,
680+
max=a_max,
681+
dst=out,
682+
sycl_queue=exec_q,
683+
depends=dep_ev,
670684
)
685+
_manager.add_event_pair(ht_binary_ev, binary_ev)
671686
if not (orig_out is None or orig_out is out):
672687
# Copy the out data from temporary buffer to original memory
673-
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
688+
ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
674689
src=out,
675690
dst=orig_out,
676691
sycl_queue=exec_q,
677692
depends=[binary_ev],
678693
)
679-
ht_copy_out_ev.wait()
694+
_manager.add_event_pair(ht_copy_out_ev, cpy_ev)
680695
out = orig_out
681-
ht_binary_ev.wait()
682696
return out
683697

684698
elif buf1_dt is None:
685699
if order == "K":
686700
buf2 = _empty_like_orderK(a_max, buf2_dt)
687701
else:
688702
buf2 = dpt.empty_like(a_max, dtype=buf2_dt, order=order)
703+
_manager = SequentialOrderManager
704+
dep_ev = _manager.submitted_events
689705
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
690-
src=a_max, dst=buf2, sycl_queue=exec_q
706+
src=a_max, dst=buf2, sycl_queue=exec_q, depends=dep_ev
691707
)
708+
_manager.add_event_pair(ht_copy_ev, copy_ev)
692709
if out is None:
693710
if order == "K":
694711
out = _empty_like_triple_orderK(
@@ -721,28 +738,30 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
721738
sycl_queue=exec_q,
722739
depends=[copy_ev],
723740
)
741+
_manager.add_event_pair(ht_binary_ev, binary_ev)
724742
if not (orig_out is None or orig_out is out):
725743
# Copy the out data from temporary buffer to original memory
726-
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
744+
ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
727745
src=out,
728746
dst=orig_out,
729747
sycl_queue=exec_q,
730748
depends=[binary_ev],
731749
)
732-
ht_copy_out_ev.wait()
750+
_manager.add_event_pair(ht_copy_out_ev, cpy_ev)
733751
out = orig_out
734-
ht_copy_ev.wait()
735-
ht_binary_ev.wait()
736752
return out
737753

738754
elif buf2_dt is None:
739755
if order == "K":
740756
buf1 = _empty_like_orderK(a_min, buf1_dt)
741757
else:
742758
buf1 = dpt.empty_like(a_min, dtype=buf1_dt, order=order)
759+
_manager = SequentialOrderManager
760+
dep_ev = _manager.submitted_events
743761
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
744-
src=a_min, dst=buf1, sycl_queue=exec_q
762+
src=a_min, dst=buf1, sycl_queue=exec_q, depends=dep_ev
745763
)
764+
_manager.add_event_pair(ht_copy_ev, copy_ev)
746765
if out is None:
747766
if order == "K":
748767
out = _empty_like_triple_orderK(
@@ -775,18 +794,17 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
775794
sycl_queue=exec_q,
776795
depends=[copy_ev],
777796
)
797+
_manager.add_event_pair(ht_binary_ev, binary_ev)
778798
if not (orig_out is None or orig_out is out):
779799
# Copy the out data from temporary buffer to original memory
780-
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
800+
ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
781801
src=out,
782802
dst=orig_out,
783803
sycl_queue=exec_q,
784804
depends=[binary_ev],
785805
)
786-
ht_copy_out_ev.wait()
806+
_manager.add_event_pair(ht_copy_out_ev, cpy_ev)
787807
out = orig_out
788-
ht_copy_ev.wait()
789-
ht_binary_ev.wait()
790808
return out
791809

792810
if order == "K":
@@ -806,16 +824,21 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
806824
buf1 = _empty_like_orderK(a_min, buf1_dt)
807825
else:
808826
buf1 = dpt.empty_like(a_min, dtype=buf1_dt, order=order)
827+
828+
_manager = SequentialOrderManager
829+
dep_evs = _manager.submitted_events
809830
ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
810-
src=a_min, dst=buf1, sycl_queue=exec_q
831+
src=a_min, dst=buf1, sycl_queue=exec_q, depends=dep_evs
811832
)
833+
_manager.add_event_pair(ht_copy1_ev, copy1_ev)
812834
if order == "K":
813835
buf2 = _empty_like_orderK(a_max, buf2_dt)
814836
else:
815837
buf2 = dpt.empty_like(a_max, dtype=buf2_dt, order=order)
816838
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
817-
src=a_max, dst=buf2, sycl_queue=exec_q
839+
src=a_max, dst=buf2, sycl_queue=exec_q, depends=dep_evs
818840
)
841+
_manager.add_event_pair(ht_copy2_ev, copy2_ev)
819842
if out is None:
820843
if order == "K":
821844
out = _empty_like_triple_orderK(
@@ -833,13 +856,13 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
833856
x = dpt.broadcast_to(x, res_shape)
834857
buf1 = dpt.broadcast_to(buf1, res_shape)
835858
buf2 = dpt.broadcast_to(buf2, res_shape)
836-
ht_, _ = ti._clip(
859+
ht_, clip_ev = ti._clip(
837860
src=x,
838861
min=buf1,
839862
max=buf2,
840863
dst=out,
841864
sycl_queue=exec_q,
842865
depends=[copy1_ev, copy2_ev],
843866
)
844-
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
867+
_manager.add_event_pair(ht_, clip_ev)
845868
return out

0 commit comments

Comments
 (0)