@@ -381,17 +381,35 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
381
381
dpt .asnumpy (ar1 ) == np .full (ar1 .shape , 2 , dtype = ar1 .dtype )
382
382
).all ()
383
383
384
- ar3 = dpt .ones (sz , dtype = op1_dtype )
385
- ar4 = dpt .ones (2 * sz , dtype = op2_dtype )
386
-
387
- ar3 [::- 1 ] += ar4 [::2 ]
384
+ ar3 = dpt .ones (sz , dtype = op1_dtype )[::- 1 ]
385
+ ar4 = dpt .ones (2 * sz , dtype = op2_dtype )[::2 ]
386
+ ar3 += ar4
388
387
assert (
389
388
dpt .asnumpy (ar3 ) == np .full (ar3 .shape , 2 , dtype = ar3 .dtype )
390
389
).all ()
391
-
392
390
else :
393
391
with pytest .raises (TypeError ):
394
392
ar1 += ar2
393
+ dpt .add (ar1 , ar2 , out = ar1 )
394
+
395
+ # out is second arg
396
+ ar1 = dpt .ones (sz , dtype = op1_dtype )
397
+ ar2 = dpt .ones_like (ar1 , dtype = op2_dtype )
398
+ if _can_cast (ar1 .dtype , ar2 .dtype , _fp16 , _fp64 ):
399
+ dpt .add (ar1 , ar2 , out = ar2 )
400
+ assert (
401
+ dpt .asnumpy (ar2 ) == np .full (ar2 .shape , 2 , dtype = ar2 .dtype )
402
+ ).all ()
403
+
404
+ ar3 = dpt .ones (sz , dtype = op1_dtype )[::- 1 ]
405
+ ar4 = dpt .ones (2 * sz , dtype = op2_dtype )[::2 ]
406
+ dpt .add (ar3 , ar4 , out = ar4 )
407
+ assert (
408
+ dpt .asnumpy (ar4 ) == np .full (ar4 .shape , 2 , dtype = ar4 .dtype )
409
+ ).all ()
410
+ else :
411
+ with pytest .raises (TypeError ):
412
+ dpt .add (ar1 , ar2 , out = ar2 )
395
413
396
414
397
415
def test_add_inplace_broadcasting ():
@@ -403,6 +421,12 @@ def test_add_inplace_broadcasting():
403
421
m += v
404
422
assert (dpt .asnumpy (m ) == np .arange (1 , 6 , dtype = "i4" )[np .newaxis , :]).all ()
405
423
424
+ # check case where second arg is out
425
+ dpt .add (v , m , out = m )
426
+ assert (
427
+ dpt .asnumpy (m ) == np .arange (10 , dtype = "i4" )[np .newaxis , 1 :10 :2 ]
428
+ ).all ()
429
+
406
430
407
431
def test_add_inplace_errors ():
408
432
get_queue_or_skip ()
@@ -441,7 +465,7 @@ def test_add_inplace_errors():
441
465
ar1 += ar2
442
466
443
467
444
- def test_add_inplace_overlap ():
468
+ def test_add_inplace_same_tensors ():
445
469
get_queue_or_skip ()
446
470
447
471
ar1 = dpt .ones (10 , dtype = "i4" )
@@ -451,7 +475,13 @@ def test_add_inplace_overlap():
451
475
ar1 = dpt .ones (10 , dtype = "i4" )
452
476
ar2 = dpt .ones (10 , dtype = "i4" )
453
477
dpt .add (ar1 , ar2 , out = ar1 )
478
+ # all ar1 vals should be 2
454
479
assert (dpt .asnumpy (ar1 ) == np .full (ar1 .shape , 2 , dtype = "i4" )).all ()
455
480
456
481
dpt .add (ar2 , ar1 , out = ar2 )
482
+ # all ar2 vals should be 3
457
483
assert (dpt .asnumpy (ar2 ) == np .full (ar2 .shape , 3 , dtype = "i4" )).all ()
484
+
485
+ dpt .add (ar1 , ar2 , out = ar2 )
486
+ # all ar2 vals should be 5
487
+ assert (dpt .asnumpy (ar2 ) == np .full (ar2 .shape , 5 , dtype = "i4" )).all ()
0 commit comments