Skip to content

Commit bd4ab55

Browse files
committed
helpers: avoid mutation in invertible_matrices
1 parent f82c7bc commit bd4ab55

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import re
2-
import itertools
32
from contextlib import contextmanager
43
from functools import reduce
54
from math import sqrt
@@ -267,18 +266,14 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
267266
# For now, just generate stacks of diagonal matrices.
268267
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
269268
stack_shape = draw(stack_shapes)
270-
shape = stack_shape + (n, n)
271-
d = draw(xps.arrays(dtypes, shape=n*prod(stack_shape),
269+
d = draw(xps.arrays(dtypes, shape=(*stack_shape, 1, n),
272270
elements=dict(allow_nan=False, allow_infinity=False)))
273271
# Functions that require invertible matrices may do anything when it is
274272
# singular, including raising an exception, so we make sure the diagonals
275273
# are sufficiently nonzero to avoid any numerical issues.
276274
assume(xp.all(xp.abs(d) > 0.5))
277-
278-
a = xp.zeros(shape)
279-
for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))):
280-
a[idx + (i, i)] = d[j]
281-
return a
275+
diag_mask = xp.arange(n) == xp.arange(n)[:, None]
276+
return xp.where(diag_mask, d, xp.zeros_like(d))
282277

283278
# TODO: Better name
284279
@composite

0 commit comments

Comments
 (0)