Skip to content

Commit c92fd6d

Browse files
ndgrigorianoleksandr-pavlyk
authored andcommitted
BinaryElementwiseFunc._inplace_op now checks if a kernel is available
Raises `ValueError` otherwise
1 parent 1f21ce7 commit c92fd6d

File tree

1 file changed

+120
-112
lines changed

1 file changed

+120
-112
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 120 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -930,126 +930,134 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
930930
return out
931931

932932
def _inplace_op(self, o1, o2):
933-
if not isinstance(o1, dpt.usm_ndarray):
934-
raise TypeError(
935-
"Expected first argument to be "
936-
f"dpctl.tensor.usm_ndarray, got {type(o1)}"
937-
)
938-
if not o1.flags.writable:
939-
raise ValueError("provided left-hand side array is read-only")
940-
q1, o1_usm_type = o1.sycl_queue, o1.usm_type
941-
q2, o2_usm_type = _get_queue_usm_type(o2)
942-
if q2 is None:
943-
exec_q = q1
944-
res_usm_type = o1_usm_type
945-
else:
946-
exec_q = dpctl.utils.get_execution_queue((q1, q2))
947-
if exec_q is None:
948-
raise ExecutionPlacementError(
949-
"Execution placement can not be unambiguously inferred "
950-
"from input arguments."
933+
if self.binary_inplace_fn_ is not None:
934+
if not isinstance(o1, dpt.usm_ndarray):
935+
raise TypeError(
936+
"Expected first argument to be "
937+
f"dpctl.tensor.usm_ndarray, got {type(o1)}"
951938
)
952-
res_usm_type = dpctl.utils.get_coerced_usm_type(
953-
(
954-
o1_usm_type,
955-
o2_usm_type,
939+
if not o1.flags.writable:
940+
raise ValueError("provided left-hand side array is read-only")
941+
q1, o1_usm_type = o1.sycl_queue, o1.usm_type
942+
q2, o2_usm_type = _get_queue_usm_type(o2)
943+
if q2 is None:
944+
exec_q = q1
945+
res_usm_type = o1_usm_type
946+
else:
947+
exec_q = dpctl.utils.get_execution_queue((q1, q2))
948+
if exec_q is None:
949+
raise ExecutionPlacementError(
950+
"Execution placement can not be unambiguously inferred "
951+
"from input arguments."
952+
)
953+
res_usm_type = dpctl.utils.get_coerced_usm_type(
954+
(
955+
o1_usm_type,
956+
o2_usm_type,
957+
)
956958
)
959+
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
960+
o1_shape = o1.shape
961+
o2_shape = _get_shape(o2)
962+
if not isinstance(o2_shape, (tuple, list)):
963+
raise TypeError(
964+
"Shape of second argument can not be inferred. "
965+
"Expected list or tuple."
966+
)
967+
try:
968+
res_shape = _broadcast_shape_impl(
969+
[
970+
o1_shape,
971+
o2_shape,
972+
]
973+
)
974+
except ValueError:
975+
raise ValueError(
976+
"operands could not be broadcast together with shapes "
977+
f"{o1_shape} and {o2_shape}"
978+
)
979+
if res_shape != o1_shape:
980+
raise ValueError("")
981+
sycl_dev = exec_q.sycl_device
982+
o1_dtype = o1.dtype
983+
o2_dtype = _get_dtype(o2, sycl_dev)
984+
if not _validate_dtype(o2_dtype):
985+
raise ValueError("Operand has an unsupported data type")
986+
987+
o1_dtype, o2_dtype = self.weak_type_resolver_(
988+
o1_dtype, o2_dtype, sycl_dev
957989
)
958-
dpctl.utils.validate_usm_type(res_usm_type, allow_none=False)
959-
o1_shape = o1.shape
960-
o2_shape = _get_shape(o2)
961-
if not isinstance(o2_shape, (tuple, list)):
962-
raise TypeError(
963-
"Shape of second argument can not be inferred. "
964-
"Expected list or tuple."
965-
)
966-
try:
967-
res_shape = _broadcast_shape_impl(
968-
[
969-
o1_shape,
970-
o2_shape,
971-
]
972-
)
973-
except ValueError:
974-
raise ValueError(
975-
"operands could not be broadcast together with shapes "
976-
f"{o1_shape} and {o2_shape}"
990+
991+
buf_dt, res_dt = _find_buf_dtype_in_place_op(
992+
o1_dtype,
993+
o2_dtype,
994+
self.result_type_resolver_fn_,
995+
sycl_dev,
977996
)
978-
if res_shape != o1_shape:
979-
raise ValueError("")
980-
sycl_dev = exec_q.sycl_device
981-
o1_dtype = o1.dtype
982-
o2_dtype = _get_dtype(o2, sycl_dev)
983-
if not _validate_dtype(o2_dtype):
984-
raise ValueError("Operand has an unsupported data type")
985997

986-
o1_dtype, o2_dtype = self.weak_type_resolver_(
987-
o1_dtype, o2_dtype, sycl_dev
988-
)
998+
if res_dt is None:
999+
raise ValueError(
1000+
f"function '{self.name_}' does not support input types "
1001+
f"({o1_dtype}, {o2_dtype}), "
1002+
"and the inputs could not be safely coerced to any "
1003+
"supported types according to the casting rule "
1004+
"''same_kind''."
1005+
)
9891006

990-
buf_dt, res_dt = _find_buf_dtype_in_place_op(
991-
o1_dtype,
992-
o2_dtype,
993-
self.result_type_resolver_fn_,
994-
sycl_dev,
995-
)
1007+
if res_dt != o1_dtype:
1008+
raise ValueError(
1009+
f"Output array of type {res_dt} is needed, "
1010+
f"got {o1_dtype}"
1011+
)
9961012

997-
if res_dt is None:
998-
raise ValueError(
999-
f"function '{self.name_}' does not support input types "
1000-
f"({o1_dtype}, {o2_dtype}), "
1001-
"and the inputs could not be safely coerced to any "
1002-
"supported types according to the casting rule ''same_kind''."
1003-
)
1013+
_manager = SequentialOrderManager[exec_q]
1014+
if isinstance(o2, dpt.usm_ndarray):
1015+
src2 = o2
1016+
if (
1017+
ti._array_overlap(o2, o1)
1018+
and not ti._same_logical_tensors(o2, o1)
1019+
and buf_dt is None
1020+
):
1021+
buf_dt = o2_dtype
1022+
else:
1023+
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1024+
if buf_dt is None:
1025+
if src2.shape != res_shape:
1026+
src2 = dpt.broadcast_to(src2, res_shape)
1027+
dep_evs = _manager.submitted_events
1028+
ht_, comp_ev = self.binary_inplace_fn_(
1029+
lhs=o1,
1030+
rhs=src2,
1031+
sycl_queue=exec_q,
1032+
depends=dep_evs,
1033+
)
1034+
_manager.add_event_pair(ht_, comp_ev)
1035+
else:
1036+
buf = dpt.empty_like(src2, dtype=buf_dt)
1037+
dep_evs = _manager.submitted_events
1038+
(
1039+
ht_copy_ev,
1040+
copy_ev,
1041+
) = ti._copy_usm_ndarray_into_usm_ndarray(
1042+
src=src2,
1043+
dst=buf,
1044+
sycl_queue=exec_q,
1045+
depends=dep_evs,
1046+
)
1047+
_manager.add_event_pair(ht_copy_ev, copy_ev)
10041048

1005-
if res_dt != o1_dtype:
1006-
raise ValueError(
1007-
f"Output array of type {res_dt} is needed, " f"got {o1_dtype}"
1008-
)
1049+
buf = dpt.broadcast_to(buf, res_shape)
1050+
ht_, bf_ev = self.binary_inplace_fn_(
1051+
lhs=o1,
1052+
rhs=buf,
1053+
sycl_queue=exec_q,
1054+
depends=[copy_ev],
1055+
)
1056+
_manager.add_event_pair(ht_, bf_ev)
10091057

1010-
_manager = SequentialOrderManager[exec_q]
1011-
if isinstance(o2, dpt.usm_ndarray):
1012-
src2 = o2
1013-
if (
1014-
ti._array_overlap(o2, o1)
1015-
and not ti._same_logical_tensors(o2, o1)
1016-
and buf_dt is None
1017-
):
1018-
buf_dt = o2_dtype
1019-
else:
1020-
src2 = dpt.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q)
1021-
if buf_dt is None:
1022-
if src2.shape != res_shape:
1023-
src2 = dpt.broadcast_to(src2, res_shape)
1024-
dep_evs = _manager.submitted_events
1025-
ht_, comp_ev = self.binary_inplace_fn_(
1026-
lhs=o1,
1027-
rhs=src2,
1028-
sycl_queue=exec_q,
1029-
depends=dep_evs,
1030-
)
1031-
_manager.add_event_pair(ht_, comp_ev)
1058+
return o1
10321059
else:
1033-
buf = dpt.empty_like(src2, dtype=buf_dt)
1034-
dep_evs = _manager.submitted_events
1035-
(
1036-
ht_copy_ev,
1037-
copy_ev,
1038-
) = ti._copy_usm_ndarray_into_usm_ndarray(
1039-
src=src2,
1040-
dst=buf,
1041-
sycl_queue=exec_q,
1042-
depends=dep_evs,
1043-
)
1044-
_manager.add_event_pair(ht_copy_ev, copy_ev)
1045-
1046-
buf = dpt.broadcast_to(buf, res_shape)
1047-
ht_, bf_ev = self.binary_inplace_fn_(
1048-
lhs=o1,
1049-
rhs=buf,
1050-
sycl_queue=exec_q,
1051-
depends=[copy_ev],
1060+
raise ValueError(
1061+
"binary function does not have a dedicated in-place "
1062+
"implementation"
10521063
)
1053-
_manager.add_event_pair(ht_, bf_ev)
1054-
1055-
return o1

0 commit comments

Comments
 (0)