Skip to content

Commit d42c9d2

Browse files
committed
Paired nd arrays for test_equal()
1 parent f766c11 commit d42c9d2

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def test_full(shape, fill_value, dtype):
130130
else:
131131
assert all(equal(a, asarray(fill_value, **kwargs))), "full() array did not equal the fill value"
132132

133-
# TODO: implement full_like (requires hypothesis arrays support)
134133
@given(
135134
a=xps.arrays(
136135
dtype=shared(xps.scalar_dtypes(), key='dtypes'),

array_api_tests/test_elementwise_functions.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
from hypothesis import given, assume
18-
from hypothesis.strategies import composite, just
18+
from hypothesis.strategies import composite, just, shared
1919

2020
import math
2121

@@ -26,7 +26,7 @@
2626
boolean_dtype_objects, floating_dtypes,
2727
numeric_dtypes, integer_or_boolean_dtypes,
2828
boolean_dtypes, mutually_promotable_dtypes,
29-
array_scalars)
29+
array_scalars, xps)
3030
from .array_helpers import (assert_exactly_equal, negative,
3131
positive_mathematical_sign,
3232
negative_mathematical_sign, logical_not,
@@ -361,9 +361,17 @@ def test_divide(args):
361361
# could test that this does implement IEEE 754 division, but we don't yet
362362
# have those sorts in general for this module.
363363

364-
@given(two_any_dtypes.flatmap(lambda i: two_array_scalars(*i)))
365-
def test_equal(args):
366-
x1, x2 = args
364+
365+
366+
@given(
367+
x1=shared(
368+
xps.arrays(dtype=xps.scalar_dtypes(), shape=xps.array_shapes()), key='arrays'
369+
),
370+
x2=shared(
371+
xps.arrays(dtype=xps.scalar_dtypes(), shape=xps.array_shapes()), key='arrays'
372+
),
373+
)
374+
def test_equal(x1, x2):
367375
sanity_check(x1, x2)
368376
a = _array_module.equal(x1, x2)
369377
# NOTE: assert_exactly_equal() itself uses equal(), so we must be careful

0 commit comments

Comments
 (0)