Skip to content

Commit 4a1c0a1

Browse files
committed
Use new default_uint=None convention in sum and prod tests
1 parent c3b4006 commit 4a1c0a1

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,24 +136,27 @@ def test_prod(x, data):
136136
default_dtype = dh.default_uint
137137
else:
138138
default_dtype = dh.default_int
139-
m, M = dh.dtype_ranges[x.dtype]
140-
d_m, d_M = dh.dtype_ranges[default_dtype]
141-
if m < d_m or M > d_M:
142-
_dtype = x.dtype
139+
if default_dtype is None:
140+
_dtype = None
143141
else:
144-
_dtype = default_dtype
142+
m, M = dh.dtype_ranges[x.dtype]
143+
d_m, d_M = dh.dtype_ranges[default_dtype]
144+
if m < d_m or M > d_M:
145+
_dtype = x.dtype
146+
else:
147+
_dtype = default_dtype
145148
else:
146149
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
147150
_dtype = x.dtype
148151
else:
149152
_dtype = dh.default_float
150153
else:
151154
_dtype = dtype
152-
if isinstance(_dtype, _UndefinedStub):
155+
if _dtype is None:
153156
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
154157
# uint32 or uint64), we skip testing the output dtype.
155158
# See https://github.com/data-apis/array-api-tests/issues/106
156-
if _dtype in dh.uint_dtypes:
159+
if x.dtype in dh.uint_dtypes:
157160
assert dh.is_int_dtype(out.dtype) # sanity check
158161
else:
159162
ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype)
@@ -241,24 +244,27 @@ def test_sum(x, data):
241244
default_dtype = dh.default_uint
242245
else:
243246
default_dtype = dh.default_int
244-
m, M = dh.dtype_ranges[x.dtype]
245-
d_m, d_M = dh.dtype_ranges[default_dtype]
246-
if m < d_m or M > d_M:
247-
_dtype = x.dtype
247+
if default_dtype is None:
248+
_dtype = None
248249
else:
249-
_dtype = default_dtype
250+
m, M = dh.dtype_ranges[x.dtype]
251+
d_m, d_M = dh.dtype_ranges[default_dtype]
252+
if m < d_m or M > d_M:
253+
_dtype = x.dtype
254+
else:
255+
_dtype = default_dtype
250256
else:
251257
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
252258
_dtype = x.dtype
253259
else:
254260
_dtype = dh.default_float
255261
else:
256262
_dtype = dtype
257-
if isinstance(_dtype, _UndefinedStub):
263+
if _dtype is None:
258264
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
259265
# uint32 or uint64), we skip testing the output dtype.
260266
# See https://github.com/data-apis/array-api-tests/issues/160
261-
if _dtype in dh.uint_dtypes:
267+
if x.dtype in dh.uint_dtypes:
262268
assert dh.is_int_dtype(out.dtype) # sanity check
263269
else:
264270
ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype)

0 commit comments

Comments
 (0)