1
1
import inspect
2
2
import math
3
3
import re
4
- from typing import Callable , Dict , NamedTuple , Pattern
4
+ from typing import Callable , Dict , List , NamedTuple , Pattern
5
5
from warnings import warn
6
6
7
7
import pytest
16
16
from ._array_module import mod as xp
17
17
from .stubs import category_to_funcs
18
18
19
+ # Condition factories
20
+ # ------------------------------------------------------------------------------
21
+
19
22
20
23
def make_eq (v : float ) -> Callable [[float ], bool ]:
21
24
if math .isnan (v ):
@@ -32,6 +35,15 @@ def eq(i: float) -> bool:
32
35
return eq
33
36
34
37
38
+ def make_neq (v : float ) -> Callable [[float ], bool ]:
39
+ eq = make_eq (v )
40
+
41
+ def neq (i : float ) -> bool :
42
+ return not eq (i )
43
+
44
+ return neq
45
+
46
+
35
47
def make_rough_eq (v : float ) -> Callable [[float ], bool ]:
36
48
assert math .isfinite (v ) # sanity check
37
49
@@ -66,6 +78,71 @@ def or_(i: float):
66
78
return or_
67
79
68
80
81
+ def make_and (cond1 : Callable , cond2 : Callable ) -> Callable :
82
+ def and_ (i : float ) -> bool :
83
+ return cond1 (i ) or cond2 (i )
84
+
85
+ return and_
86
+
87
+
88
+ def make_bin_and_factory (make_cond1 : Callable , make_cond2 : Callable ) -> Callable :
89
+ def make_bin_and (v1 : float , v2 : float ) -> Callable :
90
+ cond1 = make_cond1 (v1 )
91
+ cond2 = make_cond2 (v2 )
92
+
93
+ def bin_and (i1 : float , i2 : float ) -> bool :
94
+ return cond1 (i1 ) and cond2 (i2 )
95
+
96
+ return bin_and
97
+
98
+ return make_bin_and
99
+
100
+
101
+ def make_bin_or_factory (make_cond : Callable ) -> Callable :
102
+ def make_bin_or (v : float ) -> Callable :
103
+ cond = make_cond (v )
104
+
105
+ def bin_or (i1 : float , i2 : float ) -> bool :
106
+ return cond (i1 ) or cond (i2 )
107
+
108
+ return bin_or
109
+
110
+ return make_bin_or
111
+
112
+
113
+ def absify_cond_factory (make_cond ):
114
+ def make_abs_cond (v : float ):
115
+ cond = make_cond (v )
116
+
117
+ def abs_cond (i : float ) -> bool :
118
+ i = abs (i )
119
+ return cond (i )
120
+
121
+ return abs_cond
122
+
123
+ return make_abs_cond
124
+
125
+
126
+ def make_bin_multi_and_factory (
127
+ make_conds1 : List [Callable ], make_conds2 : List [Callable ]
128
+ ) -> Callable :
129
+ def make_bin_multi_and (* values : float ) -> Callable :
130
+ assert len (values ) == len (make_conds1 ) + len (make_conds2 )
131
+ conds1 = [make_cond (v ) for make_cond , v in zip (make_conds1 , values )]
132
+ conds2 = [make_cond (v ) for make_cond , v in zip (make_conds2 , values [::- 1 ])]
133
+
134
+ def bin_multi_and (i1 : float , i2 : float ) -> bool :
135
+ return all (cond (i1 ) for cond in conds1 ) and all (cond (i2 ) for cond in conds2 )
136
+
137
+ return bin_multi_and
138
+
139
+ return make_bin_multi_and
140
+
141
+
142
+ # Parse utils
143
+ # ------------------------------------------------------------------------------
144
+
145
+
69
146
repr_to_value = {
70
147
"NaN" : float ("nan" ),
71
148
"infinity" : float ("infinity" ),
@@ -183,6 +260,7 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
183
260
result = parse_result (s_result )
184
261
except ValueParseError as e :
185
262
warn (f"result not machine-readable: '{ e .value } '" )
263
+
186
264
break
187
265
condition_to_result [cond ] = result
188
266
break
@@ -193,10 +271,97 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
193
271
194
272
195
273
binary_pattern_to_condition_factory : Dict [Pattern , Callable ] = {
274
+ re .compile (
275
+ "If ``x1_i`` is (.+) and ``x2_i`` is not equal to (.+), the result is (.+)"
276
+ ): make_bin_and_factory (make_eq , lambda v : lambda i : i != v ),
277
+ re .compile (
278
+ "If ``x1_i`` is greater than (.+), ``x1_i`` is (.+), "
279
+ "and ``x2_i`` is (.+), the result is (.+)"
280
+ ): make_bin_multi_and_factory ([make_gt , make_eq ], [make_eq ]),
281
+ re .compile (
282
+ "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), "
283
+ "and ``x2_i`` is (.+), the result is (.+)"
284
+ ): make_bin_multi_and_factory ([make_lt , make_eq ], [make_eq ]),
285
+ re .compile (
286
+ "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), ``x2_i`` is (.+), "
287
+ "and ``x2_i`` is not (.+), the result is (.+)"
288
+ ): make_bin_multi_and_factory ([make_lt , make_eq ], [make_eq , make_neq ]),
289
+ re .compile (
290
+ "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), "
291
+ "and ``x2_i`` is (.+), the result is (.+)"
292
+ ): make_bin_multi_and_factory ([make_eq ], [make_lt , make_eq ]),
293
+ re .compile (
294
+ "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), "
295
+ "and ``x2_i`` is not (.+), the result is (.+)"
296
+ ): make_bin_multi_and_factory ([make_eq ], [make_lt , make_neq ]),
297
+ re .compile (
298
+ "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), "
299
+ "and ``x2_i`` is (.+), the result is (.+)"
300
+ ): make_bin_multi_and_factory ([make_eq ], [make_gt , make_eq ]),
301
+ re .compile (
302
+ "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), "
303
+ "and ``x2_i`` is not (.+), the result is (.+)"
304
+ ): make_bin_multi_and_factory ([make_eq ], [make_gt , make_neq ]),
305
+ re .compile (
306
+ "If ``x1_i`` is greater than (.+) and ``x2_i`` is (.+), the result is (.+)"
307
+ ): make_bin_and_factory (make_gt , make_eq ),
308
+ re .compile (
309
+ "If ``x1_i`` is (.+) and ``x2_i`` is greater than (.+), the result is (.+)"
310
+ ): make_bin_and_factory (make_eq , make_gt ),
311
+ re .compile (
312
+ "If ``x1_i`` is less than (.+) and ``x2_i`` is (.+), the result is (.+)"
313
+ ): make_bin_and_factory (make_lt , make_eq ),
314
+ re .compile (
315
+ "If ``x1_i`` is (.+) and ``x2_i`` is less than (.+), the result is (.+)"
316
+ ): make_bin_and_factory (make_eq , make_lt ),
317
+ re .compile (
318
+ "If ``x1_i`` is not (?:equal to )?(.+) and ``x2_i`` is (.+), the result is (.+)"
319
+ ): make_bin_and_factory (make_neq , make_eq ),
320
+ re .compile (
321
+ "If ``x1_i`` is (.+) and ``x2_i`` is not (?:equal to )?(.+), the result is (.+)"
322
+ ): make_bin_and_factory (make_eq , make_neq ),
323
+ re .compile (
324
+ r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), "
325
+ "the result is (.+)"
326
+ ): make_bin_and_factory (absify_cond_factory (make_gt ), make_eq ),
327
+ re .compile (
328
+ r"If `abs\(x1_i\)` is less than (.+) and ``x2_i`` is (.+), the result is (.+)"
329
+ ): make_bin_and_factory (absify_cond_factory (make_lt ), make_eq ),
330
+ re .compile (
331
+ r"If `abs\(x1_i\)` is (.+) and ``x2_i`` is (.+), the result is (.+)"
332
+ ): make_bin_and_factory (absify_cond_factory (make_eq ), make_eq ),
196
333
re .compile (
197
334
"If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
198
- ): lambda v1 , v2 : lambda i1 , i2 : make_eq (v1 )(i1 )
199
- and make_eq (v2 )(i2 ),
335
+ ): make_bin_and_factory (make_eq , make_eq ),
336
+ re .compile (
337
+ "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)"
338
+ ): make_bin_or_factory (make_eq ),
339
+ re .compile (
340
+ "If ``x1_i`` is either (.+) or (.+) and ``x2_i`` is (.+), the result is (.+)"
341
+ ): lambda v1 , v2 , v3 : (
342
+ lambda i1 , i2 : make_or (make_eq (v1 ), make_eq (v2 ))(i1 ) and make_eq (v3 )(i2 )
343
+ ),
344
+ re .compile (
345
+ "If ``x1_i`` is (.+) and ``x2_i`` is either (.+) or (.+), the result is (.+)"
346
+ ): lambda v1 , v2 , v3 : (
347
+ lambda i1 , i2 : make_eq (v1 )(i1 ) and make_or (make_eq (v2 ), make_eq (v3 ))(i2 )
348
+ ),
349
+ re .compile (
350
+ "If ``x1_i`` is either (.+) or (.+) and "
351
+ "``x2_i`` is either (.+) or (.+), the result is (.+)"
352
+ ): lambda v1 , v2 , v3 , v4 : (
353
+ lambda i1 , i2 : (
354
+ make_or (make_eq (v1 ), make_eq (v2 ))(i1 )
355
+ and make_or (make_eq (v3 ), make_eq (v4 ))(i2 )
356
+ )
357
+ ),
358
+ # re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+)")
359
+ # re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is (.+), the "sign" of (.+) is implementation-defined")
360
+ # re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign and are both (.+), the result has a (.+)")
361
+ # re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+)")
362
+ # re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the "sign" of (.+) is implementation-defined")
363
+ # re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs and are both (.+), the result has a (.+)")
364
+ # re.compile("If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+")
200
365
}
201
366
202
367
@@ -221,12 +386,6 @@ def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]:
221
386
warn (f"value not machine-readable: '{ e .value } '" )
222
387
break
223
388
cond = make_cond (* values )
224
- if (
225
- "atan2" in docstring
226
- and ph .is_pos_zero (values [0 ])
227
- and ph .is_neg_zero (values [1 ])
228
- ):
229
- breakpoint ()
230
389
try :
231
390
result = parse_result (s_result )
232
391
except ValueParseError as e :
@@ -240,6 +399,10 @@ def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]:
240
399
return condition_to_result
241
400
242
401
402
+ # Here be the tests
403
+ # ------------------------------------------------------------------------------
404
+
405
+
243
406
unary_params = []
244
407
binary_params = []
245
408
for stub in category_to_funcs ["elementwise" ]:
0 commit comments