@@ -222,27 +222,6 @@ def parse_inline_code(inline_code: str) -> float:
222
222
raise ValueParseError (inline_code )
223
223
224
224
225
- class Result (NamedTuple ):
226
- value : float
227
- repr_ : str
228
- strict_check : bool
229
-
230
-
231
- def parse_result (s_result : str ) -> Result :
232
- match = None
233
- if m := r_code .match (s_result ):
234
- match = m
235
- strict_check = True
236
- elif m := r_approx_value .match (s_result ):
237
- match = m
238
- strict_check = False
239
- else :
240
- raise ValueParseError (s_result )
241
- value = parse_value (match .group (1 ))
242
- repr_ = match .group (1 )
243
- return Result (value , repr_ , strict_check )
244
-
245
-
246
225
r_special_cases = re .compile (
247
226
r"\*\*Special [Cc]ases\*\*\n+\s*"
248
227
r"For floating-point operands,\n+"
@@ -252,65 +231,74 @@ def parse_result(s_result: str) -> Result:
252
231
r_remaining_case = re .compile ("In the remaining cases.+" )
253
232
254
233
255
- unary_pattern_to_condition_factory : Dict [Pattern , Callable ] = {
256
- re .compile ("If ``x_i`` is greater than (.+), the result is (.+)" ): make_gt ,
257
- re .compile ("If ``x_i`` is less than (.+), the result is (.+)" ): make_lt ,
258
- re .compile ("If ``x_i`` is either (.+) or (.+), the result is (.+)" ): (
259
- lambda v1 , v2 : make_or (make_eq (v1 ), make_eq (v2 ))
260
- ),
261
- # This pattern must come after the previous patterns to avoid unwanted matches
262
- re .compile ("If ``x_i`` is (.+), the result is (.+)" ): make_eq ,
263
- re .compile (
264
- "If two integers are equally close to ``x_i``, the result is (.+)"
265
- ): lambda : (lambda i : (abs (i ) - math .floor (abs (i ))) == 0.5 ),
266
- }
234
+ # unary_pattern_to_condition_factory: Dict[Pattern, Callable] = {
235
+ # re.compile("If ``x_i`` is greater than (.+), the result is (.+)"): make_gt,
236
+ # re.compile("If ``x_i`` is less than (.+), the result is (.+)"): make_lt,
237
+ # re.compile("If ``x_i`` is either (.+) or (.+), the result is (.+)"): (
238
+ # lambda v1, v2: make_or(make_eq(v1), make_eq(v2))
239
+ # ),
240
+ # # This pattern must come after the previous patterns to avoid unwanted matches
241
+ # re.compile("If ``x_i`` is (.+), the result is (.+)"): make_eq,
242
+ # re.compile(
243
+ # "If two integers are equally close to ``x_i``, the result is (.+)"
244
+ # ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5),
245
+ # }
246
+
247
+
248
+ # def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
249
+ # match = r_special_cases.search(docstring)
250
+ # if match is None:
251
+ # return {}
252
+ # cases = match.group(1).split("\n")[:-1]
253
+ # cases = {}
254
+ # for line in cases:
255
+ # if m := r_case.match(line):
256
+ # case = m.group(1)
257
+ # else:
258
+ # warn(f"line not machine-readable: '{line}'")
259
+ # continue
260
+ # for pattern, make_cond in unary_pattern_to_condition_factory.items():
261
+ # if m := pattern.search(case):
262
+ # *s_values, s_result = m.groups()
263
+ # try:
264
+ # values = [parse_inline_code(v) for v in s_values]
265
+ # except ValueParseError as e:
266
+ # warn(f"value not machine-readable: '{e.value}'")
267
+ # break
268
+ # cond = make_cond(*values)
269
+ # try:
270
+ # result = parse_result(s_result)
271
+ # except ValueParseError as e:
272
+ # warn(f"result not machine-readable: '{e.value}'")
273
+
274
+ # break
275
+ # cases[cond] = result
276
+ # break
277
+ # else:
278
+ # if not r_remaining_case.search(case):
279
+ # warn(f"case not machine-readable: '{case}'")
280
+ # return cases
281
+
282
+ x_i = "xᵢ"
283
+ x1_i = "x1ᵢ"
284
+ x2_i = "x2ᵢ"
267
285
268
286
269
- def parse_unary_docstring (docstring : str ) -> Dict [Callable , Result ]:
270
- match = r_special_cases .search (docstring )
271
- if match is None :
272
- return {}
273
- cases = match .group (1 ).split ("\n " )[:- 1 ]
274
- cases = {}
275
- for line in cases :
276
- if m := r_case .match (line ):
277
- case = m .group (1 )
278
- else :
279
- warn (f"line not machine-readable: '{ line } '" )
280
- continue
281
- for pattern , make_cond in unary_pattern_to_condition_factory .items ():
282
- if m := pattern .search (case ):
283
- * s_values , s_result = m .groups ()
284
- try :
285
- values = [parse_inline_code (v ) for v in s_values ]
286
- except ValueParseError as e :
287
- warn (f"value not machine-readable: '{ e .value } '" )
288
- break
289
- cond = make_cond (* values )
290
- try :
291
- result = parse_result (s_result )
292
- except ValueParseError as e :
293
- warn (f"result not machine-readable: '{ e .value } '" )
287
+ class Cond (Protocol ):
288
+ expr : str
294
289
295
- break
296
- cases [cond ] = result
297
- break
298
- else :
299
- if not r_remaining_case .search (case ):
300
- warn (f"case not machine-readable: '{ case } '" )
301
- return cases
290
+ def __call__ (self , * args ) -> bool :
291
+ ...
302
292
303
293
304
- class BinaryCond (NamedTuple ):
294
+ @dataclass
295
+ class BinaryCond (Cond ):
305
296
cond : BinaryCheck
306
- repr_ : str
297
+ expr : str
307
298
308
299
def __call__ (self , i1 : float , i2 : float ) -> bool :
309
300
return self .cond (i1 , i2 )
310
301
311
- def __repr__ (self ):
312
- return self .repr_
313
-
314
302
315
303
class BinaryCondFactory (Protocol ):
316
304
def __call__ (self , groups : Tuple [str , ...]) -> BinaryCond :
@@ -323,9 +311,6 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
323
311
r_gt = re .compile (f"greater than { r_code .pattern } " )
324
312
r_lt = re .compile (f"less than { r_code .pattern } " )
325
313
326
- x1_i = "x1ᵢ"
327
- x2_i = "x2ᵢ"
328
-
329
314
330
315
@dataclass
331
316
class ValueCondFactory (BinaryCondFactory ):
@@ -345,21 +330,21 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
345
330
signer = lambda i : i
346
331
347
332
if self .input_ == "i1" :
348
- repr_ = f"{ x1_i } == { sign } { x2_i } "
333
+ expr = f"{ x1_i } == { sign } { x2_i } "
349
334
350
335
def cond (i1 : float , i2 : float ) -> bool :
351
336
_cond = make_eq (signer (i2 ))
352
337
return _cond (i1 )
353
338
354
339
else :
355
340
assert self .input_ == "i2" # sanity check
356
- repr_ = f"{ x2_i } == { sign } { x1_i } "
341
+ expr = f"{ x2_i } == { sign } { x1_i } "
357
342
358
343
def cond (i1 : float , i2 : float ) -> bool :
359
344
_cond = make_eq (signer (i1 ))
360
345
return _cond (i2 )
361
346
362
- return BinaryCond (cond , repr_ )
347
+ return BinaryCond (cond , expr )
363
348
364
349
if m := r_not .match (group ):
365
350
group = m .group (1 )
@@ -419,38 +404,38 @@ def cond(i1: float, i2: float) -> bool:
419
404
f_i1 = x1_i
420
405
f_i2 = x2_i
421
406
if self .abs_ :
422
- f_i1 = f"abs{ f_i1 } "
423
- f_i2 = f"abs{ f_i2 } "
407
+ f_i1 = f"abs( { f_i1 } ) "
408
+ f_i2 = f"abs( { f_i2 } ) "
424
409
425
410
if self .input_ == "i1" :
426
- repr_ = repr_template .replace ("{}" , f_i1 )
411
+ expr = repr_template .replace ("{}" , f_i1 )
427
412
428
413
def cond (i1 : float , i2 : float ) -> bool :
429
414
return final_cond (i1 )
430
415
431
416
elif self .input_ == "i2" :
432
- repr_ = repr_template .replace ("{}" , f_i2 )
417
+ expr = repr_template .replace ("{}" , f_i2 )
433
418
434
419
def cond (i1 : float , i2 : float ) -> bool :
435
420
return final_cond (i2 )
436
421
437
422
elif self .input_ == "either" :
438
- repr_ = f"({ repr_template .replace ('{}' , f_i1 )} ) or ({ repr_template .replace ('{}' , f_i2 )} )"
423
+ expr = f"({ repr_template .replace ('{}' , f_i1 )} ) or ({ repr_template .replace ('{}' , f_i2 )} )"
439
424
440
425
def cond (i1 : float , i2 : float ) -> bool :
441
426
return final_cond (i1 ) or final_cond (i2 )
442
427
443
428
else :
444
429
assert self .input_ == "both" # sanity check
445
- repr_ = f"({ repr_template .replace ('{}' , f_i1 )} ) and ({ repr_template .replace ('{}' , f_i2 )} )"
430
+ expr = f"({ repr_template .replace ('{}' , f_i1 )} ) and ({ repr_template .replace ('{}' , f_i2 )} )"
446
431
447
432
def cond (i1 : float , i2 : float ) -> bool :
448
433
return final_cond (i1 ) and final_cond (i2 )
449
434
450
435
if notify :
451
- repr_ = f"not ({ repr_ } )"
436
+ expr = f"not ({ expr } )"
452
437
453
- return BinaryCond (cond , repr_ )
438
+ return BinaryCond (cond , expr )
454
439
455
440
456
441
class AndCondFactory (BinaryCondFactory ):
@@ -459,37 +444,40 @@ def __init__(self, *cond_factories: BinaryCondFactory):
459
444
460
445
def __call__ (self , groups : Tuple [str , ...]) -> BinaryCond :
461
446
conds = [cond_factory (groups ) for cond_factory in self .cond_factories ]
462
- repr_ = " and " .join (f"({ cond !r } )" for cond in conds )
447
+ expr = " and " .join (f"({ cond . expr } )" for cond in conds )
463
448
464
449
def cond (i1 : float , i2 : float ) -> bool :
465
450
return all (cond (i1 , i2 ) for cond in conds )
466
451
467
- return BinaryCond (cond , repr_ )
452
+ return BinaryCond (cond , expr )
468
453
469
454
470
455
@dataclass
471
456
class SignCondFactory (BinaryCondFactory ):
472
457
re_groups_i : int
473
458
474
- def __call__ (self , groups : Tuple [str , ...]) -> BinaryCheck :
459
+ def __call__ (self , groups : Tuple [str , ...]) -> BinaryCond :
475
460
group = groups [self .re_groups_i ]
476
461
if group == "the same mathematical sign" :
477
- return same_sign
462
+ cond = same_sign
463
+ expr = f"copysign(1, { x1_i } ) == copysign(1, { x2_i } )"
478
464
elif group == "different mathematical signs" :
479
- return diff_sign
465
+ cond = diff_sign
466
+ expr = f"copysign(1, { x1_i } ) != copysign(1, { x2_i } )"
480
467
else :
481
468
raise ValueParseError (group )
469
+ return BinaryCond (cond , expr )
482
470
483
471
484
472
class BinaryResultCheck (NamedTuple ):
485
473
check_result : Callable [[float , float , float ], bool ]
486
- repr_ : str
474
+ expr : str
487
475
488
476
def __call__ (self , i1 : float , i2 : float , result : float ) -> bool :
489
477
return self .check_result (i1 , i2 , result )
490
478
491
479
def __repr__ (self ):
492
- return self .repr_
480
+ return self .expr
493
481
494
482
495
483
class BinaryResultCheckFactory (Protocol ):
@@ -512,36 +500,36 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck:
512
500
signer = lambda i : i
513
501
514
502
if input_ == "1" :
515
- repr_ = f"{ sign } { x1_i } "
503
+ expr = f"{ sign } { x1_i } "
516
504
517
505
def check_result (i1 : float , i2 : float , result : float ) -> bool :
518
506
_check_result = make_eq (signer (i1 ))
519
507
return _check_result (result )
520
508
521
509
else :
522
- repr_ = f"{ sign } { x2_i } "
510
+ expr = f"{ sign } { x2_i } "
523
511
524
512
def check_result (i1 : float , i2 : float , result : float ) -> bool :
525
513
_check_result = make_eq (signer (i2 ))
526
514
return _check_result (result )
527
515
528
- return BinaryResultCheck (check_result , repr_ )
516
+ return BinaryResultCheck (check_result , expr )
529
517
530
518
if m := r_code .match (group ):
531
519
value = parse_value (m .group (1 ))
532
520
_check_result = make_eq (value )
533
- repr_ = str (value )
521
+ expr = str (value )
534
522
elif m := r_approx_value .match (group ):
535
523
value = parse_value (m .group (1 ))
536
524
_check_result = make_rough_eq (value )
537
- repr_ = f"~{ value } "
525
+ expr = f"~{ value } "
538
526
else :
539
527
raise ValueParseError (group )
540
528
541
529
def check_result (i1 : float , i2 : float , result : float ) -> bool :
542
530
return _check_result (result )
543
531
544
- return BinaryResultCheck (check_result , repr_ )
532
+ return BinaryResultCheck (check_result , expr )
545
533
546
534
547
535
class ResultSignCheckFactory (ResultCheckFactory ):
@@ -572,7 +560,7 @@ class BinaryCase(NamedTuple):
572
560
check_result : BinaryResultCheck
573
561
574
562
def __repr__ (self ):
575
- return f"BinaryCase(<{ self .cond } -> { self .check_result } >)"
563
+ return f"BinaryCase(<{ self .cond . expr } -> { self .check_result } >)"
576
564
577
565
578
566
class BinaryCaseFactory (NamedTuple ):
@@ -743,7 +731,7 @@ def test_unary(func_name, func, cases, x):
743
731
f_out = f"{ sh .fmt_idx ('out' , idx )} ={ out } "
744
732
if result .strict_check :
745
733
msg = (
746
- f"{ f_out } , but should be { result .repr_ } [{ func_name } ()]\n "
734
+ f"{ f_out } , but should be { result .expr } [{ func_name } ()]\n "
747
735
f"{ f_in } "
748
736
)
749
737
if math .isnan (result .value ):
@@ -753,7 +741,7 @@ def test_unary(func_name, func, cases, x):
753
741
else :
754
742
assert math .isfinite (result .value ) # sanity check
755
743
assert math .isclose (out , result .value , abs_tol = 0.1 ), (
756
- f"{ f_out } , but should be roughly { result .repr_ } ={ result .value } "
744
+ f"{ f_out } , but should be roughly { result .expr } ={ result .value } "
757
745
f"[{ func_name } ()]\n "
758
746
f"{ f_in } "
759
747
)
0 commit comments