Skip to content

Commit 7ce56f7

Browse files
committed
add dpnp.require
1 parent 6ee4cb5 commit 7ce56f7

File tree

4 files changed

+184
-22
lines changed

4 files changed

+184
-22
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 103 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
"permute_dims",
7777
"ravel",
7878
"repeat",
79+
"require",
7980
"reshape",
8081
"result_type",
8182
"roll",
@@ -646,12 +647,8 @@ def atleast_1d(*arys):
646647
"""
647648

648649
res = []
650+
dpnp.check_supported_arrays_type(*arys)
649651
for ary in arys:
650-
if not dpnp.is_supported_array_type(ary):
651-
raise TypeError(
652-
"Each input array must be any of supported type, "
653-
f"but got {type(ary)}"
654-
)
655652
if ary.ndim == 0:
656653
result = ary.reshape(1)
657654
else:
@@ -704,12 +701,8 @@ def atleast_2d(*arys):
704701
"""
705702

706703
res = []
704+
dpnp.check_supported_arrays_type(*arys)
707705
for ary in arys:
708-
if not dpnp.is_supported_array_type(ary):
709-
raise TypeError(
710-
"Each input array must be any of supported type, "
711-
f"but got {type(ary)}"
712-
)
713706
if ary.ndim == 0:
714707
result = ary.reshape(1, 1)
715708
elif ary.ndim == 1:
@@ -768,12 +761,8 @@ def atleast_3d(*arys):
768761
"""
769762

770763
res = []
764+
dpnp.check_supported_arrays_type(*arys)
771765
for ary in arys:
772-
if not dpnp.is_supported_array_type(ary):
773-
raise TypeError(
774-
"Each input array must be any of supported type, "
775-
f"but got {type(ary)}"
776-
)
777766
if ary.ndim == 0:
778767
result = ary.reshape(1, 1, 1)
779768
elif ary.ndim == 1:
@@ -1954,6 +1943,105 @@ def repeat(a, repeats, axis=None):
19541943
return dpnp_array._create_from_usm_ndarray(usm_res)
19551944

19561945

1946+
def require(a, dtype=None, requirements=None, *, like=None):
1947+
"""
1948+
Return an dpnp.ndarray of the provided type that satisfies requirements.
1949+
1950+
This function is useful to be sure that an array with the correct flags
1951+
is returned for passing to compiled code (perhaps through ctypes).
1952+
1953+
For full documentation refer to :obj:`numpy.require`.
1954+
1955+
Parameters
1956+
----------
1957+
a : array_like
1958+
The object to be converted to a type-and-requirement-satisfying array.
1959+
dtype : data-type, optional
1960+
The required data-type. If None preserve the current dtype. If your
1961+
application requires the data to be in native byteorder, include
1962+
a byteorder specification as a part of the dtype specification.
1963+
requirements : {str, sequence of str}, , optional
1964+
The requirements list can be any of the following:
1965+
1966+
* 'F_CONTIGUOUS' ('F') - ensure a Fortran-contiguous array
1967+
* 'C_CONTIGUOUS' ('C') - ensure a C-contiguous array
1968+
* 'WRITABLE' ('W') - ensure a writable array
1969+
1970+
Returns
1971+
-------
1972+
out : dpnp.ndarray
1973+
Array with specified requirements and type if given.
1974+
1975+
Limitations
1976+
-----------
1977+
Parameter `like` is supported only with default value ``None``.
1978+
Otherwise, the function raises `NotImplementedError` exception.
1979+
1980+
See Also
1981+
--------
1982+
:obj:`dpnp.asarray` : Convert input to an ndarray.
1983+
:obj:`dpnp.asanyarray ` : Convert to an ndarray, but pass through
1984+
ndarray subclasses.
1985+
:obj:`dpnp.ascontiguousarray` : Convert input to a contiguous array.
1986+
:obj:`dpnp.asfortranarray` : Convert input to an ndarray with
1987+
column-major memory order.
1988+
:obj:`dpnp.ndarray.flags` : Information about the memory layout
1989+
of the array.
1990+
1991+
Notes
1992+
-----
1993+
The returned array will be guaranteed to have the listed requirements
1994+
by making a copy if needed.
1995+
1996+
Examples
1997+
--------
1998+
>>> import dpnp as np
1999+
>>> x = np.arange(6).reshape(2,3)
2000+
>>> x.flags
2001+
C_CONTIGUOUS : True
2002+
F_CONTIGUOUS : False
2003+
WRITEABLE : True
2004+
2005+
>>> y = np.require(x, dtype=np.float32, requirements=['W', 'F'])
2006+
>>> y.flags
2007+
C_CONTIGUOUS : False
2008+
F_CONTIGUOUS : True
2009+
WRITEABLE : True
2010+
2011+
"""
2012+
2013+
dpnp.check_limitations(like=like)
2014+
2015+
possible_flags = {
2016+
"C": "C",
2017+
"C_CONTIGUOUS": "C",
2018+
"F": "F",
2019+
"F_CONTIGUOUS": "F",
2020+
"W": "W",
2021+
"WRITEABLE": "W",
2022+
}
2023+
2024+
if not requirements:
2025+
return dpnp.asanyarray(a, dtype=dtype)
2026+
2027+
requirements = {possible_flags[x.upper()] for x in requirements}
2028+
order = "A"
2029+
if requirements.issuperset({"C", "F"}):
2030+
raise ValueError("Cannot specify both 'C' and 'F' order")
2031+
if "F" in requirements:
2032+
order = "F"
2033+
requirements.remove("F")
2034+
elif "C" in requirements:
2035+
order = "C"
2036+
requirements.remove("C")
2037+
2038+
arr = dpnp.array(a, dtype=dtype, order=order, copy=None)
2039+
if not arr.flags["W"]:
2040+
return arr.copy(order)
2041+
2042+
return arr
2043+
2044+
19572045
def reshape(a, /, newshape, order="C", copy=None):
19582046
"""
19592047
Gives a new shape to an array without changing its data.

tests/test_arraymanipulation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dpctl.tensor as dpt
12
import numpy
23
import pytest
34
from dpctl.tensor._numpy_helper import AxisError
@@ -45,6 +46,13 @@ def test_3D_array(self):
4546
desired = [a, b]
4647
assert_array_equal(res, desired)
4748

49+
def test_dpnp_dpt_array(self):
50+
a = dpnp.array([1, 2])
51+
b = dpt.asarray([2, 3])
52+
res = dpnp.atleast_1d(a, b)
53+
desired = [dpnp.array([1, 2]), dpnp.array([2, 3])]
54+
assert_array_equal(res, desired)
55+
4856

4957
class TestAtleast2d:
5058
def test_0D_array(self):
@@ -77,6 +85,13 @@ def test_3D_array(self):
7785
desired = [a, b]
7886
assert_array_equal(res, desired)
7987

88+
def test_dpnp_dpt_array(self):
89+
a = dpnp.array([1, 2])
90+
b = dpt.asarray([2, 3])
91+
res = dpnp.atleast_2d(a, b)
92+
desired = [dpnp.array([[1, 2]]), dpnp.array([[2, 3]])]
93+
assert_array_equal(res, desired)
94+
8095

8196
class TestAtleast3d:
8297
def test_0D_array(self):
@@ -109,6 +124,13 @@ def test_3D_array(self):
109124
desired = [a, b]
110125
assert_array_equal(res, desired)
111126

127+
def test_dpnp_dpt_array(self):
128+
a = dpnp.array([1, 2])
129+
b = dpt.asarray([2, 3])
130+
res = dpnp.atleast_3d(a, b)
131+
desired = [dpnp.array([[[1], [2]]]), dpnp.array([[[2], [3]]])]
132+
assert_array_equal(res, desired)
133+
112134

113135
class TestColumnStack:
114136
def test_non_iterable(self):

tests/test_manipulation.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
import dpctl.tensor as dpt
24
import numpy
35
import pytest
@@ -665,6 +667,60 @@ def test_minimum_signed_integers(self, data, dtype):
665667
assert_array_equal(result, expected)
666668

667669

670+
class TestRequire:
671+
flag_names = ["C", "C_CONTIGUOUS", "F", "F_CONTIGUOUS", "W"]
672+
673+
def generate_all_false(self, dtype):
674+
a_np = numpy.zeros((10, 10))
675+
a_dp = dpnp.zeros((10, 10))
676+
a_np = a_np[::2, ::2]
677+
a_dp = a_dp[::2, ::2]
678+
a_np.flags["W"] = False
679+
a_dp.flags["W"] = False
680+
assert not a_dp.flags["C"]
681+
assert not a_dp.flags["F"]
682+
assert not a_dp.flags["W"]
683+
return a_np, a_dp
684+
685+
def set_and_check_flag(self, flag, dtype, arr):
686+
if dtype is None:
687+
dtype = arr[1].dtype
688+
a_np = numpy.require(arr[0], dtype, [flag])
689+
a_dp = dpnp.require(arr[1], dtype, [flag])
690+
assert a_np.flags[flag] == a_dp.flags[flag]
691+
assert a_np.dtype == a_dp.dtype
692+
693+
# a further call to dpnp.require ought to return the same array
694+
c = dpnp.require(a_dp, None, [flag])
695+
assert c is a_dp
696+
697+
def test_require_each(self):
698+
id = ["f4", "i4"]
699+
fd = [None, "f4", "c8"]
700+
for idtype, fdtype, flag in itertools.product(id, fd, self.flag_names):
701+
a = self.generate_all_false(idtype)
702+
self.set_and_check_flag(flag, fdtype, a)
703+
704+
@pytest.mark.parametrize("xp", [numpy, dpnp])
705+
def test_unknown_requirement(self, xp):
706+
a = self.generate_all_false("f4")
707+
assert_raises(KeyError, xp.require, a, None, "Q")
708+
709+
def test_non_array_input(self):
710+
a_np = numpy.require([1, 2, 3, 4], "i4", ["C", "W"])
711+
a_dp = dpnp.require([1, 2, 3, 4], "i4", ["C", "W"])
712+
assert a_np.flags["C"] == a_dp.flags["C"]
713+
assert a_np.flags["F"] == a_dp.flags["F"]
714+
assert a_np.flags["W"] == a_dp.flags["W"]
715+
assert a_np.dtype == a_dp.dtype
716+
assert_array_equal(a_np, a_dp)
717+
718+
@pytest.mark.parametrize("xp", [numpy, dpnp])
719+
def test_C_and_F_simul(self, xp):
720+
a = self.generate_all_false("f4")
721+
assert_raises(ValueError, xp.require, a, None, ["C", "F"])
722+
723+
668724
class TestTranspose:
669725
@pytest.mark.parametrize("axes", [(0, 1), (1, 0), [0, 1]])
670726
def test_2d_with_axes(self, axes):

tests/third_party/cupy/manipulation_tests/test_kind.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def func(xp):
9494

9595
assert func(numpy) == func(cupy)
9696

97-
@pytest.mark.skip("dpnp.require() is not implemented yet")
9897
@testing.for_all_dtypes()
9998
def test_require_flag_check(self, dtype):
10099
possible_flags = [["C_CONTIGUOUS"], ["F_CONTIGUOUS"]]
@@ -105,36 +104,33 @@ def test_require_flag_check(self, dtype):
105104
assert arr.flags[parameter]
106105
assert arr.dtype == dtype
107106

108-
@pytest.mark.skip("dpnp.require() is not implemented yet")
107+
@pytest.mark.skip("dpnp.require() does not support requirement ['O']")
109108
@testing.for_all_dtypes()
110109
def test_require_owndata(self, dtype):
111110
x = cupy.zeros((2, 3, 4), dtype=dtype)
112111
arr = x.view()
113112
arr = cupy.require(arr, dtype, ["O"])
114113
assert arr.flags["OWNDATA"]
115114

116-
@pytest.mark.skip("dpnp.require() is not implemented yet")
117115
@testing.for_all_dtypes()
118116
def test_require_C_and_F_flags(self, dtype):
119117
x = cupy.zeros((2, 3, 4), dtype=dtype)
120118
with pytest.raises(ValueError):
121119
cupy.require(x, dtype, ["C", "F"])
122120

123-
@pytest.mark.skip("dpnp.require() is not implemented yet")
121+
@pytest.mark.skip("dpnp.require() does support requirement ['W']")
124122
@testing.for_all_dtypes()
125123
def test_require_incorrect_requirments(self, dtype):
126124
x = cupy.zeros((2, 3, 4), dtype=dtype)
127125
with pytest.raises(ValueError):
128126
cupy.require(x, dtype, ["W"])
129127

130-
@pytest.mark.skip("dpnp.require() is not implemented yet")
131128
@testing.for_all_dtypes()
132129
def test_require_incorrect_dtype(self, dtype):
133130
x = cupy.zeros((2, 3, 4), dtype=dtype)
134-
with pytest.raises(ValueError):
131+
with pytest.raises((ValueError, TypeError)):
135132
cupy.require(x, "random", "C")
136133

137-
@pytest.mark.skip("dpnp.require() is not implemented yet")
138134
@testing.for_all_dtypes()
139135
def test_require_empty_requirements(self, dtype):
140136
x = cupy.zeros((2, 3, 4), dtype=dtype)

0 commit comments

Comments
 (0)