Skip to content

Commit 9bffaea

Browse files
committed
float_dtypes -> real_float_dtypes
Better reflect the specs naming convention and avoid confusion with complex
1 parent f4ef5ae commit 9bffaea

8 files changed

+24
-24
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
"uint_names",
1515
"int_names",
1616
"all_int_names",
17-
"float_names",
17+
"real_float_names",
1818
"real_names",
1919
"complex_names",
2020
"numeric_names",
2121
"dtype_names",
2222
"int_dtypes",
2323
"uint_dtypes",
2424
"all_int_dtypes",
25-
"float_dtypes",
25+
"real_float_dtypes",
2626
"real_dtypes",
2727
"numeric_dtypes",
2828
"all_dtypes",
@@ -98,8 +98,8 @@ def __repr__(self):
9898
uint_names = ("uint8", "uint16", "uint32", "uint64")
9999
int_names = ("int8", "int16", "int32", "int64")
100100
all_int_names = uint_names + int_names
101-
float_names = ("float32", "float64")
102-
real_names = uint_names + int_names + float_names
101+
real_float_names = ("float32", "float64")
102+
real_names = uint_names + int_names + real_float_names
103103
complex_names = ("complex64", "complex128")
104104
numeric_names = real_names + complex_names
105105
dtype_names = ("bool",) + numeric_names
@@ -128,15 +128,15 @@ def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]:
128128

129129
uint_dtypes = _make_dtype_tuple_from_names(uint_names)
130130
int_dtypes = _make_dtype_tuple_from_names(int_names)
131-
float_dtypes = _make_dtype_tuple_from_names(float_names)
131+
real_float_dtypes = _make_dtype_tuple_from_names(real_float_names)
132132
all_int_dtypes = uint_dtypes + int_dtypes
133-
real_dtypes = all_int_dtypes + float_dtypes
133+
real_dtypes = all_int_dtypes + real_float_dtypes
134134
complex_dtypes = _make_dtype_tuple_from_names(complex_names)
135135
numeric_dtypes = real_dtypes
136136
if api_version > "2021.12":
137137
numeric_dtypes += complex_dtypes
138138
all_dtypes = (xp.bool,) + numeric_dtypes
139-
all_float_dtypes = float_dtypes
139+
all_float_dtypes = real_float_dtypes
140140
if api_version > "2021.12":
141141
all_float_dtypes += complex_dtypes
142142
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
@@ -147,7 +147,7 @@ def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]:
147147
"signed integer": int_dtypes,
148148
"unsigned integer": uint_dtypes,
149149
"integral": all_int_dtypes,
150-
"real floating": float_dtypes,
150+
"real floating": real_float_dtypes,
151151
"complex floating": complex_dtypes,
152152
"numeric": numeric_dtypes,
153153
}
@@ -164,7 +164,7 @@ def is_float_dtype(dtype):
164164
# See https://github.com/numpy/numpy/issues/18434
165165
if dtype is None:
166166
return False
167-
valid_dtypes = float_dtypes
167+
valid_dtypes = real_float_dtypes
168168
if api_version > "2021.12":
169169
valid_dtypes += complex_dtypes
170170
return dtype in valid_dtypes
@@ -173,7 +173,7 @@ def is_float_dtype(dtype):
173173
def get_scalar_type(dtype: DataType) -> ScalarType:
174174
if dtype in all_int_dtypes:
175175
return int
176-
elif dtype in float_dtypes:
176+
elif dtype in real_float_dtypes:
177177
return float
178178
elif dtype in complex_dtypes:
179179
return complex
@@ -245,7 +245,7 @@ class MinMax(NamedTuple):
245245
if default_int not in int_dtypes:
246246
warn(f"inferred default int is {default_int!r}, which is not an int")
247247
default_float = xp.asarray(float()).dtype
248-
if default_float not in float_dtypes:
248+
if default_float not in real_float_dtypes:
249249
warn(f"inferred default float is {default_float!r}, which is not a float")
250250
if api_version > "2021.12":
251251
default_complex = xp.asarray(complex()).dtype
@@ -346,7 +346,7 @@ def result_type(*dtypes: DataType):
346346
category_to_dtypes = {
347347
"boolean": (xp.bool,),
348348
"integer": all_int_dtypes,
349-
"floating-point": float_dtypes,
349+
"floating-point": real_float_dtypes,
350350
"numeric": numeric_dtypes,
351351
"integer or boolean": bool_and_all_int_dtypes,
352352
}
@@ -360,7 +360,7 @@ def result_type(*dtypes: DataType):
360360
dtypes = category_to_dtypes[dtype_category]
361361
func_in_dtypes[name] = dtypes
362362
# See https://github.com/data-apis/array-api/pull/413
363-
func_in_dtypes["expm1"] = float_dtypes
363+
func_in_dtypes["expm1"] = real_float_dtypes
364364

365365

366366
func_returns_bool = {
@@ -500,7 +500,7 @@ def result_type(*dtypes: DataType):
500500
func_in_dtypes["__bool__"] = (xp.bool,)
501501
func_in_dtypes["__int__"] = all_int_dtypes
502502
func_in_dtypes["__index__"] = all_int_dtypes
503-
func_in_dtypes["__float__"] = float_dtypes
503+
func_in_dtypes["__float__"] = real_float_dtypes
504504
func_in_dtypes["from_dlpack"] = numeric_dtypes
505505
func_in_dtypes["__dlpack__"] = numeric_dtypes
506506

array_api_tests/hypothesis_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
shared_dtypes = shared(dtypes, key="dtype")
4040
shared_floating_dtypes = shared(floating_dtypes, key="dtype")
4141

42-
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes, dh.complex_dtypes]
42+
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
4343
_sorted_dtypes = [d for category in _dtype_categories for d in category]
4444

4545
def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
1616
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
1717

18-
@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes))
18+
@given(hh.mutually_promotable_dtypes(dtypes=dh.real_float_dtypes))
1919
def test_mutually_promotable_dtypes(pair):
2020
assert pair in (
2121
(xp.float32, xp.float32),

array_api_tests/pytest_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def assert_array_elements(
446446
dh.result_type(out.dtype, expected.dtype) # sanity check
447447
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
448448
f_func = f"[{func_name}({fmt_kw(kw)})]"
449-
if out.dtype in dh.float_dtypes:
449+
if out.dtype in dh.real_float_dtypes:
450450
for idx in sh.ndindex(out.shape):
451451
at_out = out[idx]
452452
at_expected = expected[idx]

array_api_tests/test_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def make_scalar_casting_param(
255255
[make_scalar_casting_param("__bool__", "bool", bool)]
256256
+ [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_names]
257257
+ [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_names]
258-
+ [make_scalar_casting_param("__float__", n, float) for n in dh.float_names],
258+
+ [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_names],
259259
)
260260
@given(data=st.data())
261261
def test_scalar_casting(method_name, dtype_name, stype, data):

array_api_tests/test_data_type_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_can_cast(_from, to, data):
124124
expected = to == xp.bool
125125
else:
126126
same_family = None
127-
for dtypes in [dh.all_int_dtypes, dh.float_dtypes, dh.complex_dtypes]:
127+
for dtypes in [dh.all_int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]:
128128
if _from in dtypes:
129129
same_family = to in dtypes
130130
break
@@ -142,7 +142,7 @@ def test_can_cast(_from, to, data):
142142
assert out == expected, f"{out=}, but should be {expected} {f_func}"
143143

144144

145-
@pytest.mark.parametrize("dtype_name", dh.float_names)
145+
@pytest.mark.parametrize("dtype_name", dh.real_float_names)
146146
def test_finfo(dtype_name):
147147
try:
148148
dtype = getattr(_xp, dtype_name)

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def test_atan(x):
776776
unary_assert_against_refimpl("atan", x, out, math.atan)
777777

778778

779-
@given(*hh.two_mutual_arrays(dh.float_dtypes))
779+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
780780
def test_atan2(x1, x2):
781781
out = xp.atan2(x1, x2)
782782
ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
@@ -1204,7 +1204,7 @@ def logaddexp(l: float, r: float) -> float:
12041204
return math.log(math.exp(l) + math.exp(r))
12051205

12061206

1207-
@given(*hh.two_mutual_arrays(dh.float_dtypes))
1207+
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
12081208
def test_logaddexp(x1, x2):
12091209
out = xp.logaddexp(x1, x2)
12101210
ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)

array_api_tests/test_special_cases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,7 @@ def test_unary(func_name, func, case, x, data):
12311231

12321232

12331233
x1_strat, x2_strat = hh.two_mutual_arrays(
1234-
dtypes=dh.float_dtypes,
1234+
dtypes=dh.real_float_dtypes,
12351235
two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1),
12361236
)
12371237

@@ -1277,7 +1277,7 @@ def test_binary(func_name, func, case, x1, x2, data):
12771277

12781278
@pytest.mark.parametrize("iop_name, iop, case", iop_params)
12791279
@given(
1280-
oneway_dtypes=hh.oneway_promotable_dtypes(dh.float_dtypes),
1280+
oneway_dtypes=hh.oneway_promotable_dtypes(dh.real_float_dtypes),
12811281
oneway_shapes=hh.oneway_broadcastable_shapes(),
12821282
data=st.data(),
12831283
)

0 commit comments

Comments
 (0)