@@ -12,12 +12,20 @@ def assert_dtype_allclose(dpnp_arr, numpy_arr, check_type=True):
12
12
Assert DPNP and NumPy array based on maximum dtype resolution of input arrays
13
13
for floating and complex types.
14
14
For other dtypes the assertion is based on exact matching of the arrays.
15
- When 'check_type' True (default), it asserts equal dtypes for exact types
16
- and either equal dtypes or kinds for inexact types, depending on the 64-bit precision
17
- support of the device on which `dpnp_arr` is created.
18
-
19
- """
20
-
15
+ When 'check_type' is True (default), the function asserts:
16
+ - Equal dtypes for exact types.
17
+ For inexact types:
18
+ - If the numpy array's dtype is `numpy.float16`, checks if the device
19
+ of the `dpnp_arr` supports 64-bit precision floating point operations.
20
+ If supported, asserts equal dtypes.
21
+ Otherwise, asserts equal type kinds.
22
+ - For other inexact types, asserts equal dtypes if the device of the `dpnp_arr`
23
+ supports 64-bit precision floating point operations or if the numpy array's inexact
24
+ dtype is not a double precision type.
25
+ Otherwise, asserts equal type kinds.
26
+ """
27
+
28
+ list_64bit_types = [numpy .float64 , numpy .complex128 ]
21
29
is_inexact = lambda x : dpnp .issubdtype (x .dtype , dpnp .inexact )
22
30
if is_inexact (dpnp_arr ) or is_inexact (numpy_arr ):
23
31
tol = 8 * max (
@@ -26,10 +34,20 @@ def assert_dtype_allclose(dpnp_arr, numpy_arr, check_type=True):
26
34
)
27
35
assert_allclose (dpnp_arr .asnumpy (), numpy_arr , atol = tol , rtol = tol )
28
36
if check_type :
29
- if has_support_aspect64 (dpnp_arr .sycl_device ):
30
- assert dpnp_arr .dtype == numpy_arr .dtype
37
+ numpy_arr_dtype = numpy_arr .dtype
38
+ dpnp_arr_dtype = dpnp_arr .dtype
39
+ dpnp_arr_dev = dpnp_arr .sycl_device
40
+ is_np_arr_f2 = numpy_arr_dtype == numpy .float16
41
+
42
+ if is_np_arr_f2 and has_support_aspect16 (dpnp_arr_dev ):
43
+ assert dpnp_arr_dtype == numpy_arr_dtype
44
+ elif is_np_arr_f2 and (
45
+ has_support_aspect64 (dpnp_arr_dev )
46
+ or numpy_arr_dtype not in list_64bit_types
47
+ ):
48
+ assert dpnp_arr_dtype == numpy_arr_dtype
31
49
else :
32
- assert dpnp_arr . dtype . kind == numpy_arr . dtype .kind
50
+ assert dpnp_arr_dtype . kind == numpy_arr_dtype .kind
33
51
else :
34
52
assert_array_equal (dpnp_arr .asnumpy (), numpy_arr )
35
53
if check_type :
0 commit comments