Skip to content

Commit 18c425e

Browse files
committed
Let matrix_shapes() take in a strategy for generating the stack shapes
This also changes it so that it can generate shapes with size 0.
1 parent 8706391 commit 18c425e

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,15 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
113113
one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)
114114

115115
# Matrix shapes assume stacks of matrices
116-
matrix_shapes = xps.array_shapes(min_dims=2, min_side=1).filter(
117-
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
118-
)
116+
@composite
117+
def matrix_shapes(draw, stack_shapes=shapes):
118+
stack_shape = draw(stack_shapes)
119+
mat_shape = draw(xps.array_shapes(max_dims=2, min_dims=2))
120+
shape = stack_shape + mat_shape
121+
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
122+
return shape
119123

120-
square_matrix_shapes = matrix_shapes.filter(lambda shape: shape[-1] == shape[-2])
124+
square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2])
121125

122126
two_mutually_broadcastable_shapes = xps.mutually_broadcastable_shapes(num_shapes=2)\
123127
.map(lambda S: S.input_shapes)\

0 commit comments

Comments
 (0)