32
32
# Standin strategy for not yet implemented tests
33
33
todo = none ()
34
34
35
- def _test_stacks (f , x , kw , res = None , dims = 2 , true_val = None ):
35
+ def _test_stacks (f , * args , res = None , dims = 2 , true_val = None , ** kw ):
36
36
"""
37
- Test that f(x , **kw) maps across stacks of matrices
37
+ Test that f(*args , **kw) maps across stacks of matrices
38
38
39
39
dims is the number of dimensions f should have for a single n x m matrix
40
40
stack.
41
41
42
- true_val may be a function such that true_val(x_stack ) gives the true
42
+ true_val may be a function such that true_val(*x_stacks ) gives the true
43
43
value for f on a stack
44
44
"""
45
45
if res is None :
46
- res = f (x )
47
- for _idx in ndindex (x .shape [:- 2 ]):
46
+ res = f (* args , ** kw )
47
+
48
+ shape = args [0 ].shape if len (args ) == 1 else broadcast_shapes (* [x .shape
49
+ for x in args ])
50
+ for _idx in ndindex (shape [:- 2 ]):
48
51
idx = _idx + (slice (None ),)* dims
49
52
res_stack = res [idx ]
50
- x_stack = x [idx ]
51
- decomp_res_stack = f (x_stack , ** kw )
53
+ x_stacks = [ x [idx ] for x in args ]
54
+ decomp_res_stack = f (* x_stacks , ** kw )
52
55
assert_exactly_equal (res_stack , decomp_res_stack )
53
56
if true_val :
54
- assert_exactly_equal (decomp_res_stack , true_val (x_stack ))
57
+ assert_exactly_equal (decomp_res_stack , true_val (* x_stacks ))
55
58
56
59
def _test_namedtuple (res , fields , func_name ):
57
60
"""
@@ -77,7 +80,7 @@ def test_cholesky(x, kw):
77
80
assert res .shape == x .shape , "cholesky() did not return the correct shape"
78
81
assert res .dtype == x .dtype , "cholesky() did not return the correct dtype"
79
82
80
- _test_stacks (_array_module .linalg .cholesky , x , kw , res )
83
+ _test_stacks (_array_module .linalg .cholesky , x , ** kw , res = res )
81
84
82
85
# Test that the result is upper or lower triangular
83
86
if kw .get ('upper' , False ):
@@ -161,7 +164,7 @@ def test_det(x):
161
164
assert res .dtype == x .dtype , "det() did not return the correct dtype"
162
165
assert res .shape == x .shape [:- 2 ], "det() did not return the correct shape"
163
166
164
- _test_stacks (_array_module .linalg .det , x , {}, res , dims = 0 )
167
+ _test_stacks (_array_module .linalg .det , x , res = res , dims = 0 )
165
168
166
169
# TODO: Test that res actually corresponds to the determinant of x
167
170
@@ -197,7 +200,7 @@ def true_diag(x_stack):
197
200
x_stack_diag = [x_stack [i - offset , i ] for i in range (diag_size )]
198
201
return asarray (x_stack_diag , dtype = x .dtype )
199
202
200
- _test_stacks (_array_module .linalg .diagonal , x , kw , res , dims = 1 , true_val = true_diag )
203
+ _test_stacks (_array_module .linalg .diagonal , x , ** kw , res = res , dims = 1 , true_val = true_diag )
201
204
202
205
@given (x = symmetric_matrices (finite = True ))
203
206
def test_eigh (x ):
@@ -214,8 +217,10 @@ def test_eigh(x):
214
217
assert eigenvectors .dtype == x .dtype , "eigh().eigenvectors did not return the correct dtype"
215
218
assert eigenvectors .shape == x .shape , "eigh().eigenvectors did not return the correct shape"
216
219
217
- _test_stacks (lambda x : _array_module .linalg .eigh (x ).eigenvalues , x , {}, eigenvalues , dims = 1 )
218
- _test_stacks (lambda x : _array_module .linalg .eigh (x ).eigenvectors , x , {}, eigenvectors , dims = 2 )
220
+ _test_stacks (lambda x : _array_module .linalg .eigh (x ).eigenvalues , x ,
221
+ res = eigenvalues , dims = 1 )
222
+ _test_stacks (lambda x : _array_module .linalg .eigh (x ).eigenvectors , x ,
223
+ res = eigenvectors , dims = 2 )
219
224
220
225
# TODO: Test that res actually corresponds to the eigenvalues and
221
226
# eigenvectors of x
@@ -227,7 +232,7 @@ def test_eigvalsh(x):
227
232
assert res .dtype == x .dtype , "eigvalsh() did not return the correct dtype"
228
233
assert res .shape == x .shape [:- 1 ], "eigvalsh() did not return the correct shape"
229
234
230
- _test_stacks (_array_module .linalg .eigvalsh , x , {}, res , dims = 1 )
235
+ _test_stacks (_array_module .linalg .eigvalsh , x , res = res , dims = 1 )
231
236
232
237
# TODO: Should we test that the result is the same as eigh(x).eigenvalues?
233
238
0 commit comments