Skip to content

Commit 8bc5558

Browse files
committed
Fix the bitwise_* tests to not call int() on boolean arrays
1 parent 813eda1 commit 8bc5558

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

array_api_tests/test_elementwise_functions.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,16 @@ def test_bitwise_and(args):
213213
# TODO: Generalize this properly for inputs that are arrays.
214214
if not (x1.shape == x2.shape == ()):
215215
raise RuntimeError("Error: test_bitwise_and needs to be updated for nonscalar array inputs")
216-
x = int(x1)
217-
y = int(x2)
218-
res = int(a)
216+
219217
if a.dtype == bool_dtype:
218+
x = bool(x1)
219+
y = bool(x2)
220+
res = bool(a)
220221
assert (x and y) == res
221222
else:
223+
x = int(x1)
224+
y = int(x2)
225+
res = int(a)
222226
ans = int_to_dtype(x & y, dtype_nbits(a.dtype), dtype_signed(a.dtype))
223227
assert ans == res
224228

@@ -254,11 +258,13 @@ def test_bitwise_invert(x):
254258
# TODO: Generalize this properly for inputs that are arrays.
255259
if not (x.shape == ()):
256260
raise RuntimeError("Error: test_bitwise_invert needs to be updated for nonscalar array inputs")
257-
x = int(x)
258-
res = int(a)
259261
if a.dtype == bool_dtype:
262+
x = bool(x)
263+
res = bool(a)
260264
assert (not x) == res
261265
else:
266+
x = int(x)
267+
res = int(a)
262268
ans = int_to_dtype(~x, dtype_nbits(a.dtype), dtype_signed(a.dtype))
263269
assert ans == res
264270

@@ -272,12 +278,15 @@ def test_bitwise_or(args):
272278
# TODO: Generalize this properly for inputs that are arrays.
273279
if not (x1.shape == x2.shape == ()):
274280
raise RuntimeError("Error: test_bitwise_or needs to be updated for nonscalar array inputs")
275-
x = int(x1)
276-
y = int(x2)
277-
res = int(a)
278281
if a.dtype == bool_dtype:
282+
x = bool(x1)
283+
y = bool(x2)
284+
res = bool(a)
279285
assert (x or y) == res
280286
else:
287+
x = int(x1)
288+
y = int(x2)
289+
res = int(a)
281290
ans = int_to_dtype(x | y, dtype_nbits(a.dtype), dtype_signed(a.dtype))
282291
assert ans == res
283292

@@ -310,12 +319,15 @@ def test_bitwise_xor(args):
310319
# TODO: Generalize this properly for inputs that are arrays.
311320
if not (x1.shape == x2.shape == ()):
312321
raise RuntimeError("Error: test_bitwise_xor needs to be updated for nonscalar array inputs")
313-
x = int(x1)
314-
y = int(x2)
315-
res = int(a)
316322
if a.dtype == bool_dtype:
323+
x = bool(x1)
324+
y = bool(x2)
325+
res = bool(a)
317326
assert (x ^ y) == res
318327
else:
328+
x = int(x1)
329+
y = int(x2)
330+
res = int(a)
319331
ans = int_to_dtype(x ^ y, dtype_nbits(a.dtype), dtype_signed(a.dtype))
320332
assert ans == res
321333

0 commit comments

Comments
 (0)