Skip to content

Commit 2f3be1f

Browse files
Merge pull request #1343 from IntelPython/fix-gh-1279
Fix gh-1279, implement tensor.allclose
2 parents bd996b5 + 142190f commit 2f3be1f

File tree

13 files changed

+710
-207
lines changed

13 files changed

+710
-207
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,15 @@ set_source_files_properties(
5858
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
6060
PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-fast-math")
61+
if (UNIX)
62+
set_source_files_properties(
63+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
64+
PROPERTIES COMPILE_DEFINITIONS "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES")
65+
endif()
6166
target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int)
6267
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
6368
if(UNIX)
64-
# this option is support on Linux only
69+
# this option is supported on Linux only
6570
target_link_options(${python_module_name} PRIVATE -fsycl-link-huge-device-code)
6671
endif()
6772
target_include_directories(${python_module_name}

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@
158158
trunc,
159159
)
160160
from ._reduction import sum
161+
from ._testing import allclose
161162

162163
__all__ = [
163164
"Device",
@@ -301,4 +302,5 @@
301302
"tan",
302303
"tanh",
303304
"trunc",
305+
"allclose",
304306
]

dpctl/tensor/_testing.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import numpy as np
18+
19+
import dpctl.tensor as dpt
20+
import dpctl.utils as du
21+
22+
from ._manipulation_functions import _broadcast_shape_impl
23+
from ._type_utils import _to_device_supported_dtype
24+
25+
26+
def _allclose_complex_fp(z1, z2, atol, rtol, equal_nan):
27+
z1r = dpt.real(z1)
28+
z1i = dpt.imag(z1)
29+
z2r = dpt.real(z2)
30+
z2i = dpt.imag(z2)
31+
if equal_nan:
32+
check1 = dpt.all(dpt.isnan(z1r) == dpt.isnan(z2r)) and dpt.all(
33+
dpt.isnan(z1i) == dpt.isnan(z2i)
34+
)
35+
else:
36+
check1 = (
37+
dpt.logical_not(dpt.any(dpt.isnan(z1r)))
38+
and dpt.logical_not(dpt.any(dpt.isnan(z1i)))
39+
) and (
40+
dpt.logical_not(dpt.any(dpt.isnan(z2r)))
41+
and dpt.logical_not(dpt.any(dpt.isnan(z2i)))
42+
)
43+
if not check1:
44+
return check1
45+
mr = dpt.isinf(z1r)
46+
mi = dpt.isinf(z1i)
47+
check2 = dpt.all(mr == dpt.isinf(z2r)) and dpt.all(mi == dpt.isinf(z2i))
48+
if not check2:
49+
return check2
50+
check3 = dpt.all(z1r[mr] == z2r[mr]) and dpt.all(z1i[mi] == z2i[mi])
51+
if not check3:
52+
return check3
53+
mr = dpt.isfinite(z1r)
54+
mi = dpt.isfinite(z1i)
55+
mv1 = z1r[mr]
56+
mv2 = z2r[mr]
57+
check4 = dpt.all(
58+
dpt.abs(mv1 - mv2)
59+
< dpt.maximum(atol, rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2)))
60+
)
61+
if not check4:
62+
return check4
63+
mv1 = z1i[mi]
64+
mv2 = z2i[mi]
65+
check5 = dpt.all(
66+
dpt.abs(mv1 - mv2)
67+
<= dpt.maximum(atol, rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2)))
68+
)
69+
return check5
70+
71+
72+
def _allclose_real_fp(r1, r2, atol, rtol, equal_nan):
73+
if equal_nan:
74+
check1 = dpt.all(dpt.isnan(r1) == dpt.isnan(r2))
75+
else:
76+
check1 = dpt.logical_not(dpt.any(dpt.isnan(r1))) and dpt.logical_not(
77+
dpt.any(dpt.isnan(r2))
78+
)
79+
if not check1:
80+
return check1
81+
mr = dpt.isinf(r1)
82+
check2 = dpt.all(mr == dpt.isinf(r2))
83+
if not check2:
84+
return check2
85+
check3 = dpt.all(r1[mr] == r2[mr])
86+
if not check3:
87+
return check3
88+
m = dpt.isfinite(r1)
89+
mv1 = r1[m]
90+
mv2 = r2[m]
91+
check4 = dpt.all(
92+
dpt.abs(mv1 - mv2)
93+
<= dpt.maximum(atol, rtol * dpt.maximum(dpt.abs(mv1), dpt.abs(mv2)))
94+
)
95+
return check4
96+
97+
98+
def _allclose_others(r1, r2):
99+
return dpt.all(r1 == r2)
100+
101+
102+
def allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False):
103+
"""allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False)
104+
105+
Returns True if two arrays are element-wise equal within tolerances.
106+
107+
The testing is based on the following elementwise comparison:
108+
109+
abs(a - b) <= max(atol, rtol * max(abs(a), abs(b)))
110+
"""
111+
if not isinstance(a1, dpt.usm_ndarray):
112+
raise TypeError(
113+
f"Expected dpctl.tensor.usm_ndarray type, got {type(a1)}."
114+
)
115+
if not isinstance(a2, dpt.usm_ndarray):
116+
raise TypeError(
117+
f"Expected dpctl.tensor.usm_ndarray type, got {type(a2)}."
118+
)
119+
atol = float(atol)
120+
rtol = float(rtol)
121+
if atol < 0.0 or rtol < 0.0:
122+
raise ValueError(
123+
"Absolute and relative tolerances must be non-negative"
124+
)
125+
equal_nan = bool(equal_nan)
126+
exec_q = du.get_execution_queue(tuple(a.sycl_queue for a in (a1, a2)))
127+
if exec_q is None:
128+
raise du.ExecutionPlacementError(
129+
"Execution placement can not be unambiguously inferred "
130+
"from input arguments."
131+
)
132+
res_sh = _broadcast_shape_impl([a1.shape, a2.shape])
133+
b1 = a1
134+
b2 = a2
135+
if b1.dtype == b2.dtype:
136+
res_dt = b1.dtype
137+
else:
138+
res_dt = np.promote_types(b1.dtype, b2.dtype)
139+
res_dt = _to_device_supported_dtype(res_dt, exec_q.sycl_device)
140+
b1 = dpt.astype(b1, res_dt)
141+
b2 = dpt.astype(b2, res_dt)
142+
143+
b1 = dpt.broadcast_to(b1, res_sh)
144+
b2 = dpt.broadcast_to(b2, res_sh)
145+
146+
k = b1.dtype.kind
147+
if k == "c":
148+
return _allclose_complex_fp(b1, b2, atol, rtol, equal_nan)
149+
elif k == "f":
150+
return _allclose_real_fp(b1, b2, atol, rtol, equal_nan)
151+
else:
152+
return _allclose_others(b1, b2)

dpctl/tensor/_usmarray.pxd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ cdef api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]:
5858

5959
cdef void _reset(usm_ndarray self)
6060
cdef void _cleanup(usm_ndarray self)
61-
cdef usm_ndarray _clone(usm_ndarray self)
6261
cdef Py_ssize_t get_offset(usm_ndarray self) except *
6362

6463
cdef char* get_data(self)

0 commit comments

Comments
 (0)