Skip to content

Commit 852a1dd

Browse files
authored
Merge pull request #148 from honno/elwise-docs
Document element-wise testing utilities
2 parents 8ad6b62 + a943a51 commit 852a1dd

File tree

1 file changed

+144
-33
lines changed

1 file changed

+144
-33
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 144 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Test element-wise functions/operators against reference implementations.
3+
"""
14
import math
25
import operator
36
from enum import Enum, auto
@@ -82,32 +85,7 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
8285
return n
8386

8487

85-
# This module tests elementwise functions/operators against a reference
86-
# implementation. We iterate through the input array(s) and resulting array,
87-
# casting the indexed arrays to Python scalars and calculating the expected
88-
# output with `refimpl` function.
89-
#
90-
# This is finicky to refactor, but possible and ultimately worthwhile - hence
91-
# why these *_assert_again_refimpl() utilities exist.
92-
#
93-
# Values which are special-cased are generated and passed, but are filtered by
94-
# the `filter_` callable before they can be asserted against `refimpl`. We
95-
# automatically generate tests for special cases in the special_cases/ dir. We
96-
# still pass them here so as to ensure their presence doesn't affect the outputs
97-
# respective to non-special-cased elements.
98-
#
99-
# By default, results are casted to scalars the same way that the inputs are.
100-
# You can specify a cast via `res_stype, i.e. when a function accepts numerical
101-
# inputs but returns boolean arrays.
102-
#
103-
# By default, floating-point functions/methods are loosely asserted against. Use
104-
# `strict_check=True` when they should be strictly asserted against, i.e.
105-
# when a function should return intergrals. Likewise, use `strict_check=False`
106-
# when integer function/methods should be loosely asserted against, i.e. when
107-
# floats are used internally for optimisation or legacy reasons.
108-
109-
110-
def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool:
88+
def isclose(a: float, b: float, *, rel_tol: float = 0.25, abs_tol: float = 1) -> bool:
11189
"""Wraps math.isclose with very generous defaults.
11290
11391
This is useful for many floating-point operations where the spec does not
@@ -137,11 +115,131 @@ def unary_assert_against_refimpl(
137115
in_: Array,
138116
res: Array,
139117
refimpl: Callable[[T], T],
140-
expr_template: Optional[str] = None,
118+
*,
141119
res_stype: Optional[ScalarType] = None,
142120
filter_: Callable[[Scalar], bool] = default_filter,
143121
strict_check: Optional[bool] = None,
122+
expr_template: Optional[str] = None,
144123
):
124+
"""
125+
Assert unary element-wise results are as expected.
126+
127+
We iterate through every element in the input and resulting arrays, casting
128+
the respective elements (0-D arrays) to Python scalars, and assert against
129+
the expected output specified by the passed reference implementation, e.g.
130+
131+
>>> x = xp.asarray([[0, 1], [2, 4]])
132+
>>> out = xp.square(x)
133+
>>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2)
134+
135+
is equivalent to
136+
137+
>>> for idx in np.ndindex(x.shape):
138+
... expected = int(x[idx]) ** 2
139+
... assert int(out[idx]) == expected
140+
141+
Casting
142+
-------
143+
144+
The input scalar type is inferred from the input array's dtype like so
145+
146+
Array dtypes | Python builtin type
147+
----------------- | ---------------------
148+
xp.bool | bool
149+
xp.int*, xp.uint* | int
150+
xp.float* | float
151+
xp.complex* | complex
152+
153+
If res_stype=None (the default), the result scalar type is the same as the
154+
input scalar type. We can also specify the result scalar type ourselves, e.g.
155+
156+
>>> x = xp.asarray([42., xp.inf])
157+
>>> out = xp.isinf(x) # should be [False, True]
158+
>>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool)
159+
160+
Filtering special-cased values
161+
------------------------------
162+
163+
Values which are special-cased can be present in the input array, but get
164+
filtered before they can be asserted against refimpl.
165+
166+
If filter_=default_filter (the default), all non-finite and floating zero
167+
values are filtered, e.g.
168+
169+
>>> unary_assert_against_refimpl('sin', x, out, math.sin)
170+
171+
is equivalent to
172+
173+
>>> for idx in np.ndindex(x.shape):
174+
... at_x = float(x[idx])
175+
... if math.isfinite(at_x) or at_x != 0:
176+
... expected = math.sin(at_x)
177+
... assert math.isclose(float(out[idx]), expected)
178+
179+
We can also specify the filter function ourselves, e.g.
180+
181+
>>> def sqrt_filter(s: float) -> bool:
182+
... return math.isfinite(s) and s >= 0
183+
>>> unary_assert_against_refimpl('sqrt', x, out, math.sqrt, filter_=sqrt_filter)
184+
185+
is equivalent to
186+
187+
>>> for idx in np.ndindex(x.shape):
188+
... at_x = float(x[idx])
189+
... if math.isfinite(s) and s >=0:
190+
... expected = math.sin(at_x)
191+
... assert math.isclose(float(out[idx]), expected)
192+
193+
Note we leave special-cased values in the input arrays, so as to ensure
194+
their presence doesn't affect the outputs respective to non-special-cased
195+
elements. We specifically test special case bevaiour in test_special_cases.py.
196+
197+
Assertion strictness
198+
--------------------
199+
200+
If strict_check=None (the default), integer elements are strictly asserted
201+
against, and floating elements are loosely asserted against, e.g.
202+
203+
>>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2)
204+
205+
is equivalent to
206+
207+
>>> for idx in np.ndindex(x.shape):
208+
... expected = in_stype(x[idx]) ** 2
209+
... if in_stype == int:
210+
... assert int(out[idx]) == expected
211+
... else: # in_stype == float
212+
... assert math.isclose(float(out[idx]), expected)
213+
214+
Specifying strict_check as True or False will assert strictly/loosely
215+
respectively, regardless of dtype. This is useful for testing functions that
216+
have definitive outputs for floating inputs, i.e. rounding functions.
217+
218+
Expressions in errors
219+
---------------------
220+
221+
Assertion error messages include an expression, by default using func_name
222+
like so
223+
224+
>>> x = xp.asarray([42., xp.inf])
225+
>>> out = xp.isinf(x)
226+
>>> out
227+
[False, False]
228+
>>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool)
229+
AssertionError: out[1]=False, but should be isinf(x[1])=True ...
230+
231+
We can specify the expression template ourselves, e.g.
232+
233+
>>> x = xp.asarray(True)
234+
>>> out = xp.logical_not(x)
235+
>>> out
236+
True
237+
>>> unary_assert_against_refimpl(
238+
... 'logical_not', x, out, expr_template='(not {})={}'
239+
... )
240+
AssertionError: out=True, but should be (not True)=False ...
241+
242+
"""
145243
if in_.shape != res.shape:
146244
raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
147245
if expr_template is None:
@@ -184,14 +282,20 @@ def binary_assert_against_refimpl(
184282
right: Array,
185283
res: Array,
186284
refimpl: Callable[[T, T], T],
187-
expr_template: Optional[str] = None,
285+
*,
188286
res_stype: Optional[ScalarType] = None,
287+
filter_: Callable[[Scalar], bool] = default_filter,
288+
strict_check: Optional[bool] = None,
189289
left_sym: str = "x1",
190290
right_sym: str = "x2",
191291
res_name: str = "out",
192-
filter_: Callable[[Scalar], bool] = default_filter,
193-
strict_check: Optional[bool] = None,
292+
expr_template: Optional[str] = None,
194293
):
294+
"""
295+
Assert binary element-wise results are as expected.
296+
297+
See unary_assert_against_refimpl for more information.
298+
"""
195299
if expr_template is None:
196300
expr_template = func_name + "({}, {})={}"
197301
in_stype = dh.get_scalar_type(left.dtype)
@@ -234,13 +338,19 @@ def right_scalar_assert_against_refimpl(
234338
right: Scalar,
235339
res: Array,
236340
refimpl: Callable[[T, T], T],
237-
expr_template: str = None,
341+
*,
238342
res_stype: Optional[ScalarType] = None,
239-
left_sym: str = "x1",
240-
res_name: str = "out",
241343
filter_: Callable[[Scalar], bool] = default_filter,
242344
strict_check: Optional[bool] = None,
345+
left_sym: str = "x1",
346+
res_name: str = "out",
347+
expr_template: str = None,
243348
):
349+
"""
350+
Assert binary element-wise results from scalar operands are as expected.
351+
352+
See unary_assert_against_refimpl for more information.
353+
"""
244354
if filter_(right):
245355
return # short-circuit here as there will be nothing to test
246356
in_stype = dh.get_scalar_type(left.dtype)
@@ -486,6 +596,7 @@ def binary_param_assert_against_refimpl(
486596
res: Array,
487597
op_sym: str,
488598
refimpl: Callable[[T, T], T],
599+
*,
489600
res_stype: Optional[ScalarType] = None,
490601
filter_: Callable[[Scalar], bool] = default_filter,
491602
strict_check: Optional[bool] = None,

0 commit comments

Comments
 (0)