Skip to content

Commit f827bd7

Browse files
authored
Merge branch 'master' into add-mean-keyword-to-std-var
2 parents 46c14cb + d522480 commit f827bd7

File tree

4 files changed

+121
-6
lines changed

4 files changed

+121
-6
lines changed

.github/workflows/array-api-skips.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33
# no 'uint8' dtype
44
array_api_tests/test_array_object.py::test_getitem_masking
55

6-
# no 'isdtype' function
7-
array_api_tests/test_data_type_functions.py::test_isdtype
8-
array_api_tests/test_has_names.py::test_has_names[data_type-isdtype]
9-
array_api_tests/test_signatures.py::test_func_signature[isdtype]
10-
116
# missing unique-like functions
127
array_api_tests/test_has_names.py::test_has_names[set-unique_all]
138
array_api_tests/test_has_names.py::test_has_names[set-unique_counts]

doc/reference/dtype.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ Data type routines
1414
dpnp.min_scalar_type
1515
dpnp.result_type
1616
dpnp.common_type
17-
dpnp.obj2sctype
1817

1918
Creating data types
2019
-------------------

dpnp/dpnp_iface_types.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"integer",
6565
"intc",
6666
"intp",
67+
"isdtype",
6768
"issubdtype",
6869
"is_type_supported",
6970
"nan",
@@ -194,11 +195,66 @@ def iinfo(dtype):
194195
smallest representable number.
195196
196197
"""
198+
197199
if isinstance(dtype, dpnp_array):
198200
dtype = dtype.dtype
199201
return dpt.iinfo(dtype)
200202

201203

204+
def isdtype(dtype, kind):
205+
"""
206+
Returns a boolean indicating whether a provided `dtype` is
207+
of a specified data type `kind`.
208+
209+
Parameters
210+
----------
211+
dtype : dtype
212+
The input dtype.
213+
kind : {dtype, str, tuple of dtypes or strs}
214+
The input dtype or dtype kind. Allowed dtype kinds are:
215+
216+
* ``'bool'`` : boolean kind
217+
* ``'signed integer'`` : signed integer data types
218+
* ``'unsigned integer'`` : unsigned integer data types
219+
* ``'integral'`` : integer data types
220+
* ``'real floating'`` : real-valued floating-point data types
221+
* ``'complex floating'`` : complex floating-point data types
222+
* ``'numeric'`` : numeric data types
223+
224+
Returns
225+
-------
226+
out : bool
227+
A boolean indicating whether a provided `dtype` is of a specified data
228+
type `kind`.
229+
230+
See Also
231+
--------
232+
:obj:`dpnp.issubdtype` : Test if the first argument is a type code
233+
lower/equal in type hierarchy.
234+
235+
Examples
236+
--------
237+
>>> import dpnp as np
238+
>>> np.isdtype(np.float32, np.float64)
239+
False
240+
>>> np.isdtype(np.float32, "real floating")
241+
True
242+
>>> np.isdtype(np.complex128, ("real floating", "complex floating"))
243+
True
244+
245+
"""
246+
247+
if isinstance(dtype, type):
248+
dtype = dpt.dtype(dtype)
249+
250+
if isinstance(kind, type):
251+
kind = dpt.dtype(kind)
252+
elif isinstance(kind, tuple):
253+
kind = tuple(dpt.dtype(k) if isinstance(k, type) else k for k in kind)
254+
255+
return dpt.isdtype(dtype, kind)
256+
257+
202258
def issubdtype(arg1, arg2):
203259
"""
204260
Returns ``True`` if the first argument is a type code lower/equal

dpnp/tests/test_dtype_routines.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import numpy
2+
import pytest
3+
from numpy.testing import assert_raises_regex
4+
5+
import dpnp
6+
7+
from .helper import numpy_version
8+
9+
if numpy_version() >= "2.0.0":
10+
from numpy._core.numerictypes import sctypes
11+
else:
12+
from numpy.core.numerictypes import sctypes
13+
14+
15+
class TestIsDType:
16+
dtype_group = {
17+
"signed integer": sctypes["int"],
18+
"unsigned integer": sctypes["uint"],
19+
"integral": sctypes["int"] + sctypes["uint"],
20+
"real floating": sctypes["float"],
21+
"complex floating": sctypes["complex"],
22+
"numeric": (
23+
sctypes["int"]
24+
+ sctypes["uint"]
25+
+ sctypes["float"]
26+
+ sctypes["complex"]
27+
),
28+
}
29+
30+
@pytest.mark.parametrize(
31+
"dt, close_dt",
32+
[
33+
# TODO: replace with (dpnp.uint64, dpnp.uint32) once available
34+
(dpnp.int64, dpnp.int32),
35+
(numpy.uint64, numpy.uint32),
36+
(dpnp.float64, dpnp.float32),
37+
(dpnp.complex128, dpnp.complex64),
38+
],
39+
)
40+
@pytest.mark.parametrize("dt_group", [None] + list(dtype_group.keys()))
41+
def test_basic(self, dt, close_dt, dt_group):
42+
# First check if same dtypes return "True" and different ones
43+
# give "False" (even if they're close in the dtype hierarchy).
44+
if dt_group is None:
45+
assert dpnp.isdtype(dt, dt)
46+
assert not dpnp.isdtype(dt, close_dt)
47+
assert dpnp.isdtype(dt, (dt, close_dt))
48+
49+
# Check that dtype and a dtype group that it belongs to return "True",
50+
# and "False" otherwise.
51+
elif dt in self.dtype_group[dt_group]:
52+
assert dpnp.isdtype(dt, dt_group)
53+
assert dpnp.isdtype(dt, (close_dt, dt_group))
54+
else:
55+
assert not dpnp.isdtype(dt, dt_group)
56+
57+
def test_invalid_args(self):
58+
with assert_raises_regex(TypeError, r"Expected instance of.*"):
59+
dpnp.isdtype("int64", dpnp.int64)
60+
61+
with assert_raises_regex(TypeError, r"Unsupported data type kind:.*"):
62+
dpnp.isdtype(dpnp.int64, 1)
63+
64+
with assert_raises_regex(ValueError, r"Unrecognized data type kind:.*"):
65+
dpnp.isdtype(dpnp.int64, "int64")

0 commit comments

Comments
 (0)