Skip to content

Commit 8cab1af

Browse files
authored
implementation of dpnp.require (#2036)
* add dpnp.require * improve coverage * address comments * fix pre-commit * Input array should be dpnp.ndarray or usm_ndarray * update CHANGELOG.md
1 parent 35419ce commit 8cab1af

File tree

7 files changed

+239
-23
lines changed

7 files changed

+239
-23
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ In addition, this release completes implementation of `dpnp.fft` module and adds
4848
* Added runtime dependency on `intel-gpu-ocl-icd-system` package [#2023](https://github.com/IntelPython/dpnp/pull/2023)
4949
* Added implementation of `dpnp.ravel_multi_index` and `dpnp.unravel_index` functions [#2022](https://github.com/IntelPython/dpnp/pull/2022)
5050
* Added implementation of `dpnp.resize` and `dpnp.rot90` functions [#2030](https://github.com/IntelPython/dpnp/pull/2030)
51+
* Added implementation of `dpnp.require` function [#2036](https://github.com/IntelPython/dpnp/pull/2036)
5152

5253
### Change
5354

dpnp/dpnp_iface_manipulation.py

Lines changed: 111 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
"permute_dims",
7878
"ravel",
7979
"repeat",
80+
"require",
8081
"reshape",
8182
"resize",
8283
"result_type",
@@ -649,12 +650,8 @@ def atleast_1d(*arys):
649650
"""
650651

651652
res = []
653+
dpnp.check_supported_arrays_type(*arys)
652654
for ary in arys:
653-
if not dpnp.is_supported_array_type(ary):
654-
raise TypeError(
655-
"Each input array must be any of supported type, "
656-
f"but got {type(ary)}"
657-
)
658655
if ary.ndim == 0:
659656
result = ary.reshape(1)
660657
else:
@@ -707,12 +704,8 @@ def atleast_2d(*arys):
707704
"""
708705

709706
res = []
707+
dpnp.check_supported_arrays_type(*arys)
710708
for ary in arys:
711-
if not dpnp.is_supported_array_type(ary):
712-
raise TypeError(
713-
"Each input array must be any of supported type, "
714-
f"but got {type(ary)}"
715-
)
716709
if ary.ndim == 0:
717710
result = ary.reshape(1, 1)
718711
elif ary.ndim == 1:
@@ -771,12 +764,8 @@ def atleast_3d(*arys):
771764
"""
772765

773766
res = []
767+
dpnp.check_supported_arrays_type(*arys)
774768
for ary in arys:
775-
if not dpnp.is_supported_array_type(ary):
776-
raise TypeError(
777-
"Each input array must be any of supported type, "
778-
f"but got {type(ary)}"
779-
)
780769
if ary.ndim == 0:
781770
result = ary.reshape(1, 1, 1)
782771
elif ary.ndim == 1:
@@ -1959,6 +1948,113 @@ def repeat(a, repeats, axis=None):
19591948
return dpnp_array._create_from_usm_ndarray(usm_res)
19601949

19611950

1951+
def require(a, dtype=None, requirements=None, *, like=None):
1952+
"""
1953+
Return a :class:`dpnp.ndarray` of the provided type that satisfies
1954+
requirements.
1955+
1956+
This function is useful to be sure that an array with the correct flags
1957+
is returned for passing to compiled code (perhaps through ctypes).
1958+
1959+
For full documentation refer to :obj:`numpy.require`.
1960+
1961+
Parameters
1962+
----------
1963+
a : {dpnp.ndarray, usm_ndarray}
1964+
The input array to be converted to a type-and-requirement-satisfying
1965+
array.
1966+
dtype : {None, data-type}, optional
1967+
The required data-type. If ``None`` preserve the current dtype.
1968+
requirements : {None, str, sequence of str}, optional
1969+
The requirements list can be any of the following:
1970+
1971+
* 'F_CONTIGUOUS' ('F') - ensure a Fortran-contiguous array
1972+
* 'C_CONTIGUOUS' ('C') - ensure a C-contiguous array
1973+
* 'WRITABLE' ('W') - ensure a writable array
1974+
1975+
Returns
1976+
-------
1977+
out : dpnp.ndarray
1978+
Array with specified requirements and type if given.
1979+
1980+
Limitations
1981+
-----------
1982+
Parameter `like` is supported only with default value ``None``.
1983+
Otherwise, the function raises `NotImplementedError` exception.
1984+
1985+
See Also
1986+
--------
1987+
:obj:`dpnp.asarray` : Convert input to an ndarray.
1988+
:obj:`dpnp.asanyarray` : Convert to an ndarray, but pass through
1989+
ndarray subclasses.
1990+
:obj:`dpnp.ascontiguousarray` : Convert input to a contiguous array.
1991+
:obj:`dpnp.asfortranarray` : Convert input to an ndarray with
1992+
column-major memory order.
1993+
:obj:`dpnp.ndarray.flags` : Information about the memory layout
1994+
of the array.
1995+
1996+
Notes
1997+
-----
1998+
The returned array will be guaranteed to have the listed requirements
1999+
by making a copy if needed.
2000+
2001+
Examples
2002+
--------
2003+
>>> import dpnp as np
2004+
>>> x = np.arange(6).reshape(2, 3)
2005+
>>> x.flags
2006+
C_CONTIGUOUS : True
2007+
F_CONTIGUOUS : False
2008+
WRITEABLE : True
2009+
2010+
>>> y = np.require(x, dtype=np.float32, requirements=['W', 'F'])
2011+
>>> y.flags
2012+
C_CONTIGUOUS : False
2013+
F_CONTIGUOUS : True
2014+
WRITEABLE : True
2015+
2016+
"""
2017+
2018+
dpnp.check_limitations(like=like)
2019+
dpnp.check_supported_arrays_type(a)
2020+
2021+
possible_flags = {
2022+
"C": "C",
2023+
"C_CONTIGUOUS": "C",
2024+
"F": "F",
2025+
"F_CONTIGUOUS": "F",
2026+
"W": "W",
2027+
"WRITEABLE": "W",
2028+
}
2029+
2030+
if not requirements:
2031+
return dpnp.asanyarray(a, dtype=dtype)
2032+
2033+
try:
2034+
requirements = {possible_flags[x.upper()] for x in requirements}
2035+
except KeyError as exc:
2036+
incorrect_flag = (set(requirements) - set(possible_flags.keys())).pop()
2037+
raise ValueError(
2038+
f"Incorrect flag {incorrect_flag} in requirements"
2039+
) from exc
2040+
2041+
order = "A"
2042+
if requirements.issuperset({"C", "F"}):
2043+
raise ValueError("Cannot specify both 'C' and 'F' order")
2044+
if "F" in requirements:
2045+
order = "F"
2046+
requirements.remove("F")
2047+
elif "C" in requirements:
2048+
order = "C"
2049+
requirements.remove("C")
2050+
2051+
arr = dpnp.array(a, dtype=dtype, order=order, copy=None)
2052+
if not arr.flags["W"]:
2053+
return arr.copy(order)
2054+
2055+
return arr
2056+
2057+
19622058
def reshape(a, /, newshape, order="C", copy=None):
19632059
"""
19642060
Gives a new shape to an array without changing its data.

tests/test_arraymanipulation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dpctl.tensor as dpt
12
import numpy
23
import pytest
34
from dpctl.tensor._numpy_helper import AxisError
@@ -45,6 +46,13 @@ def test_3D_array(self):
4546
desired = [a, b]
4647
assert_array_equal(res, desired)
4748

49+
def test_dpnp_dpt_array(self):
50+
a = dpnp.array([1, 2])
51+
b = dpt.asarray([2, 3])
52+
res = dpnp.atleast_1d(a, b)
53+
desired = [dpnp.array([1, 2]), dpnp.array([2, 3])]
54+
assert_array_equal(res, desired)
55+
4856

4957
class TestAtleast2d:
5058
def test_0D_array(self):
@@ -77,6 +85,13 @@ def test_3D_array(self):
7785
desired = [a, b]
7886
assert_array_equal(res, desired)
7987

88+
def test_dpnp_dpt_array(self):
89+
a = dpnp.array([1, 2])
90+
b = dpt.asarray([2, 3])
91+
res = dpnp.atleast_2d(a, b)
92+
desired = [dpnp.array([[1, 2]]), dpnp.array([[2, 3]])]
93+
assert_array_equal(res, desired)
94+
8095

8196
class TestAtleast3d:
8297
def test_0D_array(self):
@@ -109,6 +124,13 @@ def test_3D_array(self):
109124
desired = [a, b]
110125
assert_array_equal(res, desired)
111126

127+
def test_dpnp_dpt_array(self):
128+
a = dpnp.array([1, 2])
129+
b = dpt.asarray([2, 3])
130+
res = dpnp.atleast_3d(a, b)
131+
desired = [dpnp.array([[[1], [2]]]), dpnp.array([[[2], [3]]])]
132+
assert_array_equal(res, desired)
133+
112134

113135
class TestColumnStack:
114136
def test_non_iterable(self):

tests/test_manipulation.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
import dpctl.tensor as dpt
24
import numpy
35
import pytest
@@ -665,6 +667,73 @@ def test_minimum_signed_integers(self, data, dtype):
665667
assert_array_equal(result, expected)
666668

667669

670+
class TestRequire:
671+
flag_names = ["C", "C_CONTIGUOUS", "F", "F_CONTIGUOUS", "W"]
672+
673+
def generate_all_false(self, dtype):
674+
a_np = numpy.zeros((10, 10), dtype=dtype)
675+
a_dp = dpnp.zeros((10, 10), dtype=dtype)
676+
a_np = a_np[::2, ::2]
677+
a_dp = a_dp[::2, ::2]
678+
a_np.flags["W"] = False
679+
a_dp.flags["W"] = False
680+
assert not a_dp.flags["C"]
681+
assert not a_dp.flags["F"]
682+
assert not a_dp.flags["W"]
683+
return a_np, a_dp
684+
685+
def set_and_check_flag(self, flag, dtype, arr):
686+
if dtype is None:
687+
dtype = arr[1].dtype
688+
result = numpy.require(arr[0], dtype, [flag])
689+
expected = dpnp.require(arr[1], dtype, [flag])
690+
assert result.flags[flag] == expected.flags[flag]
691+
assert result.dtype == expected.dtype
692+
693+
# a further call to dpnp.require ought to return the same array
694+
c = dpnp.require(expected, None, [flag])
695+
assert c is expected
696+
697+
def test_require_each(self):
698+
id = ["f4", "i4"]
699+
fd = [None, "f4", "c8"]
700+
for idtype, fdtype, flag in itertools.product(id, fd, self.flag_names):
701+
a = self.generate_all_false(idtype)
702+
self.set_and_check_flag(flag, fdtype, a)
703+
704+
def test_unknown_requirement(self):
705+
a = self.generate_all_false("f4")
706+
assert_raises(KeyError, numpy.require, a[0], None, "Q")
707+
assert_raises(ValueError, dpnp.require, a[1], None, "Q")
708+
709+
def test_non_array_input(self):
710+
a_np = numpy.array([1, 2, 3, 4])
711+
a_dp = dpnp.array(a_np)
712+
expected = numpy.require(a_np, "i4", ["C", "W"])
713+
result = dpnp.require(a_dp, "i4", ["C", "W"])
714+
assert expected.flags["C"] == result.flags["C"]
715+
assert expected.flags["F"] == result.flags["F"]
716+
assert expected.flags["W"] == result.flags["W"]
717+
assert expected.dtype == result.dtype
718+
assert_array_equal(expected, result)
719+
720+
def test_C_and_F_simul(self):
721+
a = self.generate_all_false("f4")
722+
assert_raises(ValueError, numpy.require, a[0], None, ["C", "F"])
723+
assert_raises(ValueError, dpnp.require, a[1], None, ["C", "F"])
724+
725+
def test_copy(self):
726+
a_np = numpy.arange(6).reshape(2, 3)
727+
a_dp = dpnp.arange(6).reshape(2, 3)
728+
a_np.flags["W"] = False
729+
a_dp.flags["W"] = False
730+
expected = numpy.require(a_np, requirements=["W", "C"])
731+
result = dpnp.require(a_dp, requirements=["W", "C"])
732+
# copy is done
733+
assert result is not a_dp
734+
assert_array_equal(expected, result)
735+
736+
668737
class TestResize:
669738
@pytest.mark.parametrize(
670739
"data, shape",

tests/test_sycl_queue.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,26 @@ def test_out_multi_dot(device):
12851285
assert_sycl_queue_equal(result.sycl_queue, exec_q)
12861286

12871287

1288+
@pytest.mark.parametrize(
1289+
"device",
1290+
valid_devices,
1291+
ids=[device.filter_string for device in valid_devices],
1292+
)
1293+
def test_require(device):
1294+
dpnp_data = dpnp.arange(10, device=device).reshape(2, 5)
1295+
result = dpnp.require(dpnp_data, dtype="f4", requirements=["F"])
1296+
1297+
expected_queue = dpnp_data.sycl_queue
1298+
result_queue = result.sycl_queue
1299+
assert_sycl_queue_equal(result_queue, expected_queue)
1300+
1301+
# No requirements
1302+
result = dpnp.require(dpnp_data, dtype="f4")
1303+
expected_queue = dpnp_data.sycl_queue
1304+
result_queue = result.sycl_queue
1305+
assert_sycl_queue_equal(result_queue, expected_queue)
1306+
1307+
12881308
@pytest.mark.parametrize(
12891309
"device",
12901310
valid_devices,

tests/test_usm_type.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,19 @@ def test_eigenvalue(func, shape, usm_type):
10131013
assert a.usm_type == dp_val.usm_type
10141014

10151015

1016+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
1017+
def test_require(usm_type):
1018+
dpnp_data = dp.arange(10, usm_type=usm_type).reshape(2, 5)
1019+
result = dp.require(dpnp_data, dtype="f4", requirements=["F"])
1020+
assert dpnp_data.usm_type == usm_type
1021+
assert result.usm_type == usm_type
1022+
1023+
# No requirements
1024+
result = dp.require(dpnp_data, dtype="f4")
1025+
assert dpnp_data.usm_type == usm_type
1026+
assert result.usm_type == usm_type
1027+
1028+
10161029
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
10171030
def test_resize(usm_type):
10181031
dpnp_data = dp.arange(10, usm_type=usm_type)

tests/third_party/cupy/manipulation_tests/test_kind.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def func(xp):
9494

9595
assert func(numpy) == func(cupy)
9696

97-
@pytest.mark.skip("dpnp.require() is not implemented yet")
9897
@testing.for_all_dtypes()
9998
def test_require_flag_check(self, dtype):
10099
possible_flags = [["C_CONTIGUOUS"], ["F_CONTIGUOUS"]]
@@ -105,36 +104,32 @@ def test_require_flag_check(self, dtype):
105104
assert arr.flags[parameter]
106105
assert arr.dtype == dtype
107106

108-
@pytest.mark.skip("dpnp.require() is not implemented yet")
107+
@pytest.mark.skip("dpnp.require() does not support requirement ['O']")
109108
@testing.for_all_dtypes()
110109
def test_require_owndata(self, dtype):
111110
x = cupy.zeros((2, 3, 4), dtype=dtype)
112111
arr = x.view()
113112
arr = cupy.require(arr, dtype, ["O"])
114113
assert arr.flags["OWNDATA"]
115114

116-
@pytest.mark.skip("dpnp.require() is not implemented yet")
117115
@testing.for_all_dtypes()
118116
def test_require_C_and_F_flags(self, dtype):
119117
x = cupy.zeros((2, 3, 4), dtype=dtype)
120118
with pytest.raises(ValueError):
121119
cupy.require(x, dtype, ["C", "F"])
122120

123-
@pytest.mark.skip("dpnp.require() is not implemented yet")
124121
@testing.for_all_dtypes()
125122
def test_require_incorrect_requirments(self, dtype):
126123
x = cupy.zeros((2, 3, 4), dtype=dtype)
127124
with pytest.raises(ValueError):
128-
cupy.require(x, dtype, ["W"])
125+
cupy.require(x, dtype, ["O"])
129126

130-
@pytest.mark.skip("dpnp.require() is not implemented yet")
131127
@testing.for_all_dtypes()
132128
def test_require_incorrect_dtype(self, dtype):
133129
x = cupy.zeros((2, 3, 4), dtype=dtype)
134-
with pytest.raises(ValueError):
130+
with pytest.raises((ValueError, TypeError)):
135131
cupy.require(x, "random", "C")
136132

137-
@pytest.mark.skip("dpnp.require() is not implemented yet")
138133
@testing.for_all_dtypes()
139134
def test_require_empty_requirements(self, dtype):
140135
x = cupy.zeros((2, 3, 4), dtype=dtype)

0 commit comments

Comments
 (0)