@@ -510,7 +510,8 @@ def __call__(self, o1, o2, out=None, order="K"):
510
510
o2 , dtype = o2_dtype , sycl_queue = exec_q
511
511
)
512
512
if buf2_dt is None :
513
- src2 = dpt .broadcast_to (src2 , res_shape )
513
+ if src2 .shape != res_shape :
514
+ src2 = dpt .broadcast_to (src2 , res_shape )
514
515
ht_ , _ = self .binary_inplace_fn_ (
515
516
lhs = o1 , rhs = src2 , sycl_queue = exec_q
516
517
)
@@ -581,9 +582,10 @@ def __call__(self, o1, o2, out=None, order="K"):
581
582
sycl_queue = exec_q ,
582
583
order = order ,
583
584
)
584
-
585
- src1 = dpt .broadcast_to (src1 , res_shape )
586
- src2 = dpt .broadcast_to (src2 , res_shape )
585
+ if src1 .shape != res_shape :
586
+ src1 = dpt .broadcast_to (src1 , res_shape )
587
+ if src2 .shape != res_shape :
588
+ src2 = dpt .broadcast_to (src2 , res_shape )
587
589
ht_binary_ev , binary_ev = self .binary_fn_ (
588
590
src1 = src1 , src2 = src2 , dst = out , sycl_queue = exec_q
589
591
)
@@ -628,8 +630,8 @@ def __call__(self, o1, o2, out=None, order="K"):
628
630
f"Output array of type { res_dt } is needed,"
629
631
f"got { out .dtype } "
630
632
)
631
-
632
- src1 = dpt .broadcast_to (src1 , res_shape )
633
+ if src1 . shape != res_shape :
634
+ src1 = dpt .broadcast_to (src1 , res_shape )
633
635
buf2 = dpt .broadcast_to (buf2 , res_shape )
634
636
ht_binary_ev , binary_ev = self .binary_fn_ (
635
637
src1 = src1 ,
@@ -676,7 +678,8 @@ def __call__(self, o1, o2, out=None, order="K"):
676
678
)
677
679
678
680
buf1 = dpt .broadcast_to (buf1 , res_shape )
679
- src2 = dpt .broadcast_to (src2 , res_shape )
681
+ if src2 .shape != res_shape :
682
+ src2 = dpt .broadcast_to (src2 , res_shape )
680
683
ht_binary_ev , binary_ev = self .binary_fn_ (
681
684
src1 = buf1 ,
682
685
src2 = src2 ,
0 commit comments