|
| 1 | +""" |
| 2 | +Test element-wise functions/operators against reference implementations. |
| 3 | +""" |
1 | 4 | import math
|
2 | 5 | import operator
|
3 | 6 | from enum import Enum, auto
|
@@ -82,32 +85,7 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
|
82 | 85 | return n
|
83 | 86 |
|
84 | 87 |
|
85 |
| -# This module tests elementwise functions/operators against a reference |
86 |
| -# implementation. We iterate through the input array(s) and resulting array, |
87 |
| -# casting the indexed arrays to Python scalars and calculating the expected |
88 |
| -# output with `refimpl` function. |
89 |
| -# |
90 |
| -# This is finicky to refactor, but possible and ultimately worthwhile - hence |
91 |
| -# why these *_assert_again_refimpl() utilities exist. |
92 |
| -# |
93 |
| -# Values which are special-cased are generated and passed, but are filtered by |
94 |
| -# the `filter_` callable before they can be asserted against `refimpl`. We |
95 |
| -# automatically generate tests for special cases in the special_cases/ dir. We |
96 |
| -# still pass them here so as to ensure their presence doesn't affect the outputs |
97 |
| -# respective to non-special-cased elements. |
98 |
| -# |
99 |
| -# By default, results are casted to scalars the same way that the inputs are. |
100 |
| -# You can specify a cast via `res_stype, i.e. when a function accepts numerical |
101 |
| -# inputs but returns boolean arrays. |
102 |
| -# |
103 |
| -# By default, floating-point functions/methods are loosely asserted against. Use |
104 |
| -# `strict_check=True` when they should be strictly asserted against, i.e. |
105 |
| -# when a function should return intergrals. Likewise, use `strict_check=False` |
106 |
| -# when integer function/methods should be loosely asserted against, i.e. when |
107 |
| -# floats are used internally for optimisation or legacy reasons. |
108 |
| - |
109 |
| - |
110 |
| -def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: |
| 88 | +def isclose(a: float, b: float, *, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: |
111 | 89 | """Wraps math.isclose with very generous defaults.
|
112 | 90 |
|
113 | 91 | This is useful for many floating-point operations where the spec does not
|
@@ -137,11 +115,131 @@ def unary_assert_against_refimpl(
|
137 | 115 | in_: Array,
|
138 | 116 | res: Array,
|
139 | 117 | refimpl: Callable[[T], T],
|
140 |
| - expr_template: Optional[str] = None, |
| 118 | + *, |
141 | 119 | res_stype: Optional[ScalarType] = None,
|
142 | 120 | filter_: Callable[[Scalar], bool] = default_filter,
|
143 | 121 | strict_check: Optional[bool] = None,
|
| 122 | + expr_template: Optional[str] = None, |
144 | 123 | ):
|
| 124 | + """ |
| 125 | + Assert unary element-wise results are as expected. |
| 126 | +
|
| 127 | + We iterate through every element in the input and resulting arrays, casting |
| 128 | + the respective elements (0-D arrays) to Python scalars, and assert against |
| 129 | + the expected output specified by the passed reference implementation, e.g. |
| 130 | +
|
| 131 | + >>> x = xp.asarray([[0, 1], [2, 4]]) |
| 132 | + >>> out = xp.square(x) |
| 133 | + >>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2) |
| 134 | +
|
| 135 | + is equivalent to |
| 136 | +
|
| 137 | + >>> for idx in np.ndindex(x.shape): |
| 138 | + ... expected = int(x[idx]) ** 2 |
| 139 | + ... assert int(out[idx]) == expected |
| 140 | +
|
| 141 | + Casting |
| 142 | + ------- |
| 143 | +
|
| 144 | + The input scalar type is inferred from the input array's dtype like so |
| 145 | +
|
| 146 | + Array dtypes | Python builtin type |
| 147 | + ----------------- | --------------------- |
| 148 | + xp.bool | bool |
| 149 | + xp.int*, xp.uint* | int |
| 150 | + xp.float* | float |
| 151 | + xp.complex* | complex |
| 152 | +
|
| 153 | + If res_stype=None (the default), the result scalar type is the same as the |
| 154 | + input scalar type. We can also specify the result scalar type ourselves, e.g. |
| 155 | +
|
| 156 | + >>> x = xp.asarray([42., xp.inf]) |
| 157 | + >>> out = xp.isinf(x) # should be [False, True] |
| 158 | + >>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool) |
| 159 | +
|
| 160 | + Filtering special-cased values |
| 161 | + ------------------------------ |
| 162 | +
|
| 163 | + Values which are special-cased can be present in the input array, but get |
| 164 | + filtered before they can be asserted against refimpl. |
| 165 | +
|
| 166 | + If filter_=default_filter (the default), all non-finite and floating zero |
| 167 | + values are filtered, e.g. |
| 168 | +
|
| 169 | + >>> unary_assert_against_refimpl('sin', x, out, math.sin) |
| 170 | +
|
| 171 | + is equivalent to |
| 172 | +
|
| 173 | + >>> for idx in np.ndindex(x.shape): |
| 174 | + ... at_x = float(x[idx]) |
| 175 | + ... if math.isfinite(at_x) or at_x != 0: |
| 176 | + ... expected = math.sin(at_x) |
| 177 | + ... assert math.isclose(float(out[idx]), expected) |
| 178 | +
|
| 179 | + We can also specify the filter function ourselves, e.g. |
| 180 | +
|
| 181 | + >>> def sqrt_filter(s: float) -> bool: |
| 182 | + ... return math.isfinite(s) and s >= 0 |
| 183 | + >>> unary_assert_against_refimpl('sqrt', x, out, math.sqrt, filter_=sqrt_filter) |
| 184 | +
|
| 185 | + is equivalent to |
| 186 | +
|
| 187 | + >>> for idx in np.ndindex(x.shape): |
| 188 | + ... at_x = float(x[idx]) |
| 189 | + ... if math.isfinite(s) and s >=0: |
| 190 | + ... expected = math.sin(at_x) |
| 191 | + ... assert math.isclose(float(out[idx]), expected) |
| 192 | +
|
| 193 | + Note we leave special-cased values in the input arrays, so as to ensure |
| 194 | + their presence doesn't affect the outputs respective to non-special-cased |
| 195 | + elements. We specifically test special case bevaiour in test_special_cases.py. |
| 196 | +
|
| 197 | + Assertion strictness |
| 198 | + -------------------- |
| 199 | +
|
| 200 | + If strict_check=None (the default), integer elements are strictly asserted |
| 201 | + against, and floating elements are loosely asserted against, e.g. |
| 202 | +
|
| 203 | + >>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2) |
| 204 | +
|
| 205 | + is equivalent to |
| 206 | +
|
| 207 | + >>> for idx in np.ndindex(x.shape): |
| 208 | + ... expected = in_stype(x[idx]) ** 2 |
| 209 | + ... if in_stype == int: |
| 210 | + ... assert int(out[idx]) == expected |
| 211 | + ... else: # in_stype == float |
| 212 | + ... assert math.isclose(float(out[idx]), expected) |
| 213 | +
|
| 214 | + Specifying strict_check as True or False will assert strictly/loosely |
| 215 | + respectively, regardless of dtype. This is useful for testing functions that |
| 216 | + have definitive outputs for floating inputs, i.e. rounding functions. |
| 217 | +
|
| 218 | + Expressions in errors |
| 219 | + --------------------- |
| 220 | +
|
| 221 | + Assertion error messages include an expression, by default using func_name |
| 222 | + like so |
| 223 | +
|
| 224 | + >>> x = xp.asarray([42., xp.inf]) |
| 225 | + >>> out = xp.isinf(x) |
| 226 | + >>> out |
| 227 | + [False, False] |
| 228 | + >>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool) |
| 229 | + AssertionError: out[1]=False, but should be isinf(x[1])=True ... |
| 230 | +
|
| 231 | + We can specify the expression template ourselves, e.g. |
| 232 | +
|
| 233 | + >>> x = xp.asarray(True) |
| 234 | + >>> out = xp.logical_not(x) |
| 235 | + >>> out |
| 236 | + True |
| 237 | + >>> unary_assert_against_refimpl( |
| 238 | + ... 'logical_not', x, out, expr_template='(not {})={}' |
| 239 | + ... ) |
| 240 | + AssertionError: out=True, but should be (not True)=False ... |
| 241 | +
|
| 242 | + """ |
145 | 243 | if in_.shape != res.shape:
|
146 | 244 | raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
|
147 | 245 | if expr_template is None:
|
@@ -184,14 +282,20 @@ def binary_assert_against_refimpl(
|
184 | 282 | right: Array,
|
185 | 283 | res: Array,
|
186 | 284 | refimpl: Callable[[T, T], T],
|
187 |
| - expr_template: Optional[str] = None, |
| 285 | + *, |
188 | 286 | res_stype: Optional[ScalarType] = None,
|
| 287 | + filter_: Callable[[Scalar], bool] = default_filter, |
| 288 | + strict_check: Optional[bool] = None, |
189 | 289 | left_sym: str = "x1",
|
190 | 290 | right_sym: str = "x2",
|
191 | 291 | res_name: str = "out",
|
192 |
| - filter_: Callable[[Scalar], bool] = default_filter, |
193 |
| - strict_check: Optional[bool] = None, |
| 292 | + expr_template: Optional[str] = None, |
194 | 293 | ):
|
| 294 | + """ |
| 295 | + Assert binary element-wise results are as expected. |
| 296 | +
|
| 297 | + See unary_assert_against_refimpl for more information. |
| 298 | + """ |
195 | 299 | if expr_template is None:
|
196 | 300 | expr_template = func_name + "({}, {})={}"
|
197 | 301 | in_stype = dh.get_scalar_type(left.dtype)
|
@@ -234,13 +338,19 @@ def right_scalar_assert_against_refimpl(
|
234 | 338 | right: Scalar,
|
235 | 339 | res: Array,
|
236 | 340 | refimpl: Callable[[T, T], T],
|
237 |
| - expr_template: str = None, |
| 341 | + *, |
238 | 342 | res_stype: Optional[ScalarType] = None,
|
239 |
| - left_sym: str = "x1", |
240 |
| - res_name: str = "out", |
241 | 343 | filter_: Callable[[Scalar], bool] = default_filter,
|
242 | 344 | strict_check: Optional[bool] = None,
|
| 345 | + left_sym: str = "x1", |
| 346 | + res_name: str = "out", |
| 347 | + expr_template: str = None, |
243 | 348 | ):
|
| 349 | + """ |
| 350 | + Assert binary element-wise results from scalar operands are as expected. |
| 351 | +
|
| 352 | + See unary_assert_against_refimpl for more information. |
| 353 | + """ |
244 | 354 | if filter_(right):
|
245 | 355 | return # short-circuit here as there will be nothing to test
|
246 | 356 | in_stype = dh.get_scalar_type(left.dtype)
|
@@ -486,6 +596,7 @@ def binary_param_assert_against_refimpl(
|
486 | 596 | res: Array,
|
487 | 597 | op_sym: str,
|
488 | 598 | refimpl: Callable[[T, T], T],
|
| 599 | + *, |
489 | 600 | res_stype: Optional[ScalarType] = None,
|
490 | 601 | filter_: Callable[[Scalar], bool] = default_filter,
|
491 | 602 | strict_check: Optional[bool] = None,
|
|
0 commit comments