Skip to content

Commit d46d871

Browse files
committed
address more comments
1 parent 4fb7dd6 commit d46d871

File tree

3 files changed

+20
-13
lines changed

3 files changed

+20
-13
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,8 +1945,7 @@ def repeat(a, repeats, axis=None):
19451945

19461946
def require(a, dtype=None, requirements=None, *, like=None):
19471947
"""
1948-
Return a :class:`dpnp.ndarray` of the provided type that satisfies
1949-
requirements.
1948+
Return an dpnp.ndarray of the provided type that satisfies requirements.
19501949
19511950
This function is useful to be sure that an array with the correct flags
19521951
is returned for passing to compiled code (perhaps through ctypes).
@@ -1957,9 +1956,11 @@ def require(a, dtype=None, requirements=None, *, like=None):
19571956
----------
19581957
a : array_like
19591958
The object to be converted to a type-and-requirement-satisfying array.
1960-
dtype : {None, data-type}, optional
1961-
The required data-type. If ``None`` preserve the current dtype.
1962-
requirements : {None, str, sequence of str}, optional
1959+
dtype : data-type, optional
1960+
The required data-type. If None preserve the current dtype. If your
1961+
application requires the data to be in native byteorder, include
1962+
a byteorder specification as a part of the dtype specification.
1963+
requirements : {str, sequence of str}, , optional
19631964
The requirements list can be any of the following:
19641965
19651966
* 'F_CONTIGUOUS' ('F') - ensure a Fortran-contiguous array
@@ -1979,7 +1980,7 @@ def require(a, dtype=None, requirements=None, *, like=None):
19791980
See Also
19801981
--------
19811982
:obj:`dpnp.asarray` : Convert input to an ndarray.
1982-
:obj:`dpnp.asanyarray` : Convert to an ndarray, but pass through
1983+
:obj:`dpnp.asanyarray ` : Convert to an ndarray, but pass through
19831984
ndarray subclasses.
19841985
:obj:`dpnp.ascontiguousarray` : Convert input to a contiguous array.
19851986
:obj:`dpnp.asfortranarray` : Convert input to an ndarray with
@@ -1995,7 +1996,7 @@ def require(a, dtype=None, requirements=None, *, like=None):
19951996
Examples
19961997
--------
19971998
>>> import dpnp as np
1998-
>>> x = np.arange(6).reshape(2, 3)
1999+
>>> x = np.arange(6).reshape(2,3)
19992000
>>> x.flags
20002001
C_CONTIGUOUS : True
20012002
F_CONTIGUOUS : False
@@ -2023,7 +2024,14 @@ def require(a, dtype=None, requirements=None, *, like=None):
20232024
if not requirements:
20242025
return dpnp.asanyarray(a, dtype=dtype)
20252026

2026-
requirements = {possible_flags[x.upper()] for x in requirements}
2027+
try:
2028+
requirements = {possible_flags[x.upper()] for x in requirements}
2029+
except KeyError as exc:
2030+
incorrect_flag = (set(requirements) - set(possible_flags.keys())).pop()
2031+
raise ValueError(
2032+
f"Incorrect flag {incorrect_flag} in requirements"
2033+
) from exc
2034+
20272035
order = "A"
20282036
if requirements.issuperset({"C", "F"}):
20292037
raise ValueError("Cannot specify both 'C' and 'F' order")

tests/test_manipulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -671,8 +671,8 @@ class TestRequire:
671671
flag_names = ["C", "C_CONTIGUOUS", "F", "F_CONTIGUOUS", "W"]
672672

673673
def generate_all_false(self, dtype):
674-
a_np = numpy.zeros((10, 10))
675-
a_dp = dpnp.zeros((10, 10))
674+
a_np = numpy.zeros((10, 10), dtype=dtype)
675+
a_dp = dpnp.zeros((10, 10), dtype=dtype)
676676
a_np = a_np[::2, ::2]
677677
a_dp = a_dp[::2, ::2]
678678
a_np.flags["W"] = False
@@ -704,7 +704,7 @@ def test_require_each(self):
704704
@pytest.mark.parametrize("xp", [numpy, dpnp])
705705
def test_unknown_requirement(self, xp):
706706
a = self.generate_all_false("f4")
707-
assert_raises(KeyError, xp.require, a, None, "Q")
707+
assert_raises((KeyError, ValueError), xp.require, a, None, "Q")
708708

709709
def test_non_array_input(self):
710710
expected = numpy.require([1, 2, 3, 4], "i4", ["C", "W"])

tests/third_party/cupy/manipulation_tests/test_kind.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,11 @@ def test_require_C_and_F_flags(self, dtype):
118118
with pytest.raises(ValueError):
119119
cupy.require(x, dtype, ["C", "F"])
120120

121-
@pytest.mark.skip("dpnp.require() does support requirement ['W']")
122121
@testing.for_all_dtypes()
123122
def test_require_incorrect_requirments(self, dtype):
124123
x = cupy.zeros((2, 3, 4), dtype=dtype)
125124
with pytest.raises(ValueError):
126-
cupy.require(x, dtype, ["W"])
125+
cupy.require(x, dtype, ["O"])
127126

128127
@testing.for_all_dtypes()
129128
def test_require_incorrect_dtype(self, dtype):

0 commit comments

Comments
 (0)