Skip to content

Commit 369c500

Browse files
committed
Binary elementwise functions can now act on any input in-place
- A temporary will be allocated as necessary (i.e., when arrays overlap, are not going to be cast, and are not the same logical arrays) - Uses dedicated in-place kernels where they are implemented - Now called directly by Python operators - Removes _inplace method of BinaryElementwiseFunc class - Removes _find_inplace_dtype function
1 parent 47c82f5 commit 369c500

File tree

4 files changed

+110
-177
lines changed

4 files changed

+110
-177
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 99 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
_acceptance_fn_default,
3232
_find_buf_dtype,
3333
_find_buf_dtype2,
34-
_find_inplace_dtype,
3534
_to_device_supported_dtype,
3635
)
3736

@@ -383,14 +382,6 @@ def __repr__(self):
383382
return f"<{self.__name__} '{self.name_}'>"
384383

385384
def __call__(self, o1, o2, out=None, order="K"):
386-
# FIXME: replace with check against base array
387-
# when views can be identified
388-
if self.binary_inplace_fn_:
389-
if o1 is out:
390-
return self._inplace(o1, o2)
391-
elif o2 is out:
392-
return self._inplace(o2, o1)
393-
394385
if order not in ["K", "C", "F", "A"]:
395386
order = "K"
396387
q1, o1_usm_type = _get_queue_usm_type(o1)
@@ -472,6 +463,7 @@ def __call__(self, o1, o2, out=None, order="K"):
472463
"supported types according to the casting rule ''safe''."
473464
)
474465

466+
orig_out = out
475467
if out is not None:
476468
if not isinstance(out, dpt.usm_ndarray):
477469
raise TypeError(
@@ -484,19 +476,76 @@ def __call__(self, o1, o2, out=None, order="K"):
484476
f"Expected output shape is {o1_shape}, got {out.shape}"
485477
)
486478

487-
if ti._array_overlap(o1, out) or ti._array_overlap(o2, out):
488-
raise TypeError("Input and output arrays have memory overlap")
479+
if res_dt != out.dtype:
480+
raise TypeError(
481+
f"Output array of type {res_dt} is needed,"
482+
f"got {out.dtype}"
483+
)
489484

490485
if (
491-
dpctl.utils.get_execution_queue(
492-
(o1.sycl_queue, o2.sycl_queue, out.sycl_queue)
493-
)
486+
dpctl.utils.get_execution_queue((exec_q, out.sycl_queue))
494487
is None
495488
):
496489
raise TypeError(
497490
"Input and output allocation queues are not compatible"
498491
)
499492

493+
if isinstance(o1, dpt.usm_ndarray):
494+
if ti._array_overlap(o1, out) and buf1_dt is None:
495+
if not ti._same_logical_tensors(o1, out):
496+
out = dpt.empty_like(out)
497+
elif self.binary_inplace_fn_ is not None:
498+
# if there is a dedicated in-place kernel
499+
# it can be called here, otherwise continues
500+
if isinstance(o2, dpt.usm_ndarray):
501+
src2 = o2
502+
if (
503+
ti._array_overlap(o2, out)
504+
and not ti._same_logical_tensors(o2, out)
505+
and buf2_dt is None
506+
):
507+
buf2_dt = o2_dtype
508+
else:
509+
src2 = dpt.asarray(
510+
o2, dtype=o2_dtype, sycl_queue=exec_q
511+
)
512+
if buf2_dt is None:
513+
src2 = dpt.broadcast_to(src2, res_shape)
514+
ht_, _ = self.binary_inplace_fn_(
515+
lhs=o1, rhs=src2, sycl_queue=exec_q
516+
)
517+
ht_.wait()
518+
else:
519+
buf2 = dpt.empty_like(src2, dtype=buf2_dt)
520+
(
521+
ht_copy_ev,
522+
copy_ev,
523+
) = ti._copy_usm_ndarray_into_usm_ndarray(
524+
src=src2, dst=buf2, sycl_queue=exec_q
525+
)
526+
527+
buf2 = dpt.broadcast_to(buf2, res_shape)
528+
ht_, _ = self.binary_inplace_fn_(
529+
lhs=o1,
530+
rhs=buf2,
531+
sycl_queue=exec_q,
532+
depends=[copy_ev],
533+
)
534+
ht_copy_ev.wait()
535+
ht_.wait()
536+
537+
return out
538+
539+
if isinstance(o2, dpt.usm_ndarray):
540+
if (
541+
ti._array_overlap(o2, out)
542+
and not ti._same_logical_tensors(o2, out)
543+
and buf2_dt is None
544+
):
545+
# should not reach if out is reallocated
546+
# after being checked against o1
547+
out = dpt.empty_like(out)
548+
500549
if isinstance(o1, dpt.usm_ndarray):
501550
src1 = o1
502551
else:
@@ -532,19 +581,23 @@ def __call__(self, o1, o2, out=None, order="K"):
532581
sycl_queue=exec_q,
533582
order=order,
534583
)
535-
else:
536-
if res_dt != out.dtype:
537-
raise TypeError(
538-
f"Output array of type {res_dt} is needed,"
539-
f"got {out.dtype}"
540-
)
541584

542585
src1 = dpt.broadcast_to(src1, res_shape)
543586
src2 = dpt.broadcast_to(src2, res_shape)
544-
ht_, _ = self.binary_fn_(
587+
ht_binary_ev, binary_ev = self.binary_fn_(
545588
src1=src1, src2=src2, dst=out, sycl_queue=exec_q
546589
)
547-
ht_.wait()
590+
if not (orig_out is None or orig_out is out):
591+
# Copy the out data from temporary buffer to original memory
592+
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
593+
src=out,
594+
dst=orig_out,
595+
sycl_queue=exec_q,
596+
depends=[binary_ev],
597+
)
598+
ht_copy_out_ev.wait()
599+
out = orig_out
600+
ht_binary_ev.wait()
548601
return out
549602
elif buf1_dt is None:
550603
if order == "K":
@@ -578,15 +631,25 @@ def __call__(self, o1, o2, out=None, order="K"):
578631

579632
src1 = dpt.broadcast_to(src1, res_shape)
580633
buf2 = dpt.broadcast_to(buf2, res_shape)
581-
ht_, _ = self.binary_fn_(
634+
ht_binary_ev, binary_ev = self.binary_fn_(
582635
src1=src1,
583636
src2=buf2,
584637
dst=out,
585638
sycl_queue=exec_q,
586639
depends=[copy_ev],
587640
)
641+
if not (orig_out is None or orig_out is out):
642+
# Copy the out data from temporary buffer to original memory
643+
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
644+
src=out,
645+
dst=orig_out,
646+
sycl_queue=exec_q,
647+
depends=[binary_ev],
648+
)
649+
ht_copy_out_ev.wait()
650+
out = orig_out
588651
ht_copy_ev.wait()
589-
ht_.wait()
652+
ht_binary_ev.wait()
590653
return out
591654
elif buf2_dt is None:
592655
if order == "K":
@@ -611,24 +674,28 @@ def __call__(self, o1, o2, out=None, order="K"):
611674
sycl_queue=exec_q,
612675
order=order,
613676
)
614-
else:
615-
if res_dt != out.dtype:
616-
raise TypeError(
617-
f"Output array of type {res_dt} is needed,"
618-
f"got {out.dtype}"
619-
)
620677

621678
buf1 = dpt.broadcast_to(buf1, res_shape)
622679
src2 = dpt.broadcast_to(src2, res_shape)
623-
ht_, _ = self.binary_fn_(
680+
ht_binary_ev, binary_ev = self.binary_fn_(
624681
src1=buf1,
625682
src2=src2,
626683
dst=out,
627684
sycl_queue=exec_q,
628685
depends=[copy_ev],
629686
)
687+
if not (orig_out is None or orig_out is out):
688+
# Copy the out data from temporary buffer to original memory
689+
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
690+
src=out,
691+
dst=orig_out,
692+
sycl_queue=exec_q,
693+
depends=[binary_ev],
694+
)
695+
ht_copy_out_ev.wait()
696+
out = orig_out
630697
ht_copy_ev.wait()
631-
ht_.wait()
698+
ht_binary_ev.wait()
632699
return out
633700

634701
if order in ["K", "A"]:
@@ -665,11 +732,6 @@ def __call__(self, o1, o2, out=None, order="K"):
665732
sycl_queue=exec_q,
666733
order=order,
667734
)
668-
else:
669-
if res_dt != out.dtype:
670-
raise TypeError(
671-
f"Output array of type {res_dt} is needed, got {out.dtype}"
672-
)
673735

674736
buf1 = dpt.broadcast_to(buf1, res_shape)
675737
buf2 = dpt.broadcast_to(buf2, res_shape)
@@ -682,116 +744,3 @@ def __call__(self, o1, o2, out=None, order="K"):
682744
)
683745
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
684746
return out
685-
686-
def _inplace(self, lhs, val):
687-
if self.binary_inplace_fn_ is None:
688-
raise ValueError(
689-
f"In-place operation not supported for ufunc '{self.name_}'"
690-
)
691-
if not isinstance(lhs, dpt.usm_ndarray):
692-
raise TypeError(
693-
f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
694-
)
695-
q1, lhs_usm_type = _get_queue_usm_type(lhs)
696-
q2, val_usm_type = _get_queue_usm_type(val)
697-
if q2 is None:
698-
exec_q = q1
699-
usm_type = lhs_usm_type
700-
else:
701-
exec_q = dpctl.utils.get_execution_queue((q1, q2))
702-
if exec_q is None:
703-
raise ExecutionPlacementError(
704-
"Execution placement can not be unambiguously inferred "
705-
"from input arguments."
706-
)
707-
usm_type = dpctl.utils.get_coerced_usm_type(
708-
(
709-
lhs_usm_type,
710-
val_usm_type,
711-
)
712-
)
713-
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
714-
lhs_shape = _get_shape(lhs)
715-
val_shape = _get_shape(val)
716-
if not all(
717-
isinstance(s, (tuple, list))
718-
for s in (
719-
lhs_shape,
720-
val_shape,
721-
)
722-
):
723-
raise TypeError(
724-
"Shape of arguments can not be inferred. "
725-
"Arguments are expected to be "
726-
"lists, tuples, or both"
727-
)
728-
try:
729-
res_shape = _broadcast_shape_impl(
730-
[
731-
lhs_shape,
732-
val_shape,
733-
]
734-
)
735-
except ValueError:
736-
raise ValueError(
737-
"operands could not be broadcast together with shapes "
738-
f"{lhs_shape} and {val_shape}"
739-
)
740-
if res_shape != lhs_shape:
741-
raise ValueError(
742-
f"output shape {lhs_shape} does not match "
743-
f"broadcast shape {res_shape}"
744-
)
745-
sycl_dev = exec_q.sycl_device
746-
lhs_dtype = lhs.dtype
747-
val_dtype = _get_dtype(val, sycl_dev)
748-
if not _validate_dtype(val_dtype):
749-
raise ValueError("Input operand of unsupported type")
750-
751-
lhs_dtype, val_dtype = _resolve_weak_types(
752-
lhs_dtype, val_dtype, sycl_dev
753-
)
754-
755-
buf_dt = _find_inplace_dtype(
756-
lhs_dtype, val_dtype, self.result_type_resolver_fn_, sycl_dev
757-
)
758-
759-
if buf_dt is None:
760-
raise TypeError(
761-
f"In-place '{self.name_}' does not support input types "
762-
f"({lhs_dtype}, {val_dtype}), "
763-
"and the inputs could not be safely coerced to any "
764-
"supported types according to the casting rule ''safe''."
765-
)
766-
767-
if isinstance(val, dpt.usm_ndarray):
768-
rhs = val
769-
overlap = ti._array_overlap(lhs, rhs)
770-
else:
771-
rhs = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
772-
overlap = False
773-
774-
if buf_dt == val_dtype and overlap is False:
775-
rhs = dpt.broadcast_to(rhs, res_shape)
776-
ht_, _ = self.binary_inplace_fn_(
777-
lhs=lhs, rhs=rhs, sycl_queue=exec_q
778-
)
779-
ht_.wait()
780-
781-
else:
782-
buf = dpt.empty_like(rhs, dtype=buf_dt)
783-
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
784-
src=rhs, dst=buf, sycl_queue=exec_q
785-
)
786-
787-
buf = dpt.broadcast_to(buf, res_shape)
788-
ht_, _ = self.binary_inplace_fn_(
789-
lhs=lhs,
790-
rhs=buf,
791-
sycl_queue=exec_q,
792-
depends=[copy_ev],
793-
)
794-
ht_copy_ev.wait()
795-
ht_.wait()
796-
797-
return lhs

dpctl/tensor/_type_utils.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -226,27 +226,9 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
226226
return None, None, None
227227

228228

229-
def _find_inplace_dtype(lhs_dtype, rhs_dtype, query_fn, sycl_dev):
230-
res_dt = query_fn(lhs_dtype, rhs_dtype)
231-
if res_dt and res_dt == lhs_dtype:
232-
return rhs_dtype
233-
234-
_fp16 = sycl_dev.has_aspect_fp16
235-
_fp64 = sycl_dev.has_aspect_fp64
236-
all_dts = _all_data_types(_fp16, _fp64)
237-
for buf_dt in all_dts:
238-
if _can_cast(rhs_dtype, buf_dt, _fp16, _fp64):
239-
res_dt = query_fn(lhs_dtype, buf_dt)
240-
if res_dt and res_dt == lhs_dtype:
241-
return buf_dt
242-
243-
return None
244-
245-
246229
__all__ = [
247230
"_find_buf_dtype",
248231
"_find_buf_dtype2",
249-
"_find_inplace_dtype",
250232
"_to_device_supported_dtype",
251233
"_acceptance_fn_default",
252234
"_acceptance_fn_divide",

dpctl/tensor/_usmarray.pyx

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,8 +1286,7 @@ cdef class usm_ndarray:
12861286
return _dispatch_binary_elementwise2(other, "bitwise_xor", self)
12871287

12881288
def __iadd__(self, other):
1289-
from ._elementwise_funcs import add
1290-
return add._inplace(self, other)
1289+
return dpctl.tensor.add(self, other, out=self)
12911290

12921291
def __iand__(self, other):
12931292
res = self.__and__(other)
@@ -1325,8 +1324,7 @@ cdef class usm_ndarray:
13251324
return self
13261325

13271326
def __imul__(self, other):
1328-
from ._elementwise_funcs import multiply
1329-
return multiply._inplace(self, other)
1327+
return dpctl.tensor.multiply(self, other, out=self)
13301328

13311329
def __ior__(self, other):
13321330
res = self.__or__(other)
@@ -1350,8 +1348,7 @@ cdef class usm_ndarray:
13501348
return self
13511349

13521350
def __isub__(self, other):
1353-
from ._elementwise_funcs import subtract
1354-
return subtract._inplace(self, other)
1351+
return dpctl.tensor.subtract(self, other, out=self)
13551352

13561353
def __itruediv__(self, other):
13571354
res = self.__truediv__(other)

0 commit comments

Comments
 (0)