Skip to content

Commit bc86cad

Browse files
committed
Merge remote-tracking branch 'upstream/master' into gold/2021
2 parents ba2452e + 43f3b7b commit bc86cad

39 files changed

+7453
-116
lines changed

dpctl/tensor/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,28 @@
9494
from ._elementwise_funcs import (
9595
abs,
9696
add,
97+
conj,
9798
cos,
9899
divide,
99100
equal,
101+
exp,
102+
expm1,
103+
floor_divide,
104+
greater,
105+
greater_equal,
106+
imag,
100107
isfinite,
101108
isinf,
102109
isnan,
110+
less,
111+
less_equal,
112+
log,
113+
log1p,
103114
multiply,
115+
not_equal,
116+
proj,
117+
real,
118+
sin,
104119
sqrt,
105120
subtract,
106121
)
@@ -183,14 +198,29 @@
183198
"inf",
184199
"abs",
185200
"add",
201+
"conj",
186202
"cos",
203+
"exp",
204+
"expm1",
205+
"greater",
206+
"greater_equal",
207+
"imag",
187208
"isinf",
188209
"isnan",
189210
"isfinite",
211+
"less",
212+
"less_equal",
213+
"log",
214+
"log1p",
215+
"proj",
216+
"real",
217+
"sin",
190218
"sqrt",
191219
"divide",
192220
"multiply",
193221
"subtract",
194222
"equal",
223+
"not_equal",
195224
"sum",
225+
"floor_divide",
196226
]

dpctl/tensor/_elementwise_common.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,18 @@ def get(self):
162162
return self.o_
163163

164164

165-
class WeakInexactType:
166-
"""Python type representing type of Python real- or
167-
complex-valued floating point objects"""
165+
class WeakFloatingType:
166+
"""Python type representing type of Python floating point objects"""
167+
168+
def __init__(self, o):
169+
self.o_ = o
170+
171+
def get(self):
172+
return self.o_
173+
174+
175+
class WeakComplexType:
176+
"""Python type representing type of Python complex floating point objects"""
168177

169178
def __init__(self, o):
170179
self.o_ = o
@@ -189,14 +198,17 @@ def _get_dtype(o, dev):
189198
return WeakBooleanType(o)
190199
if isinstance(o, int):
191200
return WeakIntegralType(o)
192-
if isinstance(o, (float, complex)):
193-
return WeakInexactType(o)
201+
if isinstance(o, float):
202+
return WeakFloatingType(o)
203+
if isinstance(o, complex):
204+
return WeakComplexType(o)
194205
return np.object_
195206

196207

197208
def _validate_dtype(dt) -> bool:
198209
return isinstance(
199-
dt, (WeakBooleanType, WeakInexactType, WeakIntegralType)
210+
dt,
211+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
200212
) or (
201213
isinstance(dt, dpt.dtype)
202214
and dt
@@ -220,22 +232,24 @@ def _validate_dtype(dt) -> bool:
220232

221233

222234
def _weak_type_num_kind(o):
223-
_map = {"?": 0, "i": 1, "f": 2}
235+
_map = {"?": 0, "i": 1, "f": 2, "c": 3}
224236
if isinstance(o, WeakBooleanType):
225237
return _map["?"]
226238
if isinstance(o, WeakIntegralType):
227239
return _map["i"]
228-
if isinstance(o, WeakInexactType):
240+
if isinstance(o, WeakFloatingType):
229241
return _map["f"]
242+
if isinstance(o, WeakComplexType):
243+
return _map["c"]
230244
raise TypeError(
231245
f"Unexpected type {o} while expecting "
232-
"`WeakBooleanType`, `WeakIntegralType`, or "
233-
"`WeakInexactType`."
246+
"`WeakBooleanType`, `WeakIntegralType`,"
247+
"`WeakFloatingType`, or `WeakComplexType`."
234248
)
235249

236250

237251
def _strong_dtype_num_kind(o):
238-
_map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 2}
252+
_map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
239253
if not isinstance(o, dpt.dtype):
240254
raise TypeError
241255
k = o.kind
@@ -247,10 +261,17 @@ def _strong_dtype_num_kind(o):
247261
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
248262
"Resolves weak data type per NEP-0050"
249263
if isinstance(
250-
o1_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType)
264+
o1_dtype,
265+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
251266
):
252267
if isinstance(
253-
o2_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType)
268+
o2_dtype,
269+
(
270+
WeakBooleanType,
271+
WeakIntegralType,
272+
WeakFloatingType,
273+
WeakComplexType,
274+
),
254275
):
255276
raise ValueError
256277
o1_kind_num = _weak_type_num_kind(o1_dtype)
@@ -260,7 +281,9 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
260281
return dpt.bool, o2_dtype
261282
if isinstance(o1_dtype, WeakIntegralType):
262283
return dpt.int64, o2_dtype
263-
if isinstance(o1_dtype.get(), complex):
284+
if isinstance(o1_dtype, WeakComplexType):
285+
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
286+
return dpt.complex64, o2_dtype
264287
return (
265288
_to_device_supported_dtype(dpt.complex128, dev),
266289
o2_dtype,
@@ -269,7 +292,8 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
269292
else:
270293
return o2_dtype, o2_dtype
271294
elif isinstance(
272-
o2_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType)
295+
o2_dtype,
296+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
273297
):
274298
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
275299
o2_kind_num = _weak_type_num_kind(o2_dtype)
@@ -278,7 +302,9 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
278302
return o1_dtype, dpt.bool
279303
if isinstance(o2_dtype, WeakIntegralType):
280304
return o1_dtype, dpt.int64
281-
if isinstance(o2_dtype.get(), complex):
305+
if isinstance(o2_dtype, WeakComplexType):
306+
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
307+
return o1_dtype, dpt.complex64
282308
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
283309
return (
284310
o1_dtype,
@@ -433,7 +459,7 @@ def __call__(self, o1, o2, out=None, order="K"):
433459
if out is None:
434460
if order == "K":
435461
out = _empty_like_pair_orderK(
436-
src1, src2, res_dt, res_usm_type, exec_q
462+
src1, src2, res_dt, res_shape, res_usm_type, exec_q
437463
)
438464
else:
439465
if order == "A":
@@ -482,7 +508,7 @@ def __call__(self, o1, o2, out=None, order="K"):
482508
if out is None:
483509
if order == "K":
484510
out = _empty_like_pair_orderK(
485-
src1, buf2, res_dt, res_usm_type, exec_q
511+
src1, buf2, res_dt, res_shape, res_usm_type, exec_q
486512
)
487513
else:
488514
out = dpt.empty(
@@ -524,7 +550,7 @@ def __call__(self, o1, o2, out=None, order="K"):
524550
if out is None:
525551
if order == "K":
526552
out = _empty_like_pair_orderK(
527-
buf1, src2, res_dt, res_usm_type, exec_q
553+
buf1, src2, res_dt, res_shape, res_usm_type, exec_q
528554
)
529555
else:
530556
out = dpt.empty(
@@ -578,7 +604,7 @@ def __call__(self, o1, o2, out=None, order="K"):
578604
if out is None:
579605
if order == "K":
580606
out = _empty_like_pair_orderK(
581-
buf1, buf2, res_dt, res_usm_type, exec_q
607+
buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
582608
)
583609
else:
584610
out = dpt.empty(

0 commit comments

Comments
 (0)