Skip to content

Commit bf3b773

Browse files
committed
Generate keyword arguments in test_clip()
1 parent aafb6a1 commit bf3b773

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -924,10 +924,30 @@ def test_ceil(x):
924924

925925

926926
@pytest.mark.min_version("2023.12")
927-
@given(hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes()))
928-
def test_clip(x):
927+
@given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes()), data=st.data())
928+
def test_clip(x, data):
929929
# TODO: test min/max kwargs, adjust values testing accordingly
930-
out = xp.clip(x)
930+
931+
# Ensure that if both min and max are arrays that all three of x, min, max
932+
# are broadcast compatible.
933+
shape1, shape2 = data.draw(hh.mutually_broadcastable_shapes(2, base_shape=x.shape))
934+
935+
dtypes = hh.real_floating_dtypes if dh.is_float_dtype(x.dtype) else hh.int_dtypes
936+
937+
min = data.draw(st.one_of(
938+
st.none(),
939+
hh.scalars(dtypes=st.just(x.dtype)),
940+
hh.arrays(dtype=dtypes, shape=shape1),
941+
))
942+
max = data.draw(st.one_of(
943+
st.none(),
944+
hh.scalars(dtypes=st.just(x.dtype)),
945+
hh.arrays(dtype=dtypes, shape=shape2),
946+
))
947+
948+
kw = data.draw(hh.specified_kwargs(("min", min, None), ("max", max, None)))
949+
950+
out = xp.clip(x, **kw)
931951
ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)
932952
ph.assert_shape("clip", out_shape=out.shape, expected=x.shape)
933953
ph.assert_array_elements("clip", out=out, expected=x)

0 commit comments

Comments
 (0)