Skip to content

Commit 8d4ccd6

Browse files
authored
Merge pull request #201 from jakevdp/fix-invertible-matrices
helpers: avoid mutation in invertible_matrices
2 parents 974367f + f81c3a2 commit 8d4ccd6

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import re
2-
import itertools
32
from contextlib import contextmanager
43
from functools import reduce, wraps
54
import math
@@ -309,18 +308,14 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
309308
# For now, just generate stacks of diagonal matrices.
310309
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
311310
stack_shape = draw(stack_shapes)
312-
shape = stack_shape + (n, n)
313-
d = draw(arrays(dtypes, shape=n*prod(stack_shape),
311+
d = draw(arrays(dtypes, shape=(*stack_shape, 1, n),
314312
elements=dict(allow_nan=False, allow_infinity=False)))
315313
# Functions that require invertible matrices may do anything when it is
316314
# singular, including raising an exception, so we make sure the diagonals
317315
# are sufficiently nonzero to avoid any numerical issues.
318316
assume(xp.all(xp.abs(d) > 0.5))
319-
320-
a = xp.zeros(shape)
321-
for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))):
322-
a[idx + (i, i)] = d[j]
323-
return a
317+
diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1))
318+
return xp.where(diag_mask, d, xp.zeros_like(d))
324319

325320
# TODO: Better name
326321
@composite

0 commit comments

Comments
 (0)