Skip to content

Commit 7ce7b28

Browse files
jakevdphonno
authored andcommitted
helpers: avoid mutation in invertible_matrices
1 parent 6bcede9 commit 7ce7b28

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import re
2-
import itertools
32
from contextlib import contextmanager
43
from functools import reduce
54
from math import sqrt
@@ -274,18 +273,14 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
274273
# For now, just generate stacks of diagonal matrices.
275274
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
276275
stack_shape = draw(stack_shapes)
277-
shape = stack_shape + (n, n)
278-
d = draw(xps.arrays(dtypes, shape=n*prod(stack_shape),
276+
d = draw(xps.arrays(dtypes, shape=(*stack_shape, 1, n),
279277
elements=dict(allow_nan=False, allow_infinity=False)))
280278
# Functions that require invertible matrices may do anything when it is
281279
# singular, including raising an exception, so we make sure the diagonals
282280
# are sufficiently nonzero to avoid any numerical issues.
283281
assume(xp.all(xp.abs(d) > 0.5))
284-
285-
a = xp.zeros(shape)
286-
for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))):
287-
a[idx + (i, i)] = d[j]
288-
return a
282+
diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1))
283+
return xp.where(diag_mask, d, xp.zeros_like(d))
289284

290285
# TODO: Better name
291286
@composite

0 commit comments

Comments
 (0)