|
1 | 1 | import re
|
2 |
| -import itertools |
3 | 2 | from contextlib import contextmanager
|
4 | 3 | from functools import reduce
|
5 | 4 | from math import sqrt
|
@@ -267,18 +266,14 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
|
267 | 266 | # For now, just generate stacks of diagonal matrices.
|
268 | 267 | n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
|
269 | 268 | stack_shape = draw(stack_shapes)
|
270 |
| - shape = stack_shape + (n, n) |
271 |
| - d = draw(xps.arrays(dtypes, shape=n*prod(stack_shape), |
| 269 | + d = draw(xps.arrays(dtypes, shape=(*stack_shape, 1, n), |
272 | 270 | elements=dict(allow_nan=False, allow_infinity=False)))
|
273 | 271 | # Functions that require invertible matrices may do anything when it is
|
274 | 272 | # singular, including raising an exception, so we make sure the diagonals
|
275 | 273 | # are sufficiently nonzero to avoid any numerical issues.
|
276 | 274 | assume(xp.all(xp.abs(d) > 0.5))
|
277 |
| - |
278 |
| - a = xp.zeros(shape) |
279 |
| - for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))): |
280 |
| - a[idx + (i, i)] = d[j] |
281 |
| - return a |
| 275 | + diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1)) |
| 276 | + return xp.where(diag_mask, d, xp.zeros_like(d)) |
282 | 277 |
|
283 | 278 | # TODO: Better name
|
284 | 279 | @composite
|
|
0 commit comments