|
1 | 1 | import re
|
2 |
| -import itertools |
3 | 2 | from contextlib import contextmanager
|
4 | 3 | from functools import reduce, wraps
|
5 | 4 | import math
|
@@ -309,18 +308,14 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
|
309 | 308 | # For now, just generate stacks of diagonal matrices.
|
310 | 309 | n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
|
311 | 310 | stack_shape = draw(stack_shapes)
|
312 |
| - shape = stack_shape + (n, n) |
313 |
| - d = draw(arrays(dtypes, shape=n*prod(stack_shape), |
| 311 | + d = draw(arrays(dtypes, shape=(*stack_shape, 1, n), |
314 | 312 | elements=dict(allow_nan=False, allow_infinity=False)))
|
315 | 313 | # Functions that require invertible matrices may do anything when it is
|
316 | 314 | # singular, including raising an exception, so we make sure the diagonals
|
317 | 315 | # are sufficiently nonzero to avoid any numerical issues.
|
318 | 316 | assume(xp.all(xp.abs(d) > 0.5))
|
319 |
| - |
320 |
| - a = xp.zeros(shape) |
321 |
| - for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))): |
322 |
| - a[idx + (i, i)] = d[j] |
323 |
| - return a |
| 317 | + diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1)) |
| 318 | + return xp.where(diag_mask, d, xp.zeros_like(d)) |
324 | 319 |
|
325 | 320 | # TODO: Better name
|
326 | 321 | @composite
|
|
0 commit comments