@@ -526,6 +526,10 @@ def __repr__(self) -> str:
526
526
return f"{ self .__class__ .__name__ } (<{ self } >)"
527
527
528
528
529
+ r_case_block = re .compile (r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters" )
530
+ r_case = re .compile (r"\s+-\s*(.*)\." )
531
+
532
+
529
533
class UnaryCond (Protocol ):
530
534
def __call__ (self , i : float ) -> bool :
531
535
...
@@ -586,7 +590,7 @@ def check_result(i: float, result: float) -> bool:
586
590
return check_result
587
591
588
592
589
- def parse_unary_docstring ( docstring : str ) -> List [UnaryCase ]:
593
+ def parse_unary_case_block ( case_block : str ) -> List [UnaryCase ]:
590
594
"""
591
595
Parses a Sphinx-formatted docstring of a unary function to return a list of
592
596
codified unary cases, e.g.
@@ -616,7 +620,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
616
620
... an array containing the square root of each element in ``x``
617
621
... '''
618
622
...
619
- >>> unary_cases = parse_unary_docstring(sqrt.__doc__)
623
+ >>> case_block = r_case_block.match(sqrt.__doc__).group(1)
624
+ >>> unary_cases = parse_unary_case_block(case_block)
620
625
>>> for case in unary_cases:
621
626
... print(repr(case))
622
627
UnaryCase(<x_i < 0 -> NaN>)
@@ -631,19 +636,10 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
631
636
True
632
637
633
638
"""
634
-
635
- match = r_special_cases .search (docstring )
636
- if match is None :
637
- return []
638
- lines = match .group (1 ).split ("\n " )[:- 1 ]
639
639
cases = []
640
- for line in lines :
641
- if m := r_case .match (line ):
642
- case = m .group (1 )
643
- else :
644
- warn (f"line not machine-readable: '{ line } '" )
645
- continue
646
- if m := r_unary_case .search (case ):
640
+ for case_m in r_case .finditer (case_block ):
641
+ case_str = case_m .group (1 )
642
+ if m := r_unary_case .search (case_str ):
647
643
try :
648
644
cond , cond_expr_template , cond_from_dtype = parse_cond (m .group (1 ))
649
645
_check_result , result_expr = parse_result (m .group (2 ))
@@ -662,11 +658,11 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
662
658
check_result = check_result ,
663
659
)
664
660
cases .append (case )
665
- elif m := r_even_round_halves_case .search (case ):
661
+ elif m := r_even_round_halves_case .search (case_str ):
666
662
cases .append (even_round_halves_case )
667
663
else :
668
- if not r_remaining_case .search (case ):
669
- warn (f"case not machine-readable: '{ case } '" )
664
+ if not r_remaining_case .search (case_str ):
665
+ warn (f"case not machine-readable: '{ case_str } '" )
670
666
return cases
671
667
672
668
@@ -690,12 +686,6 @@ class BinaryCase(Case):
690
686
check_result : BinaryResultCheck
691
687
692
688
693
- r_special_cases = re .compile (
694
- r"\*\*Special [Cc]ases\*\*(?:\n.*)+"
695
- r"For floating-point operands,\n+"
696
- r"((?:\s*-\s*.*\n)+)"
697
- )
698
- r_case = re .compile (r"\s+-\s*(.*)\.\n?" )
699
689
r_binary_case = re .compile ("If (.+), the result (.+)" )
700
690
r_remaining_case = re .compile ("In the remaining cases.+" )
701
691
r_cond_sep = re .compile (r"(?<!``x1_i``),? and |(?<!i\.e\.), " )
@@ -880,8 +870,7 @@ def parse_binary_case(case_str: str) -> BinaryCase:
880
870
881
871
"""
882
872
case_m = r_binary_case .match (case_str )
883
- if case_m is None :
884
- raise ParseError (case_str )
873
+ assert case_m is not None # sanity check
885
874
cond_strs = r_cond_sep .split (case_m .group (1 ))
886
875
887
876
partial_conds = []
@@ -1078,7 +1067,7 @@ def cond(i1: float, i2: float) -> bool:
1078
1067
r_redundant_case = re .compile ("result.+determined by the rule already stated above" )
1079
1068
1080
1069
1081
- def parse_binary_docstring ( docstring : str ) -> List [BinaryCase ]:
1070
+ def parse_binary_case_block ( case_block : str ) -> List [BinaryCase ]:
1082
1071
"""
1083
1072
Parses a Sphinx-formatted docstring of a binary function to return a list of
1084
1073
codified binary cases, e.g.
@@ -1108,29 +1097,21 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1108
1097
... an array containing the results
1109
1098
... '''
1110
1099
...
1111
- >>> binary_cases = parse_binary_docstring(logaddexp.__doc__)
1100
+ >>> case_block = r_case_block.match(logaddexp.__doc__).group(1)
1101
+ >>> binary_cases = parse_binary_case_block(case_block)
1112
1102
>>> for case in binary_cases:
1113
1103
... print(repr(case))
1114
1104
BinaryCase(<x1_i == NaN or x2_i == NaN -> NaN>)
1115
1105
BinaryCase(<x1_i == +infinity and not x2_i == NaN -> +infinity>)
1116
1106
BinaryCase(<not x1_i == NaN and x2_i == +infinity -> +infinity>)
1117
1107
1118
1108
"""
1119
-
1120
- match = r_special_cases .search (docstring )
1121
- if match is None :
1122
- return []
1123
- lines = match .group (1 ).split ("\n " )[:- 1 ]
1124
1109
cases = []
1125
- for line in lines :
1126
- if m := r_case .match (line ):
1127
- case_str = m .group (1 )
1128
- else :
1129
- warn (f"line not machine-readable: '{ line } '" )
1130
- continue
1110
+ for case_m in r_case .finditer (case_block ):
1111
+ case_str = case_m .group (1 )
1131
1112
if r_redundant_case .search (case_str ):
1132
1113
continue
1133
- if m := r_binary_case .match (case_str ):
1114
+ if r_binary_case .match (case_str ):
1134
1115
try :
1135
1116
case = parse_binary_case (case_str )
1136
1117
cases .append (case )
@@ -1142,14 +1123,19 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1142
1123
return cases
1143
1124
1144
1125
1126
+ category_stub_pairs = [(c , s ) for c , stubs in category_to_funcs .items () for s in stubs ]
1145
1127
unary_params = []
1146
1128
binary_params = []
1147
1129
iop_params = []
1148
1130
func_to_op : Dict [str , str ] = {v : k for k , v in dh .op_to_func .items ()}
1149
- for stub in category_to_funcs [ "elementwise" ] :
1131
+ for category , stub in category_stub_pairs :
1150
1132
if stub .__doc__ is None :
1151
1133
warn (f"{ stub .__name__ } () stub has no docstring" )
1152
1134
continue
1135
+ if m := r_case_block .search (stub .__doc__ ):
1136
+ case_block = m .group (1 )
1137
+ else :
1138
+ continue
1153
1139
marks = []
1154
1140
try :
1155
1141
func = getattr (xp , stub .__name__ )
@@ -1163,47 +1149,56 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1163
1149
if len (sig .parameters ) == 0 :
1164
1150
warn (f"{ func = } has no parameters" )
1165
1151
continue
1166
- if param_names [0 ] == "x" :
1167
- if cases := parse_unary_docstring (stub .__doc__ ):
1168
- name_to_func = {stub .__name__ : func }
1169
- if stub .__name__ in func_to_op .keys ():
1170
- op_name = func_to_op [stub .__name__ ]
1171
- op = getattr (operator , op_name )
1172
- name_to_func [op_name ] = op
1173
- for func_name , func in name_to_func .items ():
1174
- for case in cases :
1175
- id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1176
- p = pytest .param (func_name , func , case , id = id_ )
1177
- unary_params .append (p )
1178
- continue
1179
- if len (sig .parameters ) == 1 :
1180
- warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
1181
- continue
1182
- if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1183
- if cases := parse_binary_docstring (stub .__doc__ ):
1184
- name_to_func = {stub .__name__ : func }
1185
- if stub .__name__ in func_to_op .keys ():
1186
- op_name = func_to_op [stub .__name__ ]
1187
- op = getattr (operator , op_name )
1188
- name_to_func [op_name ] = op
1189
- # We collect inplaceoperator test cases seperately
1190
- iop_name = "__i" + op_name [2 :]
1191
- iop = getattr (operator , iop_name )
1192
- for case in cases :
1193
- id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1194
- p = pytest .param (iop_name , iop , case , id = id_ )
1195
- iop_params .append (p )
1196
- for func_name , func in name_to_func .items ():
1197
- for case in cases :
1198
- id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1199
- p = pytest .param (func_name , func , case , id = id_ )
1200
- binary_params .append (p )
1201
- continue
1152
+ if category == "elementwise" :
1153
+ if param_names [0 ] == "x" :
1154
+ if cases := parse_unary_case_block (case_block ):
1155
+ name_to_func = {stub .__name__ : func }
1156
+ if stub .__name__ in func_to_op .keys ():
1157
+ op_name = func_to_op [stub .__name__ ]
1158
+ op = getattr (operator , op_name )
1159
+ name_to_func [op_name ] = op
1160
+ for func_name , func in name_to_func .items ():
1161
+ for case in cases :
1162
+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1163
+ p = pytest .param (func_name , func , case , id = id_ )
1164
+ unary_params .append (p )
1165
+ else :
1166
+ warn ("TODO" )
1167
+ continue
1168
+ if len (sig .parameters ) == 1 :
1169
+ warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
1170
+ continue
1171
+ if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1172
+ if cases := parse_binary_case_block (case_block ):
1173
+ name_to_func = {stub .__name__ : func }
1174
+ if stub .__name__ in func_to_op .keys ():
1175
+ op_name = func_to_op [stub .__name__ ]
1176
+ op = getattr (operator , op_name )
1177
+ name_to_func [op_name ] = op
1178
+ # We collect inplace operator test cases seperately
1179
+ iop_name = "__i" + op_name [2 :]
1180
+ iop = getattr (operator , iop_name )
1181
+ for case in cases :
1182
+ id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1183
+ p = pytest .param (iop_name , iop , case , id = id_ )
1184
+ iop_params .append (p )
1185
+ for func_name , func in name_to_func .items ():
1186
+ for case in cases :
1187
+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1188
+ p = pytest .param (func_name , func , case , id = id_ )
1189
+ binary_params .append (p )
1190
+ else :
1191
+ warn ("TODO" )
1192
+ continue
1193
+ else :
1194
+ warn (
1195
+ f"{ func = } starts with two parameters '{ param_names [0 ]} ' and "
1196
+ f"'{ param_names [1 ]} ', which are not named 'x1' and 'x2'"
1197
+ )
1198
+ elif category == "statistical" :
1199
+ pass # TODO
1202
1200
else :
1203
- warn (
1204
- f"{ func = } starts with two parameters '{ param_names [0 ]} ' and "
1205
- f"'{ param_names [1 ]} ', which are not named 'x1' and 'x2'"
1206
- )
1201
+ warn ("TODO" )
1207
1202
1208
1203
1209
1204
# test_unary and test_binary naively generate arrays, i.e. arrays that might not
0 commit comments