@@ -1127,12 +1127,11 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
1127
1127
return cases
1128
1128
1129
1129
1130
- category_stub_pairs = [(c , s ) for c , stubs in category_to_funcs .items () for s in stubs ]
1131
1130
unary_params = []
1132
1131
binary_params = []
1133
1132
iop_params = []
1134
1133
func_to_op : Dict [str , str ] = {v : k for k , v in dh .op_to_func .items ()}
1135
- for category , stub in category_stub_pairs :
1134
+ for stub in category_to_funcs [ "elementwise" ] :
1136
1135
if stub .__doc__ is None :
1137
1136
warn (f"{ stub .__name__ } () stub has no docstring" )
1138
1137
continue
@@ -1153,56 +1152,51 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
1153
1152
if len (sig .parameters ) == 0 :
1154
1153
warn (f"{ func = } has no parameters" )
1155
1154
continue
1156
- if category == "elementwise" :
1157
- if param_names [0 ] == "x" :
1158
- if cases := parse_unary_case_block (case_block ):
1159
- name_to_func = {stub .__name__ : func }
1160
- if stub .__name__ in func_to_op .keys ():
1161
- op_name = func_to_op [stub .__name__ ]
1162
- op = getattr (operator , op_name )
1163
- name_to_func [op_name ] = op
1164
- for func_name , func in name_to_func .items ():
1165
- for case in cases :
1166
- id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1167
- p = pytest .param (func_name , func , case , id = id_ )
1168
- unary_params .append (p )
1169
- else :
1170
- warn ("TODO" )
1171
- continue
1172
- if len (sig .parameters ) == 1 :
1173
- warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
1174
- continue
1175
- if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1176
- if cases := parse_binary_case_block (case_block ):
1177
- name_to_func = {stub .__name__ : func }
1178
- if stub .__name__ in func_to_op .keys ():
1179
- op_name = func_to_op [stub .__name__ ]
1180
- op = getattr (operator , op_name )
1181
- name_to_func [op_name ] = op
1182
- # We collect inplace operator test cases seperately
1183
- iop_name = "__i" + op_name [2 :]
1184
- iop = getattr (operator , iop_name )
1185
- for case in cases :
1186
- id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1187
- p = pytest .param (iop_name , iop , case , id = id_ )
1188
- iop_params .append (p )
1189
- for func_name , func in name_to_func .items ():
1190
- for case in cases :
1191
- id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1192
- p = pytest .param (func_name , func , case , id = id_ )
1193
- binary_params .append (p )
1194
- else :
1195
- warn ("TODO" )
1196
- continue
1155
+ if param_names [0 ] == "x" :
1156
+ if cases := parse_unary_case_block (case_block ):
1157
+ name_to_func = {stub .__name__ : func }
1158
+ if stub .__name__ in func_to_op .keys ():
1159
+ op_name = func_to_op [stub .__name__ ]
1160
+ op = getattr (operator , op_name )
1161
+ name_to_func [op_name ] = op
1162
+ for func_name , func in name_to_func .items ():
1163
+ for case in cases :
1164
+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1165
+ p = pytest .param (func_name , func , case , id = id_ )
1166
+ unary_params .append (p )
1197
1167
else :
1198
- warn (
1199
- f"{ func = } starts with two parameters '{ param_names [0 ]} ' and "
1200
- f"'{ param_names [1 ]} ', which are not named 'x1' and 'x2'"
1201
- )
1202
- elif category == "statistical" :
1203
- pass # TODO
1168
+ warn ("TODO" )
1169
+ continue
1170
+ if len (sig .parameters ) == 1 :
1171
+ warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
1172
+ continue
1173
+ if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1174
+ if cases := parse_binary_case_block (case_block ):
1175
+ name_to_func = {stub .__name__ : func }
1176
+ if stub .__name__ in func_to_op .keys ():
1177
+ op_name = func_to_op [stub .__name__ ]
1178
+ op = getattr (operator , op_name )
1179
+ name_to_func [op_name ] = op
1180
+ # We collect inplace operator test cases seperately
1181
+ iop_name = "__i" + op_name [2 :]
1182
+ iop = getattr (operator , iop_name )
1183
+ for case in cases :
1184
+ id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1185
+ p = pytest .param (iop_name , iop , case , id = id_ )
1186
+ iop_params .append (p )
1187
+ for func_name , func in name_to_func .items ():
1188
+ for case in cases :
1189
+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1190
+ p = pytest .param (func_name , func , case , id = id_ )
1191
+ binary_params .append (p )
1192
+ else :
1193
+ warn ("TODO" )
1194
+ continue
1204
1195
else :
1205
- warn ("TODO" )
1196
+ warn (
1197
+ f"{ func = } starts with two parameters '{ param_names [0 ]} ' and "
1198
+ f"'{ param_names [1 ]} ', which are not named 'x1' and 'x2'"
1199
+ )
1206
1200
1207
1201
1208
1202
# test_unary and test_binary naively generate arrays, i.e. arrays that might not
@@ -1342,3 +1336,24 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
1342
1336
)
1343
1337
break
1344
1338
assume (good_example )
1339
+
1340
+
1341
+ @pytest .mark .parametrize (
1342
+ "func_name" , [f .__name__ for f in category_to_funcs ["statistical" ]]
1343
+ )
1344
+ @given (
1345
+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )),
1346
+ data = st .data (),
1347
+ )
1348
+ def test_nan_propagation (func_name , x , data ):
1349
+ func = getattr (xp , func_name )
1350
+ set_idx = data .draw (
1351
+ xps .indices (x .shape , max_dims = 0 , allow_ellipsis = False ), label = "set idx"
1352
+ )
1353
+ x [set_idx ] = float ("nan" )
1354
+ note (f"{ x = } " )
1355
+
1356
+ out = func (x )
1357
+
1358
+ ph .assert_shape (func_name , out .shape , ()) # sanity check
1359
+ assert xp .isnan (out ), f"{ out = !r} , but should be NaN"
0 commit comments