-
Notifications
You must be signed in to change notification settings - Fork 45
ENH: test vecdot values, incl complex conj #314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
from . import _array_module as xp | ||
from ._array_module import linalg | ||
|
||
|
||
def assert_equal(x, y, msg_extra=None): | ||
extra = '' if not msg_extra else f' ({msg_extra})' | ||
if x.dtype in dh.all_float_dtypes: | ||
|
@@ -60,6 +61,7 @@ def assert_equal(x, y, msg_extra=None): | |
else: | ||
assert_exactly_equal(x, y, msg_extra=msg_extra) | ||
|
||
|
||
def _test_stacks(f, *args, res=None, dims=2, true_val=None, | ||
matrix_axes=(-2, -1), | ||
res_axes=None, | ||
|
@@ -106,6 +108,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, | |
if true_val: | ||
assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra) | ||
|
||
|
||
def _test_namedtuple(res, fields, func_name): | ||
""" | ||
Test that res is a namedtuple with the correct fields. | ||
|
@@ -121,6 +124,7 @@ def _test_namedtuple(res, fields, func_name): | |
assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field" | ||
assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}" | ||
|
||
|
||
@pytest.mark.unvectorized | ||
@pytest.mark.xp_extension('linalg') | ||
@given( | ||
|
@@ -901,6 +905,15 @@ def true_trace(x_stack, offset=0): | |
|
||
_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace) | ||
|
||
|
||
def _conj(x): | ||
"""Work around xp.conj rejecting floats.""" | ||
if xp.isdtype(x.dtype, 'complex floating'): | ||
return xp.conj(x) | ||
else: | ||
return x | ||
|
||
|
||
def _test_vecdot(namespace, x1, x2, data): | ||
vecdot = namespace.vecdot | ||
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape) | ||
|
@@ -925,11 +938,8 @@ def _test_vecdot(namespace, x1, x2, data): | |
ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape], | ||
out_shape=res.shape, expected=expected_shape) | ||
|
||
if x1.dtype in dh.int_dtypes: | ||
def true_val(x, y, axis=-1): | ||
return xp.sum(xp.multiply(x, y), dtype=res.dtype) | ||
else: | ||
true_val = None | ||
def true_val(x, y, axis=-1): | ||
return xp.sum(xp.multiply(_conj(x), y), dtype=res.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, if you look here, there is no approximate testing done at all for floating-point values https://github.com/data-apis/array-api-tests/pull/314/files?diff=unified#diff-6056c0b3af9cd3ba66387432a17f5f36bbd54220419656441a8b01bcdc4df44bR57. We should probably add a flag to that helper to allow approximate testing to be enabled. Some functions are impossible to do approximate testing for because they don't even have a single possible output (e.g., There are helpers used in the elementwise functions that could be reused here for testing floating-point (and complex) closeness. Basically, they test with very large epsilons. Even that would be enough to detect that a library isn't conjugating, which is the real concern for this test specifically. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, what: https://github.com/data-apis/array-api-tests/blob/master/array_api_tests/test_linalg.py#L109 and I was sure this PR switches on equality testing. But There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I originally had it there but later had to comment it out when I found out that you fundamentally can't compare floating point stacks for some functions like eigh because sometimes implementations (like CuPy) use different algorithms that give different (but still mathematically correct) results. #101 (comment) I never renamed the functions. While I did turn the floating-point checks off completely, and they don't make sense for some functions like eigh, they do make sense for functions like vecdot and others where the mathematical answer is has a single well-defined value. We just need to be careful about loss of significance, especially for the tensor contraction functions (vecdot, matmul, tensordot, etc.) that involve additions. |
||
|
||
_test_stacks(vecdot, x1, x2, res=res, dims=0, | ||
matrix_axes=(axis,), true_val=true_val) | ||
|
@@ -944,6 +954,7 @@ def true_val(x, y, axis=-1): | |
def test_linalg_vecdot(x1, x2, data): | ||
_test_vecdot(linalg, x1, x2, data) | ||
|
||
|
||
@pytest.mark.unvectorized | ||
@given( | ||
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)), | ||
|
@@ -952,10 +963,12 @@ def test_linalg_vecdot(x1, x2, data): | |
def test_vecdot(x1, x2, data): | ||
_test_vecdot(_array_module, x1, x2, data) | ||
|
||
|
||
# Insanely large orders might not work. There isn't a limit specified in the | ||
# spec, so we just limit to reasonable values here. | ||
max_ord = 100 | ||
|
||
|
||
@pytest.mark.unvectorized | ||
@pytest.mark.xp_extension('linalg') | ||
@given( | ||
|
Uh oh!
There was an error while loading. Please reload this page.