Skip to content

Commit c614af9

Browse files
committed
update dpnp.result_type tests
1 parent 7c45c10 commit c614af9

File tree

4 files changed

+32
-32
lines changed

4 files changed

+32
-32
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ env:
6363
third_party/cupy/statistics_tests/test_histogram.py
6464
third_party/cupy/statistics_tests/test_meanvar.py
6565
third_party/cupy/test_ndim.py
66+
third_party/cupy/test_type_routines.py
6667
VER_JSON_NAME: 'version.json'
6768
VER_SCRIPT1: "import json; f = open('version.json', 'r'); j = json.load(f); f.close(); "
6869
VER_SCRIPT2: "d = j['dpnp'][0]; print('='.join((d[s] for s in ('version', 'build'))))"

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2388,6 +2388,7 @@ def result_type(*arrays_and_dtypes):
23882388
----------
23892389
arrays_and_dtypes : list of {dpnp.ndarray, usm_ndarray, dtype}
23902390
An arbitrary length sequence of arrays or dtypes.
2391+
There should be at least one array in the input.
23912392
23922393
Returns
23932394
-------
@@ -2402,9 +2403,6 @@ def result_type(*arrays_and_dtypes):
24022403
>>> np.result_type(a, b)
24032404
dtype('int64')
24042405
2405-
>>> np.result_type(np.int64, np.complex128)
2406-
dtype('complex128')
2407-
24082406
>>> np.result_type(np.ones(10, dtype=np.float32), np.float64)
24092407
dtype('float64')
24102408

tests/test_manipulation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,8 @@ def test_result_type():
8787

8888
def test_result_type_only_dtypes():
8989
X = [dpnp.int64, dpnp.int32, dpnp.bool, dpnp.float32]
90-
X_np = [numpy.int64, numpy.int32, numpy.bool_, numpy.float32]
91-
92-
assert dpnp.result_type(*X) == numpy.result_type(*X_np)
90+
with pytest.raises(ValueError):
91+
dpnp.result_type(*X)
9392

9493

9594
def test_result_type_only_arrays():

tests/third_party/cupy/test_type_routines.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@ class TestCanCast(unittest.TestCase):
3434
@testing.for_all_dtypes_combination(names=("from_dtype", "to_dtype"))
3535
@testing.numpy_cupy_equal()
3636
def test_can_cast(self, xp, from_dtype, to_dtype):
37-
if self.obj_type == "scalar":
37+
if (
38+
self.obj_type == "scalar"
39+
and numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0"
40+
):
3841
pytest.skip("to be aligned with NEP-50")
3942

4043
from_obj = _generate_type_routines_input(xp, from_dtype, self.obj_type)
41-
4244
ret = xp.can_cast(from_obj, to_dtype)
4345
assert isinstance(ret, bool)
4446
return ret
@@ -83,7 +85,9 @@ def test_common_type_bool(self, dtype):
8385
@testing.parameterize(
8486
*testing.product(
8587
{
86-
"obj_type1": ["dtype", "specifier", "scalar", "array", "primitive"],
88+
# obj_type1 is modified since at least one input should be an
89+
# array for dpnp.result_dtype
90+
"obj_type1": ["array"],
8791
"obj_type2": ["dtype", "specifier", "scalar", "array", "primitive"],
8892
}
8993
)
@@ -92,37 +96,35 @@ class TestResultType(unittest.TestCase):
9296
@testing.for_all_dtypes_combination(names=("dtype1", "dtype2"))
9397
@testing.numpy_cupy_equal()
9498
def test_result_type(self, xp, dtype1, dtype2):
95-
if "scalar" in {self.obj_type1, self.obj_type2}:
99+
if (
100+
self.obj_type2 == "scalar"
101+
and numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0"
102+
):
96103
pytest.skip("to be aligned with NEP-50")
97104

98105
input1 = _generate_type_routines_input(xp, dtype1, self.obj_type1)
99-
100106
input2 = _generate_type_routines_input(xp, dtype2, self.obj_type2)
101107

102-
flag1 = isinstance(input1, (numpy.ndarray, cupy.ndarray))
103-
flag2 = isinstance(input2, (numpy.ndarray, cupy.ndarray))
104-
dt1 = cupy.dtype(input1) if not flag1 else None
105-
dt2 = cupy.dtype(input2) if not flag2 else None
106-
# dpnp takes into account device capabilities only if one of the
107-
# inputs is an array, for such a case, if the other dtype is not
108-
# supported by device, dpnp raise ValueError. So, we skip the test.
109-
if flag1 or flag2:
110-
if (
111-
dt1 in [cupy.float64, cupy.complex128]
112-
or dt2 in [cupy.float64, cupy.complex128]
113-
) and not has_support_aspect64():
114-
pytest.skip("No fp64 support by device.")
108+
# dpnp.result_type only takes into account device capabilities. If
109+
# dtype2 is `float32` and the object is primitive, the `input2` variable
110+
# is `float` which needs a device with double precision support.
111+
# so we skip the test for such a case on a device that does not support fp64
112+
flag = self.obj_type2 == "primitive" and input2 == float
113+
if flag and not has_support_aspect64():
114+
pytest.skip("No fp64 support by device.")
115115

116116
ret = xp.result_type(input1, input2)
117117

118-
# dpnp takes into account device capabilities if one of the inputs
119-
# is an array, for such a case, we have to modify the results for
120-
# NumPy to align it with device capabilities.
121-
if (flag1 or flag2) and xp == numpy and not has_support_aspect64():
122-
ret = numpy.dtype(numpy.float32) if ret == numpy.float64 else ret
123-
ret = (
124-
numpy.dtype(numpy.complex64) if ret == numpy.complex128 else ret
125-
)
118+
# dpnp.result_type takes into account device capabilities.
119+
# So, we have to modify the results for NumPy to align it with
120+
# device capabilities.
121+
flag1 = isinstance(input1, numpy.ndarray)
122+
flag2 = isinstance(input2, numpy.ndarray)
123+
if (flag1 or flag2) and not has_support_aspect64():
124+
if ret == numpy.float64:
125+
ret = numpy.dtype(numpy.float32)
126+
elif ret == numpy.complex128:
127+
ret = numpy.dtype(numpy.complex64)
126128

127129
assert isinstance(ret, numpy.dtype)
128130
return ret

0 commit comments

Comments
 (0)