Skip to content

Commit ac07299

Browse files
committed
Input array should be dpnp.ndarray or usm_ndarray
1 parent 2346fbf commit ac07299

File tree

4 files changed

+47
-10
lines changed

4 files changed

+47
-10
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1960,8 +1960,9 @@ def require(a, dtype=None, requirements=None, *, like=None):
19601960
19611961
Parameters
19621962
----------
1963-
a : array_like
1964-
The object to be converted to a type-and-requirement-satisfying array.
1963+
a : {dpnp.ndarray, usm_ndarray}
1964+
The input array to be converted to a type-and-requirement-satisfying
1965+
array.
19651966
dtype : {None, data-type}, optional
19661967
The required data-type. If ``None`` preserve the current dtype.
19671968
requirements : {None, str, sequence of str}, optional
@@ -2015,6 +2016,7 @@ def require(a, dtype=None, requirements=None, *, like=None):
20152016
"""
20162017

20172018
dpnp.check_limitations(like=like)
2019+
dpnp.check_supported_arrays_type(a)
20182020

20192021
possible_flags = {
20202022
"C": "C",

tests/test_manipulation.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -701,24 +701,26 @@ def test_require_each(self):
701701
a = self.generate_all_false(idtype)
702702
self.set_and_check_flag(flag, fdtype, a)
703703

704-
@pytest.mark.parametrize("xp", [numpy, dpnp])
705-
def test_unknown_requirement(self, xp):
704+
def test_unknown_requirement(self):
706705
a = self.generate_all_false("f4")
707-
assert_raises((KeyError, ValueError), xp.require, a, None, "Q")
706+
assert_raises(KeyError, numpy.require, a[0], None, "Q")
707+
assert_raises(ValueError, dpnp.require, a[1], None, "Q")
708708

709709
def test_non_array_input(self):
710-
expected = numpy.require([1, 2, 3, 4], "i4", ["C", "W"])
711-
result = dpnp.require([1, 2, 3, 4], "i4", ["C", "W"])
710+
a_np = numpy.array([1, 2, 3, 4])
711+
a_dp = dpnp.array(a_np)
712+
expected = numpy.require(a_np, "i4", ["C", "W"])
713+
result = dpnp.require(a_dp, "i4", ["C", "W"])
712714
assert expected.flags["C"] == result.flags["C"]
713715
assert expected.flags["F"] == result.flags["F"]
714716
assert expected.flags["W"] == result.flags["W"]
715717
assert expected.dtype == result.dtype
716718
assert_array_equal(expected, result)
717719

718-
@pytest.mark.parametrize("xp", [numpy, dpnp])
719-
def test_C_and_F_simul(self, xp):
720+
def test_C_and_F_simul(self):
720721
a = self.generate_all_false("f4")
721-
assert_raises(ValueError, xp.require, a, None, ["C", "F"])
722+
assert_raises(ValueError, numpy.require, a[0], None, ["C", "F"])
723+
assert_raises(ValueError, dpnp.require, a[1], None, ["C", "F"])
722724

723725
def test_copy(self):
724726
a_np = numpy.arange(6).reshape(2, 3)

tests/test_sycl_queue.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,26 @@ def test_out_multi_dot(device):
12851285
assert_sycl_queue_equal(result.sycl_queue, exec_q)
12861286

12871287

1288+
@pytest.mark.parametrize(
1289+
"device",
1290+
valid_devices,
1291+
ids=[device.filter_string for device in valid_devices],
1292+
)
1293+
def test_require(device):
1294+
dpnp_data = dpnp.arange(10, device=device).reshape(2, 5)
1295+
result = dpnp.require(dpnp_data, dtype="f4", requirements=["F"])
1296+
1297+
expected_queue = dpnp_data.sycl_queue
1298+
result_queue = result.sycl_queue
1299+
assert_sycl_queue_equal(result_queue, expected_queue)
1300+
1301+
# No requirements
1302+
result = dpnp.require(dpnp_data, dtype="f4")
1303+
expected_queue = dpnp_data.sycl_queue
1304+
result_queue = result.sycl_queue
1305+
assert_sycl_queue_equal(result_queue, expected_queue)
1306+
1307+
12881308
@pytest.mark.parametrize(
12891309
"device",
12901310
valid_devices,

tests/test_usm_type.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,19 @@ def test_eigenvalue(func, shape, usm_type):
10131013
assert a.usm_type == dp_val.usm_type
10141014

10151015

1016+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
1017+
def test_require(usm_type):
1018+
dpnp_data = dp.arange(10, usm_type=usm_type).reshape(2, 5)
1019+
result = dp.require(dpnp_data, dtype="f4", requirements=["F"])
1020+
assert dpnp_data.usm_type == usm_type
1021+
assert result.usm_type == usm_type
1022+
1023+
# No requirements
1024+
result = dp.require(dpnp_data, dtype="f4")
1025+
assert dpnp_data.usm_type == usm_type
1026+
assert result.usm_type == usm_type
1027+
1028+
10161029
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
10171030
def test_resize(usm_type):
10181031
dpnp_data = dp.arange(10, usm_type=usm_type)

0 commit comments

Comments
 (0)