Skip to content

Commit fc9a8da

Browse files
committed
Tests for new out parameter behavior for add
1 parent 369c500 commit fc9a8da

File tree

1 file changed

+36
-6
lines changed

1 file changed

+36
-6
lines changed

dpctl/tests/elementwise/test_add.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,17 +381,35 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
381381
dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype=ar1.dtype)
382382
).all()
383383

384-
ar3 = dpt.ones(sz, dtype=op1_dtype)
385-
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
386-
387-
ar3[::-1] += ar4[::2]
384+
ar3 = dpt.ones(sz, dtype=op1_dtype)[::-1]
385+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)[::2]
386+
ar3 += ar4
388387
assert (
389388
dpt.asnumpy(ar3) == np.full(ar3.shape, 2, dtype=ar3.dtype)
390389
).all()
391-
392390
else:
393391
with pytest.raises(TypeError):
394392
ar1 += ar2
393+
dpt.add(ar1, ar2, out=ar1)
394+
395+
# out is second arg
396+
ar1 = dpt.ones(sz, dtype=op1_dtype)
397+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
398+
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
399+
dpt.add(ar1, ar2, out=ar2)
400+
assert (
401+
dpt.asnumpy(ar2) == np.full(ar2.shape, 2, dtype=ar2.dtype)
402+
).all()
403+
404+
ar3 = dpt.ones(sz, dtype=op1_dtype)[::-1]
405+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)[::2]
406+
dpt.add(ar3, ar4, out=ar4)
407+
assert (
408+
dpt.asnumpy(ar4) == np.full(ar4.shape, 2, dtype=ar4.dtype)
409+
).all()
410+
else:
411+
with pytest.raises(TypeError):
412+
dpt.add(ar1, ar2, out=ar2)
395413

396414

397415
def test_add_inplace_broadcasting():
@@ -403,6 +421,12 @@ def test_add_inplace_broadcasting():
403421
m += v
404422
assert (dpt.asnumpy(m) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
405423

424+
# check case where second arg is out
425+
dpt.add(v, m, out=m)
426+
assert (
427+
dpt.asnumpy(m) == np.arange(10, dtype="i4")[np.newaxis, 1:10:2]
428+
).all()
429+
406430

407431
def test_add_inplace_errors():
408432
get_queue_or_skip()
@@ -441,7 +465,7 @@ def test_add_inplace_errors():
441465
ar1 += ar2
442466

443467

444-
def test_add_inplace_overlap():
468+
def test_add_inplace_same_tensors():
445469
get_queue_or_skip()
446470

447471
ar1 = dpt.ones(10, dtype="i4")
@@ -451,7 +475,13 @@ def test_add_inplace_overlap():
451475
ar1 = dpt.ones(10, dtype="i4")
452476
ar2 = dpt.ones(10, dtype="i4")
453477
dpt.add(ar1, ar2, out=ar1)
478+
# all ar1 vals should be 2
454479
assert (dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype="i4")).all()
455480

456481
dpt.add(ar2, ar1, out=ar2)
482+
# all ar2 vals should be 3
457483
assert (dpt.asnumpy(ar2) == np.full(ar2.shape, 3, dtype="i4")).all()
484+
485+
dpt.add(ar1, ar2, out=ar2)
486+
# all ar2 vals should be 5
487+
assert (dpt.asnumpy(ar2) == np.full(ar2.shape, 5, dtype="i4")).all()

0 commit comments

Comments
 (0)