Skip to content

Commit ff8ea02

Browse files
committed
Code review
1 parent 05250c1 commit ff8ea02

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

array_api_strict/tests/test_array_object.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,12 @@ def _matmul_array_vals():
412412

413413

414414
@pytest.mark.parametrize("op,dtypes", binary_op_dtypes.items())
415-
def test_binary_operators_vs_numpy_generics(op, dtypes):
416-
"""Test that np.bool_, np.int64, np.float32, np.float64, np.complex64, np.complex128
417-
are disallowed in binary operators.
418-
np.float64 and np.complex128 are subclasses of float and complex, so they need
415+
def test_binary_operators_numpy_scalars(op, dtypes):
416+
"""
417+
Test that NumPy scalars (np.generic) are explicitly disallowed.
418+
419+
This must notably include np.float64 and np.complex128, which are
420+
subclasses of float and complex respectively, so they need
419421
special treatment in order to be rejected.
420422
"""
421423
match = "Expected Array or Python scalar"

array_api_strict/tests/test_elementwise_functions.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,13 @@ def test_elementwise_function_device_mismatch(func_name):
181181

182182

183183
@pytest.mark.parametrize("func_name", elementwise_function_input_types)
184-
def test_elementwise_function_vs_numpy_generics(func_name):
184+
def test_elementwise_function_numpy_scalars(func_name):
185185
"""
186-
Test that NumPy generics are explicitly disallowed.
186+
Test that NumPy scalars (np.generic) are explicitly disallowed.
187187
188-
This must notably includes np.float64 and np.complex128, which are
189-
subclasses of float and complex respectively.
188+
This must notably include np.float64 and np.complex128, which are
189+
subclasses of float and complex respectively, so they need
190+
special treatment in order to be rejected.
190191
"""
191192
func = getattr(_elementwise_functions, func_name)
192193
dtypes = elementwise_function_input_types[func_name]
@@ -203,6 +204,8 @@ def test_elementwise_function_vs_numpy_generics(func_name):
203204
_ = func(a, a)
204205
with pytest.raises(TypeError, match="neither Array nor Python scalars"):
205206
func(a, b)
207+
with pytest.raises(TypeError, match="neither Array nor Python scalars"):
208+
func(b, a)
206209
else:
207210
_ = func(a)
208211
with pytest.raises(TypeError, match="allowed"):

array_api_strict/tests/test_searching_functions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,13 @@ def test_where_device_mismatch(cond_device, x1_device, x2_device):
8181

8282

8383
@pytest.mark.parametrize("dtype", _all_dtypes)
84-
def test_where_numpy_generics(dtype):
84+
def test_where_numpy_scalars(dtype):
8585
"""
86-
Test that NumPy generics are explicitly disallowed.
86+
Test that NumPy scalars (np.generic) are explicitly disallowed.
8787
88-
This must notably includes np.float64 and np.complex128, which are
89-
subclasses of float and complex respectively.
88+
This must notably include np.float64 and np.complex128, which are
89+
subclasses of float and complex respectively, so they need
90+
special treatment in order to be rejected.
9091
"""
9192
cond = xp.asarray(True)
9293
x1 = xp.asarray(1, dtype=dtype)

0 commit comments

Comments
 (0)