@@ -299,8 +299,19 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
299
299
return cases
300
300
301
301
302
- class CondFactory (Protocol ):
303
- def __call__ (self , groups : Tuple [str , ...]) -> BinaryCheck :
302
+ class BinaryCond (NamedTuple ):
303
+ cond : BinaryCheck
304
+ repr_ : str
305
+
306
+ def __call__ (self , i1 : float , i2 : float ) -> bool :
307
+ return self .cond (i1 , i2 )
308
+
309
+ def __repr__ (self ):
310
+ return self .repr_
311
+
312
+
313
+ class BinaryCondFactory (Protocol ):
314
+ def __call__ (self , groups : Tuple [str , ...]) -> BinaryCond :
304
315
...
305
316
306
317
@@ -310,31 +321,43 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
310
321
r_gt = re .compile (f"greater than { r_code .pattern } " )
311
322
r_lt = re .compile (f"less than { r_code .pattern } " )
312
323
324
+ x1_i = "x1ᵢ"
325
+ x2_i = "x2ᵢ"
326
+
313
327
314
328
@dataclass
315
- class ValueCondFactory (CondFactory ):
329
+ class ValueCondFactory (BinaryCondFactory ):
316
330
input_ : Union [Literal ["i1" ], Literal ["i2" ], Literal ["either" ], Literal ["both" ]]
317
331
re_groups_i : int
332
+ abs_ : bool = False
318
333
319
- def __call__ (self , groups : Tuple [str , ...]) -> BinaryCheck :
334
+ def __call__ (self , groups : Tuple [str , ...]) -> BinaryCond :
320
335
group = groups [self .re_groups_i ]
321
336
322
337
if m := r_array_element .match (group ):
323
- cond_factory = make_eq if m .group (1 ) != "-" else make_neq
338
+ assert not self .abs_ # sanity check
339
+ sign = m .group (1 )
340
+ if sign == "-" :
341
+ signer = lambda i : - i
342
+ else :
343
+ signer = lambda i : i
344
+
324
345
if self .input_ == "i1" :
346
+ repr_ = f"{ x1_i } == { sign } { x2_i } "
325
347
326
348
def cond (i1 : float , i2 : float ) -> bool :
327
- _cond = cond_factory ( i2 )
349
+ _cond = make_eq ( signer ( i2 ) )
328
350
return _cond (i1 )
329
351
330
352
else :
331
353
assert self .input_ == "i2" # sanity check
354
+ repr_ = f"{ x2_i } == { sign } { x1_i } "
332
355
333
356
def cond (i1 : float , i2 : float ) -> bool :
334
- _cond = cond_factory ( i1 )
357
+ _cond = make_eq ( signer ( i1 ) )
335
358
return _cond (i2 )
336
359
337
- return cond
360
+ return BinaryCond ( cond , repr_ )
338
361
339
362
if m := r_not .match (group ):
340
363
group = m .group (1 )
@@ -345,100 +368,105 @@ def cond(i1: float, i2: float) -> bool:
345
368
if m := r_code .match (group ):
346
369
value = parse_value (m .group (1 ))
347
370
_cond = make_eq (value )
371
+ repr_template = "{} == " + str (value )
348
372
elif m := r_gt .match (group ):
349
373
value = parse_value (m .group (1 ))
350
374
_cond = make_gt (value )
375
+ repr_template = "{} > " + str (value )
351
376
elif m := r_lt .match (group ):
352
377
value = parse_value (m .group (1 ))
353
378
_cond = make_lt (value )
379
+ repr_template = "{} < " + str (value )
354
380
elif m := r_either_code .match (group ):
355
381
v1 = parse_value (m .group (1 ))
356
382
v2 = parse_value (m .group (2 ))
357
383
_cond = make_or (make_eq (v1 ), make_eq (v2 ))
384
+ repr_template = "{} == " + str (v1 ) + " or {} == " + str (v2 )
358
385
elif group in ["finite" , "a finite number" ]:
359
386
_cond = math .isfinite
387
+ repr_template = "isfinite({})"
360
388
elif group in "a positive (i.e., greater than ``0``) finite number" :
361
389
_cond = lambda i : math .isfinite (i ) and i > 0
390
+ repr_template = "isfinite({}) and {} > 0"
362
391
elif group == "a negative (i.e., less than ``0``) finite number" :
363
392
_cond = lambda i : math .isfinite (i ) and i < 0
393
+ repr_template = "isfinite({}) and {} < 0"
364
394
elif group == "positive" :
365
395
_cond = lambda i : math .copysign (1 , i ) == 1
396
+ repr_template = "copysign(1, {}) == 1"
366
397
elif group == "negative" :
367
398
_cond = lambda i : math .copysign (1 , i ) == - 1
399
+ repr_template = "copysign(1, {}) == -1"
368
400
elif "nonzero finite" in group :
369
401
_cond = lambda i : math .isfinite (i ) and i != 0
402
+ repr_template = "copysign(1, {}) == -1"
370
403
elif group == "an integer value" :
371
404
_cond = lambda i : i .is_integer ()
405
+ repr_template = "{}.is_integer()"
372
406
elif group == "an odd integer value" :
373
407
_cond = lambda i : i .is_integer () and i % 2 == 1
408
+ repr_template = "{}.is_integer() and {} % 2 == 1"
374
409
else :
375
- print (f"{ group = } " )
376
410
raise ValueParseError (group )
377
411
378
412
if notify :
379
413
final_cond = lambda i : not _cond (i )
380
414
else :
381
415
final_cond = _cond
382
416
417
+ f_i1 = x1_i
418
+ f_i2 = x2_i
419
+ if self .abs_ :
420
+ f_i1 = f"abs{ f_i1 } "
421
+ f_i2 = f"abs{ f_i2 } "
422
+
383
423
if self .input_ == "i1" :
424
+ repr_ = repr_template .replace ("{}" , f_i1 )
384
425
385
426
def cond (i1 : float , i2 : float ) -> bool :
386
427
return final_cond (i1 )
387
428
388
429
elif self .input_ == "i2" :
430
+ repr_ = repr_template .replace ("{}" , f_i2 )
389
431
390
432
def cond (i1 : float , i2 : float ) -> bool :
391
433
return final_cond (i2 )
392
434
393
435
elif self .input_ == "either" :
436
+ repr_ = f"({ repr_template .replace ('{}' , f_i1 )} ) or ({ repr_template .replace ('{}' , f_i2 )} )"
394
437
395
438
def cond (i1 : float , i2 : float ) -> bool :
396
439
return final_cond (i1 ) or final_cond (i2 )
397
440
398
441
else :
442
+ assert self .input_ == "both" # sanity check
443
+ repr_ = f"({ repr_template .replace ('{}' , f_i1 )} ) and ({ repr_template .replace ('{}' , f_i2 )} )"
399
444
400
445
def cond (i1 : float , i2 : float ) -> bool :
401
446
return final_cond (i1 ) and final_cond (i2 )
402
447
403
- return cond
404
-
405
-
406
- @dataclass
407
- class AbsCondFactory (CondFactory ):
408
- cond_factory : CondFactory
409
-
410
- def __call__ (self , groups : Tuple [str , ...]) -> BinaryCheck :
411
- _cond = self .cond_factory (groups )
412
-
413
- def cond (i1 : float , i2 : float ) -> bool :
414
- i1 = abs (i1 )
415
- i2 = abs (i2 )
416
- return _cond (i1 , i2 )
448
+ if notify :
449
+ repr_ = f"not ({ repr_ } )"
417
450
418
- return cond
451
+ return BinaryCond ( cond , repr_ )
419
452
420
453
421
- class AndCondFactory (CondFactory ):
422
- def __init__ (self , * cond_factories : CondFactory ):
454
+ class AndCondFactory (BinaryCondFactory ):
455
+ def __init__ (self , * cond_factories : BinaryCondFactory ):
423
456
self .cond_factories = cond_factories
424
457
425
- def __call__ (self , groups : Tuple [str , ...]) -> BinaryCheck :
458
+ def __call__ (self , groups : Tuple [str , ...]) -> BinaryCond :
426
459
conds = [cond_factory (groups ) for cond_factory in self .cond_factories ]
460
+ repr_ = " and " .join (f"({ cond !r} )" for cond in conds )
427
461
428
462
def cond (i1 : float , i2 : float ) -> bool :
429
463
return all (cond (i1 , i2 ) for cond in conds )
430
464
431
- return cond
432
-
433
- def __repr__ (self ) -> str :
434
- f_cond_factories = ", " .join (
435
- repr (cond_factory ) for cond_factory in self .cond_factories
436
- )
437
- return f"{ self .__class__ .__name__ } ({ f_cond_factories } )"
465
+ return BinaryCond (cond , repr_ )
438
466
439
467
440
468
@dataclass
441
- class SignCondFactory (CondFactory ):
469
+ class SignCondFactory (BinaryCondFactory ):
442
470
re_groups_i : int
443
471
444
472
def __call__ (self , groups : Tuple [str , ...]) -> BinaryCheck :
@@ -451,45 +479,67 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
451
479
raise ValueParseError (group )
452
480
453
481
454
- BinaryResultCheck = Callable [[float , float , float ], bool ]
482
+ class BinaryResultCheck (NamedTuple ):
483
+ check_result : Callable [[float , float , float ], bool ]
484
+ repr_ : str
485
+
486
+ def __call__ (self , i1 : float , i2 : float , result : float ) -> bool :
487
+ return self .check_result (i1 , i2 , result )
488
+
489
+ def __repr__ (self ):
490
+ return self .repr_
491
+
492
+
493
+ class BinaryResultCheckFactory (Protocol ):
494
+ def __call__ (self , groups : Tuple [str , ...]) -> BinaryCond :
495
+ ...
455
496
456
497
457
- class ResultCheckFactory (NamedTuple ):
498
+ @dataclass
499
+ class ResultCheckFactory (BinaryResultCheckFactory ):
458
500
re_groups_i : int
459
501
460
502
def __call__ (self , groups : Tuple [str , ...]) -> BinaryResultCheck :
461
503
group = groups [self .re_groups_i ]
462
504
463
505
if m := r_array_element .match (group ):
464
- cond_factory = make_eq if m .group (1 ) != "-" else make_neq
506
+ sign , input_ = m .groups ()
507
+ if sign == "-" :
508
+ signer = lambda i : - i
509
+ else :
510
+ signer = lambda i : i
465
511
466
- if m .group (2 ) == "1" :
512
+ if input_ == "1" :
513
+ repr_ = f"{ sign } { x1_i } "
467
514
468
- def cond (i1 : float , i2 : float , result : float ) -> bool :
469
- _cond = cond_factory ( i1 )
470
- return _cond (result )
515
+ def check_result (i1 : float , i2 : float , result : float ) -> bool :
516
+ _check_result = make_eq ( signer ( i1 ) )
517
+ return _check_result (result )
471
518
472
519
else :
520
+ repr_ = f"{ sign } { x2_i } "
473
521
474
- def cond (i1 : float , i2 : float , result : float ) -> bool :
475
- _cond = cond_factory ( i2 )
476
- return _cond (result )
522
+ def check_result (i1 : float , i2 : float , result : float ) -> bool :
523
+ _check_result = make_eq ( signer ( i2 ) )
524
+ return _check_result (result )
477
525
478
- return cond
526
+ return BinaryResultCheck ( check_result , repr_ )
479
527
480
528
if m := r_code .match (group ):
481
529
value = parse_value (m .group (1 ))
482
- _cond = make_eq (value )
530
+ _check_result = make_eq (value )
531
+ repr_ = str (value )
483
532
elif m := r_approx_value .match (group ):
484
533
value = parse_value (m .group (1 ))
485
- _cond = make_rough_eq (value )
534
+ _check_result = make_rough_eq (value )
535
+ repr_ = f"~{ value } "
486
536
else :
487
537
raise ValueParseError (group )
488
538
489
- def cond (i1 : float , i2 : float , result : float ) -> bool :
490
- return _cond (result )
539
+ def check_result (i1 : float , i2 : float , result : float ) -> bool :
540
+ return _check_result (result )
491
541
492
- return cond
542
+ return BinaryResultCheck ( check_result , repr_ )
493
543
494
544
495
545
class ResultSignCheckFactory (ResultCheckFactory ):
@@ -516,12 +566,15 @@ def cond(i1: float, i2: float, result: float) -> bool:
516
566
517
567
518
568
class BinaryCase (NamedTuple ):
519
- cond : BinaryCheck
569
+ cond : BinaryCond
520
570
check_result : BinaryResultCheck
521
571
572
+ def __repr__ (self ):
573
+ return f"BinaryCase(<{ self .cond } -> { self .check_result } >)"
574
+
522
575
523
576
class BinaryCaseFactory (NamedTuple ):
524
- cond_factory : CondFactory
577
+ cond_factory : BinaryCondFactory
525
578
check_result_factory : ResultCheckFactory
526
579
527
580
def __call__ (self , groups : Tuple [str , ...]) -> BinaryCase :
@@ -564,9 +617,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
564
617
re .compile (
565
618
r"If ``abs\(x1_i\)`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
566
619
): BinaryCaseFactory (
567
- AndCondFactory (
568
- AbsCondFactory (ValueCondFactory ("i1" , 0 )), ValueCondFactory ("i2" , 1 )
569
- ),
620
+ AndCondFactory (ValueCondFactory ("i1" , 0 , abs_ = True ), ValueCondFactory ("i2" , 1 )),
570
621
ResultCheckFactory (2 ),
571
622
),
572
623
re .compile (
@@ -726,25 +777,11 @@ def test_binary(func_name, func, cases, x1, x2):
726
777
if case .cond (l , r ):
727
778
good_example = True
728
779
o = float (res [o_idx ])
729
- assert case .check_result (l , r , o )
730
- # f_left = f"{sh.fmt_idx('x1', l_idx)}={l}"
731
- # f_right = f"{sh.fmt_idx('x2', r_idx)}={r}"
732
- # f_out = f"{sh.fmt_idx('out', o_idx)}={out}"
733
- # if result.strict_check:
734
- # msg = (
735
- # f"{f_out}, but should be {result.repr_} [{func_name}()]\n"
736
- # f"{f_left}, {f_right}"
737
- # )
738
- # if math.isnan(result.value):
739
- # assert math.isnan(out), msg
740
- # else:
741
- # assert out == result.value, msg
742
- # else:
743
- # assert math.isfinite(result.value) # sanity check
744
- # assert math.isclose(out, result.value, abs_tol=0.1), (
745
- # f"{f_out}, but should be roughly {result.repr_}={result.value} "
746
- # f"[{func_name}()]\n"
747
- # f"{f_left}, {f_right}"
748
- # )
780
+ f_left = f"{ sh .fmt_idx ('x1' , l_idx )} ={ l } "
781
+ f_right = f"{ sh .fmt_idx ('x2' , r_idx )} ={ r } "
782
+ f_out = f"{ sh .fmt_idx ('out' , o_idx )} ={ o } "
783
+ assert case .check_result (l , r , o ), (
784
+ f"{ f_out } not good [{ func_name } ()]\n " f"{ f_left } , { f_right } "
785
+ )
749
786
break
750
787
assume (good_example )
0 commit comments