17
17
import dpctl
18
18
import dpctl .tensor as dpt
19
19
import dpctl .tensor ._tensor_impl as ti
20
- from dpctl .tensor ._manipulation_functions import _broadcast_shapes
20
+ from dpctl .tensor ._elementwise_common import (
21
+ _get_dtype ,
22
+ _get_queue_usm_type ,
23
+ _get_shape ,
24
+ _validate_dtype ,
25
+ )
26
+ from dpctl .tensor ._manipulation_functions import _broadcast_shape_impl
21
27
from dpctl .utils import ExecutionPlacementError , SequentialOrderManager
22
28
23
29
from ._copy_utils import _empty_like_orderK , _empty_like_triple_orderK
24
- from ._type_utils import _all_data_types , _can_cast
30
+ from ._type_utils import (
31
+ WeakBooleanType ,
32
+ WeakComplexType ,
33
+ WeakFloatingType ,
34
+ WeakIntegralType ,
35
+ _all_data_types ,
36
+ _can_cast ,
37
+ _is_weak_dtype ,
38
+ _strong_dtype_num_kind ,
39
+ _to_device_supported_dtype ,
40
+ _weak_type_num_kind ,
41
+ )
42
+
43
+
44
+ def _default_dtype_from_weak_type (dt , dev ):
45
+ if isinstance (dt , WeakBooleanType ):
46
+ return dpt .bool
47
+ if isinstance (dt , WeakIntegralType ):
48
+ return dpt .dtype (ti .default_device_int_type (dev ))
49
+ if isinstance (dt , WeakFloatingType ):
50
+ return dpt .dtype (ti .default_device_fp_type (dev ))
51
+ if isinstance (dt , WeakComplexType ):
52
+ return dpt .dtype (ti .default_device_complex_type (dev ))
53
+
54
+
55
+ def _resolve_two_weak_types (o1_dtype , o2_dtype , dev ):
56
+ "Resolves two weak data types per NEP-0050"
57
+ if _is_weak_dtype (o1_dtype ):
58
+ if _is_weak_dtype (o2_dtype ):
59
+ return _default_dtype_from_weak_type (
60
+ o1_dtype , dev
61
+ ), _default_dtype_from_weak_type (o2_dtype , dev )
62
+ o1_kind_num = _weak_type_num_kind (o1_dtype )
63
+ o2_kind_num = _strong_dtype_num_kind (o2_dtype )
64
+ if o1_kind_num > o2_kind_num :
65
+ if isinstance (o1_dtype , WeakIntegralType ):
66
+ return dpt .dtype (ti .default_device_int_type (dev )), o2_dtype
67
+ if isinstance (o1_dtype , WeakComplexType ):
68
+ if o2_dtype is dpt .float16 or o2_dtype is dpt .float32 :
69
+ return dpt .complex64 , o2_dtype
70
+ return (
71
+ _to_device_supported_dtype (dpt .complex128 , dev ),
72
+ o2_dtype ,
73
+ )
74
+ return _to_device_supported_dtype (dpt .float64 , dev ), o2_dtype
75
+ else :
76
+ return o2_dtype , o2_dtype
77
+ elif _is_weak_dtype (o2_dtype ):
78
+ o1_kind_num = _strong_dtype_num_kind (o1_dtype )
79
+ o2_kind_num = _weak_type_num_kind (o2_dtype )
80
+ if o2_kind_num > o1_kind_num :
81
+ if isinstance (o2_dtype , WeakIntegralType ):
82
+ return o1_dtype , dpt .dtype (ti .default_device_int_type (dev ))
83
+ if isinstance (o2_dtype , WeakComplexType ):
84
+ if o1_dtype is dpt .float16 or o1_dtype is dpt .float32 :
85
+ return o1_dtype , dpt .complex64
86
+ return o1_dtype , _to_device_supported_dtype (dpt .complex128 , dev )
87
+ return (
88
+ o1_dtype ,
89
+ _to_device_supported_dtype (dpt .float64 , dev ),
90
+ )
91
+ else :
92
+ return o1_dtype , o1_dtype
93
+ else :
94
+ return o1_dtype , o2_dtype
25
95
26
96
27
97
def _where_result_type (dt1 , dt2 , dev ):
@@ -51,16 +121,17 @@ def where(condition, x1, x2, /, *, order="K", out=None):
51
121
and otherwise yields from ``x2``.
52
122
Must be compatible with ``x1`` and ``x2`` according
53
123
to broadcasting rules.
54
- x1 (usm_ndarray): Array from which values are chosen when
55
- ``condition`` is ``True``.
124
+ x1 (Union[ usm_ndarray, bool, int, float, complex]):
125
+ Array from which values are chosen when ``condition`` is ``True``.
56
126
Must be compatible with ``condition`` and ``x2`` according
57
127
to broadcasting rules.
58
- x2 (usm_ndarray): Array from which values are chosen when
59
- ``condition`` is not ``True``.
128
+ x2 (Union[usm_ndarray, bool, int, float, complex]):
129
+ Array from which values are chosen when ``condition`` is not
130
+ ``True``.
60
131
Must be compatible with ``condition`` and ``x2`` according
61
132
to broadcasting rules.
62
133
order (``"K"``, ``"C"``, ``"F"``, ``"A"``, optional):
63
- Memory layout of the new output arra ,
134
+ Memory layout of the new output array ,
64
135
if parameter ``out`` is ``None``.
65
136
Default: ``"K"``.
66
137
out (Optional[usm_ndarray]):
@@ -81,36 +152,90 @@ def where(condition, x1, x2, /, *, order="K", out=None):
81
152
raise TypeError (
82
153
"Expecting dpctl.tensor.usm_ndarray type, " f"got { type (condition )} "
83
154
)
84
- if not isinstance (x1 , dpt .usm_ndarray ):
85
- raise TypeError (
86
- "Expecting dpctl.tensor.usm_ndarray type, " f"got { type (x1 )} "
155
+ if order not in ["K" , "C" , "F" , "A" ]:
156
+ order = "K"
157
+ q1 , condition_usm_type = condition .sycl_queue , condition .usm_type
158
+ q2 , x1_usm_type = _get_queue_usm_type (x1 )
159
+ q3 , x2_usm_type = _get_queue_usm_type (x2 )
160
+ if q2 is None and q3 is None :
161
+ exec_q = q1
162
+ out_usm_type = condition_usm_type
163
+ elif q3 is None :
164
+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
165
+ if exec_q is None :
166
+ raise ExecutionPlacementError (
167
+ "Execution placement can not be unambiguously inferred "
168
+ "from input arguments."
169
+ )
170
+ out_usm_type = dpctl .utils .get_coerced_usm_type (
171
+ (
172
+ condition_usm_type ,
173
+ x1_usm_type ,
174
+ )
87
175
)
88
- if not isinstance (x2 , dpt .usm_ndarray ):
176
+ elif q2 is None :
177
+ exec_q = dpctl .utils .get_execution_queue ((q1 , q3 ))
178
+ if exec_q is None :
179
+ raise ExecutionPlacementError (
180
+ "Execution placement can not be unambiguously inferred "
181
+ "from input arguments."
182
+ )
183
+ out_usm_type = dpctl .utils .get_coerced_usm_type (
184
+ (
185
+ condition_usm_type ,
186
+ x2_usm_type ,
187
+ )
188
+ )
189
+ else :
190
+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 , q3 ))
191
+ if exec_q is None :
192
+ raise ExecutionPlacementError (
193
+ "Execution placement can not be unambiguously inferred "
194
+ "from input arguments."
195
+ )
196
+ out_usm_type = dpctl .utils .get_coerced_usm_type (
197
+ (
198
+ condition_usm_type ,
199
+ x1_usm_type ,
200
+ x2_usm_type ,
201
+ )
202
+ )
203
+ dpctl .utils .validate_usm_type (out_usm_type , allow_none = False )
204
+ condition_shape = condition .shape
205
+ x1_shape = _get_shape (x1 )
206
+ x2_shape = _get_shape (x2 )
207
+ if not all (
208
+ isinstance (s , (tuple , list ))
209
+ for s in (
210
+ x1_shape ,
211
+ x2_shape ,
212
+ )
213
+ ):
89
214
raise TypeError (
90
- "Expecting dpctl.tensor.usm_ndarray type, " f"got { type (x2 )} "
215
+ "Shape of arguments can not be inferred. "
216
+ "Arguments are expected to be "
217
+ "lists, tuples, or both"
91
218
)
92
- if order not in [ "K" , "C" , "F" , "A" ] :
93
- order = "K"
94
- exec_q = dpctl . utils . get_execution_queue (
95
- (
96
- condition . sycl_queue ,
97
- x1 . sycl_queue ,
98
- x2 . sycl_queue ,
219
+ try :
220
+ res_shape = _broadcast_shape_impl (
221
+ [
222
+ condition_shape ,
223
+ x1_shape ,
224
+ x2_shape ,
225
+ ]
99
226
)
100
- )
101
- if exec_q is None :
102
- raise dpctl .utils .ExecutionPlacementError
103
- out_usm_type = dpctl .utils .get_coerced_usm_type (
104
- (
105
- condition .usm_type ,
106
- x1 .usm_type ,
107
- x2 .usm_type ,
227
+ except ValueError :
228
+ raise ValueError (
229
+ "operands could not be broadcast together with shapes "
230
+ f"{ condition_shape } , { x1_shape } , and { x2_shape } "
108
231
)
109
- )
110
-
111
- x1_dtype = x1 .dtype
112
- x2_dtype = x2 .dtype
113
- out_dtype = _where_result_type (x1_dtype , x2_dtype , exec_q .sycl_device )
232
+ sycl_dev = exec_q .sycl_device
233
+ x1_dtype = _get_dtype (x1 , sycl_dev )
234
+ x2_dtype = _get_dtype (x2 , sycl_dev )
235
+ if not all (_validate_dtype (o ) for o in (x1_dtype , x2_dtype )):
236
+ raise ValueError ("Operands have unsupported data types" )
237
+ x1_dtype , x2_dtype = _resolve_two_weak_types (x1_dtype , x2_dtype , sycl_dev )
238
+ out_dtype = _where_result_type (x1_dtype , x2_dtype , sycl_dev )
114
239
if out_dtype is None :
115
240
raise TypeError (
116
241
"function 'where' does not support input "
@@ -119,8 +244,6 @@ def where(condition, x1, x2, /, *, order="K", out=None):
119
244
"to any supported types according to the casting rule ''safe''."
120
245
)
121
246
122
- res_shape = _broadcast_shapes (condition , x1 , x2 )
123
-
124
247
orig_out = out
125
248
if out is not None :
126
249
if not isinstance (out , dpt .usm_ndarray ):
@@ -149,16 +272,25 @@ def where(condition, x1, x2, /, *, order="K", out=None):
149
272
"Input and output allocation queues are not compatible"
150
273
)
151
274
152
- if ti ._array_overlap (condition , out ):
153
- if not ti ._same_logical_tensors (condition , out ):
154
- out = dpt .empty_like (out )
275
+ if ti ._array_overlap (condition , out ) and not ti ._same_logical_tensors (
276
+ condition , out
277
+ ):
278
+ out = dpt .empty_like (out )
155
279
156
- if ti ._array_overlap (x1 , out ):
157
- if not ti ._same_logical_tensors (x1 , out ):
280
+ if isinstance (x1 , dpt .usm_ndarray ):
281
+ if (
282
+ ti ._array_overlap (x1 , out )
283
+ and not ti ._same_logical_tensors (x1 , out )
284
+ and x1_dtype == out_dtype
285
+ ):
158
286
out = dpt .empty_like (out )
159
287
160
- if ti ._array_overlap (x2 , out ):
161
- if not ti ._same_logical_tensors (x2 , out ):
288
+ if isinstance (x2 , dpt .usm_ndarray ):
289
+ if (
290
+ ti ._array_overlap (x2 , out )
291
+ and not ti ._same_logical_tensors (x2 , out )
292
+ and x2_dtype == out_dtype
293
+ ):
162
294
out = dpt .empty_like (out )
163
295
164
296
if order == "A" :
@@ -174,6 +306,10 @@ def where(condition, x1, x2, /, *, order="K", out=None):
174
306
)
175
307
else "C"
176
308
)
309
+ if not isinstance (x1 , dpt .usm_ndarray ):
310
+ x1 = dpt .asarray (x1 , dtype = x1_dtype , sycl_queue = exec_q )
311
+ if not isinstance (x2 , dpt .usm_ndarray ):
312
+ x2 = dpt .asarray (x2 , dtype = x2_dtype , sycl_queue = exec_q )
177
313
178
314
if condition .size == 0 :
179
315
if out is not None :
@@ -236,9 +372,12 @@ def where(condition, x1, x2, /, *, order="K", out=None):
236
372
sycl_queue = exec_q ,
237
373
)
238
374
239
- condition = dpt .broadcast_to (condition , res_shape )
240
- x1 = dpt .broadcast_to (x1 , res_shape )
241
- x2 = dpt .broadcast_to (x2 , res_shape )
375
+ if condition_shape != res_shape :
376
+ condition = dpt .broadcast_to (condition , res_shape )
377
+ if x1_shape != res_shape :
378
+ x1 = dpt .broadcast_to (x1 , res_shape )
379
+ if x2_shape != res_shape :
380
+ x2 = dpt .broadcast_to (x2 , res_shape )
242
381
243
382
dep_evs = _manager .submitted_events
244
383
hev , where_ev = ti ._where (
0 commit comments