Skip to content

Commit 8e79ebf

Browse files
committed
Merge branch 'master' into type-promotion-refactor
2 parents b37ecc9 + 10a928b commit 8e79ebf

File tree

2 files changed

+58
-13
lines changed

2 files changed

+58
-13
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
105105
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
106106
)
107107

108+
one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)
109+
108110
# Matrix shapes assume stacks of matrices
109111
matrix_shapes = xps.array_shapes(min_dims=2, min_side=1).filter(
110112
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
@@ -293,9 +295,11 @@ def multiaxis_indices(draw, shapes):
293295
return tuple(res)
294296

295297

296-
def two_mutual_arrays(dtype_objs=dh.all_dtypes):
298+
def two_mutual_arrays(
299+
dtype_objs=dh.all_dtypes, two_shapes=two_mutually_broadcastable_shapes
300+
):
297301
mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objs))
298-
mutual_shapes = shared(two_mutually_broadcastable_shapes)
302+
mutual_shapes = shared(two_shapes)
299303
arrays1 = xps.arrays(
300304
dtype=mutual_dtypes.map(lambda pair: pair[0]),
301305
shape=mutual_shapes.map(lambda pair: pair[0]),
@@ -306,7 +310,6 @@ def two_mutual_arrays(dtype_objs=dh.all_dtypes):
306310
)
307311
return arrays1, arrays2
308312

309-
310313
@composite
311314
def kwargs(draw, **kw):
312315
"""

array_api_tests/test_linalg.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
"""
1515

1616
from hypothesis import assume, given
17-
from hypothesis.strategies import booleans, composite, none, integers, shared
17+
from hypothesis.strategies import booleans, composite, none, tuples, integers, shared
1818

19-
from .array_helpers import assert_exactly_equal, ndindex, asarray
19+
from .array_helpers import assert_exactly_equal, ndindex, asarray, equal, zero, infinity
2020
from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
2121
square_matrix_shapes, symmetric_matrices,
2222
positive_definite_matrices, MAX_ARRAY_SIZE,
2323
invertible_matrices, two_mutual_arrays,
24-
mutually_promotable_dtypes)
24+
mutually_promotable_dtypes, one_d_shapes)
2525
from .pytest_helpers import raises
2626
from . import dtype_helpers as dh
2727

@@ -339,12 +339,27 @@ def test_matrix_transpose(x):
339339
_test_stacks(linalg.matrix_transpose, x, res=res, true_val=true_val)
340340

341341
@given(
342-
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
343-
x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
342+
*two_mutual_arrays(dtype_objects=dh.numeric_dtypes,
343+
two_shapes=tuples(one_d_shapes, one_d_shapes))
344344
)
345345
def test_outer(x1, x2):
346-
# res = linalg.outer(x1, x2)
347-
pass
346+
# outer does not work on stacks. See
347+
# https://github.com/data-apis/array-api/issues/242.
348+
res = linalg.outer(x1, x2)
349+
350+
shape = (x1.shape[0], x2.shape[0])
351+
assert res.shape == shape, "outer() did not return the correct shape"
352+
assert res.dtype == dh.promotion_table[x1, x2], "outer() did not return the correct dtype"
353+
354+
if 0 in shape:
355+
true_res = _array_module.empty(shape, dtype=res.dtype)
356+
else:
357+
true_res = _array_module.asarray([[x1[i]*x2[j]
358+
for j in range(x2.shape[0])]
359+
for i in range(x1.shape[0])],
360+
dtype=res.dtype)
361+
362+
assert_exactly_equal(res, true_res)
348363

349364
@given(
350365
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
@@ -363,11 +378,38 @@ def test_qr(x, kw):
363378
pass
364379

365380
@given(
366-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
381+
x=xps.arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes),
367382
)
368383
def test_slogdet(x):
369-
# res = linalg.slogdet(x)
370-
pass
384+
res = linalg.slogdet(x)
385+
386+
_test_namedtuple(res, ['sign', 'logabsdet'], 'slotdet')
387+
388+
sign, logabsdet = res
389+
390+
assert sign.dtype == x.dtype, "slogdet().sign did not return the correct dtype"
391+
assert sign.shape == x.shape[:-2], "slogdet().sign did not return the correct shape"
392+
assert logabsdet.dtype == x.dtype, "slogdet().logabsdet did not return the correct dtype"
393+
assert logabsdet.shape == x.shape[:-2], "slogdet().logabsdet did not return the correct shape"
394+
395+
396+
_test_stacks(lambda x: linalg.slogdet(x).sign, x,
397+
res=sign, dims=0)
398+
_test_stacks(lambda x: linalg.slogdet(x).logabsdet, x,
399+
res=logabsdet, dims=0)
400+
401+
# Check that when the determinant is 0, the sign and logabsdet are (0,
402+
# -inf).
403+
d = linalg.det(x)
404+
zero_det = equal(d, zero(d.shape, d.dtype))
405+
assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype))
406+
assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype))
407+
408+
# More generally, det(x) should equal sign*exp(logabsdet), but this does
409+
# not hold exactly due to floating-point loss of precision.
410+
411+
# TODO: Test this when we have tests for floating-point values.
412+
# assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps)
371413

372414
@given(
373415
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),

0 commit comments

Comments
 (0)