@@ -732,6 +732,35 @@ def make_binary_cond(
732
732
* ,
733
733
input_wrapper : Optional [Callable [[float ], float ]] = None ,
734
734
) -> BinaryCond :
735
+ """
736
+ Wraps a unary condition as a binary condition, e.g.
737
+
738
+ >>> unary_cond = lambda i: i == 42
739
+
740
+ >>> binary_cond_first = make_binary_cond(BinaryCondArg.FIRST, unary_cond)
741
+ >>> binary_cond_first(42, 0)
742
+ True
743
+ >>> binary_cond_second = make_binary_cond(BinaryCondArg.SECOND, unary_cond)
744
+ >>> binary_cond_second(42, 0)
745
+ False
746
+ >>> binary_cond_second(0, 42)
747
+ True
748
+ >>> binary_cond_both = make_binary_cond(BinaryCondArg.BOTH, unary_cond)
749
+ >>> binary_cond_both(42, 0)
750
+ False
751
+ >>> binary_cond_both(42, 42)
752
+ True
753
+ >>> binary_cond_either = make_binary_cond(BinaryCondArg.EITHER, unary_cond)
754
+ >>> binary_cond_either(0, 0)
755
+ False
756
+ >>> binary_cond_either(42, 0)
757
+ True
758
+ >>> binary_cond_either(0, 42)
759
+ True
760
+ >>> binary_cond_either(42, 42)
761
+ True
762
+
763
+ """
735
764
if input_wrapper is None :
736
765
input_wrapper = noop
737
766
@@ -823,11 +852,13 @@ def parse_binary_case(case_str: str) -> BinaryCase:
823
852
if in_sign != "" or other_no == in_no :
824
853
raise ParseError (cond_str )
825
854
partial_expr = f"{ in_sign } x{ in_no } _i == { other_sign } x{ other_no } _i"
855
+
826
856
input_wrapper = lambda i : - i if other_sign == "-" else noop
857
+ # For these scenarios, we want to make sure both array elements
858
+ # generate respective to one another by using a shared strategy.
827
859
shared_from_dtype = lambda d , ** kw : st .shared (
828
860
xps .from_dtype (d , ** kw ), key = cond_str
829
861
)
830
-
831
862
if other_no == "1" :
832
863
833
864
def partial_cond (i1 : float , i2 : float ) -> bool :
0 commit comments