Skip to content

Commit 62742fa

Browse files
committed
Fix the example args in test_signatures.py for some functions
1 parent 050d222 commit 62742fa

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

array_api_tests/test_signatures.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def array_method(name):
2121
def function_category(name):
2222
return stub_module(name).split('_')[0]
2323

24-
def example_argument(arg):
24+
def example_argument(arg, func_name):
2525
"""
26-
Get an example argument for the argument arg
26+
Get an example argument for the argument arg for the function func_name
2727
2828
The full tests for function behavior is in other files. We just need to
2929
have an example input for each argument name that should work so that we
@@ -37,14 +37,14 @@ def example_argument(arg):
3737
known_args = dict(
3838
M=1,
3939
N=1,
40-
arrays=(ones((1, 1, 1)), ones((1, 1, 1))),
40+
arrays=(ones((1, 3, 3)), ones((1, 3, 3))),
4141
# These cannot be the same as each other, which is why all our test
4242
# arrays have to have at least 3 dimensions.
4343
axis1=2,
4444
axis2=2,
4545
axis=1,
4646
axes=(2, 1, 0),
47-
condition=ones((1, 1, 1), dtype=bool),
47+
condition=ones((1, 3, 3), dtype=bool),
4848
correction=1.0,
4949
descending=True,
5050
dtype=float64,
@@ -59,20 +59,24 @@ def example_argument(arg):
5959
return_counts=True,
6060
return_index=True,
6161
return_inverse=True,
62-
shape=(1, 1, 1),
62+
shape=(1, 3, 3),
6363
shift=1,
6464
sorted=False,
6565
stable=False,
6666
start=0,
6767
step=2,
6868
stop=1,
6969
value=0,
70-
x1=ones((1, 1, 1)),
71-
x2=ones((1, 1, 1)),
72-
x=ones((1, 1, 1)),
70+
x1=ones((1, 3, 3)),
71+
x2=ones((1, 3, 3)),
72+
x=ones((1, 3, 3)),
7373
)
7474

7575
if arg in known_args:
76+
# This is the only special case. squeeze() requires an axis of size 1,
77+
# but other functions such as cross() require axes of size >1
78+
if func_name == 'squeeze' and arg == 'axis':
79+
return 0
7680
return known_args[arg]
7781
else:
7882
raise RuntimeError(f"Don't know how to test argument {arg}. Please update test_signatures.py")
@@ -107,9 +111,9 @@ def test_function_positional_args(name):
107111
if argspec.defaults:
108112
raise RuntimeError(f"Unexpected non-keyword-only keyword argument for {name}. Please update test_signatures.py")
109113

110-
args = [example_argument(name) for name in args]
114+
args = [example_argument(arg, name) for arg in args]
111115
if not args:
112-
args = [example_argument('x')]
116+
args = [example_argument('x', name)]
113117
else:
114118
# Duplicate the last positional argument for the n+1 test.
115119
args = args + [args[-1]]
@@ -142,10 +146,10 @@ def test_function_keyword_only_args(name):
142146
kwonlyargs = argspec.kwonlyargs
143147
kwonlydefaults = argspec.kwonlydefaults
144148

145-
args = [example_argument(name) for name in args]
149+
args = [example_argument(arg, name) for arg in args]
146150

147151
for arg in kwonlyargs:
148-
value = example_argument(arg)
152+
value = example_argument(arg, name)
149153
# The "only" part of keyword-only is tested by the positional test above.
150154
doesnt_raise(lambda: mod_func(*args, **{arg: value}),
151155
f"{name}() should accept the keyword-only argument {arg!r}")

0 commit comments

Comments
 (0)