Skip to content

Commit 4c68483

Browse files
committed
Test promotable dtype and broadcastable shape in test_equal
1 parent a185813 commit 4c68483

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

array_api_tests/test_elementwise_functions.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
boolean_dtype_objects, floating_dtypes,
2727
numeric_dtypes, integer_or_boolean_dtypes,
2828
boolean_dtypes, mutually_promotable_dtype_pairs,
29-
array_scalars, xps)
29+
array_scalars, two_broadcastable_shapes, xps, shared_dtypes, promotable_dtypes)
3030
from .array_helpers import (assert_exactly_equal, negative,
3131
positive_mathematical_sign,
3232
negative_mathematical_sign, logical_not,
@@ -377,11 +377,13 @@ def test_divide(args):
377377

378378

379379
@given(
380-
x1=shared(
381-
xps.arrays(dtype=xps.scalar_dtypes(), shape=xps.array_shapes()), key='arrays'
380+
x1=xps.arrays(
381+
dtype=shared_dtypes,
382+
shape=shared(two_broadcastable_shapes(), key="shape_pair").map(lambda pair: pair[0])
382383
),
383-
x2=shared(
384-
xps.arrays(dtype=xps.scalar_dtypes(), shape=xps.array_shapes()), key='arrays'
384+
x2=xps.arrays(
385+
dtype=promotable_dtypes(shared_dtypes),
386+
shape=shared(two_broadcastable_shapes(), key="shape_pair").map(lambda pair: pair[1])
385387
),
386388
)
387389
def test_equal(x1, x2):

0 commit comments

Comments
 (0)