Skip to content

Commit e24fa99

Browse files
authored
improve covergae dpnp_iface_manipulation.py (#2326)
After merging #2230, coverage for `dpnp.iface_manipulation.py` was decreased related to `dpnp.copyto` function. This PR adds new tests to improve the coverage and updates the logic used in `dpnp.copyto` to pass the new tests.
1 parent af0fc93 commit e24fa99

File tree

2 files changed

+99
-52
lines changed

2 files changed

+99
-52
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,11 +1518,10 @@ def copyto(dst, src, casting="same_kind", where=True):
15181518
f"but got {type(dst)}"
15191519
)
15201520
if not dpnp.is_supported_array_type(src):
1521-
no_dtype_attr = not hasattr(src, "dtype")
1521+
python_sc = dpnp.isscalar(src) and not isinstance(src, numpy.generic)
15221522
src = dpnp.array(src, sycl_queue=dst.sycl_queue)
1523-
if no_dtype_attr:
1524-
# This case (scalar, list, etc) needs special handling to
1525-
# behave similar to NumPy
1523+
if python_sc:
1524+
# Python scalar needs special handling to behave similar to NumPy
15261525
if dpnp.issubdtype(src, dpnp.integer) and dpnp.issubdtype(
15271526
dst, dpnp.unsignedinteger
15281527
):

dpnp/tests/test_manipulation.py

Lines changed: 96 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,6 @@
2525
)
2626
from .third_party.cupy import testing
2727

28-
testdata = []
29-
testdata += [
30-
([True, False, True], dtype)
31-
for dtype in get_all_dtypes(no_none=True, no_complex=True)
32-
]
33-
testdata += [
34-
([1, -1, 0], dtype)
35-
for dtype in get_all_dtypes(
36-
no_none=True, no_bool=True, no_complex=True, no_unsigned=True
37-
)
38-
]
39-
testdata += [([0.1, 0.0, -0.1], dtype) for dtype in get_float_dtypes()]
40-
testdata += [([1j, -1j, 1 - 2j], dtype) for dtype in get_complex_dtypes()]
41-
4228

4329
def _compare_results(result, expected):
4430
"""Compare lists of arrays."""
@@ -49,40 +35,6 @@ def _compare_results(result, expected):
4935
assert_array_equal(x, y)
5036

5137

52-
@pytest.mark.parametrize("in_obj, out_dtype", testdata)
53-
def test_copyto_dtype(in_obj, out_dtype):
54-
ndarr = numpy.array(in_obj)
55-
expected = numpy.empty(ndarr.size, dtype=out_dtype)
56-
numpy.copyto(expected, ndarr)
57-
58-
dparr = dpnp.array(in_obj)
59-
result = dpnp.empty(dparr.size, dtype=out_dtype)
60-
dpnp.copyto(result, dparr)
61-
62-
assert_array_equal(result, expected)
63-
64-
65-
@pytest.mark.parametrize("dst", [7, numpy.ones(10), (2, 7), [5], range(3)])
66-
def test_copyto_dst_raises(dst):
67-
a = dpnp.array(4)
68-
with pytest.raises(
69-
TypeError,
70-
match="Destination array must be any of supported type, but got",
71-
):
72-
dpnp.copyto(dst, a)
73-
74-
75-
@pytest.mark.parametrize("where", [numpy.ones(10), (2, 7), [5], range(3)])
76-
def test_copyto_where_raises(where):
77-
a = dpnp.empty((2, 3))
78-
b = dpnp.arange(6).reshape((2, 3))
79-
80-
with pytest.raises(
81-
TypeError, match="`where` array must be any of supported type, but got"
82-
):
83-
dpnp.copyto(a, b, where=where)
84-
85-
8638
def test_result_type():
8739
X = [dpnp.ones((2), dtype=dpnp.int64), dpnp.int32, "float32"]
8840
X_np = [numpy.ones((2), dtype=numpy.int64), numpy.int32, "float32"]
@@ -365,6 +317,102 @@ def test_broadcast_shapes(self, shape):
365317
assert_equal(result, expected)
366318

367319

320+
class TestCopyTo:
321+
testdata = []
322+
testdata += [
323+
([True, False, True], dtype)
324+
for dtype in get_all_dtypes(no_none=True, no_complex=True)
325+
]
326+
testdata += [
327+
([1, -1, 0], dtype)
328+
for dtype in get_all_dtypes(
329+
no_none=True, no_bool=True, no_complex=True, no_unsigned=True
330+
)
331+
]
332+
testdata += [([0.1, 0.0, -0.1], dtype) for dtype in get_float_dtypes()]
333+
testdata += [([1j, -1j, 1 - 2j], dtype) for dtype in get_complex_dtypes()]
334+
335+
@pytest.mark.parametrize("data, dt_out", testdata)
336+
def test_dtype(self, data, dt_out):
337+
a = numpy.array(data)
338+
ia = dpnp.array(a)
339+
340+
expected = numpy.empty(a.size, dtype=dt_out)
341+
result = dpnp.empty(ia.size, dtype=dt_out)
342+
numpy.copyto(expected, a)
343+
dpnp.copyto(result, ia)
344+
345+
assert_array_equal(result, expected)
346+
347+
@pytest.mark.parametrize("data, dt_out", testdata)
348+
def test_dtype_input_list(self, data, dt_out):
349+
expected = numpy.empty(3, dtype=dt_out)
350+
result = dpnp.empty(3, dtype=dt_out)
351+
assert isinstance(data, list)
352+
numpy.copyto(expected, data)
353+
dpnp.copyto(result, data)
354+
355+
assert_array_equal(result, expected)
356+
357+
@pytest.mark.parametrize("xp", [dpnp, numpy])
358+
@pytest.mark.parametrize(
359+
"data", [(1, 2, -3), [1, 2, -3]], ids=["tuple", "list"]
360+
)
361+
@pytest.mark.parametrize(
362+
"dst_dt", [dpnp.uint8, dpnp.uint16, dpnp.uint32, dpnp.uint64]
363+
)
364+
def test_casting_error(self, xp, data, dst_dt):
365+
# cannot cast to unsigned integer
366+
dst = xp.empty(3, dtype=dst_dt)
367+
assert_raises(TypeError, xp.copyto, dst, data)
368+
369+
@pytest.mark.parametrize(
370+
"dt_out", [dpnp.uint8, dpnp.uint16, dpnp.uint32, dpnp.uint64]
371+
)
372+
def test_positive_python_scalar(self, dt_out):
373+
# src is python scalar and positive
374+
expected = numpy.empty(1, dtype=dt_out)
375+
result = dpnp.array(expected)
376+
numpy.copyto(expected, 5)
377+
dpnp.copyto(result, 5)
378+
379+
assert_array_equal(result, expected)
380+
381+
@testing.with_requires("numpy>=2.1")
382+
@pytest.mark.parametrize("xp", [dpnp, numpy])
383+
@pytest.mark.parametrize(
384+
"dst_dt", [dpnp.uint8, dpnp.uint16, dpnp.uint32, dpnp.uint64]
385+
)
386+
def test_numpy_scalar(self, xp, dst_dt):
387+
dst = xp.empty(1, dtype=dst_dt)
388+
# cannot cast from signed int to unsigned int, src is numpy scalar
389+
assert_raises(TypeError, xp.copyto, dst, numpy.int32(5))
390+
assert_raises(TypeError, xp.copyto, dst, numpy.int32(-5))
391+
392+
# Python integer -5 out of bounds, src is python scalar and negative
393+
assert_raises(OverflowError, xp.copyto, dst, -5)
394+
395+
@pytest.mark.parametrize("dst", [7, numpy.ones(10), (2, 7), [5], range(3)])
396+
def test_dst_raises(self, dst):
397+
a = dpnp.array(4)
398+
with pytest.raises(
399+
TypeError,
400+
match="Destination array must be any of supported type, but got",
401+
):
402+
dpnp.copyto(dst, a)
403+
404+
@pytest.mark.parametrize("where", [numpy.ones(10), (2, 7), [5], range(3)])
405+
def test_where_raises(self, where):
406+
a = dpnp.empty((2, 3))
407+
b = dpnp.arange(6).reshape((2, 3))
408+
409+
with pytest.raises(
410+
TypeError,
411+
match="`where` array must be any of supported type, but got",
412+
):
413+
dpnp.copyto(a, b, where=where)
414+
415+
368416
class TestDelete:
369417
@pytest.mark.parametrize(
370418
"obj", [slice(0, 4, 2), 3, [2, 3]], ids=["slice", "int", "list"]

0 commit comments

Comments
 (0)