Skip to content

Commit 6c51306

Browse files
committed
Implement test_cross()
1 parent 9fb015b commit 6c51306

File tree

1 file changed

+70
-10
lines changed

1 file changed

+70
-10
lines changed

array_api_tests/test_linalg.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
1414
"""
1515

16-
from hypothesis import given
17-
from hypothesis.strategies import booleans, none, integers
16+
from hypothesis import assume, given
17+
from hypothesis.strategies import booleans, composite, none, integers, shared
1818

19-
from .array_helpers import assert_exactly_equal, ndindex, asarray
19+
from .array_helpers import (assert_exactly_equal, ndindex, asarray,
20+
numeric_dtype_objects)
2021
from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
2122
square_matrix_shapes, symmetric_matrices,
2223
positive_definite_matrices, MAX_ARRAY_SIZE,
23-
invertible_matrices)
24+
invertible_matrices,
25+
mutually_promotable_dtypes)
2426

2527
from . import _array_module
2628

@@ -80,14 +82,72 @@ def test_cholesky(x, kw):
8082
else:
8183
assert_exactly_equal(res, _array_module.tril(res))
8284

85+
86+
@composite
87+
def cross_args(draw, dtype_objects=numeric_dtype_objects):
88+
"""
89+
cross() requires two arrays with a size 3 in the 'axis' dimension
90+
91+
To do this, we generate a shape and an axis but change the shape to be 3
92+
in the drawn axis.
93+
94+
"""
95+
shape = list(draw(shapes))
96+
size = len(shape)
97+
assume(size > 0)
98+
99+
kw = draw(kwargs(axis=integers(-size, size-1)))
100+
axis = kw.get('axis', -1)
101+
shape[axis] = 3
102+
103+
mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objects))
104+
arrays1 = xps.arrays(
105+
dtype=mutual_dtypes.map(lambda pair: pair[0]),
106+
shape=shape,
107+
)
108+
arrays2 = xps.arrays(
109+
dtype=mutual_dtypes.map(lambda pair: pair[1]),
110+
shape=shape,
111+
)
112+
return draw(arrays1), draw(arrays2), kw
113+
83114
@given(
84-
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
85-
x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
86-
kw=kwargs(axis=todo)
115+
cross_args()
87116
)
88-
def test_cross(x1, x2, kw):
89-
# res = _array_module.linalg.cross(x1, x2, **kw)
90-
pass
117+
def test_cross(x1_x2_kw):
118+
x1, x2, kw = x1_x2_kw
119+
120+
axis = kw.get('axis', -1)
121+
err = "test_cross produced invalid input. This indicates a bug in the test suite."
122+
assert x1.shape == x2.shape, err
123+
shape = x1.shape
124+
assert x1.shape[axis] == x2.shape[axis] == 3, err
125+
126+
res = _array_module.linalg.cross(x1, x2, **kw)
127+
128+
assert res.dtype == _array_module.result_type(x1, x2), "cross() did not return the correct dtype"
129+
assert res.shape == shape, "cross() did not return the correct shape"
130+
131+
# cross is too different from other functions to use _test_stacks, and it
132+
# is the only function that works the way it does, so it's not really
133+
# worth generalizing _test_stacks to handle it.
134+
a = axis if axis >= 0 else axis + len(shape)
135+
for _idx in ndindex(shape[:a] + shape[a+1:]):
136+
idx = _idx[:a] + (slice(None),) + _idx[a:]
137+
assert len(idx) == len(shape), "Invalid index. This indicates a bug in the test suite."
138+
res_stack = res[idx]
139+
x1_stack = x1[idx]
140+
x2_stack = x2[idx]
141+
assert x1_stack.shape == x2_stack.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
142+
decomp_res_stack = _array_module.linalg.cross(x1_stack, x2_stack)
143+
assert_exactly_equal(res_stack, decomp_res_stack)
144+
145+
exact_cross = asarray([
146+
x1_stack[1]*x2_stack[2] - x1_stack[2]*x2_stack[1],
147+
x1_stack[2]*x2_stack[0] - x1_stack[0]*x2_stack[2],
148+
x1_stack[0]*x2_stack[1] - x1_stack[1]*x2_stack[0],
149+
], dtype=res.dtype)
150+
assert_exactly_equal(res_stack, exact_cross)
91151

92152
@given(
93153
x=xps.arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes),

0 commit comments

Comments
 (0)