Skip to content

Commit acc565f

Browse files
committed
Refactored strategy used in test_equal
1 parent 7e28676 commit acc565f

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,13 @@ def multiaxis_indices(draw, shapes):
242242
extra = draw(lists(one_of(integer_indices(sizes), slices(sizes)), min_size=0, max_size=3))
243243
res += extra
244244
return tuple(res)
245+
246+
247+
shared_arrays1 = xps.arrays(
248+
dtype=shared_dtypes,
249+
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[0]),
250+
)
251+
shared_arrays2 = xps.arrays(
252+
dtype=promotable_dtypes(shared_dtypes),
253+
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[1]),
254+
)

array_api_tests/test_elementwise_functions.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
from hypothesis import given, assume
18-
from hypothesis.strategies import composite, just, shared
18+
from hypothesis.strategies import composite, just
1919

2020
import math
2121

@@ -26,7 +26,7 @@
2626
boolean_dtype_objects, floating_dtypes,
2727
numeric_dtypes, integer_or_boolean_dtypes,
2828
boolean_dtypes, mutually_promotable_dtype_pairs,
29-
array_scalars, two_broadcastable_shapes, xps, shared_dtypes, promotable_dtypes)
29+
array_scalars, shared_arrays1, shared_arrays2)
3030
from .array_helpers import (assert_exactly_equal, negative,
3131
positive_mathematical_sign,
3232
negative_mathematical_sign, logical_not,
@@ -375,17 +375,7 @@ def test_divide(args):
375375
# have those sorts in general for this module.
376376

377377

378-
379-
@given(
380-
x1=xps.arrays(
381-
dtype=shared_dtypes,
382-
shape=shared(two_broadcastable_shapes(), key="shape_pair").map(lambda pair: pair[0])
383-
),
384-
x2=xps.arrays(
385-
dtype=promotable_dtypes(shared_dtypes),
386-
shape=shared(two_broadcastable_shapes(), key="shape_pair").map(lambda pair: pair[1])
387-
),
388-
)
378+
@given(shared_arrays1, shared_arrays2)
389379
def test_equal(x1, x2):
390380
sanity_check(x1, x2)
391381
a = _array_module.equal(x1, x2)

0 commit comments

Comments
 (0)