Skip to content

Commit a943a51

Browse files
committed
Document major elwise testing helpers
1 parent 5b26f9d commit a943a51

File tree

1 file changed

+132
-25
lines changed

1 file changed

+132
-25
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 132 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Test element-wise functions/operators against reference implementations.
3+
"""
14
import math
25
import operator
36
from enum import Enum, auto
@@ -82,31 +85,6 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
8285
return n
8386

8487

85-
# This module tests elementwise functions/operators against a reference
86-
# implementation. We iterate through the input array(s) and resulting array,
87-
# casting the indexed arrays to Python scalars and calculating the expected
88-
# output with `refimpl` function.
89-
#
90-
# This is finicky to refactor, but possible and ultimately worthwhile - hence
91-
# why these *_assert_again_refimpl() utilities exist.
92-
#
93-
# Values which are special-cased are generated and passed, but are filtered by
94-
# the `filter_` callable before they can be asserted against `refimpl`. We
95-
# automatically generate tests for special cases in the special_cases/ dir. We
96-
# still pass them here so as to ensure their presence doesn't affect the outputs
97-
# respective to non-special-cased elements.
98-
#
99-
# By default, results are casted to scalars the same way that the inputs are.
100-
# You can specify a cast via `res_stype, i.e. when a function accepts numerical
101-
# inputs but returns boolean arrays.
102-
#
103-
# By default, floating-point functions/methods are loosely asserted against. Use
104-
# `strict_check=True` when they should be strictly asserted against, i.e.
105-
# when a function should return intergrals. Likewise, use `strict_check=False`
106-
# when integer function/methods should be loosely asserted against, i.e. when
107-
# floats are used internally for optimisation or legacy reasons.
108-
109-
11088
def isclose(a: float, b: float, *, rel_tol: float = 0.25, abs_tol: float = 1) -> bool:
11189
"""Wraps math.isclose with very generous defaults.
11290
@@ -143,6 +121,125 @@ def unary_assert_against_refimpl(
143121
strict_check: Optional[bool] = None,
144122
expr_template: Optional[str] = None,
145123
):
124+
"""
125+
Assert unary element-wise results are as expected.
126+
127+
We iterate through every element in the input and resulting arrays, casting
128+
the respective elements (0-D arrays) to Python scalars, and assert against
129+
the expected output specified by the passed reference implementation, e.g.
130+
131+
>>> x = xp.asarray([[0, 1], [2, 4]])
132+
>>> out = xp.square(x)
133+
>>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2)
134+
135+
is equivalent to
136+
137+
>>> for idx in np.ndindex(x.shape):
138+
... expected = int(x[idx]) ** 2
139+
... assert int(out[idx]) == expected
140+
141+
Casting
142+
-------
143+
144+
The input scalar type is inferred from the input array's dtype like so
145+
146+
Array dtypes | Python builtin type
147+
----------------- | ---------------------
148+
xp.bool | bool
149+
xp.int*, xp.uint* | int
150+
xp.float* | float
151+
xp.complex* | complex
152+
153+
If res_stype=None (the default), the result scalar type is the same as the
154+
input scalar type. We can also specify the result scalar type ourselves, e.g.
155+
156+
>>> x = xp.asarray([42., xp.inf])
157+
>>> out = xp.isinf(x) # should be [False, True]
158+
>>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool)
159+
160+
Filtering special-cased values
161+
------------------------------
162+
163+
Values which are special-cased can be present in the input array, but get
164+
filtered before they can be asserted against refimpl.
165+
166+
If filter_=default_filter (the default), all non-finite and floating zero
167+
values are filtered, e.g.
168+
169+
>>> unary_assert_against_refimpl('sin', x, out, math.sin)
170+
171+
is equivalent to
172+
173+
>>> for idx in np.ndindex(x.shape):
174+
... at_x = float(x[idx])
175+
... if math.isfinite(at_x) or at_x != 0:
176+
... expected = math.sin(at_x)
177+
... assert math.isclose(float(out[idx]), expected)
178+
179+
We can also specify the filter function ourselves, e.g.
180+
181+
>>> def sqrt_filter(s: float) -> bool:
182+
... return math.isfinite(s) and s >= 0
183+
>>> unary_assert_against_refimpl('sqrt', x, out, math.sqrt, filter_=sqrt_filter)
184+
185+
is equivalent to
186+
187+
>>> for idx in np.ndindex(x.shape):
188+
... at_x = float(x[idx])
189+
... if math.isfinite(s) and s >=0:
190+
... expected = math.sin(at_x)
191+
... assert math.isclose(float(out[idx]), expected)
192+
193+
Note we leave special-cased values in the input arrays, so as to ensure
194+
their presence doesn't affect the outputs respective to non-special-cased
195+
elements. We specifically test special case bevaiour in test_special_cases.py.
196+
197+
Assertion strictness
198+
--------------------
199+
200+
If strict_check=None (the default), integer elements are strictly asserted
201+
against, and floating elements are loosely asserted against, e.g.
202+
203+
>>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2)
204+
205+
is equivalent to
206+
207+
>>> for idx in np.ndindex(x.shape):
208+
... expected = in_stype(x[idx]) ** 2
209+
... if in_stype == int:
210+
... assert int(out[idx]) == expected
211+
... else: # in_stype == float
212+
... assert math.isclose(float(out[idx]), expected)
213+
214+
Specifying strict_check as True or False will assert strictly/loosely
215+
respectively, regardless of dtype. This is useful for testing functions that
216+
have definitive outputs for floating inputs, i.e. rounding functions.
217+
218+
Expressions in errors
219+
---------------------
220+
221+
Assertion error messages include an expression, by default using func_name
222+
like so
223+
224+
>>> x = xp.asarray([42., xp.inf])
225+
>>> out = xp.isinf(x)
226+
>>> out
227+
[False, False]
228+
>>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool)
229+
AssertionError: out[1]=False, but should be isinf(x[1])=True ...
230+
231+
We can specify the expression template ourselves, e.g.
232+
233+
>>> x = xp.asarray(True)
234+
>>> out = xp.logical_not(x)
235+
>>> out
236+
True
237+
>>> unary_assert_against_refimpl(
238+
... 'logical_not', x, out, expr_template='(not {})={}'
239+
... )
240+
AssertionError: out=True, but should be (not True)=False ...
241+
242+
"""
146243
if in_.shape != res.shape:
147244
raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
148245
if expr_template is None:
@@ -194,6 +291,11 @@ def binary_assert_against_refimpl(
194291
res_name: str = "out",
195292
expr_template: Optional[str] = None,
196293
):
294+
"""
295+
Assert binary element-wise results are as expected.
296+
297+
See unary_assert_against_refimpl for more information.
298+
"""
197299
if expr_template is None:
198300
expr_template = func_name + "({}, {})={}"
199301
in_stype = dh.get_scalar_type(left.dtype)
@@ -244,6 +346,11 @@ def right_scalar_assert_against_refimpl(
244346
res_name: str = "out",
245347
expr_template: str = None,
246348
):
349+
"""
350+
Assert binary element-wise results from scalar operands are as expected.
351+
352+
See unary_assert_against_refimpl for more information.
353+
"""
247354
if filter_(right):
248355
return # short-circuit here as there will be nothing to test
249356
in_stype = dh.get_scalar_type(left.dtype)

0 commit comments

Comments
 (0)