Skip to content

Commit 9c606c8

Browse files
committed
Allow symmetric_matrices to generate finite matrices
1 parent 9e45097 commit 9c606c8

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,11 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
129129

130130
# Note: This should become hermitian_matrices when complex dtypes are added
131131
@composite
132-
def symmetric_matrices(draw, dtypes=xps.floating_dtypes()):
132+
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
133133
shape = draw(square_matrix_shapes)
134134
dtype = draw(dtypes)
135-
a = draw(xps.arrays(dtype=dtype, shape=shape))
135+
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
136+
a = draw(xps.arrays(dtype=dtype, shape=shape, elements=elements))
136137
upper = xp.triu(a)
137138
lower = xp.triu(a, k=1).mT
138139
return upper + lower

array_api_tests/meta_tests/test_hypothesis_helpers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,20 @@ def run(kw):
7070
assert len(c_results) > 0
7171
assert all(isinstance(kw["c"], str) for kw in c_results)
7272

73-
@given(hh.symmetric_matrices(hh.shared_floating_dtypes), hh.shared_floating_dtypes)
74-
def test_symmetric_matrices(m, dtype):
73+
@given(m=hh.symmetric_matrices(hh.shared_floating_dtypes,
74+
finite=st.shared(st.booleans(), key='finite')),
75+
dtype=hh.shared_floating_dtypes,
76+
finite=st.shared(st.booleans(), key='finite'))
77+
def test_symmetric_matrices(m, dtype, finite):
7578
assert m.dtype == dtype
7679
# TODO: This part of this test should be part of the .mT test
7780
ah.assert_exactly_equal(m, m.mT)
7881

82+
if finite:
83+
ah.assert_finite(m)
7984

80-
@given(hh.positive_definite_matrices(hh.shared_floating_dtypes), hh.shared_floating_dtypes)
85+
@given(m=hh.positive_definite_matrices(hh.shared_floating_dtypes),
86+
dtype=hh.shared_floating_dtypes)
8187
def test_positive_definite_matrices(m, dtype):
8288
assert m.dtype == dtype
8389
# TODO: Test that it actually is positive definite

0 commit comments

Comments
 (0)