Skip to content

Commit 36d63b7

Browse files
committed
fix TestResultType
1 parent 29b2f87 commit 36d63b7

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

tests/third_party/cupy/test_type_routines.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,26 @@ def test_result_type(self, xp, dtype1, dtype2):
103103
flag2 = isinstance(input2, (numpy.ndarray, cupy.ndarray))
104104
dt1 = cupy.dtype(input1) if not flag1 else None
105105
dt2 = cupy.dtype(input2) if not flag2 else None
106-
# dpnp takes into account devices capabilities only if one of the
106+
# dpnp takes into account device capabilities only if one of the
107107
# inputs is an array, for such a case, if the other dtype is not
108108
# supported by device, dpnp raise ValueError. So, we skip the test.
109109
if flag1 or flag2:
110110
if (
111111
dt1 in [cupy.float64, cupy.complex128]
112112
or dt2 in [cupy.float64, cupy.complex128]
113-
and not has_support_aspect64()
114-
):
113+
) and not has_support_aspect64():
115114
pytest.skip("No fp64 support by device.")
116115

117116
ret = xp.result_type(input1, input2)
118117

118+
# dpnp takes into account device capabilities if one of the inputs
119+
# is an array, for such a case, we have to modify the results for
120+
# NumPy to align it with device capabilities.
121+
if (flag1 or flag2) and xp == numpy and not has_support_aspect64():
122+
ret = numpy.dtype(numpy.float32) if ret == numpy.float64 else ret
123+
ret = (
124+
numpy.dtype(numpy.complex64) if ret == numpy.complex128 else ret
125+
)
126+
119127
assert isinstance(ret, numpy.dtype)
120128
return ret

0 commit comments

Comments
 (0)