Skip to content

Commit 616c21e

Browse files
Implements logaddexp and hypot (#1272)
* Implements logaddexp and hypot * BinaryElementwiseFunc get acceptance_fn parameter, set individually for tensor.divide The _find_buf_dtype2 must stay consistent with _find_buf_dtype in that for binary_fn that can be expressed via unary_fn(other_binary_fn(unary_fn(in1), unary_fn(in2))), e.g. logaddexp(x1, x2) == log(exp(x1) + exp(x2)) the promotion of integral types to fp type must be consistence with both evaluation. The special-casing, necessary for proper working of dpt.divide, where dpt.divide(dpt.asarray(1, dtype="i1"), dpt.asarray(1, dtype="u1")) should give default fp type resulted in logaddexp(dpt.asarray(1, dtype="i1"), dpt.asarray(1, dtype="u1")) returning array of different data-type than if evaluated alternatively, since dpt.exp(dpt.asarray(1, dtype="i1")) and dpt.exp(dpt.asarray(1, dtype="u1")) both gave float16 arrays, and the result was of float16 data type. Now, both behaviors are fixed: ``` In [5]: dpt.log(dpt.add(dpt.exp(dpt.asarray(1, dtype="u1")), dpt.exp(dpt.asarray(1, dtype="i1")))) Out[5]: usm_ndarray(1.693, dtype=float16) In [6]: dpt.logaddexp(dpt.asarray(1, dtype="u1"), dpt.asarray(1, dtype="i1")) Out[6]: usm_ndarray(1.693, dtype=float16) In [7]: dpt.divide(dpt.asarray(1, dtype="u1"), dpt.asarray(1, dtype="i1")) Out[7]: usm_ndarray(1., dtype=float32) ``` * Only try calling _inplace if inplace function is defined * Since logaddexp is not defined for complex types, do not try complex Python scalar * Implement logaddexp in numerically stable way * Added missing argument * Fixed computing allclose testing tolerance in test_dtype_matrix test When NumPy's result has coarser dtype than that of DPCTL we must use the largest resolution from each dtype. * Calls of the type log(1 + X) changed to log1p(X) in logaddexp --------- Co-authored-by: Oleksandr Pavlyk <[email protected]>
1 parent 7c0a54f commit 616c21e

File tree

10 files changed

+1218
-29
lines changed

10 files changed

+1218
-29
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
floor_divide,
106106
greater,
107107
greater_equal,
108+
hypot,
108109
imag,
109110
isfinite,
110111
isinf,
@@ -115,6 +116,7 @@
115116
log1p,
116117
log2,
117118
log10,
119+
logaddexp,
118120
logical_and,
119121
logical_not,
120122
logical_or,
@@ -222,6 +224,7 @@
222224
"floor_divide",
223225
"greater",
224226
"greater_equal",
227+
"hypot",
225228
"imag",
226229
"isfinite",
227230
"isinf",
@@ -241,6 +244,7 @@
241244
"not_equal",
242245
"positive",
243246
"pow",
247+
"logaddexp",
244248
"proj",
245249
"real",
246250
"sin",

dpctl/tensor/_elementwise_common.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from dpctl.utils import ExecutionPlacementError
2828

2929
from ._type_utils import (
30+
_acceptance_fn_default,
3031
_empty_like_orderK,
3132
_empty_like_pair_orderK,
3233
_find_buf_dtype,
@@ -48,6 +49,12 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
4849
self.unary_fn_ = unary_dp_impl_fn
4950
self.__doc__ = docs
5051

52+
def __str__(self):
53+
return f"<{self.__name__} '{self.name_}'>"
54+
55+
def __repr__(self):
56+
return f"<{self.__name__} '{self.name_}'>"
57+
5158
def __call__(self, x, out=None, order="K"):
5259
if not isinstance(x, dpt.usm_ndarray):
5360
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
@@ -357,27 +364,33 @@ def __init__(
357364
binary_dp_impl_fn,
358365
docs,
359366
binary_inplace_fn=None,
367+
acceptance_fn=None,
360368
):
361369
self.__name__ = "BinaryElementwiseFunc"
362370
self.name_ = name
363371
self.result_type_resolver_fn_ = result_type_resolver_fn
364372
self.binary_fn_ = binary_dp_impl_fn
365373
self.binary_inplace_fn_ = binary_inplace_fn
366374
self.__doc__ = docs
375+
if callable(acceptance_fn):
376+
self.acceptance_fn_ = acceptance_fn
377+
else:
378+
self.acceptance_fn_ = _acceptance_fn_default
367379

368380
def __str__(self):
369-
return f"<BinaryElementwiseFunc '{self.name_}'>"
381+
return f"<{self.__name__} '{self.name_}'>"
370382

371383
def __repr__(self):
372-
return f"<BinaryElementwiseFunc '{self.name_}'>"
384+
return f"<{self.__name__} '{self.name_}'>"
373385

374386
def __call__(self, o1, o2, out=None, order="K"):
375387
# FIXME: replace with check against base array
376388
# when views can be identified
377-
if o1 is out:
378-
return self._inplace(o1, o2)
379-
elif o2 is out:
380-
return self._inplace(o2, o1)
389+
if self.binary_inplace_fn_:
390+
if o1 is out:
391+
return self._inplace(o1, o2)
392+
elif o2 is out:
393+
return self._inplace(o2, o1)
381394

382395
if order not in ["K", "C", "F", "A"]:
383396
order = "K"
@@ -445,7 +458,11 @@ def __call__(self, o1, o2, out=None, order="K"):
445458
o1_dtype, o2_dtype = _resolve_weak_types(o1_dtype, o2_dtype, sycl_dev)
446459

447460
buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
448-
o1_dtype, o2_dtype, self.result_type_resolver_fn_, sycl_dev
461+
o1_dtype,
462+
o2_dtype,
463+
self.result_type_resolver_fn_,
464+
sycl_dev,
465+
acceptance_fn=self.acceptance_fn_,
449466
)
450467

451468
if res_dt is None:

dpctl/tensor/_elementwise_funcs.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dpctl.tensor._tensor_impl as ti
1818

1919
from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc
20+
from ._type_utils import _acceptance_fn_divide
2021

2122
# U01: ==== ABS (x)
2223
_abs_docstring_ = """
@@ -218,7 +219,11 @@
218219
"""
219220

220221
divide = BinaryElementwiseFunc(
221-
"divide", ti._divide_result_type, ti._divide, _divide_docstring_
222+
"divide",
223+
ti._divide_result_type,
224+
ti._divide,
225+
_divide_docstring_,
226+
acceptance_fn=_acceptance_fn_divide,
222227
)
223228

224229
# B09: ==== EQUAL (x1, x2)
@@ -665,7 +670,32 @@
665670
)
666671

667672
# B15: ==== LOGADDEXP (x1, x2)
668-
# FIXME: implement B15
673+
_logaddexp_docstring_ = """
674+
logaddexp(x1, x2, out=None, order='K')
675+
676+
Calculates the ratio for each element `x1_i` of the input array `x1` with
677+
the respective element `x2_i` of the input array `x2`.
678+
679+
Args:
680+
x1 (usm_ndarray):
681+
First input array, expected to have numeric data type.
682+
x2 (usm_ndarray):
683+
Second input array, also expected to have numeric data type.
684+
out ({None, usm_ndarray}, optional):
685+
Output array to populate.
686+
Array have the correct shape and the expected data type.
687+
order ("C","F","A","K", optional):
688+
Memory layout of the newly output array, if parameter `out` is `None`.
689+
Default: "K".
690+
Returns:
691+
usm_narray:
692+
An array containing the result of element-wise division. The data type
693+
of the returned array is determined by the Type Promotion Rules.
694+
"""
695+
696+
logaddexp = BinaryElementwiseFunc(
697+
"logaddexp", ti._logaddexp_result_type, ti._logaddexp, _logaddexp_docstring_
698+
)
669699

670700
# B16: ==== LOGICAL_AND (x1, x2)
671701
_logical_and_docstring_ = """
@@ -1100,12 +1130,40 @@
11001130
order ("C","F","A","K", optional):
11011131
Memory layout of the newly output array, if parameter `out` is `None`.
11021132
Default: "K".
1133+
Returns:
1134+
usm_narray:
1135+
An array containing the result of element-wise division. The data type
1136+
of the returned array is determined by the Type Promotion Rules.
1137+
"""
1138+
trunc = UnaryElementwiseFunc(
1139+
"trunc", ti._trunc_result_type, ti._trunc, _trunc_docstring
1140+
)
1141+
1142+
1143+
# B24: ==== HYPOT (x1, x2)
1144+
_hypot_docstring_ = """
1145+
hypot(x1, x2, out=None, order='K')
1146+
1147+
Calculates the ratio for each element `x1_i` of the input array `x1` with
1148+
the respective element `x2_i` of the input array `x2`.
1149+
1150+
Args:
1151+
x1 (usm_ndarray):
1152+
First input array, expected to have numeric data type.
1153+
x2 (usm_ndarray):
1154+
Second input array, also expected to have numeric data type.
1155+
out ({None, usm_ndarray}, optional):
1156+
Output array to populate.
1157+
Array have the correct shape and the expected data type.
1158+
order ("C","F","A","K", optional):
1159+
Memory layout of the newly output array, if parameter `out` is `None`.
1160+
Default: "K".
11031161
Returns:
11041162
usm_narray:
11051163
An array containing the element-wise truncated value of input array.
11061164
The returned array has the same data type as `x`.
11071165
"""
11081166

1109-
trunc = UnaryElementwiseFunc(
1110-
"trunc", ti._trunc_result_type, ti._trunc, _trunc_docstring
1167+
hypot = BinaryElementwiseFunc(
1168+
"hypot", ti._hypot_result_type, ti._hypot, _hypot_docstring_
11111169
)

dpctl/tensor/_type_utils.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,34 @@ def _get_device_default_dtype(dt_kind, sycl_dev):
255255
raise RuntimeError
256256

257257

258-
def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
258+
def _acceptance_fn_default(
259+
arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
260+
):
261+
return True
262+
263+
264+
def _acceptance_fn_divide(
265+
arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
266+
):
267+
# both are being promoted, if the kind of result is
268+
# different than the kind of original input dtypes,
269+
# we use default dtype for the resulting kind.
270+
# This covers, e.g. (array_dtype_i1 / array_dtype_u1)
271+
# result of which in divide is double (in NumPy), but
272+
# regular type promotion rules peg at float16
273+
if (ret_buf1_dt.kind != arg1_dtype.kind) and (
274+
ret_buf2_dt.kind != arg2_dtype.kind
275+
):
276+
default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
277+
if res_dt == default_dt:
278+
return True
279+
else:
280+
return False
281+
else:
282+
return True
283+
284+
285+
def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
259286
res_dt = query_fn(arg1_dtype, arg2_dtype)
260287
if res_dt:
261288
return None, None, res_dt
@@ -275,21 +302,18 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
275302
if ret_buf1_dt is None or ret_buf2_dt is None:
276303
return ret_buf1_dt, ret_buf2_dt, res_dt
277304
else:
278-
# both are being promoted, if the kind of result is
279-
# different than the kind of original input dtypes,
280-
# we must use default dtype for the resulting kind.
281-
if (res_dt.kind != arg1_dtype.kind) and (
282-
res_dt.kind != arg2_dtype.kind
283-
):
284-
default_dt = _get_device_default_dtype(
285-
res_dt.kind, sycl_dev
286-
)
287-
if res_dt == default_dt:
288-
return ret_buf1_dt, ret_buf2_dt, res_dt
289-
else:
290-
continue
291-
else:
305+
acceptable = acceptance_fn(
306+
arg1_dtype,
307+
arg2_dtype,
308+
ret_buf1_dt,
309+
ret_buf2_dt,
310+
res_dt,
311+
sycl_dev,
312+
)
313+
if acceptable:
292314
return ret_buf1_dt, ret_buf2_dt, res_dt
315+
else:
316+
continue
293317

294318
return None, None, None
295319

@@ -318,4 +342,6 @@ def _find_inplace_dtype(lhs_dtype, rhs_dtype, query_fn, sycl_dev):
318342
"_empty_like_orderK",
319343
"_empty_like_pair_orderK",
320344
"_to_device_supported_dtype",
345+
"_acceptance_fn_default",
346+
"_acceptance_fn_divide",
321347
]

0 commit comments

Comments
 (0)