@@ -924,10 +924,30 @@ def test_ceil(x):
924
924
925
925
926
926
@pytest .mark .min_version ("2023.12" )
927
- @given (hh .arrays (dtype = hh .real_floating_dtypes , shape = hh .shapes ()))
928
- def test_clip (x ):
927
+ @given (x = hh .arrays (dtype = hh .real_floating_dtypes , shape = hh .shapes ()), data = st . data ( ))
928
+ def test_clip (x , data ):
929
929
# TODO: test min/max kwargs, adjust values testing accordingly
930
- out = xp .clip (x )
930
+
931
+ # Ensure that if both min and max are arrays that all three of x, min, max
932
+ # are broadcast compatible.
933
+ shape1 , shape2 = data .draw (hh .mutually_broadcastable_shapes (2 , base_shape = x .shape ))
934
+
935
+ dtypes = hh .real_floating_dtypes if dh .is_float_dtype (x .dtype ) else hh .int_dtypes
936
+
937
+ min = data .draw (st .one_of (
938
+ st .none (),
939
+ hh .scalars (dtypes = st .just (x .dtype )),
940
+ hh .arrays (dtype = dtypes , shape = shape1 ),
941
+ ))
942
+ max = data .draw (st .one_of (
943
+ st .none (),
944
+ hh .scalars (dtypes = st .just (x .dtype )),
945
+ hh .arrays (dtype = dtypes , shape = shape2 ),
946
+ ))
947
+
948
+ kw = data .draw (hh .specified_kwargs (("min" , min , None ), ("max" , max , None )))
949
+
950
+ out = xp .clip (x , ** kw )
931
951
ph .assert_dtype ("clip" , in_dtype = x .dtype , out_dtype = out .dtype )
932
952
ph .assert_shape ("clip" , out_shape = out .shape , expected = x .shape )
933
953
ph .assert_array_elements ("clip" , out = out , expected = x )
0 commit comments