Skip to content

Document element-wise testing utilities #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 144 additions & 33 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Test element-wise functions/operators against reference implementations.
"""
import math
import operator
from enum import Enum, auto
Expand Down Expand Up @@ -82,32 +85,7 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
return n


# This module tests elementwise functions/operators against a reference
# implementation. We iterate through the input array(s) and resulting array,
# casting the indexed arrays to Python scalars and calculating the expected
# output with `refimpl` function.
#
# This is finicky to refactor, but possible and ultimately worthwhile - hence
# why these *_assert_again_refimpl() utilities exist.
#
# Values which are special-cased are generated and passed, but are filtered by
# the `filter_` callable before they can be asserted against `refimpl`. We
# automatically generate tests for special cases in the special_cases/ dir. We
# still pass them here so as to ensure their presence doesn't affect the outputs
# respective to non-special-cased elements.
#
# By default, results are casted to scalars the same way that the inputs are.
# You can specify a cast via `res_stype, i.e. when a function accepts numerical
# inputs but returns boolean arrays.
#
# By default, floating-point functions/methods are loosely asserted against. Use
# `strict_check=True` when they should be strictly asserted against, i.e.
# when a function should return intergrals. Likewise, use `strict_check=False`
# when integer function/methods should be loosely asserted against, i.e. when
# floats are used internally for optimisation or legacy reasons.


def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool:
def isclose(a: float, b: float, *, rel_tol: float = 0.25, abs_tol: float = 1) -> bool:
"""Wraps math.isclose with very generous defaults.

This is useful for many floating-point operations where the spec does not
Expand Down Expand Up @@ -137,11 +115,131 @@ def unary_assert_against_refimpl(
in_: Array,
res: Array,
refimpl: Callable[[T], T],
expr_template: Optional[str] = None,
*,
res_stype: Optional[ScalarType] = None,
filter_: Callable[[Scalar], bool] = default_filter,
strict_check: Optional[bool] = None,
expr_template: Optional[str] = None,
):
"""
Assert unary element-wise results are as expected.

We iterate through every element in the input and resulting arrays, casting
the respective elements (0-D arrays) to Python scalars, and assert against
the expected output specified by the passed reference implementation, e.g.

>>> x = xp.asarray([[0, 1], [2, 4]])
>>> out = xp.square(x)
>>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2)

is equivalent to

>>> for idx in np.ndindex(x.shape):
... expected = int(x[idx]) ** 2
... assert int(out[idx]) == expected

Casting
-------

The input scalar type is inferred from the input array's dtype like so

Array dtypes | Python builtin type
----------------- | ---------------------
xp.bool | bool
xp.int*, xp.uint* | int
xp.float* | float
xp.complex* | complex

If res_stype=None (the default), the result scalar type is the same as the
input scalar type. We can also specify the result scalar type ourselves, e.g.

>>> x = xp.asarray([42., xp.inf])
>>> out = xp.isinf(x) # should be [False, True]
>>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool)

Filtering special-cased values
------------------------------

Values which are special-cased can be present in the input array, but get
filtered before they can be asserted against refimpl.

If filter_=default_filter (the default), all non-finite and floating zero
values are filtered, e.g.

>>> unary_assert_against_refimpl('sin', x, out, math.sin)

is equivalent to

>>> for idx in np.ndindex(x.shape):
... at_x = float(x[idx])
... if math.isfinite(at_x) or at_x != 0:
... expected = math.sin(at_x)
... assert math.isclose(float(out[idx]), expected)

We can also specify the filter function ourselves, e.g.

>>> def sqrt_filter(s: float) -> bool:
... return math.isfinite(s) and s >= 0
>>> unary_assert_against_refimpl('sqrt', x, out, math.sqrt, filter_=sqrt_filter)

is equivalent to

>>> for idx in np.ndindex(x.shape):
... at_x = float(x[idx])
... if math.isfinite(s) and s >=0:
... expected = math.sin(at_x)
... assert math.isclose(float(out[idx]), expected)

Note we leave special-cased values in the input arrays, so as to ensure
their presence doesn't affect the outputs respective to non-special-cased
elements. We specifically test special case bevaiour in test_special_cases.py.

Assertion strictness
--------------------

If strict_check=None (the default), integer elements are strictly asserted
against, and floating elements are loosely asserted against, e.g.

>>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2)

is equivalent to

>>> for idx in np.ndindex(x.shape):
... expected = in_stype(x[idx]) ** 2
... if in_stype == int:
... assert int(out[idx]) == expected
... else: # in_stype == float
... assert math.isclose(float(out[idx]), expected)

Specifying strict_check as True or False will assert strictly/loosely
respectively, regardless of dtype. This is useful for testing functions that
have definitive outputs for floating inputs, i.e. rounding functions.

Expressions in errors
---------------------

Assertion error messages include an expression, by default using func_name
like so

>>> x = xp.asarray([42., xp.inf])
>>> out = xp.isinf(x)
>>> out
[False, False]
>>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool)
AssertionError: out[1]=False, but should be isinf(x[1])=True ...

We can specify the expression template ourselves, e.g.

>>> x = xp.asarray(True)
>>> out = xp.logical_not(x)
>>> out
True
>>> unary_assert_against_refimpl(
... 'logical_not', x, out, expr_template='(not {})={}'
... )
AssertionError: out=True, but should be (not True)=False ...

"""
if in_.shape != res.shape:
raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
if expr_template is None:
Expand Down Expand Up @@ -184,14 +282,20 @@ def binary_assert_against_refimpl(
right: Array,
res: Array,
refimpl: Callable[[T, T], T],
expr_template: Optional[str] = None,
*,
res_stype: Optional[ScalarType] = None,
filter_: Callable[[Scalar], bool] = default_filter,
strict_check: Optional[bool] = None,
left_sym: str = "x1",
right_sym: str = "x2",
res_name: str = "out",
filter_: Callable[[Scalar], bool] = default_filter,
strict_check: Optional[bool] = None,
expr_template: Optional[str] = None,
):
"""
Assert binary element-wise results are as expected.

See unary_assert_against_refimpl for more information.
"""
if expr_template is None:
expr_template = func_name + "({}, {})={}"
in_stype = dh.get_scalar_type(left.dtype)
Expand Down Expand Up @@ -234,13 +338,19 @@ def right_scalar_assert_against_refimpl(
right: Scalar,
res: Array,
refimpl: Callable[[T, T], T],
expr_template: str = None,
*,
res_stype: Optional[ScalarType] = None,
left_sym: str = "x1",
res_name: str = "out",
filter_: Callable[[Scalar], bool] = default_filter,
strict_check: Optional[bool] = None,
left_sym: str = "x1",
res_name: str = "out",
expr_template: str = None,
):
"""
Assert binary element-wise results from scalar operands are as expected.

See unary_assert_against_refimpl for more information.
"""
if filter_(right):
return # short-circuit here as there will be nothing to test
in_stype = dh.get_scalar_type(left.dtype)
Expand Down Expand Up @@ -486,6 +596,7 @@ def binary_param_assert_against_refimpl(
res: Array,
op_sym: str,
refimpl: Callable[[T, T], T],
*,
res_stype: Optional[ScalarType] = None,
filter_: Callable[[Scalar], bool] = default_filter,
strict_check: Optional[bool] = None,
Expand Down