@@ -21,9 +21,9 @@ def array_method(name):
21
21
def function_category (name ):
22
22
return stub_module (name ).split ('_' )[0 ]
23
23
24
- def example_argument (arg ):
24
+ def example_argument (arg , func_name ):
25
25
"""
26
- Get an example argument for the argument arg
26
+ Get an example argument for the argument arg for the function func_name
27
27
28
28
The full tests for function behavior is in other files. We just need to
29
29
have an example input for each argument name that should work so that we
@@ -37,14 +37,14 @@ def example_argument(arg):
37
37
known_args = dict (
38
38
M = 1 ,
39
39
N = 1 ,
40
- arrays = (ones ((1 , 1 , 1 )), ones ((1 , 1 , 1 ))),
40
+ arrays = (ones ((1 , 3 , 3 )), ones ((1 , 3 , 3 ))),
41
41
# These cannot be the same as each other, which is why all our test
42
42
# arrays have to have at least 3 dimensions.
43
43
axis1 = 2 ,
44
44
axis2 = 2 ,
45
45
axis = 1 ,
46
46
axes = (2 , 1 , 0 ),
47
- condition = ones ((1 , 1 , 1 ), dtype = bool ),
47
+ condition = ones ((1 , 3 , 3 ), dtype = bool ),
48
48
correction = 1.0 ,
49
49
descending = True ,
50
50
dtype = float64 ,
@@ -59,20 +59,24 @@ def example_argument(arg):
59
59
return_counts = True ,
60
60
return_index = True ,
61
61
return_inverse = True ,
62
- shape = (1 , 1 , 1 ),
62
+ shape = (1 , 3 , 3 ),
63
63
shift = 1 ,
64
64
sorted = False ,
65
65
stable = False ,
66
66
start = 0 ,
67
67
step = 2 ,
68
68
stop = 1 ,
69
69
value = 0 ,
70
- x1 = ones ((1 , 1 , 1 )),
71
- x2 = ones ((1 , 1 , 1 )),
72
- x = ones ((1 , 1 , 1 )),
70
+ x1 = ones ((1 , 3 , 3 )),
71
+ x2 = ones ((1 , 3 , 3 )),
72
+ x = ones ((1 , 3 , 3 )),
73
73
)
74
74
75
75
if arg in known_args :
76
+ # This is the only special case. squeeze() requires an axis of size 1,
77
+ # but other functions such as cross() require axes of size >1
78
+ if func_name == 'squeeze' and arg == 'axis' :
79
+ return 0
76
80
return known_args [arg ]
77
81
else :
78
82
raise RuntimeError (f"Don't know how to test argument { arg } . Please update test_signatures.py" )
@@ -107,9 +111,9 @@ def test_function_positional_args(name):
107
111
if argspec .defaults :
108
112
raise RuntimeError (f"Unexpected non-keyword-only keyword argument for { name } . Please update test_signatures.py" )
109
113
110
- args = [example_argument (name ) for name in args ]
114
+ args = [example_argument (arg , name ) for arg in args ]
111
115
if not args :
112
- args = [example_argument ('x' )]
116
+ args = [example_argument ('x' , name )]
113
117
else :
114
118
# Duplicate the last positional argument for the n+1 test.
115
119
args = args + [args [- 1 ]]
@@ -142,10 +146,10 @@ def test_function_keyword_only_args(name):
142
146
kwonlyargs = argspec .kwonlyargs
143
147
kwonlydefaults = argspec .kwonlydefaults
144
148
145
- args = [example_argument (name ) for name in args ]
149
+ args = [example_argument (arg , name ) for arg in args ]
146
150
147
151
for arg in kwonlyargs :
148
- value = example_argument (arg )
152
+ value = example_argument (arg , name )
149
153
# The "only" part of keyword-only is tested by the positional test above.
150
154
doesnt_raise (lambda : mod_func (* args , ** {arg : value }),
151
155
f"{ name } () should accept the keyword-only argument { arg !r} " )
0 commit comments