Skip to content

Commit 8b300c2

Browse files
Add a check for support of 16 bit types
1 parent 579fcc2 commit 8b300c2

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

tests/helper.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,20 @@ def assert_dtype_allclose(dpnp_arr, numpy_arr, check_type=True):
1212
Assert DPNP and NumPy array based on maximum dtype resolution of input arrays
1313
for floating and complex types.
1414
For other dtypes the assertion is based on exact matching of the arrays.
15-
When 'check_type' True (default), it asserts equal dtypes for exact types
16-
and either equal dtypes or kinds for inexact types, depending on the 64-bit precision
17-
support of the device on which `dpnp_arr` is created.
18-
19-
"""
20-
15+
When 'check_type' is True (default), the function asserts:
16+
- Equal dtypes for exact types.
17+
For inexact types:
18+
- If the numpy array's dtype is `numpy.float16`, checks if the device
19+
of the `dpnp_arr` supports 64-bit precision floating point operations.
20+
If supported, asserts equal dtypes.
21+
Otherwise, asserts equal type kinds.
22+
- For other inexact types, asserts equal dtypes if the device of the `dpnp_arr`
23+
supports 64-bit precision floating point operations or if the numpy array's inexact
24+
dtype is not a double precision type.
25+
Otherwise, asserts equal type kinds.
26+
"""
27+
28+
list_64bit_types = [numpy.float64, numpy.complex128]
2129
is_inexact = lambda x: dpnp.issubdtype(x.dtype, dpnp.inexact)
2230
if is_inexact(dpnp_arr) or is_inexact(numpy_arr):
2331
tol = 8 * max(
@@ -26,10 +34,20 @@ def assert_dtype_allclose(dpnp_arr, numpy_arr, check_type=True):
2634
)
2735
assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol)
2836
if check_type:
29-
if has_support_aspect64(dpnp_arr.sycl_device):
30-
assert dpnp_arr.dtype == numpy_arr.dtype
37+
numpy_arr_dtype = numpy_arr.dtype
38+
dpnp_arr_dtype = dpnp_arr.dtype
39+
dpnp_arr_dev = dpnp_arr.sycl_device
40+
is_np_arr_f2 = numpy_arr_dtype == numpy.float16
41+
42+
if is_np_arr_f2 and has_support_aspect16(dpnp_arr_dev):
43+
assert dpnp_arr_dtype == numpy_arr_dtype
44+
elif is_np_arr_f2 and (
45+
has_support_aspect64(dpnp_arr_dev)
46+
or numpy_arr_dtype not in list_64bit_types
47+
):
48+
assert dpnp_arr_dtype == numpy_arr_dtype
3149
else:
32-
assert dpnp_arr.dtype.kind == numpy_arr.dtype.kind
50+
assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind
3351
else:
3452
assert_array_equal(dpnp_arr.asnumpy(), numpy_arr)
3553
if check_type:

0 commit comments

Comments
 (0)