|
13 | 13 |
|
14 | 14 | """
|
15 | 15 |
|
16 |
| -from hypothesis import given |
17 |
| -from hypothesis.strategies import booleans, none, integers |
| 16 | +from hypothesis import assume, given |
| 17 | +from hypothesis.strategies import booleans, composite, none, integers, shared |
18 | 18 |
|
19 |
| -from .array_helpers import assert_exactly_equal, ndindex, asarray |
| 19 | +from .array_helpers import (assert_exactly_equal, ndindex, asarray, |
| 20 | + numeric_dtype_objects) |
20 | 21 | from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
|
21 | 22 | square_matrix_shapes, symmetric_matrices,
|
22 | 23 | positive_definite_matrices, MAX_ARRAY_SIZE,
|
23 |
| - invertible_matrices) |
| 24 | + invertible_matrices, |
| 25 | + mutually_promotable_dtypes) |
24 | 26 |
|
25 | 27 | from . import _array_module
|
26 | 28 |
|
@@ -80,14 +82,72 @@ def test_cholesky(x, kw):
|
80 | 82 | else:
|
81 | 83 | assert_exactly_equal(res, _array_module.tril(res))
|
82 | 84 |
|
| 85 | + |
| 86 | +@composite |
| 87 | +def cross_args(draw, dtype_objects=numeric_dtype_objects): |
| 88 | + """ |
| 89 | + cross() requires two arrays with a size 3 in the 'axis' dimension |
| 90 | +
|
| 91 | + To do this, we generate a shape and an axis but change the shape to be 3 |
| 92 | + in the drawn axis. |
| 93 | +
|
| 94 | + """ |
| 95 | + shape = list(draw(shapes)) |
| 96 | + size = len(shape) |
| 97 | + assume(size > 0) |
| 98 | + |
| 99 | + kw = draw(kwargs(axis=integers(-size, size-1))) |
| 100 | + axis = kw.get('axis', -1) |
| 101 | + shape[axis] = 3 |
| 102 | + |
| 103 | + mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objects)) |
| 104 | + arrays1 = xps.arrays( |
| 105 | + dtype=mutual_dtypes.map(lambda pair: pair[0]), |
| 106 | + shape=shape, |
| 107 | + ) |
| 108 | + arrays2 = xps.arrays( |
| 109 | + dtype=mutual_dtypes.map(lambda pair: pair[1]), |
| 110 | + shape=shape, |
| 111 | + ) |
| 112 | + return draw(arrays1), draw(arrays2), kw |
| 113 | + |
83 | 114 | @given(
|
84 |
| - x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes), |
85 |
| - x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes), |
86 |
| - kw=kwargs(axis=todo) |
| 115 | + cross_args() |
87 | 116 | )
|
88 |
| -def test_cross(x1, x2, kw): |
89 |
| - # res = _array_module.linalg.cross(x1, x2, **kw) |
90 |
| - pass |
| 117 | +def test_cross(x1_x2_kw): |
| 118 | + x1, x2, kw = x1_x2_kw |
| 119 | + |
| 120 | + axis = kw.get('axis', -1) |
| 121 | + err = "test_cross produced invalid input. This indicates a bug in the test suite." |
| 122 | + assert x1.shape == x2.shape, err |
| 123 | + shape = x1.shape |
| 124 | + assert x1.shape[axis] == x2.shape[axis] == 3, err |
| 125 | + |
| 126 | + res = _array_module.linalg.cross(x1, x2, **kw) |
| 127 | + |
| 128 | + assert res.dtype == _array_module.result_type(x1, x2), "cross() did not return the correct dtype" |
| 129 | + assert res.shape == shape, "cross() did not return the correct shape" |
| 130 | + |
| 131 | + # cross is too different from other functions to use _test_stacks, and it |
| 132 | + # is the only function that works the way it does, so it's not really |
| 133 | + # worth generalizing _test_stacks to handle it. |
| 134 | + a = axis if axis >= 0 else axis + len(shape) |
| 135 | + for _idx in ndindex(shape[:a] + shape[a+1:]): |
| 136 | + idx = _idx[:a] + (slice(None),) + _idx[a:] |
| 137 | + assert len(idx) == len(shape), "Invalid index. This indicates a bug in the test suite." |
| 138 | + res_stack = res[idx] |
| 139 | + x1_stack = x1[idx] |
| 140 | + x2_stack = x2[idx] |
| 141 | + assert x1_stack.shape == x2_stack.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite." |
| 142 | + decomp_res_stack = _array_module.linalg.cross(x1_stack, x2_stack) |
| 143 | + assert_exactly_equal(res_stack, decomp_res_stack) |
| 144 | + |
| 145 | + exact_cross = asarray([ |
| 146 | + x1_stack[1]*x2_stack[2] - x1_stack[2]*x2_stack[1], |
| 147 | + x1_stack[2]*x2_stack[0] - x1_stack[0]*x2_stack[2], |
| 148 | + x1_stack[0]*x2_stack[1] - x1_stack[1]*x2_stack[0], |
| 149 | + ], dtype=res.dtype) |
| 150 | + assert_exactly_equal(res_stack, exact_cross) |
91 | 151 |
|
92 | 152 | @given(
|
93 | 153 | x=xps.arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes),
|
|
0 commit comments