@@ -305,7 +305,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
305
305
306
306
307
307
r_not_code = re .compile (f"not (?:equal to )?{ r_code .pattern } " )
308
- r_array_element = re .compile (r"``([+-]?)x[12]_i``" )
308
+ r_array_element = re .compile (r"``([+-]?)x( [12]) _i``" )
309
309
r_gt = re .compile (f"greater than { r_code .pattern } " )
310
310
r_lt = re .compile (f"less than { r_code .pattern } " )
311
311
r_either_code = re .compile (f"either { r_code .pattern } or { r_code .pattern } " )
@@ -333,8 +333,8 @@ def cond(i1: float, i2: float) -> bool:
333
333
return _cond (i2 )
334
334
335
335
return cond
336
- # this branch must come after checking for array elements
337
- elif m := r_code .match (group ):
336
+
337
+ if m := r_code .match (group ):
338
338
value = parse_value (m .group (1 ))
339
339
_cond = make_eq (value )
340
340
elif m := r_not_code .match (group ):
@@ -398,39 +398,75 @@ def cond(i1: float, i2: float) -> bool:
398
398
399
399
return cond
400
400
401
+ def __repr__ (self ) -> str :
402
+ f_cond_factories = ", " .join (
403
+ repr (cond_factory ) for cond_factory in self .cond_factories
404
+ )
405
+ return f"{ self .__class__ .__name__ } ({ f_cond_factories } )"
401
406
402
- class BinaryCase (NamedTuple ):
403
- cond : BinaryCheck
404
- check_result : Callable [[float ], bool ]
405
407
408
+ BinaryResultCheck = Callable [[float , float , float ], bool ]
406
409
407
- class BinaryCaseFactory (NamedTuple ):
408
- cond_factory : CondFactory
409
- result_re_group : int
410
410
411
- def __call__ (self , groups : Tuple [str , ...]) -> BinaryCase :
412
- in_cond = self .cond_factory (groups )
411
+ class ResultCheckFactory (NamedTuple ):
412
+ re_group : int
413
+
414
+ def __call__ (self , groups : Tuple [str , ...]) -> BinaryResultCheck :
415
+ group = groups [self .re_group ]
416
+
417
+ if m := r_array_element .match (group ):
418
+ cond_factory = make_eq if m .group (1 ) != "-" else make_neq
419
+
420
+ if m .group (2 ) == "1" :
421
+
422
+ def cond (i1 : float , i2 : float , result : float ) -> bool :
423
+ _cond = cond_factory (i1 )
424
+ return _cond (result )
425
+
426
+ else :
413
427
414
- s_result = groups [self .result_re_group ]
415
- if m := r_array_element .match (s_result ):
416
- raise ValueParseError (s_result ) # TODO
417
- elif m := r_code .match (s_result ):
428
+ def cond (i1 : float , i2 : float , result : float ) -> bool :
429
+ _cond = cond_factory (i2 )
430
+ return _cond (result )
431
+
432
+ return cond
433
+
434
+ if m := r_code .match (group ):
418
435
value = parse_value (m .group (1 ))
419
- out_cond = make_eq (value )
420
- elif m := r_approx_value .match (s_result ):
436
+ _cond = make_eq (value )
437
+ elif m := r_approx_value .match (group ):
421
438
value = parse_value (m .group (1 ))
422
- out_cond = make_rough_eq (value )
439
+ _cond = make_rough_eq (value )
423
440
else :
424
- raise ValueParseError (s_result )
441
+ raise ValueParseError (group )
442
+
443
+ def cond (i1 : float , i2 : float , result : float ) -> bool :
444
+ return _cond (result )
445
+
446
+ return cond
425
447
426
- return BinaryCase (in_cond , out_cond )
448
+
449
+ class BinaryCase (NamedTuple ):
450
+ cond : BinaryCheck
451
+ check_result : BinaryResultCheck
452
+
453
+
454
+ class BinaryCaseFactory (NamedTuple ):
455
+ cond_factory : CondFactory
456
+ check_result_factory : ResultCheckFactory
457
+
458
+ def __call__ (self , groups : Tuple [str , ...]) -> BinaryCase :
459
+ cond = self .cond_factory (groups )
460
+ check_result = self .check_result_factory (groups )
461
+ return BinaryCase (cond , check_result )
427
462
428
463
429
464
binary_pattern_to_case_factory : Dict [Pattern , BinaryCaseFactory ] = {
430
465
re .compile (
431
466
"If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
432
467
): BinaryCaseFactory (
433
- AndCondFactory (ValueCondFactory ("i1" , 0 ), ValueCondFactory ("i2" , 1 )), 2
468
+ AndCondFactory (ValueCondFactory ("i1" , 0 ), ValueCondFactory ("i2" , 1 )),
469
+ ResultCheckFactory (2 ),
434
470
),
435
471
# re.compile(
436
472
# "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+"
@@ -671,8 +707,8 @@ def test_binary(func_name, func, cases, x1, x2):
671
707
for case in cases :
672
708
if case .cond (l , r ):
673
709
good_example = True
674
- out = float (res [o_idx ])
675
- assert case .check_result (out )
710
+ o = float (res [o_idx ])
711
+ assert case .check_result (l , r , o )
676
712
# f_left = f"{sh.fmt_idx('x1', l_idx)}={l}"
677
713
# f_right = f"{sh.fmt_idx('x2', r_idx)}={r}"
678
714
# f_out = f"{sh.fmt_idx('out', o_idx)}={out}"
0 commit comments