@@ -213,12 +213,16 @@ def test_bitwise_and(args):
213
213
# TODO: Generalize this properly for inputs that are arrays.
214
214
if not (x1 .shape == x2 .shape == ()):
215
215
raise RuntimeError ("Error: test_bitwise_and needs to be updated for nonscalar array inputs" )
216
- x = int (x1 )
217
- y = int (x2 )
218
- res = int (a )
216
+
219
217
if a .dtype == bool_dtype :
218
+ x = bool (x1 )
219
+ y = bool (x2 )
220
+ res = bool (a )
220
221
assert (x and y ) == res
221
222
else :
223
+ x = int (x1 )
224
+ y = int (x2 )
225
+ res = int (a )
222
226
ans = int_to_dtype (x & y , dtype_nbits (a .dtype ), dtype_signed (a .dtype ))
223
227
assert ans == res
224
228
@@ -254,11 +258,13 @@ def test_bitwise_invert(x):
254
258
# TODO: Generalize this properly for inputs that are arrays.
255
259
if not (x .shape == ()):
256
260
raise RuntimeError ("Error: test_bitwise_invert needs to be updated for nonscalar array inputs" )
257
- x = int (x )
258
- res = int (a )
259
261
if a .dtype == bool_dtype :
262
+ x = bool (x )
263
+ res = bool (a )
260
264
assert (not x ) == res
261
265
else :
266
+ x = int (x )
267
+ res = int (a )
262
268
ans = int_to_dtype (~ x , dtype_nbits (a .dtype ), dtype_signed (a .dtype ))
263
269
assert ans == res
264
270
@@ -272,12 +278,15 @@ def test_bitwise_or(args):
272
278
# TODO: Generalize this properly for inputs that are arrays.
273
279
if not (x1 .shape == x2 .shape == ()):
274
280
raise RuntimeError ("Error: test_bitwise_or needs to be updated for nonscalar array inputs" )
275
- x = int (x1 )
276
- y = int (x2 )
277
- res = int (a )
278
281
if a .dtype == bool_dtype :
282
+ x = bool (x1 )
283
+ y = bool (x2 )
284
+ res = bool (a )
279
285
assert (x or y ) == res
280
286
else :
287
+ x = int (x1 )
288
+ y = int (x2 )
289
+ res = int (a )
281
290
ans = int_to_dtype (x | y , dtype_nbits (a .dtype ), dtype_signed (a .dtype ))
282
291
assert ans == res
283
292
@@ -310,12 +319,15 @@ def test_bitwise_xor(args):
310
319
# TODO: Generalize this properly for inputs that are arrays.
311
320
if not (x1 .shape == x2 .shape == ()):
312
321
raise RuntimeError ("Error: test_bitwise_xor needs to be updated for nonscalar array inputs" )
313
- x = int (x1 )
314
- y = int (x2 )
315
- res = int (a )
316
322
if a .dtype == bool_dtype :
323
+ x = bool (x1 )
324
+ y = bool (x2 )
325
+ res = bool (a )
317
326
assert (x ^ y ) == res
318
327
else :
328
+ x = int (x1 )
329
+ y = int (x2 )
330
+ res = int (a )
319
331
ans = int_to_dtype (x ^ y , dtype_nbits (a .dtype ), dtype_signed (a .dtype ))
320
332
assert ans == res
321
333
0 commit comments