Skip to content

Commit 3f24ef2

Browse files
committed
UnaryElementwiseFunc class now takes an acceptance function
This change was made to mirror promotion behavior of divide in reciprocal Adds getter method for the acceptance function Adds tests for reciprocal
1 parent 4e2440e commit 3f24ef2

File tree

5 files changed

+177
-11
lines changed

5 files changed

+177
-11
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
3030
from ._type_utils import (
31-
_acceptance_fn_default,
31+
_acceptance_fn_default1,
32+
_acceptance_fn_default2,
3233
_all_data_types,
3334
_find_buf_dtype,
3435
_find_buf_dtype2,
@@ -62,17 +63,39 @@ class UnaryElementwiseFunc:
6263
computational tasks complete execution, while the second event
6364
corresponds to computational tasks associated with function
6465
evaluation.
66+
acceptance_fn (callable, optional):
67+
Function to influence type promotion behavior of this unary
68+
function. The function takes 4 arguments:
69+
arg_dtype - Data type of the first argument
70+
buf_dtype - Data type the argument would be cast to
71+
res_dtype - Data type of the output array with function values
72+
sycl_dev - The :class:`dpctl.SyclDevice` where the function
73+
evaluation is carried out.
74+
The function is invoked when the argument of the unary function
75+
requires casting, e.g. the argument of `dpctl.tensor.log` is an
76+
array with integral data type.
6577
docs (str):
6678
Documentation string for the unary function.
6779
"""
6880

69-
def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
81+
def __init__(
82+
self,
83+
name,
84+
result_type_resolver_fn,
85+
unary_dp_impl_fn,
86+
docs,
87+
acceptance_fn=None,
88+
):
7089
self.__name__ = "UnaryElementwiseFunc"
7190
self.name_ = name
7291
self.result_type_resolver_fn_ = result_type_resolver_fn
7392
self.types_ = None
7493
self.unary_fn_ = unary_dp_impl_fn
7594
self.__doc__ = docs
95+
if callable(acceptance_fn):
96+
self.acceptance_fn_ = acceptance_fn
97+
else:
98+
self.acceptance_fn_ = _acceptance_fn_default1
7699

77100
def __str__(self):
78101
return f"<{self.__name__} '{self.name_}'>"
@@ -93,6 +116,24 @@ def get_type_result_resolver_function(self):
93116
"""
94117
return self.result_type_resolver_fn_
95118

119+
def get_type_promotion_path_acceptance_function(self):
120+
"""Returns the acceptance function for this
121+
elementwise binary function.
122+
123+
Acceptance function influences the type promotion
124+
behavior of this unary function.
125+
The function takes 4 arguments:
126+
arg_dtype - Data type of the first argument
127+
buf_dtype - Data type the argument would be cast to
128+
res_dtype - Data type of the output array with function values
129+
sycl_dev - The :class:`dpctl.SyclDevice` where the function
130+
evaluation is carried out.
131+
The function is invoked when the argument of the unary function
132+
requires casting, e.g. the argument of `dpctl.tensor.log` is an
133+
array with integral data type.
134+
"""
135+
return self.acceptance_fn_
136+
96137
@property
97138
def types(self):
98139
"""Returns information about types supported by
@@ -122,7 +163,10 @@ def __call__(self, x, out=None, order="K"):
122163
if order not in ["C", "F", "K", "A"]:
123164
order = "K"
124165
buf_dt, res_dt = _find_buf_dtype(
125-
x.dtype, self.result_type_resolver_fn_, x.sycl_device
166+
x.dtype,
167+
self.result_type_resolver_fn_,
168+
x.sycl_device,
169+
acceptance_fn=self.acceptance_fn_,
126170
)
127171
if res_dt is None:
128172
raise TypeError(
@@ -482,7 +526,7 @@ def __init__(
482526
if callable(acceptance_fn):
483527
self.acceptance_fn_ = acceptance_fn
484528
else:
485-
self.acceptance_fn_ = _acceptance_fn_default
529+
self.acceptance_fn_ = _acceptance_fn_default2
486530

487531
def __str__(self):
488532
return f"<{self.__name__} '{self.name_}'>"

dpctl/tensor/_elementwise_funcs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import dpctl.tensor._tensor_elementwise_impl as ti
1818

1919
from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc
20-
from ._type_utils import _acceptance_fn_divide
20+
from ._type_utils import _acceptance_fn_divide, _acceptance_fn_reciprocal
2121

2222
# U01: ==== ABS (x)
2323
_abs_docstring_ = """
@@ -1916,6 +1916,7 @@
19161916
ti._reciprocal_result_type,
19171917
ti._reciprocal,
19181918
_reciprocal_docstring,
1919+
acceptance_fn=_acceptance_fn_reciprocal,
19191920
)
19201921

19211922

dpctl/tensor/_type_utils.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,27 @@ def _to_device_supported_dtype(dt, dev):
132132
return dt
133133

134134

135-
def _find_buf_dtype(arg_dtype, query_fn, sycl_dev):
135+
def _acceptance_fn_default1(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
136+
return True
137+
138+
139+
def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
140+
# if the kind of result is different from
141+
# the kind of input, use the default data
142+
# we use default dtype for the resulting kind.
143+
# This guarantees alignment of reciprocal and
144+
# divide output types.
145+
if buf_dt.kind != arg_dtype.kind:
146+
default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
147+
if res_dt == default_dt:
148+
return True
149+
else:
150+
return False
151+
else:
152+
return True
153+
154+
155+
def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn):
136156
res_dt = query_fn(arg_dtype)
137157
if res_dt:
138158
return None, res_dt
@@ -144,7 +164,11 @@ def _find_buf_dtype(arg_dtype, query_fn, sycl_dev):
144164
if _can_cast(arg_dtype, buf_dt, _fp16, _fp64):
145165
res_dt = query_fn(buf_dt)
146166
if res_dt:
147-
return buf_dt, res_dt
167+
acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev)
168+
if acceptable:
169+
return buf_dt, res_dt
170+
else:
171+
continue
148172

149173
return None, None
150174

@@ -163,7 +187,7 @@ def _get_device_default_dtype(dt_kind, sycl_dev):
163187
raise RuntimeError
164188

165189

166-
def _acceptance_fn_default(
190+
def _acceptance_fn_default2(
167191
arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
168192
):
169193
return True
@@ -230,6 +254,8 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
230254
"_find_buf_dtype",
231255
"_find_buf_dtype2",
232256
"_to_device_supported_dtype",
233-
"_acceptance_fn_default",
257+
"_acceptance_fn_default1",
258+
"_acceptance_fn_reciprocal",
259+
"_acceptance_fn_default2",
234260
"_acceptance_fn_divide",
235261
]
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import itertools
18+
19+
import pytest
20+
21+
import dpctl.tensor as dpt
22+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
23+
24+
from .utils import _all_dtypes, _complex_fp_dtypes
25+
26+
27+
@pytest.mark.parametrize("dtype", _all_dtypes)
28+
def test_reciprocal_out_type(dtype):
29+
q = get_queue_or_skip()
30+
skip_if_dtype_not_supported(dtype, q)
31+
32+
x = dpt.asarray(1, dtype=dtype, sycl_queue=q)
33+
one = dpt.asarray(1, dtype=dtype, sycl_queue=q)
34+
expected_dtype = dpt.divide(one, x).dtype
35+
assert dpt.reciprocal(x).dtype == expected_dtype
36+
37+
38+
@pytest.mark.parametrize("dtype", _all_dtypes)
39+
def test_reciprocal_output_contig(dtype):
40+
q = get_queue_or_skip()
41+
skip_if_dtype_not_supported(dtype, q)
42+
43+
n_seq = 1027
44+
45+
x = dpt.linspace(1, 13, num=n_seq, dtype=dtype, sycl_queue=q)
46+
res = dpt.reciprocal(x)
47+
expected = 1 / x
48+
tol = 8 * dpt.finfo(res.dtype).resolution
49+
assert dpt.allclose(res, expected, atol=tol, rtol=tol)
50+
51+
52+
@pytest.mark.parametrize("dtype", _all_dtypes)
53+
def test_reciprocal_output_strided(dtype):
54+
q = get_queue_or_skip()
55+
skip_if_dtype_not_supported(dtype, q)
56+
57+
n_seq = 2054
58+
59+
x = dpt.linspace(1, 13, num=n_seq, dtype=dtype, sycl_queue=q)[::-2]
60+
res = dpt.reciprocal(x)
61+
expected = 1 / x
62+
tol = 8 * dpt.finfo(res.dtype).resolution
63+
assert dpt.allclose(res, expected, atol=tol, rtol=tol)
64+
65+
66+
def test_reciprocal_special_cases():
67+
get_queue_or_skip()
68+
69+
x = dpt.asarray([dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4")
70+
res = dpt.reciprocal(x)
71+
expected = dpt.asarray([dpt.nan, dpt.inf, -dpt.inf, 0.0, -0.0], dtype="f4")
72+
assert dpt.allclose(res, expected, equal_nan=True)
73+
74+
75+
@pytest.mark.parametrize("dtype", _complex_fp_dtypes)
76+
def test_reciprocal_complex_special_cases(dtype):
77+
q = get_queue_or_skip()
78+
skip_if_dtype_not_supported(dtype, q)
79+
80+
nans_ = [dpt.nan, -dpt.nan]
81+
infs_ = [dpt.inf, -dpt.inf]
82+
finites_ = [-1.0, -0.0, 0.0, 1.0]
83+
inps_ = nans_ + infs_ + finites_
84+
c_ = [complex(*v) for v in itertools.product(inps_, repeat=2)]
85+
86+
z = dpt.asarray(c_, dtype=dtype)
87+
r = dpt.reciprocal(z)
88+
89+
expected = 1 / z
90+
91+
tol = dpt.finfo(r.dtype).resolution
92+
93+
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)

dpctl/tests/elementwise/test_type_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def _denier_fn(dt):
124124
for fp16 in [True, False]:
125125
dev = MockDevice(fp16, fp64)
126126
arg_dt = dpt.float64
127-
r = tu._find_buf_dtype(arg_dt, _denier_fn, dev)
127+
r = tu._find_buf_dtype(
128+
arg_dt, _denier_fn, dev, tu._acceptance_fn_default1
129+
)
128130
assert r == (
129131
None,
130132
None,
@@ -157,7 +159,7 @@ def _denier_fn(dt1, dt2):
157159
arg1_dt = dpt.float64
158160
arg2_dt = dpt.complex64
159161
r = tu._find_buf_dtype2(
160-
arg1_dt, arg2_dt, _denier_fn, dev, tu._acceptance_fn_default
162+
arg1_dt, arg2_dt, _denier_fn, dev, tu._acceptance_fn_default2
161163
)
162164
assert r == (
163165
None,

0 commit comments

Comments
 (0)