Skip to content

Commit bab4345

Browse files
committed
address comments
1 parent df15308 commit bab4345

File tree

4 files changed

+42
-61
lines changed

4 files changed

+42
-61
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040

4141
import math
42+
import operator
4243

4344
import dpctl.tensor as dpt
4445
import numpy
@@ -1524,6 +1525,7 @@ def fliplr(m):
15241525
--------
15251526
:obj:`dpnp.flipud` : Flip an array vertically (axis=0).
15261527
:obj:`dpnp.flip` : Flip array in one or more dimensions.
1528+
:obj:`dpnp.rot90` : Rotate array counterclockwise.
15271529
15281530
Examples
15291531
--------
@@ -1574,6 +1576,7 @@ def flipud(m):
15741576
--------
15751577
:obj:`dpnp.fliplr` : Flip array in the left/right direction.
15761578
:obj:`dpnp.flip` : Flip array in one or more dimensions.
1579+
:obj:`dpnp.rot90` : Rotate array counterclockwise.
15771580
15781581
Examples
15791582
--------
@@ -2056,12 +2059,12 @@ def resize(a, new_shape):
20562059
-------
20572060
out : dpnp.ndarray
20582061
The new array is formed from the data in the old array, repeated
2059-
if necessary to fill out the required number of elements. The
2062+
if necessary to fill out the required number of elements. The
20602063
data are repeated iterating over the array in C-order.
20612064
20622065
See Also
20632066
--------
2064-
:obj:`dpnp.ndarray.reshape` : Resize an array in-place.
2067+
:obj:`dpnp.ndarray.resize` : Resize an array in-place.
20652068
:obj:`dpnp.reshape` : Reshape an array without changing the total size.
20662069
:obj:`dpnp.pad` : Enlarge and pad an array.
20672070
:obj:`dpnp.repeat` : Repeat elements of an array.
@@ -2083,8 +2086,8 @@ def resize(a, new_shape):
20832086
Examples
20842087
--------
20852088
>>> import dpnp as np
2086-
>>> a=np.array([[0, 1], [2, 3]])
2087-
>>> np.resize(a, (2 ,3))
2089+
>>> a = np.array([[0, 1], [2, 3]])
2090+
>>> np.resize(a, (2, 3))
20882091
array([[0, 1, 2],
20892092
[3, 0, 1]])
20902093
>>> np.resize(a, (1, 4))
@@ -2097,24 +2100,24 @@ def resize(a, new_shape):
20972100

20982101
dpnp.check_supported_arrays_type(a)
20992102
if a.ndim == 0:
2100-
return dpnp.full(new_shape, a)
2103+
return dpnp.full_like(a, a, shape=new_shape)
21012104

21022105
if isinstance(new_shape, (int, numpy.integer)):
21032106
new_shape = (new_shape,)
21042107

2105-
a = dpnp.ravel(a)
21062108
new_size = 1
21072109
for dim_length in new_shape:
21082110
if dim_length < 0:
21092111
raise ValueError("all elements of `new_shape` must be non-negative")
21102112
new_size *= dim_length
21112113

2112-
if a.size == 0 or new_size == 0:
2114+
a_size = a.size
2115+
if a_size == 0 or new_size == 0:
21132116
# First case must zero fill. The second would have repeats == 0.
21142117
return dpnp.zeros_like(a, shape=new_shape)
21152118

2116-
repeats = -(-new_size // a.size) # ceil division
2117-
a = dpnp.concatenate((a,) * repeats)[:new_size]
2119+
repeats = -(-new_size // a_size) # ceil division
2120+
a = dpnp.concatenate((dpnp.ravel(a),) * repeats)[:new_size]
21182121

21192122
return a.reshape(new_shape)
21202123

@@ -2324,7 +2327,7 @@ def rot90(m, k=1, axes=(0, 1)):
23242327
23252328
Notes
23262329
-----
2327-
``rot90(m, k=1, axes=(1,0))`` is the reverse of
2330+
``rot90(m, k=1, axes=(1,0))`` is the reverse of
23282331
``rot90(m, k=1, axes=(0,1))``.
23292332
23302333
``rot90(m, k=1, axes=(1,0))`` is equivalent to
@@ -2353,8 +2356,7 @@ def rot90(m, k=1, axes=(0, 1)):
23532356
"""
23542357

23552358
dpnp.check_supported_arrays_type(m)
2356-
if not isinstance(k, (int, dpnp.integer)):
2357-
raise TypeError("k must be an integer.")
2359+
k = operator.index(k)
23582360

23592361
m_ndim = m.ndim
23602362
if m_ndim < 2:
@@ -2384,7 +2386,7 @@ def rot90(m, k=1, axes=(0, 1)):
23842386
)
23852387

23862388
if k == 1:
2387-
return dpnp.transpose(flip(m, axes[1]), axes_list)
2389+
return dpnp.transpose(dpnp.flip(m, axes[1]), axes_list)
23882390

23892391
# k == 3
23902392
return dpnp.flip(dpnp.transpose(m, axes_list), axes[1])

tests/test_manipulation.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -666,37 +666,37 @@ def test_minimum_signed_integers(self, data, dtype):
666666

667667

668668
class TestResize:
669-
def test_copies(self):
670-
a = numpy.array([[1, 2], [3, 4]])
671-
ia = dpnp.array(a)
672-
assert_equal(dpnp.resize(ia, (2, 4)), numpy.resize(a, (2, 4)))
673-
674-
a = numpy.array([[1, 2], [3, 4], [1, 2], [3, 4]])
675-
ia = dpnp.array(a)
676-
assert_equal(dpnp.resize(ia, (4, 2)), numpy.resize(a, (4, 2)))
677-
678-
a = numpy.array([[1, 2, 3], [4, 1, 2], [3, 4, 1], [2, 3, 4]])
669+
@pytest.mark.parametrize(
670+
"data, shape",
671+
[
672+
pytest.param([[1, 2], [3, 4]], (2, 4)),
673+
pytest.param([[1, 2], [3, 4], [1, 2], [3, 4]], (4, 2)),
674+
pytest.param([[1, 2, 3], [4, 1, 2], [3, 4, 1], [2, 3, 4]], (4, 3)),
675+
],
676+
)
677+
def test_copies(self, data, shape):
678+
a = numpy.array(data)
679679
ia = dpnp.array(a)
680-
assert_equal(dpnp.resize(ia, (4, 3)), numpy.resize(a, (4, 3)))
680+
assert_equal(dpnp.resize(ia, shape), numpy.resize(a, shape))
681681

682682
@pytest.mark.parametrize("newshape", [(2, 4), [2, 4], (10,), 10])
683683
def test_newshape_type(self, newshape):
684684
a = numpy.array([[1, 2], [3, 4]])
685685
ia = dpnp.array(a)
686686
assert_equal(dpnp.resize(ia, newshape), numpy.resize(a, newshape))
687687

688-
def test_repeats(self):
689-
a = numpy.array([1, 2, 3])
690-
ia = dpnp.array(a)
691-
assert_equal(dpnp.resize(ia, (2, 4)), numpy.resize(a, (2, 4)))
692-
693-
a = numpy.array([[1, 2], [3, 1], [2, 3], [1, 2]])
694-
ia = dpnp.array(a)
695-
assert_equal(dpnp.resize(ia, (4, 2)), numpy.resize(a, (4, 2)))
696-
697-
a = numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]])
688+
@pytest.mark.parametrize(
689+
"data, shape",
690+
[
691+
pytest.param([1, 2, 3], (2, 4)),
692+
pytest.param([[1, 2], [3, 1], [2, 3], [1, 2]], (4, 2)),
693+
pytest.param([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]], (4, 3)),
694+
],
695+
)
696+
def test_repeats(self, data, shape):
697+
a = numpy.array(data)
698698
ia = dpnp.array(a)
699-
assert_equal(dpnp.resize(ia, (4, 3)), numpy.resize(a, (4, 3)))
699+
assert_equal(dpnp.resize(ia, shape), numpy.resize(a, shape))
700700

701701
def test_zeroresize(self):
702702
a = numpy.array([[1, 2], [3, 4]])
@@ -729,9 +729,9 @@ def test_error(self, xp):
729729
assert_raises(ValueError, xp.rot90, xp.ones((2, 2)), axes=(0, 2))
730730
assert_raises(ValueError, xp.rot90, xp.ones((2, 2)), axes=(1, 1))
731731
assert_raises(ValueError, xp.rot90, xp.ones((2, 2, 2)), axes=(-2, 1))
732-
if xp == dpnp:
733-
# NumPy return result of k=3 incorrectly when k is float
734-
assert_raises(TypeError, xp.rot90, xp.ones((2, 2)), k=2.5)
732+
733+
def test_error_float_k(self):
734+
assert_raises(TypeError, dpnp.rot90, dpnp.ones((2, 2)), k=2.5)
735735

736736
def test_basic(self):
737737
a = numpy.array([[0, 1, 2], [3, 4, 5]])

tests/test_sycl_queue.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ def test_meshgrid(device):
495495
),
496496
pytest.param("real_if_close", [2.1 + 4e-15j, 5.2 + 3e-16j]),
497497
pytest.param("reciprocal", [1.0, 2.0, 4.0, 7.0]),
498+
pytest.param("rot90", [[1, 2], [3, 4]]),
498499
pytest.param("sign", [-5.0, 0.0, 4.5]),
499500
pytest.param("signbit", [-5.0, 0.0, 4.5]),
500501
pytest.param(
@@ -1298,20 +1299,6 @@ def test_resize(device):
12981299
assert_sycl_queue_equal(result_queue, expected_queue)
12991300

13001301

1301-
@pytest.mark.parametrize(
1302-
"device",
1303-
valid_devices,
1304-
ids=[device.filter_string for device in valid_devices],
1305-
)
1306-
def test_rot90(device):
1307-
dpnp_data = dpnp.array([[1, 2], [3, 4]], device=device)
1308-
result = dpnp.rot90(dpnp_data)
1309-
1310-
expected_queue = dpnp_data.sycl_queue
1311-
result_queue = result.sycl_queue
1312-
assert_sycl_queue_equal(result_queue, expected_queue)
1313-
1314-
13151302
class TestFft:
13161303
@pytest.mark.parametrize(
13171304
"func", ["fft", "ifft", "rfft", "irfft", "hfft", "ihfft"]

tests/test_usm_type.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ def test_norm(usm_type, ord, axis):
622622
pytest.param("real_if_close", [2.1 + 4e-15j, 5.2 + 3e-16j]),
623623
pytest.param("reciprocal", [1.0, 2.0, 4.0, 7.0]),
624624
pytest.param("reduce_hypot", [1.0, 2.0, 4.0, 7.0]),
625+
pytest.param("rot90", [[1, 2], [3, 4]]),
625626
pytest.param("rsqrt", [1, 8, 27]),
626627
pytest.param("sign", [-5.0, 0.0, 4.5]),
627628
pytest.param("signbit", [-5.0, 0.0, 4.5]),
@@ -1021,15 +1022,6 @@ def test_resize(usm_type):
10211022
assert result.usm_type == usm_type
10221023

10231024

1024-
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
1025-
def test_rot90(usm_type):
1026-
dpnp_data = dp.array([[1, 2], [3, 4]], usm_type=usm_type)
1027-
result = dp.rot90(dpnp_data)
1028-
1029-
assert dpnp_data.usm_type == usm_type
1030-
assert result.usm_type == usm_type
1031-
1032-
10331025
class TestFft:
10341026
@pytest.mark.parametrize(
10351027
"func", ["fft", "ifft", "rfft", "irfft", "hfft", "ihfft"]

0 commit comments

Comments
 (0)