Skip to content

Added check for numpy version as np.bool is deprecated #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import lru_cache
from typing import Any, DefaultDict, Dict, List, NamedTuple, Sequence, Tuple, Union
from warnings import warn
from packaging import version

from . import api_version
from . import xp
Expand Down Expand Up @@ -126,6 +127,17 @@ def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]:
return tuple(dtypes)


def get_array_module_bool():
# Numpy deprecated np.bool starting from version 1.20.0
if xp.__name__ == "numpy":
xp_version = version.parse(xp.__version__)
if xp_version >= version.parse("1.20.0") and xp_version < version.parse("2.0.0"):
return xp.bool_
return xp.bool


array_mod_bool = get_array_module_bool()

uint_dtypes = _make_dtype_tuple_from_names(uint_names)
int_dtypes = _make_dtype_tuple_from_names(int_names)
real_float_dtypes = _make_dtype_tuple_from_names(real_float_names)
Expand All @@ -135,15 +147,15 @@ def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]:
numeric_dtypes = real_dtypes
if api_version > "2021.12":
numeric_dtypes += complex_dtypes
all_dtypes = (xp.bool,) + numeric_dtypes
all_dtypes = (array_mod_bool,) + numeric_dtypes
all_float_dtypes = real_float_dtypes
if api_version > "2021.12":
all_float_dtypes += complex_dtypes
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
bool_and_all_int_dtypes = (array_mod_bool,) + all_int_dtypes


kind_to_dtypes = {
"bool": [xp.bool],
"bool": [array_mod_bool],
"signed integer": int_dtypes,
"unsigned integer": uint_dtypes,
"integral": all_int_dtypes,
Expand Down Expand Up @@ -400,7 +412,7 @@ def result_type(*dtypes: DataType):
"the result is implementation-dependent"
)
category_to_dtypes = {
"boolean": (xp.bool,),
"boolean": (array_mod_bool,),
"integer": all_int_dtypes,
"floating-point": real_float_dtypes,
"real-valued": real_float_dtypes,
Expand Down Expand Up @@ -554,7 +566,7 @@ def result_type(*dtypes: DataType):
inplace_op_to_symbol[iop] = f"{symbol}="
func_in_dtypes[iop] = func_in_dtypes[op]
func_returns_bool[iop] = func_returns_bool[op]
func_in_dtypes["__bool__"] = (xp.bool,)
func_in_dtypes["__bool__"] = (array_mod_bool,)
func_in_dtypes["__int__"] = all_int_dtypes
func_in_dtypes["__index__"] = all_int_dtypes
func_in_dtypes["__float__"] = real_float_dtypes
Expand Down
7 changes: 5 additions & 2 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ def _float32ify(n: Union[int, float]) -> float:
return struct.unpack("!f", struct.pack("!f", n))[0]


array_mod_bool = dh.get_array_module_bool()


@wraps(xps.from_dtype)
def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
"""xps.from_dtype() without the crazy large numbers."""
if dtype == xp.bool:
if dtype == array_mod_bool:
return xps.from_dtype(dtype, **kwargs)

if dtype in dh.complex_dtypes:
Expand Down Expand Up @@ -76,7 +79,7 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
return xps.arrays(dtype, *args, elements=elements, **kwargs)


_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
_dtype_categories = [(array_mod_bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
_sorted_dtypes = [d for category in _dtype_categories for d in category]

def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
Expand Down
6 changes: 4 additions & 2 deletions array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape


array_mod_bool = dh.get_array_module_bool()

def scalar_objects(
dtype: DataType, shape: Shape
) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]:
Expand Down Expand Up @@ -165,7 +167,7 @@ def test_getitem_masking(shape, data):
),
hh.shapes(),
)
key = data.draw(hh.arrays(dtype=xp.bool, shape=mask_shapes), label="key")
key = data.draw(hh.arrays(dtype=array_mod_bool, shape=mask_shapes), label="key")

if key.ndim > x.ndim or not all(
ks in (xs, 0) for xs, ks in zip(x.shape, key.shape)
Expand Down Expand Up @@ -203,7 +205,7 @@ def test_getitem_masking(shape, data):
@given(hh.shapes(), st.data())
def test_setitem_masking(shape, data):
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
key = data.draw(hh.arrays(dtype=xp.bool, shape=shape), label="key")
key = data.draw(hh.arrays(dtype=array_mod_bool, shape=shape), label="key")
value = data.draw(
hh.from_dtype(x.dtype) | hh.arrays(dtype=x.dtype, shape=()), label="value"
)
Expand Down
8 changes: 5 additions & 3 deletions array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float]
),
)

array_mod_bool = dh.get_array_module_bool()


# TODO: support testing complex dtypes
@given(dtype=st.none() | xps.real_dtypes(), data=st.data())
Expand Down Expand Up @@ -204,7 +206,7 @@ def test_asarray_scalars(shape, data):
if dtype is None:
dtype_family = data.draw(
st.sampled_from(
[(xp.bool,), (xp.int32, xp.int64), (xp.float32, xp.float64)]
[(array_mod_bool,), (xp.int32, xp.int64), (xp.float32, xp.float64)]
),
label="expected out dtypes",
)
Expand Down Expand Up @@ -393,7 +395,7 @@ def test_full(shape, fill_value, kw):
if kw.get("dtype", None):
dtype = kw["dtype"]
elif isinstance(fill_value, bool):
dtype = xp.bool
dtype = array_mod_bool
elif isinstance(fill_value, int):
dtype = dh.default_int
elif isinstance(fill_value, float):
Expand All @@ -410,7 +412,7 @@ def test_full(shape, fill_value, kw):
assume(all(abs(c) < math.sqrt(M) for c in [fill_value.real, fill_value.imag]))
if kw.get("dtype", None) is None:
if isinstance(fill_value, bool):
assert out.dtype == xp.bool, f"{out.dtype=}, but should be bool [full()]"
assert out.dtype == array_mod_bool, f"{out.dtype=}, but should be bool [full()]"
elif isinstance(fill_value, int):
ph.assert_default_int("full", out.dtype)
elif isinstance(fill_value, float):
Expand Down
7 changes: 4 additions & 3 deletions array_api_tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def non_complex_dtypes():
def float32(n: Union[int, float]) -> float:
return struct.unpack("!f", struct.pack("!f", float(n)))[0]

array_mod_bool = dh.get_array_module_bool()

@given(
x_dtype=non_complex_dtypes(),
Expand All @@ -31,7 +32,7 @@ def float32(n: Union[int, float]) -> float:
data=st.data(),
)
def test_astype(x_dtype, dtype, kw, data):
if xp.bool in (x_dtype, dtype):
if array_mod_bool in (x_dtype, dtype):
elements_strat = hh.from_dtype(x_dtype)
else:
m1, M1 = dh.dtype_ranges[x_dtype]
Expand Down Expand Up @@ -118,8 +119,8 @@ def test_can_cast(_from, to, data):

f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]"
assert isinstance(out, bool), f"{type(out)=}, but should be bool {f_func}"
if _from == xp.bool:
expected = to == xp.bool
if _from == array_mod_bool:
expected = to == array_mod_bool
else:
same_family = None
for dtypes in [dh.all_int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]:
Expand Down
Loading