31
31
_acceptance_fn_default ,
32
32
_find_buf_dtype ,
33
33
_find_buf_dtype2 ,
34
- _find_inplace_dtype ,
35
34
_to_device_supported_dtype ,
36
35
)
37
36
@@ -383,14 +382,6 @@ def __repr__(self):
383
382
return f"<{ self .__name__ } '{ self .name_ } '>"
384
383
385
384
def __call__ (self , o1 , o2 , out = None , order = "K" ):
386
- # FIXME: replace with check against base array
387
- # when views can be identified
388
- if self .binary_inplace_fn_ :
389
- if o1 is out :
390
- return self ._inplace (o1 , o2 )
391
- elif o2 is out :
392
- return self ._inplace (o2 , o1 )
393
-
394
385
if order not in ["K" , "C" , "F" , "A" ]:
395
386
order = "K"
396
387
q1 , o1_usm_type = _get_queue_usm_type (o1 )
@@ -472,6 +463,7 @@ def __call__(self, o1, o2, out=None, order="K"):
472
463
"supported types according to the casting rule ''safe''."
473
464
)
474
465
466
+ orig_out = out
475
467
if out is not None :
476
468
if not isinstance (out , dpt .usm_ndarray ):
477
469
raise TypeError (
@@ -484,19 +476,76 @@ def __call__(self, o1, o2, out=None, order="K"):
484
476
f"Expected output shape is { o1_shape } , got { out .shape } "
485
477
)
486
478
487
- if ti ._array_overlap (o1 , out ) or ti ._array_overlap (o2 , out ):
488
- raise TypeError ("Input and output arrays have memory overlap" )
479
+ if res_dt != out .dtype :
480
+ raise TypeError (
481
+ f"Output array of type { res_dt } is needed,"
482
+ f"got { out .dtype } "
483
+ )
489
484
490
485
if (
491
- dpctl .utils .get_execution_queue (
492
- (o1 .sycl_queue , o2 .sycl_queue , out .sycl_queue )
493
- )
486
+ dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue ))
494
487
is None
495
488
):
496
489
raise TypeError (
497
490
"Input and output allocation queues are not compatible"
498
491
)
499
492
493
+ if isinstance (o1 , dpt .usm_ndarray ):
494
+ if ti ._array_overlap (o1 , out ) and buf1_dt is None :
495
+ if not ti ._same_logical_tensors (o1 , out ):
496
+ out = dpt .empty_like (out )
497
+ elif self .binary_inplace_fn_ is not None :
498
+ # if there is a dedicated in-place kernel
499
+ # it can be called here, otherwise continues
500
+ if isinstance (o2 , dpt .usm_ndarray ):
501
+ src2 = o2
502
+ if (
503
+ ti ._array_overlap (o2 , out )
504
+ and not ti ._same_logical_tensors (o2 , out )
505
+ and buf2_dt is None
506
+ ):
507
+ buf2_dt = o2_dtype
508
+ else :
509
+ src2 = dpt .asarray (
510
+ o2 , dtype = o2_dtype , sycl_queue = exec_q
511
+ )
512
+ if buf2_dt is None :
513
+ src2 = dpt .broadcast_to (src2 , res_shape )
514
+ ht_ , _ = self .binary_inplace_fn_ (
515
+ lhs = o1 , rhs = src2 , sycl_queue = exec_q
516
+ )
517
+ ht_ .wait ()
518
+ else :
519
+ buf2 = dpt .empty_like (src2 , dtype = buf2_dt )
520
+ (
521
+ ht_copy_ev ,
522
+ copy_ev ,
523
+ ) = ti ._copy_usm_ndarray_into_usm_ndarray (
524
+ src = src2 , dst = buf2 , sycl_queue = exec_q
525
+ )
526
+
527
+ buf2 = dpt .broadcast_to (buf2 , res_shape )
528
+ ht_ , _ = self .binary_inplace_fn_ (
529
+ lhs = o1 ,
530
+ rhs = buf2 ,
531
+ sycl_queue = exec_q ,
532
+ depends = [copy_ev ],
533
+ )
534
+ ht_copy_ev .wait ()
535
+ ht_ .wait ()
536
+
537
+ return out
538
+
539
+ if isinstance (o2 , dpt .usm_ndarray ):
540
+ if (
541
+ ti ._array_overlap (o2 , out )
542
+ and not ti ._same_logical_tensors (o2 , out )
543
+ and buf2_dt is None
544
+ ):
545
+ # should not reach if out is reallocated
546
+ # after being checked against o1
547
+ out = dpt .empty_like (out )
548
+
500
549
if isinstance (o1 , dpt .usm_ndarray ):
501
550
src1 = o1
502
551
else :
@@ -532,19 +581,23 @@ def __call__(self, o1, o2, out=None, order="K"):
532
581
sycl_queue = exec_q ,
533
582
order = order ,
534
583
)
535
- else :
536
- if res_dt != out .dtype :
537
- raise TypeError (
538
- f"Output array of type { res_dt } is needed,"
539
- f"got { out .dtype } "
540
- )
541
584
542
585
src1 = dpt .broadcast_to (src1 , res_shape )
543
586
src2 = dpt .broadcast_to (src2 , res_shape )
544
- ht_ , _ = self .binary_fn_ (
587
+ ht_binary_ev , binary_ev = self .binary_fn_ (
545
588
src1 = src1 , src2 = src2 , dst = out , sycl_queue = exec_q
546
589
)
547
- ht_ .wait ()
590
+ if not (orig_out is None or orig_out is out ):
591
+ # Copy the out data from temporary buffer to original memory
592
+ ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
593
+ src = out ,
594
+ dst = orig_out ,
595
+ sycl_queue = exec_q ,
596
+ depends = [binary_ev ],
597
+ )
598
+ ht_copy_out_ev .wait ()
599
+ out = orig_out
600
+ ht_binary_ev .wait ()
548
601
return out
549
602
elif buf1_dt is None :
550
603
if order == "K" :
@@ -578,15 +631,25 @@ def __call__(self, o1, o2, out=None, order="K"):
578
631
579
632
src1 = dpt .broadcast_to (src1 , res_shape )
580
633
buf2 = dpt .broadcast_to (buf2 , res_shape )
581
- ht_ , _ = self .binary_fn_ (
634
+ ht_binary_ev , binary_ev = self .binary_fn_ (
582
635
src1 = src1 ,
583
636
src2 = buf2 ,
584
637
dst = out ,
585
638
sycl_queue = exec_q ,
586
639
depends = [copy_ev ],
587
640
)
641
+ if not (orig_out is None or orig_out is out ):
642
+ # Copy the out data from temporary buffer to original memory
643
+ ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
644
+ src = out ,
645
+ dst = orig_out ,
646
+ sycl_queue = exec_q ,
647
+ depends = [binary_ev ],
648
+ )
649
+ ht_copy_out_ev .wait ()
650
+ out = orig_out
588
651
ht_copy_ev .wait ()
589
- ht_ .wait ()
652
+ ht_binary_ev .wait ()
590
653
return out
591
654
elif buf2_dt is None :
592
655
if order == "K" :
@@ -611,24 +674,28 @@ def __call__(self, o1, o2, out=None, order="K"):
611
674
sycl_queue = exec_q ,
612
675
order = order ,
613
676
)
614
- else :
615
- if res_dt != out .dtype :
616
- raise TypeError (
617
- f"Output array of type { res_dt } is needed,"
618
- f"got { out .dtype } "
619
- )
620
677
621
678
buf1 = dpt .broadcast_to (buf1 , res_shape )
622
679
src2 = dpt .broadcast_to (src2 , res_shape )
623
- ht_ , _ = self .binary_fn_ (
680
+ ht_binary_ev , binary_ev = self .binary_fn_ (
624
681
src1 = buf1 ,
625
682
src2 = src2 ,
626
683
dst = out ,
627
684
sycl_queue = exec_q ,
628
685
depends = [copy_ev ],
629
686
)
687
+ if not (orig_out is None or orig_out is out ):
688
+ # Copy the out data from temporary buffer to original memory
689
+ ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
690
+ src = out ,
691
+ dst = orig_out ,
692
+ sycl_queue = exec_q ,
693
+ depends = [binary_ev ],
694
+ )
695
+ ht_copy_out_ev .wait ()
696
+ out = orig_out
630
697
ht_copy_ev .wait ()
631
- ht_ .wait ()
698
+ ht_binary_ev .wait ()
632
699
return out
633
700
634
701
if order in ["K" , "A" ]:
@@ -665,11 +732,6 @@ def __call__(self, o1, o2, out=None, order="K"):
665
732
sycl_queue = exec_q ,
666
733
order = order ,
667
734
)
668
- else :
669
- if res_dt != out .dtype :
670
- raise TypeError (
671
- f"Output array of type { res_dt } is needed, got { out .dtype } "
672
- )
673
735
674
736
buf1 = dpt .broadcast_to (buf1 , res_shape )
675
737
buf2 = dpt .broadcast_to (buf2 , res_shape )
@@ -682,116 +744,3 @@ def __call__(self, o1, o2, out=None, order="K"):
682
744
)
683
745
dpctl .SyclEvent .wait_for ([ht_copy1_ev , ht_copy2_ev , ht_ ])
684
746
return out
685
-
686
- def _inplace (self , lhs , val ):
687
- if self .binary_inplace_fn_ is None :
688
- raise ValueError (
689
- f"In-place operation not supported for ufunc '{ self .name_ } '"
690
- )
691
- if not isinstance (lhs , dpt .usm_ndarray ):
692
- raise TypeError (
693
- f"Expected dpctl.tensor.usm_ndarray, got { type (lhs )} "
694
- )
695
- q1 , lhs_usm_type = _get_queue_usm_type (lhs )
696
- q2 , val_usm_type = _get_queue_usm_type (val )
697
- if q2 is None :
698
- exec_q = q1
699
- usm_type = lhs_usm_type
700
- else :
701
- exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
702
- if exec_q is None :
703
- raise ExecutionPlacementError (
704
- "Execution placement can not be unambiguously inferred "
705
- "from input arguments."
706
- )
707
- usm_type = dpctl .utils .get_coerced_usm_type (
708
- (
709
- lhs_usm_type ,
710
- val_usm_type ,
711
- )
712
- )
713
- dpctl .utils .validate_usm_type (usm_type , allow_none = False )
714
- lhs_shape = _get_shape (lhs )
715
- val_shape = _get_shape (val )
716
- if not all (
717
- isinstance (s , (tuple , list ))
718
- for s in (
719
- lhs_shape ,
720
- val_shape ,
721
- )
722
- ):
723
- raise TypeError (
724
- "Shape of arguments can not be inferred. "
725
- "Arguments are expected to be "
726
- "lists, tuples, or both"
727
- )
728
- try :
729
- res_shape = _broadcast_shape_impl (
730
- [
731
- lhs_shape ,
732
- val_shape ,
733
- ]
734
- )
735
- except ValueError :
736
- raise ValueError (
737
- "operands could not be broadcast together with shapes "
738
- f"{ lhs_shape } and { val_shape } "
739
- )
740
- if res_shape != lhs_shape :
741
- raise ValueError (
742
- f"output shape { lhs_shape } does not match "
743
- f"broadcast shape { res_shape } "
744
- )
745
- sycl_dev = exec_q .sycl_device
746
- lhs_dtype = lhs .dtype
747
- val_dtype = _get_dtype (val , sycl_dev )
748
- if not _validate_dtype (val_dtype ):
749
- raise ValueError ("Input operand of unsupported type" )
750
-
751
- lhs_dtype , val_dtype = _resolve_weak_types (
752
- lhs_dtype , val_dtype , sycl_dev
753
- )
754
-
755
- buf_dt = _find_inplace_dtype (
756
- lhs_dtype , val_dtype , self .result_type_resolver_fn_ , sycl_dev
757
- )
758
-
759
- if buf_dt is None :
760
- raise TypeError (
761
- f"In-place '{ self .name_ } ' does not support input types "
762
- f"({ lhs_dtype } , { val_dtype } ), "
763
- "and the inputs could not be safely coerced to any "
764
- "supported types according to the casting rule ''safe''."
765
- )
766
-
767
- if isinstance (val , dpt .usm_ndarray ):
768
- rhs = val
769
- overlap = ti ._array_overlap (lhs , rhs )
770
- else :
771
- rhs = dpt .asarray (val , dtype = val_dtype , sycl_queue = exec_q )
772
- overlap = False
773
-
774
- if buf_dt == val_dtype and overlap is False :
775
- rhs = dpt .broadcast_to (rhs , res_shape )
776
- ht_ , _ = self .binary_inplace_fn_ (
777
- lhs = lhs , rhs = rhs , sycl_queue = exec_q
778
- )
779
- ht_ .wait ()
780
-
781
- else :
782
- buf = dpt .empty_like (rhs , dtype = buf_dt )
783
- ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
784
- src = rhs , dst = buf , sycl_queue = exec_q
785
- )
786
-
787
- buf = dpt .broadcast_to (buf , res_shape )
788
- ht_ , _ = self .binary_inplace_fn_ (
789
- lhs = lhs ,
790
- rhs = buf ,
791
- sycl_queue = exec_q ,
792
- depends = [copy_ev ],
793
- )
794
- ht_copy_ev .wait ()
795
- ht_ .wait ()
796
-
797
- return lhs
0 commit comments