Skip to content

Commit d15b934

Browse files
Update test_sum_float in test_sum.py
1 parent 42b199d commit d15b934

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

tests/test_sum.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
assert_dtype_allclose,
1010
get_float_dtypes,
1111
has_support_aspect64,
12+
is_cpu_device,
1213
)
1314

1415

@@ -27,10 +28,20 @@ def test_sum_float(dtype):
2728
)
2829
ia = dpnp.array(a)
2930

31+
# Flag for type check in special cases
32+
# Skip dtype checks when dpnp handles float32 arrays on CPU
33+
# as `dpnp.sum()` and `numpy.sum()` return different dtypes:
34+
# numpy - 'float32', dpnp - 'float64'
35+
check_dtype = (ia.dtype != dpnp.float32) or not is_cpu_device(
36+
ia.sycl_device
37+
)
3038
for axis in range(len(a)):
3139
result = dpnp.sum(ia, axis=axis)
3240
expected = numpy.sum(a, axis=axis)
33-
assert_dtype_allclose(result, expected)
41+
assert_dtype_allclose(result, expected, check_type=check_dtype)
42+
if not check_dtype:
43+
# Ensure dtype kind matches when check_dtype is False
44+
assert result.dtype.kind == expected.dtype.kind
3445

3546

3647
def test_sum_int():

0 commit comments

Comments
 (0)