@@ -136,24 +136,27 @@ def test_prod(x, data):
136
136
default_dtype = dh .default_uint
137
137
else :
138
138
default_dtype = dh .default_int
139
- m , M = dh .dtype_ranges [x .dtype ]
140
- d_m , d_M = dh .dtype_ranges [default_dtype ]
141
- if m < d_m or M > d_M :
142
- _dtype = x .dtype
139
+ if default_dtype is None :
140
+ _dtype = None
143
141
else :
144
- _dtype = default_dtype
142
+ m , M = dh .dtype_ranges [x .dtype ]
143
+ d_m , d_M = dh .dtype_ranges [default_dtype ]
144
+ if m < d_m or M > d_M :
145
+ _dtype = x .dtype
146
+ else :
147
+ _dtype = default_dtype
145
148
else :
146
149
if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
147
150
_dtype = x .dtype
148
151
else :
149
152
_dtype = dh .default_float
150
153
else :
151
154
_dtype = dtype
152
- if isinstance ( _dtype , _UndefinedStub ) :
155
+ if _dtype is None :
153
156
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
154
157
# uint32 or uint64), we skip testing the output dtype.
155
158
# See https://github.com/data-apis/array-api-tests/issues/106
156
- if _dtype in dh .uint_dtypes :
159
+ if x . dtype in dh .uint_dtypes :
157
160
assert dh .is_int_dtype (out .dtype ) # sanity check
158
161
else :
159
162
ph .assert_dtype ("prod" , in_dtype = x .dtype , out_dtype = out .dtype , expected = _dtype )
@@ -241,24 +244,27 @@ def test_sum(x, data):
241
244
default_dtype = dh .default_uint
242
245
else :
243
246
default_dtype = dh .default_int
244
- m , M = dh .dtype_ranges [x .dtype ]
245
- d_m , d_M = dh .dtype_ranges [default_dtype ]
246
- if m < d_m or M > d_M :
247
- _dtype = x .dtype
247
+ if default_dtype is None :
248
+ _dtype = None
248
249
else :
249
- _dtype = default_dtype
250
+ m , M = dh .dtype_ranges [x .dtype ]
251
+ d_m , d_M = dh .dtype_ranges [default_dtype ]
252
+ if m < d_m or M > d_M :
253
+ _dtype = x .dtype
254
+ else :
255
+ _dtype = default_dtype
250
256
else :
251
257
if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
252
258
_dtype = x .dtype
253
259
else :
254
260
_dtype = dh .default_float
255
261
else :
256
262
_dtype = dtype
257
- if isinstance ( _dtype , _UndefinedStub ) :
263
+ if _dtype is None :
258
264
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
259
265
# uint32 or uint64), we skip testing the output dtype.
260
266
# See https://github.com/data-apis/array-api-tests/issues/160
261
- if _dtype in dh .uint_dtypes :
267
+ if x . dtype in dh .uint_dtypes :
262
268
assert dh .is_int_dtype (out .dtype ) # sanity check
263
269
else :
264
270
ph .assert_dtype ("sum" , in_dtype = x .dtype , out_dtype = out .dtype , expected = _dtype )
0 commit comments