Skip to content

Commit 81d088d

Browse files
committed
Make assert_array_elements more efficient in the non-error case
1 parent f82c7bc commit 81d088d

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,13 @@ def assert_array_elements(
446446
dh.result_type(out.dtype, expected.dtype) # sanity check
447447
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
448448
f_func = f"[{func_name}({fmt_kw(kw)})]"
449+
450+
match = (out == expected)
451+
if xp.all(match):
452+
return
453+
454+
# In case of mismatch, generate a more helpful error. Cycling through all indices is
455+
# costly in some array api implementations, so we only do this in the case of a failure.
449456
if out.dtype in dh.real_float_dtypes:
450457
for idx in sh.ndindex(out.shape):
451458
at_out = out[idx]
@@ -467,6 +474,4 @@ def assert_array_elements(
467474
_assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
468475
_assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
469476
else:
470-
assert xp.all(
471-
out == expected
472-
), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"
477+
assert xp.all(match), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"

0 commit comments

Comments
 (0)