Skip to content

Fully enable usm_ndarray in-place arithmetic operators #1352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 18, 2023
280 changes: 114 additions & 166 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
_acceptance_fn_default,
_find_buf_dtype,
_find_buf_dtype2,
_find_inplace_dtype,
_to_device_supported_dtype,
)

Expand Down Expand Up @@ -79,8 +78,8 @@ def __call__(self, x, out=None, order="K"):
)

if out.shape != x.shape:
raise TypeError(
"The shape of input and output arrays are inconsistent."
raise ValueError(
"The shape of input and output arrays are inconsistent. "
f"Expected output shape is {x.shape}, got {out.shape}"
)

Expand All @@ -104,7 +103,7 @@ def __call__(self, x, out=None, order="K"):
dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
is None
):
raise TypeError(
raise ExecutionPlacementError(
"Input and output allocation queues are not compatible"
)

Expand Down Expand Up @@ -302,8 +301,6 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
o1_kind_num = _weak_type_num_kind(o1_dtype)
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
if o1_kind_num > o2_kind_num:
if isinstance(o1_dtype, WeakBooleanType):
return dpt.bool, o2_dtype
if isinstance(o1_dtype, WeakIntegralType):
return dpt.int64, o2_dtype
if isinstance(o1_dtype, WeakComplexType):
Expand All @@ -323,8 +320,6 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
o2_kind_num = _weak_type_num_kind(o2_dtype)
if o2_kind_num > o1_kind_num:
if isinstance(o2_dtype, WeakBooleanType):
return o1_dtype, dpt.bool
if isinstance(o2_dtype, WeakIntegralType):
return o1_dtype, dpt.int64
if isinstance(o2_dtype, WeakComplexType):
Expand Down Expand Up @@ -383,14 +378,6 @@ def __repr__(self):
return f"<{self.__name__} '{self.name_}'>"

def __call__(self, o1, o2, out=None, order="K"):
# FIXME: replace with check against base array
# when views can be identified
if self.binary_inplace_fn_:
if o1 is out:
return self._inplace(o1, o2)
elif o2 is out:
return self._inplace(o2, o1)

if order not in ["K", "C", "F", "A"]:
order = "K"
q1, o1_usm_type = _get_queue_usm_type(o1)
Expand Down Expand Up @@ -472,31 +459,90 @@ def __call__(self, o1, o2, out=None, order="K"):
"supported types according to the casting rule ''safe''."
)

orig_out = out
if out is not None:
if not isinstance(out, dpt.usm_ndarray):
raise TypeError(
f"output array must be of usm_ndarray type, got {type(out)}"
)

if out.shape != res_shape:
raise TypeError(
"The shape of input and output arrays are inconsistent."
raise ValueError(
"The shape of input and output arrays are inconsistent. "
f"Expected output shape is {o1_shape}, got {out.shape}"
)

if ti._array_overlap(o1, out) or ti._array_overlap(o2, out):
raise TypeError("Input and output arrays have memory overlap")
if res_dt != out.dtype:
raise TypeError(
f"Output array of type {res_dt} is needed,"
f"got {out.dtype}"
)

if (
dpctl.utils.get_execution_queue(
(o1.sycl_queue, o2.sycl_queue, out.sycl_queue)
)
dpctl.utils.get_execution_queue((exec_q, out.sycl_queue))
is None
):
raise TypeError(
raise ExecutionPlacementError(
"Input and output allocation queues are not compatible"
)

if isinstance(o1, dpt.usm_ndarray):
if ti._array_overlap(o1, out) and buf1_dt is None:
if not ti._same_logical_tensors(o1, out):
out = dpt.empty_like(out)
elif self.binary_inplace_fn_ is not None:
# if there is a dedicated in-place kernel
# it can be called here, otherwise continues
if isinstance(o2, dpt.usm_ndarray):
src2 = o2
if (
ti._array_overlap(o2, out)
and not ti._same_logical_tensors(o2, out)
and buf2_dt is None
):
buf2_dt = o2_dtype
else:
src2 = dpt.asarray(
o2, dtype=o2_dtype, sycl_queue=exec_q
)
if buf2_dt is None:
if src2.shape != res_shape:
src2 = dpt.broadcast_to(src2, res_shape)
ht_, _ = self.binary_inplace_fn_(
lhs=o1, rhs=src2, sycl_queue=exec_q
)
ht_.wait()
else:
buf2 = dpt.empty_like(src2, dtype=buf2_dt)
(
ht_copy_ev,
copy_ev,
) = ti._copy_usm_ndarray_into_usm_ndarray(
src=src2, dst=buf2, sycl_queue=exec_q
)

buf2 = dpt.broadcast_to(buf2, res_shape)
ht_, _ = self.binary_inplace_fn_(
lhs=o1,
rhs=buf2,
sycl_queue=exec_q,
depends=[copy_ev],
)
ht_copy_ev.wait()
ht_.wait()

return out

if isinstance(o2, dpt.usm_ndarray):
if (
ti._array_overlap(o2, out)
and not ti._same_logical_tensors(o2, out)
and buf2_dt is None
):
# should not reach if out is reallocated
# after being checked against o1
out = dpt.empty_like(out)

if isinstance(o1, dpt.usm_ndarray):
src1 = o1
else:
Expand Down Expand Up @@ -532,19 +578,24 @@ def __call__(self, o1, o2, out=None, order="K"):
sycl_queue=exec_q,
order=order,
)
else:
if res_dt != out.dtype:
raise TypeError(
f"Output array of type {res_dt} is needed,"
f"got {out.dtype}"
)

src1 = dpt.broadcast_to(src1, res_shape)
src2 = dpt.broadcast_to(src2, res_shape)
ht_, _ = self.binary_fn_(
if src1.shape != res_shape:
src1 = dpt.broadcast_to(src1, res_shape)
if src2.shape != res_shape:
src2 = dpt.broadcast_to(src2, res_shape)
ht_binary_ev, binary_ev = self.binary_fn_(
src1=src1, src2=src2, dst=out, sycl_queue=exec_q
)
ht_.wait()
if not (orig_out is None or orig_out is out):
# Copy the out data from temporary buffer to original memory
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=out,
dst=orig_out,
sycl_queue=exec_q,
depends=[binary_ev],
)
ht_copy_out_ev.wait()
out = orig_out
ht_binary_ev.wait()
return out
elif buf1_dt is None:
if order == "K":
Expand Down Expand Up @@ -575,18 +626,28 @@ def __call__(self, o1, o2, out=None, order="K"):
f"Output array of type {res_dt} is needed,"
f"got {out.dtype}"
)

src1 = dpt.broadcast_to(src1, res_shape)
if src1.shape != res_shape:
src1 = dpt.broadcast_to(src1, res_shape)
buf2 = dpt.broadcast_to(buf2, res_shape)
ht_, _ = self.binary_fn_(
ht_binary_ev, binary_ev = self.binary_fn_(
src1=src1,
src2=buf2,
dst=out,
sycl_queue=exec_q,
depends=[copy_ev],
)
if not (orig_out is None or orig_out is out):
# Copy the out data from temporary buffer to original memory
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=out,
dst=orig_out,
sycl_queue=exec_q,
depends=[binary_ev],
)
ht_copy_out_ev.wait()
out = orig_out
ht_copy_ev.wait()
ht_.wait()
ht_binary_ev.wait()
return out
elif buf2_dt is None:
if order == "K":
Expand All @@ -611,24 +672,29 @@ def __call__(self, o1, o2, out=None, order="K"):
sycl_queue=exec_q,
order=order,
)
else:
if res_dt != out.dtype:
raise TypeError(
f"Output array of type {res_dt} is needed,"
f"got {out.dtype}"
)

buf1 = dpt.broadcast_to(buf1, res_shape)
src2 = dpt.broadcast_to(src2, res_shape)
ht_, _ = self.binary_fn_(
if src2.shape != res_shape:
src2 = dpt.broadcast_to(src2, res_shape)
ht_binary_ev, binary_ev = self.binary_fn_(
src1=buf1,
src2=src2,
dst=out,
sycl_queue=exec_q,
depends=[copy_ev],
)
if not (orig_out is None or orig_out is out):
# Copy the out data from temporary buffer to original memory
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=out,
dst=orig_out,
sycl_queue=exec_q,
depends=[binary_ev],
)
ht_copy_out_ev.wait()
out = orig_out
ht_copy_ev.wait()
ht_.wait()
ht_binary_ev.wait()
return out

if order in ["K", "A"]:
Expand Down Expand Up @@ -665,11 +731,6 @@ def __call__(self, o1, o2, out=None, order="K"):
sycl_queue=exec_q,
order=order,
)
else:
if res_dt != out.dtype:
raise TypeError(
f"Output array of type {res_dt} is needed, got {out.dtype}"
)

buf1 = dpt.broadcast_to(buf1, res_shape)
buf2 = dpt.broadcast_to(buf2, res_shape)
Expand All @@ -682,116 +743,3 @@ def __call__(self, o1, o2, out=None, order="K"):
)
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
return out

def _inplace(self, lhs, val):
if self.binary_inplace_fn_ is None:
raise ValueError(
f"In-place operation not supported for ufunc '{self.name_}'"
)
if not isinstance(lhs, dpt.usm_ndarray):
raise TypeError(
f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
)
q1, lhs_usm_type = _get_queue_usm_type(lhs)
q2, val_usm_type = _get_queue_usm_type(val)
if q2 is None:
exec_q = q1
usm_type = lhs_usm_type
else:
exec_q = dpctl.utils.get_execution_queue((q1, q2))
if exec_q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments."
)
usm_type = dpctl.utils.get_coerced_usm_type(
(
lhs_usm_type,
val_usm_type,
)
)
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
lhs_shape = _get_shape(lhs)
val_shape = _get_shape(val)
if not all(
isinstance(s, (tuple, list))
for s in (
lhs_shape,
val_shape,
)
):
raise TypeError(
"Shape of arguments can not be inferred. "
"Arguments are expected to be "
"lists, tuples, or both"
)
try:
res_shape = _broadcast_shape_impl(
[
lhs_shape,
val_shape,
]
)
except ValueError:
raise ValueError(
"operands could not be broadcast together with shapes "
f"{lhs_shape} and {val_shape}"
)
if res_shape != lhs_shape:
raise ValueError(
f"output shape {lhs_shape} does not match "
f"broadcast shape {res_shape}"
)
sycl_dev = exec_q.sycl_device
lhs_dtype = lhs.dtype
val_dtype = _get_dtype(val, sycl_dev)
if not _validate_dtype(val_dtype):
raise ValueError("Input operand of unsupported type")

lhs_dtype, val_dtype = _resolve_weak_types(
lhs_dtype, val_dtype, sycl_dev
)

buf_dt = _find_inplace_dtype(
lhs_dtype, val_dtype, self.result_type_resolver_fn_, sycl_dev
)

if buf_dt is None:
raise TypeError(
f"In-place '{self.name_}' does not support input types "
f"({lhs_dtype}, {val_dtype}), "
"and the inputs could not be safely coerced to any "
"supported types according to the casting rule ''safe''."
)

if isinstance(val, dpt.usm_ndarray):
rhs = val
overlap = ti._array_overlap(lhs, rhs)
else:
rhs = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
overlap = False

if buf_dt == val_dtype and overlap is False:
rhs = dpt.broadcast_to(rhs, res_shape)
ht_, _ = self.binary_inplace_fn_(
lhs=lhs, rhs=rhs, sycl_queue=exec_q
)
ht_.wait()

else:
buf = dpt.empty_like(rhs, dtype=buf_dt)
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=rhs, dst=buf, sycl_queue=exec_q
)

buf = dpt.broadcast_to(buf, res_shape)
ht_, _ = self.binary_inplace_fn_(
lhs=lhs,
rhs=buf,
sycl_queue=exec_q,
depends=[copy_ev],
)
ht_copy_ev.wait()
ht_.wait()

return lhs
Loading