Skip to content

Commit 0334ced

Browse files
committed
Make _test_stacks work on two-argument functions
1 parent b71bf13 commit 0334ced

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

array_api_tests/test_linalg.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,29 @@
3232
# Standin strategy for not yet implemented tests
3333
todo = none()
3434

35-
def _test_stacks(f, x, kw, res=None, dims=2, true_val=None):
35+
def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw):
3636
"""
37-
Test that f(x, **kw) maps across stacks of matrices
37+
Test that f(*args, **kw) maps across stacks of matrices
3838
3939
dims is the number of dimensions f should have for a single n x m matrix
4040
stack.
4141
42-
true_val may be a function such that true_val(x_stack) gives the true
42+
true_val may be a function such that true_val(*x_stacks) gives the true
4343
value for f on a stack
4444
"""
4545
if res is None:
46-
res = f(x)
47-
for _idx in ndindex(x.shape[:-2]):
46+
res = f(*args, **kw)
47+
48+
shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape
49+
for x in args])
50+
for _idx in ndindex(shape[:-2]):
4851
idx = _idx + (slice(None),)*dims
4952
res_stack = res[idx]
50-
x_stack = x[idx]
51-
decomp_res_stack = f(x_stack, **kw)
53+
x_stacks = [x[idx] for x in args]
54+
decomp_res_stack = f(*x_stacks, **kw)
5255
assert_exactly_equal(res_stack, decomp_res_stack)
5356
if true_val:
54-
assert_exactly_equal(decomp_res_stack, true_val(x_stack))
57+
assert_exactly_equal(decomp_res_stack, true_val(*x_stacks))
5558

5659
def _test_namedtuple(res, fields, func_name):
5760
"""
@@ -77,7 +80,7 @@ def test_cholesky(x, kw):
7780
assert res.shape == x.shape, "cholesky() did not return the correct shape"
7881
assert res.dtype == x.dtype, "cholesky() did not return the correct dtype"
7982

80-
_test_stacks(_array_module.linalg.cholesky, x, kw, res)
83+
_test_stacks(_array_module.linalg.cholesky, x, **kw, res=res)
8184

8285
# Test that the result is upper or lower triangular
8386
if kw.get('upper', False):
@@ -161,7 +164,7 @@ def test_det(x):
161164
assert res.dtype == x.dtype, "det() did not return the correct dtype"
162165
assert res.shape == x.shape[:-2], "det() did not return the correct shape"
163166

164-
_test_stacks(_array_module.linalg.det, x, {}, res, dims=0)
167+
_test_stacks(_array_module.linalg.det, x, res=res, dims=0)
165168

166169
# TODO: Test that res actually corresponds to the determinant of x
167170

@@ -197,7 +200,7 @@ def true_diag(x_stack):
197200
x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)]
198201
return asarray(x_stack_diag, dtype=x.dtype)
199202

200-
_test_stacks(_array_module.linalg.diagonal, x, kw, res, dims=1, true_val=true_diag)
203+
_test_stacks(_array_module.linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag)
201204

202205
@given(x=symmetric_matrices(finite=True))
203206
def test_eigh(x):
@@ -214,8 +217,10 @@ def test_eigh(x):
214217
assert eigenvectors.dtype == x.dtype, "eigh().eigenvectors did not return the correct dtype"
215218
assert eigenvectors.shape == x.shape, "eigh().eigenvectors did not return the correct shape"
216219

217-
_test_stacks(lambda x: _array_module.linalg.eigh(x).eigenvalues, x, {}, eigenvalues, dims=1)
218-
_test_stacks(lambda x: _array_module.linalg.eigh(x).eigenvectors, x, {}, eigenvectors, dims=2)
220+
_test_stacks(lambda x: _array_module.linalg.eigh(x).eigenvalues, x,
221+
res=eigenvalues, dims=1)
222+
_test_stacks(lambda x: _array_module.linalg.eigh(x).eigenvectors, x,
223+
res=eigenvectors, dims=2)
219224

220225
# TODO: Test that res actually corresponds to the eigenvalues and
221226
# eigenvectors of x
@@ -227,7 +232,7 @@ def test_eigvalsh(x):
227232
assert res.dtype == x.dtype, "eigvalsh() did not return the correct dtype"
228233
assert res.shape == x.shape[:-1], "eigvalsh() did not return the correct shape"
229234

230-
_test_stacks(_array_module.linalg.eigvalsh, x, {}, res, dims=1)
235+
_test_stacks(_array_module.linalg.eigvalsh, x, res=res, dims=1)
231236

232237
# TODO: Should we test that the result is the same as eigh(x).eigenvalues?
233238

0 commit comments

Comments
 (0)