Skip to content

Commit 2ac7f39

Browse files
committed
Added support of dpnp.allclose() for a device without fp64 aspect
1 parent 3a8ba50 commit 2ac7f39

File tree

6 files changed

+144
-74
lines changed

6 files changed

+144
-74
lines changed

dpnp/backend/kernels/dpnp_krnl_logic.cpp

Lines changed: 82 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ DPCTLSyclEventRef dpnp_all_c(DPCTLSyclQueueRef q_ref,
7474
sycl::nd_range<1> gws(gws_range, lws_range);
7575

7676
auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
77-
auto gr = nd_it.get_group();
77+
auto gr = nd_it.get_sub_group();
7878
const auto max_gr_size = gr.get_max_local_range()[0];
7979
const size_t start =
8080
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) +
@@ -127,8 +127,72 @@ DPCTLSyclEventRef (*dpnp_all_ext_c)(DPCTLSyclQueueRef,
127127
const DPCTLEventVectorRef) =
128128
dpnp_all_c<_DataType, _ResultType>;
129129

130-
template <typename _DataType1, typename _DataType2, typename _ResultType>
131-
class dpnp_allclose_c_kernel;
130+
template <typename _DataType1, typename _DataType2, typename _TolType>
131+
class dpnp_allclose_kernel;
132+
133+
template <typename _DataType1, typename _DataType2, typename _TolType>
134+
static sycl::event dpnp_allclose(sycl::queue &q,
135+
const _DataType1 *array1,
136+
const _DataType2 *array2,
137+
bool *result,
138+
const size_t size,
139+
const _TolType rtol_val,
140+
const _TolType atol_val)
141+
{
142+
sycl::event fill_event = q.fill(result, true, 1);
143+
if (!size) {
144+
return fill_event;
145+
}
146+
147+
constexpr size_t lws = 64;
148+
constexpr size_t vec_sz = 8;
149+
150+
auto gws_range =
151+
sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
152+
auto lws_range = sycl::range<1>(lws);
153+
sycl::nd_range<1> gws(gws_range, lws_range);
154+
155+
auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
156+
auto gr = nd_it.get_sub_group();
157+
const auto max_gr_size = gr.get_max_local_range()[0];
158+
const auto gr_size = gr.get_local_linear_range();
159+
const size_t start =
160+
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) +
161+
gr.get_group_linear_id() * max_gr_size);
162+
const size_t end = sycl::min(start + vec_sz * gr_size, size);
163+
164+
// each work-item iterates over "vec_sz" elements in the input arrays
165+
bool partial = true;
166+
167+
for (size_t i = start + gr.get_local_linear_id(); i < end; i += gr_size)
168+
{
169+
if constexpr (std::is_floating_point_v<_DataType1> &&
170+
std::is_floating_point_v<_DataType2>)
171+
{
172+
if (std::isinf(array1[i]) || std::isinf(array2[i])) {
173+
partial &= (array1[i] == array2[i]);
174+
continue;
175+
}
176+
}
177+
partial &= (std::abs(array1[i] - array2[i]) <=
178+
(atol_val + rtol_val * std::abs(array2[i])));
179+
}
180+
partial = sycl::all_of_group(gr, partial);
181+
182+
if (gr.leader() && (partial == false)) {
183+
result[0] = false;
184+
}
185+
};
186+
187+
auto kernel_func = [&](sycl::handler &cgh) {
188+
cgh.depends_on(fill_event);
189+
cgh.parallel_for<
190+
class dpnp_allclose_kernel<_DataType1, _DataType2, _TolType>>(
191+
gws, kernel_parallel_for_func);
192+
};
193+
194+
return q.submit(kernel_func);
195+
}
132196

133197
template <typename _DataType1, typename _DataType2, typename _ResultType>
134198
DPCTLSyclEventRef dpnp_allclose_c(DPCTLSyclQueueRef q_ref,
@@ -140,6 +204,9 @@ DPCTLSyclEventRef dpnp_allclose_c(DPCTLSyclQueueRef q_ref,
140204
double atol_val,
141205
const DPCTLEventVectorRef dep_event_vec_ref)
142206
{
207+
static_assert(std::is_same_v<_ResultType, bool>,
208+
"Boolean result type is required");
209+
143210
// avoid warning unused variable
144211
(void)dep_event_vec_ref;
145212

@@ -152,40 +219,21 @@ DPCTLSyclEventRef dpnp_allclose_c(DPCTLSyclQueueRef q_ref,
152219
sycl::queue q = *(reinterpret_cast<sycl::queue *>(q_ref));
153220
sycl::event event;
154221

155-
DPNPC_ptr_adapter<_DataType1> input1_ptr(q_ref, array1_in, size);
156-
DPNPC_ptr_adapter<_DataType2> input2_ptr(q_ref, array2_in, size);
157-
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, 1, true, true);
158-
const _DataType1 *array1 = input1_ptr.get_ptr();
159-
const _DataType2 *array2 = input2_ptr.get_ptr();
160-
_ResultType *result = result1_ptr.get_ptr();
161-
162-
result[0] = true;
222+
const _DataType1 *array1 = static_cast<const _DataType1 *>(array1_in);
223+
const _DataType2 *array2 = static_cast<const _DataType2 *>(array2_in);
224+
bool *result = static_cast<bool *>(result1);
163225

164-
if (!size) {
165-
return event_ref;
226+
if (q.get_device().has(sycl::aspect::fp64)) {
227+
event =
228+
dpnp_allclose(q, array1, array2, result, size, rtol_val, atol_val);
229+
}
230+
else {
231+
float rtol = static_cast<float>(rtol_val);
232+
float atol = static_cast<float>(atol_val);
233+
event = dpnp_allclose(q, array1, array2, result, size, rtol, atol);
166234
}
167-
168-
sycl::range<1> gws(size);
169-
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
170-
size_t i = global_id[0];
171-
172-
if (std::abs(array1[i] - array2[i]) >
173-
(atol_val + rtol_val * std::abs(array2[i])))
174-
{
175-
result[0] = false;
176-
}
177-
};
178-
179-
auto kernel_func = [&](sycl::handler &cgh) {
180-
cgh.parallel_for<
181-
class dpnp_allclose_c_kernel<_DataType1, _DataType2, _ResultType>>(
182-
gws, kernel_parallel_for_func);
183-
};
184-
185-
event = q.submit(kernel_func);
186235

187236
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
188-
189237
return DPCTLEvent_Copy(event_ref);
190238
}
191239

@@ -269,7 +317,7 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
269317
sycl::nd_range<1> gws(gws_range, lws_range);
270318

271319
auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
272-
auto gr = nd_it.get_group();
320+
auto gr = nd_it.get_sub_group();
273321
const auto max_gr_size = gr.get_max_local_range()[0];
274322
const size_t start =
275323
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) +

dpnp/dpnp_iface_logic.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -152,42 +152,74 @@ def all(x, /, axis=None, out=None, keepdims=False, *, where=True):
152152
)
153153

154154

155-
def allclose(x1, x2, rtol=1.0e-5, atol=1.0e-8, **kwargs):
155+
def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, **kwargs):
156156
"""
157157
Returns True if two arrays are element-wise equal within a tolerance.
158158
159159
For full documentation refer to :obj:`numpy.allclose`.
160160
161+
Returns
162+
-------
163+
out : dpnp.ndarray
164+
A boolean 0-dim array. If its value is ``True``,
165+
two arrays are element-wise equal within a tolerance.
166+
161167
Limitations
162168
-----------
163-
Parameters `x1` and `x2` are supported as either :class:`dpnp.ndarray` or scalar.
169+
Parameters `a` and `b` are supported either as :class:`dpnp.ndarray`,
170+
:class:`dpctl.tensor.usm_ndarray` or scalars, but both `a` and `b`
171+
can not be scalars at the same time.
164172
Keyword argument `kwargs` is currently unsupported.
165173
Otherwise the functions will be executed sequentially on CPU.
166-
Input array data types are limited by supported DPNP :ref:`Data types`.
174+
Parameters `rtol` and `atol` are supported as scalars. Otherwise
175+
``TypeError`` exeption will be raised.
176+
Input array data types are limited by supported integer and
177+
floating DPNP :ref:`Data types`.
178+
179+
See Also
180+
--------
181+
:obj:`dpnp.isclose` : Test whether two arrays are element-wise equal.
182+
:obj:`dpnp.all` : Test whether all elements evaluate to True.
183+
:obj:`dpnp.any` : Test whether any element evaluates to True.
184+
:obj:`dpnp.equal` : Return (x1 == x2) element-wise.
167185
168186
Examples
169187
--------
170188
>>> import dpnp as np
171-
>>> np.allclose([1e10,1e-7], [1.00001e10,1e-8])
172-
>>> False
189+
>>> np.allclose(np.array([1e10, 1e-7]), np.array([1.00001e10, 1e-8]))
190+
array([False])
191+
>>> np.allclose(np.array([1.0, np.nan]), np.array([1.0, np.nan]))
192+
array([False])
193+
>>> np.allclose(np.array([1.0, np.inf]), np.array([1.0, np.inf]))
194+
array([ True])
173195
174196
"""
175197

176-
rtol_is_scalar = dpnp.isscalar(rtol)
177-
atol_is_scalar = dpnp.isscalar(atol)
178-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
179-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
180-
181-
if x1_desc and x2_desc and not kwargs:
182-
if not rtol_is_scalar or not atol_is_scalar:
183-
pass
184-
else:
185-
result_obj = dpnp_allclose(x1_desc, x2_desc, rtol, atol).get_pyobj()
186-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
198+
if dpnp.isscalar(a) and dpnp.isscalar(b):
199+
# at least one of inputs has to be an array
200+
pass
201+
elif kwargs:
202+
pass
203+
else:
204+
if not dpnp.isscalar(rtol):
205+
raise TypeError(
206+
"An argument `rtol` must be a scalar, but got {}".format(
207+
type(rtol)
208+
)
209+
)
210+
elif not dpnp.isscalar(atol):
211+
raise TypeError(
212+
"An argument `atol` must be a scalar, but got {}".format(
213+
type(atol)
214+
)
215+
)
187216

188-
return result
217+
a_desc = dpnp.get_dpnp_descriptor(a, copy_when_nondefault_queue=False)
218+
b_desc = dpnp.get_dpnp_descriptor(b, copy_when_nondefault_queue=False)
219+
if a_desc and b_desc:
220+
return dpnp_allclose(a_desc, b_desc, rtol, atol).get_pyobj()
189221

190-
return call_origin(numpy.allclose, x1, x2, rtol=rtol, atol=atol, **kwargs)
222+
return call_origin(numpy.allclose, a, b, rtol=rtol, atol=atol, **kwargs)
191223

192224

193225
def any(x, /, axis=None, out=None, keepdims=False, *, where=True):

tests/skipped_tests.tbl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,7 @@ tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transpose
452452
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes
453453
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes
454454
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_reversed_vdot
455-
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_array_scalar
456-
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_finite
457-
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_infinite
458-
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_infinite_equal_nan
459-
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_min_int
455+
460456
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed
461457
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal
462458
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_not_equal

tests/skipped_tests_gpu.tbl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -598,11 +598,7 @@ tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transpose
598598
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim
599599
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot_with_out_f_contiguous
600600
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_multidim_vdot
601-
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_array_scalar
602-
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_finite
603-
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_infinite
604-
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_infinite_equal_nan
605-
tests/third_party/cupy/logic_tests/test_comparison.py::TestAllclose::test_allclose_min_int
601+
606602
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed
607603
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal
608604
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_not_equal

tests/test_logic.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,6 @@ def test_all(type, shape):
4444
assert_allclose(dpnp_res, np_res)
4545

4646

47-
@pytest.mark.skipif(
48-
not has_support_aspect64(), reason="Aborted on Iris Xe: SAT-5988"
49-
)
5047
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
5148
def test_allclose(type):
5249
a = numpy.random.rand(10)

tests/third_party/cupy/logic_tests/test_comparison.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,15 @@ class TestAllclose(unittest.TestCase):
121121
@testing.for_all_dtypes()
122122
@testing.numpy_cupy_equal()
123123
def test_allclose_finite(self, xp, dtype):
124-
a = xp.array([0.9e-5, 1.1e-5, 1000 + 1e-4, 1000 - 1e-4], dtype=dtype)
125-
b = xp.array([0, 0, 1000, 1000], dtype=dtype)
124+
a = xp.array([0.9e-5, 1.1e-5, 1000 + 1e-4, 1000 - 1e-4]).astype(dtype)
125+
b = xp.array([0, 0, 1000, 1000]).astype(dtype)
126126
return xp.allclose(a, b)
127127

128128
@testing.for_all_dtypes()
129129
@testing.numpy_cupy_equal()
130130
def test_allclose_min_int(self, xp, dtype):
131-
a = xp.array([0], dtype=dtype)
132-
b = xp.array([numpy.iinfo("i").min], dtype=dtype)
131+
a = xp.array([0]).astype(dtype)
132+
b = xp.array([numpy.iinfo("i").min]).astype(dtype)
133133
return xp.allclose(a, b)
134134

135135
@testing.for_float_dtypes()
@@ -138,24 +138,25 @@ def test_allclose_infinite(self, xp, dtype):
138138
nan = float("nan")
139139
inf = float("inf")
140140
ninf = float("-inf")
141-
a = xp.array([0, nan, nan, 0, inf, ninf], dtype=dtype)
142-
b = xp.array([0, nan, 0, nan, inf, ninf], dtype=dtype)
141+
a = xp.array([0, nan, nan, 0, inf, ninf]).astype(dtype)
142+
b = xp.array([0, nan, 0, nan, inf, ninf]).astype(dtype)
143143
return xp.allclose(a, b)
144144

145145
@testing.for_float_dtypes()
146146
@testing.numpy_cupy_equal()
147+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
147148
def test_allclose_infinite_equal_nan(self, xp, dtype):
148149
nan = float("nan")
149150
inf = float("inf")
150151
ninf = float("-inf")
151-
a = xp.array([0, nan, inf, ninf], dtype=dtype)
152-
b = xp.array([0, nan, inf, ninf], dtype=dtype)
152+
a = xp.array([0, nan, inf, ninf]).astype(dtype)
153+
b = xp.array([0, nan, inf, ninf]).astype(dtype)
153154
return xp.allclose(a, b, equal_nan=True)
154155

155156
@testing.for_all_dtypes()
156157
@testing.numpy_cupy_equal()
157158
def test_allclose_array_scalar(self, xp, dtype):
158-
a = xp.array([0.9e-5, 1.1e-5], dtype=dtype)
159+
a = xp.array([0.9e-5, 1.1e-5]).astype(dtype)
159160
b = xp.dtype(xp.dtype).type(0)
160161
return xp.allclose(a, b)
161162

0 commit comments

Comments
 (0)