Skip to content

Commit c7dd08a

Browse files
ndgrigorianoleksandr-pavlyk
authored andcommitted
Add more tests for element-wise in-place operators
Also clean up and make some tests for in-place operators more efficient
1 parent 8e2fc81 commit c7dd08a

File tree

5 files changed

+71
-55
lines changed

5 files changed

+71
-55
lines changed

dpctl/tests/elementwise/test_add.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,25 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
373373
else:
374374
with pytest.raises(ValueError):
375375
ar1 += ar2
376+
377+
ar1 = dpt.ones(sz, dtype=op1_dtype)
378+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
379+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
380+
dpt.add(ar1, ar2, out=ar1)
381+
assert (
382+
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
383+
).all()
384+
385+
ar3 = dpt.ones(sz, dtype=op1_dtype)[::-1]
386+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)[::2]
387+
dpt.add(ar3, ar4, out=ar3)
388+
assert (
389+
dpt.asnumpy(ar3) == np.full(ar3.shape, 2, dtype=ar3.dtype)
390+
).all()
391+
else:
392+
with pytest.raises(ValueError):
376393
dpt.add(ar1, ar2, out=ar1)
377394

378-
# out is second arg
379395
ar1 = dpt.ones(sz, dtype=op1_dtype)
380396
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
381397
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
@@ -401,7 +417,7 @@ def test_add_inplace_broadcasting():
401417
m = dpt.ones((100, 5), dtype="i4")
402418
v = dpt.arange(5, dtype="i4")
403419

404-
m += v
420+
dpt.add(m, v, out=m)
405421
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
406422

407423
# check case where second arg is out
@@ -411,6 +427,26 @@ def test_add_inplace_broadcasting():
411427
).all()
412428

413429

430+
def test_add_inplace_operator_broadcasting():
431+
get_queue_or_skip()
432+
433+
m = dpt.ones((100, 5), dtype="i4")
434+
v = dpt.arange(5, dtype="i4")
435+
436+
m += v
437+
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
438+
439+
440+
def test_add_inplace_operator_mutual_broadcast():
441+
get_queue_or_skip()
442+
443+
x1 = dpt.ones((1, 10), dtype="i4")
444+
x2 = dpt.ones((10, 1), dtype="i4")
445+
446+
with pytest.raises(ValueError):
447+
dpt.add._inplace_op(x1, x2)
448+
449+
414450
def test_add_inplace_errors():
415451
get_queue_or_skip()
416452
try:
@@ -425,27 +461,45 @@ def test_add_inplace_errors():
425461
ar1 = dpt.ones(2, dtype="float32", sycl_queue=gpu_queue)
426462
ar2 = dpt.ones_like(ar1, sycl_queue=cpu_queue)
427463
with pytest.raises(ExecutionPlacementError):
428-
ar1 += ar2
464+
dpt.add(ar1, ar2, out=ar1)
429465

430466
ar1 = dpt.ones(2, dtype="float32")
431467
ar2 = dpt.ones(3, dtype="float32")
432468
with pytest.raises(ValueError):
433-
ar1 += ar2
469+
dpt.add(ar1, ar2, out=ar1)
434470

435471
ar1 = np.ones(2, dtype="float32")
436472
ar2 = dpt.ones(2, dtype="float32")
437473
with pytest.raises(TypeError):
438-
ar1 += ar2
474+
dpt.add(ar1, ar2, out=ar1)
439475

440476
ar1 = dpt.ones(2, dtype="float32")
441477
ar2 = dict()
442478
with pytest.raises(ValueError):
443-
ar1 += ar2
479+
dpt.add(ar1, ar2, out=ar1)
444480

445481
ar1 = dpt.ones((2, 1), dtype="float32")
446482
ar2 = dpt.ones((1, 2), dtype="float32")
447483
with pytest.raises(ValueError):
448-
ar1 += ar2
484+
dpt.add(ar1, ar2, out=ar1)
485+
486+
487+
def test_add_inplace_operator_errors():
488+
q1 = get_queue_or_skip()
489+
q2 = get_queue_or_skip()
490+
491+
x = dpt.ones(10, dtype="i4", sycl_queue=q1)
492+
with pytest.raises(TypeError):
493+
dpt.add._inplace_op(dict(), x)
494+
495+
x.flags["W"] = False
496+
with pytest.raises(ValueError):
497+
dpt.add._inplace_op(x, 2)
498+
499+
x_q1 = dpt.ones(10, dtype="i4", sycl_queue=q1)
500+
x_q2 = dpt.ones(10, dtype="i4", sycl_queue=q2)
501+
with pytest.raises(ExecutionPlacementError):
502+
dpt.add._inplace_op(x_q1, x_q2)
449503

450504

451505
def test_add_inplace_same_tensors():

dpctl/tests/elementwise/test_bitwise_and.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,3 @@ def test_bitwise_and_inplace_dtype_matrix(op1_dtype, op2_dtype):
125125
else:
126126
with pytest.raises(ValueError):
127127
ar1 &= ar2
128-
dpt.bitwise_and(ar1, ar2, out=ar1)
129-
130-
# out is second arg
131-
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
132-
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
133-
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
134-
dpt.bitwise_and(ar1, ar2, out=ar2)
135-
assert dpt.all(ar2 == 1)
136-
137-
ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
138-
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
139-
dpt.bitwise_and(ar3, ar4, out=ar4)
140-
dpt.all(ar4 == 1)
141-
else:
142-
with pytest.raises(ValueError):
143-
dpt.bitwise_and(ar1, ar2, out=ar2)

dpctl/tests/elementwise/test_bitwise_left_shift.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,3 @@ def test_bitwise_left_shift_inplace_dtype_matrix(op1_dtype, op2_dtype):
133133
else:
134134
with pytest.raises(ValueError):
135135
ar1 <<= ar2
136-
dpt.bitwise_left_shift(ar1, ar2, out=ar1)
137-
138-
# out is second arg
139-
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
140-
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
141-
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
142-
dpt.bitwise_left_shift(ar1, ar2, out=ar2)
143-
assert dpt.all(ar2 == 2)
144-
145-
ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
146-
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
147-
dpt.bitwise_left_shift(ar3, ar4, out=ar4)
148-
dpt.all(ar4 == 2)
149-
else:
150-
with pytest.raises(ValueError):
151-
dpt.bitwise_left_shift(ar1, ar2, out=ar2)

dpctl/tests/elementwise/test_elementwise_classes.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,20 @@ def test_binary_class_nout():
118118
assert nout == 1
119119

120120

121-
def test_biary_read_only_out():
121+
def test_binary_read_only_out():
122122
get_queue_or_skip()
123123
x1 = dpt.ones(32, dtype=dpt.float32)
124124
x2 = dpt.ones_like(x1)
125125
r = dpt.empty_like(x1)
126126
r.flags["W"] = False
127127
with pytest.raises(ValueError):
128128
binary_fn(x1, x2, out=r)
129+
130+
131+
def test_binary_no_inplace_op():
132+
get_queue_or_skip()
133+
x1 = dpt.ones(10, dtype="i4")
134+
x2 = dpt.ones_like(x1)
135+
136+
with pytest.raises(ValueError):
137+
dpt.logaddexp._inplace_op(x1, x2)

dpctl/tests/elementwise/test_floor_divide.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -302,18 +302,3 @@ def test_floor_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
302302
with pytest.raises(ValueError):
303303
ar1 //= ar2
304304
dpt.floor_divide(ar1, ar2, out=ar1)
305-
306-
# out is second arg
307-
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
308-
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
309-
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
310-
dpt.floor_divide(ar1, ar2, out=ar2)
311-
assert dpt.all(ar2 == 1)
312-
313-
ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
314-
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
315-
dpt.floor_divide(ar3, ar4, out=ar4)
316-
dpt.all(ar4 == 1)
317-
else:
318-
with pytest.raises(ValueError):
319-
dpt.floor_divide(ar1, ar2, out=ar2)

0 commit comments

Comments
 (0)