Skip to content

Commit bc20d33

Browse files
committed
Move the array API dtype categories into the top level
They are not an official part of the spec but are useful for various parts of the implementation.
1 parent 5605d68 commit bc20d33

File tree

3 files changed

+17
-26
lines changed

3 files changed

+17
-26
lines changed

numpy/array_api/_array_object.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,23 +98,14 @@ def _check_allowed_dtypes(self, other, dtype_category, op):
9898
if other is NotImplemented:
9999
return other
100100
"""
101-
from ._dtypes import _result_type
102-
103-
_dtypes = {
104-
'all': _all_dtypes,
105-
'numeric': _numeric_dtypes,
106-
'integer': _integer_dtypes,
107-
'integer or boolean': _integer_or_boolean_dtypes,
108-
'boolean': _boolean_dtypes,
109-
'floating-point': _floating_dtypes,
110-
}
111-
112-
if self.dtype not in _dtypes[dtype_category]:
101+
from ._dtypes import _result_type, _dtype_categories
102+
103+
if self.dtype not in _dtype_categories[dtype_category]:
113104
raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
114105
if isinstance(other, (int, float, bool)):
115106
other = self._promote_scalar(other)
116107
elif isinstance(other, Array):
117-
if other.dtype not in _dtypes[dtype_category]:
108+
if other.dtype not in _dtype_categories[dtype_category]:
118109
raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
119110
else:
120111
return NotImplemented

numpy/array_api/_dtypes.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@
2323
_integer_or_boolean_dtypes = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64)
2424
_numeric_dtypes = (float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64)
2525

26+
_dtype_categories = {
27+
'all': _all_dtypes,
28+
'numeric': _numeric_dtypes,
29+
'integer': _integer_dtypes,
30+
'integer or boolean': _integer_or_boolean_dtypes,
31+
'boolean': _boolean_dtypes,
32+
'floating-point': _floating_dtypes,
33+
}
34+
35+
2636
# Note: the spec defines a restricted type promotion table compared to NumPy.
2737
# In particular, cross-kind promotions like integer + float or boolean +
2838
# integer are not allowed, even for functions that accept both kinds.

numpy/array_api/tests/test_elementwise_functions.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44

55
from .. import asarray, _elementwise_functions
66
from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
7-
from .._dtypes import (_all_dtypes, _boolean_dtypes, _floating_dtypes,
8-
_integer_dtypes, _integer_or_boolean_dtypes,
9-
_numeric_dtypes)
7+
from .._dtypes import (_dtype_categories, _boolean_dtypes, _floating_dtypes,
8+
_integer_dtypes)
109

1110
def nargs(func):
1211
return len(getfullargspec(func).args)
@@ -75,15 +74,6 @@ def test_function_types():
7574
'trunc': 'numeric',
7675
}
7776

78-
_dtypes = {
79-
'all': _all_dtypes,
80-
'numeric': _numeric_dtypes,
81-
'integer': _integer_dtypes,
82-
'integer_or_boolean': _integer_or_boolean_dtypes,
83-
'boolean': _boolean_dtypes,
84-
'floating': _floating_dtypes,
85-
}
86-
8777
def _array_vals():
8878
for d in _integer_dtypes:
8979
yield asarray(1, dtype=d)
@@ -94,7 +84,7 @@ def _array_vals():
9484

9585
for x in _array_vals():
9686
for func_name, types in elementwise_function_input_types.items():
97-
dtypes = _dtypes[types]
87+
dtypes = _dtype_categories[types]
9888
func = getattr(_elementwise_functions, func_name)
9989
if nargs(func) == 2:
10090
for y in _array_vals():

0 commit comments

Comments
 (0)