Skip to content

Commit a945ca4

Browse files
committed
Test NaN propagation special cases
1 parent 8eeae4d commit a945ca4

File tree

2 files changed

+66
-51
lines changed

2 files changed

+66
-51
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def assert_0d_equals(
198198
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
199199
):
200200
msg = (
201-
f"{out_repr}={out_val}, should be {x_repr}={x_val} "
201+
f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
202202
f"[{func_name}({fmt_kw(kw)})]"
203203
)
204204
if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val):

array_api_tests/test_special_cases.py

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,12 +1127,11 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11271127
return cases
11281128

11291129

1130-
category_stub_pairs = [(c, s) for c, stubs in category_to_funcs.items() for s in stubs]
11311130
unary_params = []
11321131
binary_params = []
11331132
iop_params = []
11341133
func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()}
1135-
for category, stub in category_stub_pairs:
1134+
for stub in category_to_funcs["elementwise"]:
11361135
if stub.__doc__ is None:
11371136
warn(f"{stub.__name__}() stub has no docstring")
11381137
continue
@@ -1153,56 +1152,51 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11531152
if len(sig.parameters) == 0:
11541153
warn(f"{func=} has no parameters")
11551154
continue
1156-
if category == "elementwise":
1157-
if param_names[0] == "x":
1158-
if cases := parse_unary_case_block(case_block):
1159-
name_to_func = {stub.__name__: func}
1160-
if stub.__name__ in func_to_op.keys():
1161-
op_name = func_to_op[stub.__name__]
1162-
op = getattr(operator, op_name)
1163-
name_to_func[op_name] = op
1164-
for func_name, func in name_to_func.items():
1165-
for case in cases:
1166-
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
1167-
p = pytest.param(func_name, func, case, id=id_)
1168-
unary_params.append(p)
1169-
else:
1170-
warn("TODO")
1171-
continue
1172-
if len(sig.parameters) == 1:
1173-
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
1174-
continue
1175-
if param_names[0] == "x1" and param_names[1] == "x2":
1176-
if cases := parse_binary_case_block(case_block):
1177-
name_to_func = {stub.__name__: func}
1178-
if stub.__name__ in func_to_op.keys():
1179-
op_name = func_to_op[stub.__name__]
1180-
op = getattr(operator, op_name)
1181-
name_to_func[op_name] = op
1182-
# We collect inplace operator test cases seperately
1183-
iop_name = "__i" + op_name[2:]
1184-
iop = getattr(operator, iop_name)
1185-
for case in cases:
1186-
id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}"
1187-
p = pytest.param(iop_name, iop, case, id=id_)
1188-
iop_params.append(p)
1189-
for func_name, func in name_to_func.items():
1190-
for case in cases:
1191-
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
1192-
p = pytest.param(func_name, func, case, id=id_)
1193-
binary_params.append(p)
1194-
else:
1195-
warn("TODO")
1196-
continue
1155+
if param_names[0] == "x":
1156+
if cases := parse_unary_case_block(case_block):
1157+
name_to_func = {stub.__name__: func}
1158+
if stub.__name__ in func_to_op.keys():
1159+
op_name = func_to_op[stub.__name__]
1160+
op = getattr(operator, op_name)
1161+
name_to_func[op_name] = op
1162+
for func_name, func in name_to_func.items():
1163+
for case in cases:
1164+
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
1165+
p = pytest.param(func_name, func, case, id=id_)
1166+
unary_params.append(p)
11971167
else:
1198-
warn(
1199-
f"{func=} starts with two parameters '{param_names[0]}' and "
1200-
f"'{param_names[1]}', which are not named 'x1' and 'x2'"
1201-
)
1202-
elif category == "statistical":
1203-
pass # TODO
1168+
warn("TODO")
1169+
continue
1170+
if len(sig.parameters) == 1:
1171+
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
1172+
continue
1173+
if param_names[0] == "x1" and param_names[1] == "x2":
1174+
if cases := parse_binary_case_block(case_block):
1175+
name_to_func = {stub.__name__: func}
1176+
if stub.__name__ in func_to_op.keys():
1177+
op_name = func_to_op[stub.__name__]
1178+
op = getattr(operator, op_name)
1179+
name_to_func[op_name] = op
1180+
# We collect inplace operator test cases seperately
1181+
iop_name = "__i" + op_name[2:]
1182+
iop = getattr(operator, iop_name)
1183+
for case in cases:
1184+
id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}"
1185+
p = pytest.param(iop_name, iop, case, id=id_)
1186+
iop_params.append(p)
1187+
for func_name, func in name_to_func.items():
1188+
for case in cases:
1189+
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
1190+
p = pytest.param(func_name, func, case, id=id_)
1191+
binary_params.append(p)
1192+
else:
1193+
warn("TODO")
1194+
continue
12041195
else:
1205-
warn("TODO")
1196+
warn(
1197+
f"{func=} starts with two parameters '{param_names[0]}' and "
1198+
f"'{param_names[1]}', which are not named 'x1' and 'x2'"
1199+
)
12061200

12071201

12081202
# test_unary and test_binary naively generate arrays, i.e. arrays that might not
@@ -1342,3 +1336,24 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
13421336
)
13431337
break
13441338
assume(good_example)
1339+
1340+
1341+
@pytest.mark.parametrize(
1342+
"func_name", [f.__name__ for f in category_to_funcs["statistical"]]
1343+
)
1344+
@given(
1345+
x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)),
1346+
data=st.data(),
1347+
)
1348+
def test_nan_propagation(func_name, x, data):
1349+
func = getattr(xp, func_name)
1350+
set_idx = data.draw(
1351+
xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx"
1352+
)
1353+
x[set_idx] = float("nan")
1354+
note(f"{x=}")
1355+
1356+
out = func(x)
1357+
1358+
ph.assert_shape(func_name, out.shape, ()) # sanity check
1359+
assert xp.isnan(out), f"{out=!r}, but should be NaN"

0 commit comments

Comments
 (0)