Skip to content

Commit 301858b

Browse files
committed
improve covergae dpnp_iface_manipulation.py
1 parent b09533e commit 301858b

File tree

2 files changed

+82
-52
lines changed

2 files changed

+82
-52
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,11 +1494,10 @@ def copyto(dst, src, casting="same_kind", where=True):
14941494
f"but got {type(dst)}"
14951495
)
14961496
if not dpnp.is_supported_array_type(src):
1497-
no_dtype_attr = not hasattr(src, "dtype")
1497+
src_is_scalar = dpnp.isscalar(src)
14981498
src = dpnp.array(src, sycl_queue=dst.sycl_queue)
1499-
if no_dtype_attr:
1500-
# This case (scalar, list, etc) needs special handling to
1501-
# behave similar to NumPy
1499+
if src_is_scalar:
1500+
# scalar needs special handling to behave similar to NumPy
15021501
if dpnp.issubdtype(src, dpnp.integer) and dpnp.issubdtype(
15031502
dst, dpnp.unsignedinteger
15041503
):

dpnp/tests/test_manipulation.py

Lines changed: 79 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,6 @@
2424
)
2525
from .third_party.cupy import testing
2626

27-
testdata = []
28-
testdata += [
29-
([True, False, True], dtype)
30-
for dtype in get_all_dtypes(no_none=True, no_complex=True)
31-
]
32-
testdata += [
33-
([1, -1, 0], dtype)
34-
for dtype in get_all_dtypes(
35-
no_none=True, no_bool=True, no_complex=True, no_unsigned=True
36-
)
37-
]
38-
testdata += [([0.1, 0.0, -0.1], dtype) for dtype in get_float_dtypes()]
39-
testdata += [([1j, -1j, 1 - 2j], dtype) for dtype in get_complex_dtypes()]
40-
4127

4228
def _compare_results(result, expected):
4329
"""Compare lists of arrays."""
@@ -48,40 +34,6 @@ def _compare_results(result, expected):
4834
assert_array_equal(x, y)
4935

5036

51-
@pytest.mark.parametrize("in_obj, out_dtype", testdata)
52-
def test_copyto_dtype(in_obj, out_dtype):
53-
ndarr = numpy.array(in_obj)
54-
expected = numpy.empty(ndarr.size, dtype=out_dtype)
55-
numpy.copyto(expected, ndarr)
56-
57-
dparr = dpnp.array(in_obj)
58-
result = dpnp.empty(dparr.size, dtype=out_dtype)
59-
dpnp.copyto(result, dparr)
60-
61-
assert_array_equal(result, expected)
62-
63-
64-
@pytest.mark.parametrize("dst", [7, numpy.ones(10), (2, 7), [5], range(3)])
65-
def test_copyto_dst_raises(dst):
66-
a = dpnp.array(4)
67-
with pytest.raises(
68-
TypeError,
69-
match="Destination array must be any of supported type, but got",
70-
):
71-
dpnp.copyto(dst, a)
72-
73-
74-
@pytest.mark.parametrize("where", [numpy.ones(10), (2, 7), [5], range(3)])
75-
def test_copyto_where_raises(where):
76-
a = dpnp.empty((2, 3))
77-
b = dpnp.arange(6).reshape((2, 3))
78-
79-
with pytest.raises(
80-
TypeError, match="`where` array must be any of supported type, but got"
81-
):
82-
dpnp.copyto(a, b, where=where)
83-
84-
8537
def test_result_type():
8638
X = [dpnp.ones((2), dtype=dpnp.int64), dpnp.int32, "float32"]
8739
X_np = [numpy.ones((2), dtype=numpy.int64), numpy.int32, "float32"]
@@ -364,6 +316,85 @@ def test_broadcast_shapes(self, shape):
364316
assert_equal(result, expected)
365317

366318

319+
class TestCopyTo:
320+
testdata = []
321+
testdata += [
322+
([True, False, True], dtype)
323+
for dtype in get_all_dtypes(no_none=True, no_complex=True)
324+
]
325+
testdata += [
326+
([1, -1, 0], dtype)
327+
for dtype in get_all_dtypes(
328+
no_none=True, no_bool=True, no_complex=True, no_unsigned=True
329+
)
330+
]
331+
testdata += [([0.1, 0.0, -0.1], dtype) for dtype in get_float_dtypes()]
332+
testdata += [([1j, -1j, 1 - 2j], dtype) for dtype in get_complex_dtypes()]
333+
334+
@pytest.mark.parametrize("data, dt_out", testdata)
335+
def test_dtype(self, data, dt_out):
336+
a = numpy.array(data)
337+
ia = dpnp.array(a)
338+
339+
expected = numpy.empty(a.size, dtype=dt_out)
340+
result = dpnp.empty(ia.size, dtype=dt_out)
341+
numpy.copyto(expected, a)
342+
dpnp.copyto(result, ia)
343+
344+
assert_array_equal(result, expected)
345+
346+
@pytest.mark.parametrize("data, dt_out", testdata)
347+
def test_dtype_input_list(self, data, dt_out):
348+
expected = numpy.empty(3, dtype=dt_out)
349+
result = dpnp.empty(3, dtype=dt_out)
350+
assert isinstance(data, list)
351+
numpy.copyto(expected, data)
352+
dpnp.copyto(result, data)
353+
354+
assert_array_equal(result, expected)
355+
356+
@pytest.mark.parametrize("xp", [dpnp, numpy])
357+
@pytest.mark.parametrize(
358+
"data", [(1, 2, -3), [1, 2, -3]], ids=["tuple", "list"]
359+
)
360+
@pytest.mark.parametrize(
361+
"dst_dt", [dpnp.uint8, dpnp.uint16, dpnp.uint32, dpnp.uint64]
362+
)
363+
def test_casting_error(self, xp, data, dst_dt):
364+
# cannot cast to unsigned integer
365+
dst = xp.empty(3, dtype=dst_dt)
366+
assert_raises(TypeError, xp.copyto, dst, data)
367+
368+
@pytest.mark.parametrize("xp", [dpnp, numpy])
369+
@pytest.mark.parametrize(
370+
"dst_dt", [dpnp.uint8, dpnp.uint16, dpnp.uint32, dpnp.uint64]
371+
)
372+
def test_scalar_error(self, xp, dst_dt):
373+
# cannot cast to unsigned integer, input is scalar
374+
dst = xp.empty(1, dtype=dst_dt)
375+
assert_raises(OverflowError, xp.copyto, dst, -5)
376+
377+
@pytest.mark.parametrize("dst", [7, numpy.ones(10), (2, 7), [5], range(3)])
378+
def test_dst_raises(self, dst):
379+
a = dpnp.array(4)
380+
with pytest.raises(
381+
TypeError,
382+
match="Destination array must be any of supported type, but got",
383+
):
384+
dpnp.copyto(dst, a)
385+
386+
@pytest.mark.parametrize("where", [numpy.ones(10), (2, 7), [5], range(3)])
387+
def test_where_raises(self, where):
388+
a = dpnp.empty((2, 3))
389+
b = dpnp.arange(6).reshape((2, 3))
390+
391+
with pytest.raises(
392+
TypeError,
393+
match="`where` array must be any of supported type, but got",
394+
):
395+
dpnp.copyto(a, b, where=where)
396+
397+
367398
class TestDelete:
368399
@pytest.mark.parametrize(
369400
"obj", [slice(0, 4, 2), 3, [2, 3]], ids=["slice", "int", "list"]

0 commit comments

Comments
 (0)