Skip to content

Commit 0d08906

Browse files
committed
Fix positive_definite_matrices to actually use the dtype
1 parent 74d0118 commit 0d08906

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,16 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
127127
.filter(lambda S: all(prod(i for i in shape if i) < MAX_ARRAY_SIZE for shape in S))
128128

129129
@composite
130-
def positive_definite_matrices(draw, dtype=xps.floating_dtypes()):
130+
def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
131131
# For now just generate stacks of identity matrices
132132
# TODO: Generate arbitrary positive definite matrices, for instance, by
133133
# using something like
134134
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
135135
n = draw(integers(0))
136136
shape = draw(shapes) + (n, n)
137137
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
138-
return broadcast_to(eye(n), shape)
138+
dtype = draw(dtypes)
139+
return broadcast_to(eye(n, dtype=dtype), shape)
139140

140141
@composite
141142
def two_broadcastable_shapes(draw):

0 commit comments

Comments
 (0)