Skip to content

Commit 12e024d

Browse files
authored
Merge pull request #1719 from IntelPython/support-scalars-in-where
Add Python scalar support to `dpt.where`
2 parents b7fd613 + 9f449d1 commit 12e024d

File tree

2 files changed

+239
-46
lines changed

2 files changed

+239
-46
lines changed

dpctl/tensor/_search_functions.py

Lines changed: 183 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,81 @@
1717
import dpctl
1818
import dpctl.tensor as dpt
1919
import dpctl.tensor._tensor_impl as ti
20-
from dpctl.tensor._manipulation_functions import _broadcast_shapes
20+
from dpctl.tensor._elementwise_common import (
21+
_get_dtype,
22+
_get_queue_usm_type,
23+
_get_shape,
24+
_validate_dtype,
25+
)
26+
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
2127
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
2228

2329
from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
24-
from ._type_utils import _all_data_types, _can_cast
30+
from ._type_utils import (
31+
WeakBooleanType,
32+
WeakComplexType,
33+
WeakFloatingType,
34+
WeakIntegralType,
35+
_all_data_types,
36+
_can_cast,
37+
_is_weak_dtype,
38+
_strong_dtype_num_kind,
39+
_to_device_supported_dtype,
40+
_weak_type_num_kind,
41+
)
42+
43+
44+
def _default_dtype_from_weak_type(dt, dev):
45+
if isinstance(dt, WeakBooleanType):
46+
return dpt.bool
47+
if isinstance(dt, WeakIntegralType):
48+
return dpt.dtype(ti.default_device_int_type(dev))
49+
if isinstance(dt, WeakFloatingType):
50+
return dpt.dtype(ti.default_device_fp_type(dev))
51+
if isinstance(dt, WeakComplexType):
52+
return dpt.dtype(ti.default_device_complex_type(dev))
53+
54+
55+
def _resolve_two_weak_types(o1_dtype, o2_dtype, dev):
56+
"Resolves two weak data types per NEP-0050"
57+
if _is_weak_dtype(o1_dtype):
58+
if _is_weak_dtype(o2_dtype):
59+
return _default_dtype_from_weak_type(
60+
o1_dtype, dev
61+
), _default_dtype_from_weak_type(o2_dtype, dev)
62+
o1_kind_num = _weak_type_num_kind(o1_dtype)
63+
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
64+
if o1_kind_num > o2_kind_num:
65+
if isinstance(o1_dtype, WeakIntegralType):
66+
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
67+
if isinstance(o1_dtype, WeakComplexType):
68+
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
69+
return dpt.complex64, o2_dtype
70+
return (
71+
_to_device_supported_dtype(dpt.complex128, dev),
72+
o2_dtype,
73+
)
74+
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
75+
else:
76+
return o2_dtype, o2_dtype
77+
elif _is_weak_dtype(o2_dtype):
78+
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
79+
o2_kind_num = _weak_type_num_kind(o2_dtype)
80+
if o2_kind_num > o1_kind_num:
81+
if isinstance(o2_dtype, WeakIntegralType):
82+
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
83+
if isinstance(o2_dtype, WeakComplexType):
84+
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
85+
return o1_dtype, dpt.complex64
86+
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
87+
return (
88+
o1_dtype,
89+
_to_device_supported_dtype(dpt.float64, dev),
90+
)
91+
else:
92+
return o1_dtype, o1_dtype
93+
else:
94+
return o1_dtype, o2_dtype
2595

2696

2797
def _where_result_type(dt1, dt2, dev):
@@ -51,16 +121,17 @@ def where(condition, x1, x2, /, *, order="K", out=None):
51121
and otherwise yields from ``x2``.
52122
Must be compatible with ``x1`` and ``x2`` according
53123
to broadcasting rules.
54-
x1 (usm_ndarray): Array from which values are chosen when
55-
``condition`` is ``True``.
124+
x1 (Union[usm_ndarray, bool, int, float, complex]):
125+
Array from which values are chosen when ``condition`` is ``True``.
56126
Must be compatible with ``condition`` and ``x2`` according
57127
to broadcasting rules.
58-
x2 (usm_ndarray): Array from which values are chosen when
59-
``condition`` is not ``True``.
128+
x2 (Union[usm_ndarray, bool, int, float, complex]):
129+
Array from which values are chosen when ``condition`` is not
130+
``True``.
60131
Must be compatible with ``condition`` and ``x2`` according
61132
to broadcasting rules.
62133
order (``"K"``, ``"C"``, ``"F"``, ``"A"``, optional):
63-
Memory layout of the new output arra,
134+
Memory layout of the new output array,
64135
if parameter ``out`` is ``None``.
65136
Default: ``"K"``.
66137
out (Optional[usm_ndarray]):
@@ -81,36 +152,90 @@ def where(condition, x1, x2, /, *, order="K", out=None):
81152
raise TypeError(
82153
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(condition)}"
83154
)
84-
if not isinstance(x1, dpt.usm_ndarray):
85-
raise TypeError(
86-
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x1)}"
155+
if order not in ["K", "C", "F", "A"]:
156+
order = "K"
157+
q1, condition_usm_type = condition.sycl_queue, condition.usm_type
158+
q2, x1_usm_type = _get_queue_usm_type(x1)
159+
q3, x2_usm_type = _get_queue_usm_type(x2)
160+
if q2 is None and q3 is None:
161+
exec_q = q1
162+
out_usm_type = condition_usm_type
163+
elif q3 is None:
164+
exec_q = dpctl.utils.get_execution_queue((q1, q2))
165+
if exec_q is None:
166+
raise ExecutionPlacementError(
167+
"Execution placement can not be unambiguously inferred "
168+
"from input arguments."
169+
)
170+
out_usm_type = dpctl.utils.get_coerced_usm_type(
171+
(
172+
condition_usm_type,
173+
x1_usm_type,
174+
)
87175
)
88-
if not isinstance(x2, dpt.usm_ndarray):
176+
elif q2 is None:
177+
exec_q = dpctl.utils.get_execution_queue((q1, q3))
178+
if exec_q is None:
179+
raise ExecutionPlacementError(
180+
"Execution placement can not be unambiguously inferred "
181+
"from input arguments."
182+
)
183+
out_usm_type = dpctl.utils.get_coerced_usm_type(
184+
(
185+
condition_usm_type,
186+
x2_usm_type,
187+
)
188+
)
189+
else:
190+
exec_q = dpctl.utils.get_execution_queue((q1, q2, q3))
191+
if exec_q is None:
192+
raise ExecutionPlacementError(
193+
"Execution placement can not be unambiguously inferred "
194+
"from input arguments."
195+
)
196+
out_usm_type = dpctl.utils.get_coerced_usm_type(
197+
(
198+
condition_usm_type,
199+
x1_usm_type,
200+
x2_usm_type,
201+
)
202+
)
203+
dpctl.utils.validate_usm_type(out_usm_type, allow_none=False)
204+
condition_shape = condition.shape
205+
x1_shape = _get_shape(x1)
206+
x2_shape = _get_shape(x2)
207+
if not all(
208+
isinstance(s, (tuple, list))
209+
for s in (
210+
x1_shape,
211+
x2_shape,
212+
)
213+
):
89214
raise TypeError(
90-
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x2)}"
215+
"Shape of arguments can not be inferred. "
216+
"Arguments are expected to be "
217+
"lists, tuples, or both"
91218
)
92-
if order not in ["K", "C", "F", "A"]:
93-
order = "K"
94-
exec_q = dpctl.utils.get_execution_queue(
95-
(
96-
condition.sycl_queue,
97-
x1.sycl_queue,
98-
x2.sycl_queue,
219+
try:
220+
res_shape = _broadcast_shape_impl(
221+
[
222+
condition_shape,
223+
x1_shape,
224+
x2_shape,
225+
]
99226
)
100-
)
101-
if exec_q is None:
102-
raise dpctl.utils.ExecutionPlacementError
103-
out_usm_type = dpctl.utils.get_coerced_usm_type(
104-
(
105-
condition.usm_type,
106-
x1.usm_type,
107-
x2.usm_type,
227+
except ValueError:
228+
raise ValueError(
229+
"operands could not be broadcast together with shapes "
230+
f"{condition_shape}, {x1_shape}, and {x2_shape}"
108231
)
109-
)
110-
111-
x1_dtype = x1.dtype
112-
x2_dtype = x2.dtype
113-
out_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device)
232+
sycl_dev = exec_q.sycl_device
233+
x1_dtype = _get_dtype(x1, sycl_dev)
234+
x2_dtype = _get_dtype(x2, sycl_dev)
235+
if not all(_validate_dtype(o) for o in (x1_dtype, x2_dtype)):
236+
raise ValueError("Operands have unsupported data types")
237+
x1_dtype, x2_dtype = _resolve_two_weak_types(x1_dtype, x2_dtype, sycl_dev)
238+
out_dtype = _where_result_type(x1_dtype, x2_dtype, sycl_dev)
114239
if out_dtype is None:
115240
raise TypeError(
116241
"function 'where' does not support input "
@@ -119,8 +244,6 @@ def where(condition, x1, x2, /, *, order="K", out=None):
119244
"to any supported types according to the casting rule ''safe''."
120245
)
121246

122-
res_shape = _broadcast_shapes(condition, x1, x2)
123-
124247
orig_out = out
125248
if out is not None:
126249
if not isinstance(out, dpt.usm_ndarray):
@@ -149,16 +272,25 @@ def where(condition, x1, x2, /, *, order="K", out=None):
149272
"Input and output allocation queues are not compatible"
150273
)
151274

152-
if ti._array_overlap(condition, out):
153-
if not ti._same_logical_tensors(condition, out):
154-
out = dpt.empty_like(out)
275+
if ti._array_overlap(condition, out) and not ti._same_logical_tensors(
276+
condition, out
277+
):
278+
out = dpt.empty_like(out)
155279

156-
if ti._array_overlap(x1, out):
157-
if not ti._same_logical_tensors(x1, out):
280+
if isinstance(x1, dpt.usm_ndarray):
281+
if (
282+
ti._array_overlap(x1, out)
283+
and not ti._same_logical_tensors(x1, out)
284+
and x1_dtype == out_dtype
285+
):
158286
out = dpt.empty_like(out)
159287

160-
if ti._array_overlap(x2, out):
161-
if not ti._same_logical_tensors(x2, out):
288+
if isinstance(x2, dpt.usm_ndarray):
289+
if (
290+
ti._array_overlap(x2, out)
291+
and not ti._same_logical_tensors(x2, out)
292+
and x2_dtype == out_dtype
293+
):
162294
out = dpt.empty_like(out)
163295

164296
if order == "A":
@@ -174,6 +306,10 @@ def where(condition, x1, x2, /, *, order="K", out=None):
174306
)
175307
else "C"
176308
)
309+
if not isinstance(x1, dpt.usm_ndarray):
310+
x1 = dpt.asarray(x1, dtype=x1_dtype, sycl_queue=exec_q)
311+
if not isinstance(x2, dpt.usm_ndarray):
312+
x2 = dpt.asarray(x2, dtype=x2_dtype, sycl_queue=exec_q)
177313

178314
if condition.size == 0:
179315
if out is not None:
@@ -236,9 +372,12 @@ def where(condition, x1, x2, /, *, order="K", out=None):
236372
sycl_queue=exec_q,
237373
)
238374

239-
condition = dpt.broadcast_to(condition, res_shape)
240-
x1 = dpt.broadcast_to(x1, res_shape)
241-
x2 = dpt.broadcast_to(x2, res_shape)
375+
if condition_shape != res_shape:
376+
condition = dpt.broadcast_to(condition, res_shape)
377+
if x1_shape != res_shape:
378+
x1 = dpt.broadcast_to(x1, res_shape)
379+
if x2_shape != res_shape:
380+
x2 = dpt.broadcast_to(x2, res_shape)
242381

243382
dep_evs = _manager.submitted_events
244383
hev, where_ev = ti._where(

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import ctypes
18+
import itertools
19+
1720
import numpy as np
1821
import pytest
1922
from helper import get_queue_or_skip, skip_if_dtype_not_supported
@@ -350,9 +353,9 @@ def test_where_arg_validation():
350353

351354
with pytest.raises(TypeError):
352355
dpt.where(check, x1, x2)
353-
with pytest.raises(TypeError):
356+
with pytest.raises(ValueError):
354357
dpt.where(x1, check, x2)
355-
with pytest.raises(TypeError):
358+
with pytest.raises(ValueError):
356359
dpt.where(x1, x2, check)
357360

358361

@@ -522,3 +525,54 @@ def test_where_out_arg_validation():
522525
dpt.where(condition, x1, x2, out=out_wrong_shape)
523526
with pytest.raises(ValueError):
524527
dpt.where(condition, x1, x2, out=out_not_writable)
528+
529+
530+
@pytest.mark.parametrize("arr_dt", _all_dtypes)
531+
def test_where_python_scalar(arr_dt):
532+
q = get_queue_or_skip()
533+
skip_if_dtype_not_supported(arr_dt, q)
534+
535+
n1, n2 = 10, 10
536+
condition = dpt.tile(
537+
dpt.reshape(
538+
dpt.asarray([True, False], dtype="?", sycl_queue=q), (1, 2)
539+
),
540+
(n1, n2 // 2),
541+
)
542+
x = dpt.zeros((n1, n2), dtype=arr_dt, sycl_queue=q)
543+
py_scalars = (
544+
bool(0),
545+
int(0),
546+
float(0),
547+
complex(0),
548+
np.float32(0),
549+
ctypes.c_int(0),
550+
)
551+
for sc in py_scalars:
552+
r = dpt.where(condition, x, sc)
553+
assert isinstance(r, dpt.usm_ndarray)
554+
r = dpt.where(condition, sc, x)
555+
assert isinstance(r, dpt.usm_ndarray)
556+
557+
558+
def test_where_two_python_scalars():
559+
get_queue_or_skip()
560+
561+
n1, n2 = 10, 10
562+
condition = dpt.tile(
563+
dpt.reshape(dpt.asarray([True, False], dtype="?"), (1, 2)),
564+
(n1, n2 // 2),
565+
)
566+
567+
py_scalars = [
568+
bool(0),
569+
int(0),
570+
float(0),
571+
complex(0),
572+
np.float32(0),
573+
ctypes.c_int(0),
574+
]
575+
576+
for sc1, sc2 in itertools.product(py_scalars, repeat=2):
577+
r = dpt.where(condition, sc1, sc2)
578+
assert isinstance(r, dpt.usm_ndarray)

0 commit comments

Comments
 (0)