Skip to content

Commit 7eaab6d

Browse files
committed
Broadcasting made conditional in binary functions where memory overlap is possible
- Broadcasting can change the values of strides without changing array shape
1 parent fc9a8da commit 7eaab6d

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,8 @@ def __call__(self, o1, o2, out=None, order="K"):
510510
o2, dtype=o2_dtype, sycl_queue=exec_q
511511
)
512512
if buf2_dt is None:
513-
src2 = dpt.broadcast_to(src2, res_shape)
513+
if src2.shape != res_shape:
514+
src2 = dpt.broadcast_to(src2, res_shape)
514515
ht_, _ = self.binary_inplace_fn_(
515516
lhs=o1, rhs=src2, sycl_queue=exec_q
516517
)
@@ -581,9 +582,10 @@ def __call__(self, o1, o2, out=None, order="K"):
581582
sycl_queue=exec_q,
582583
order=order,
583584
)
584-
585-
src1 = dpt.broadcast_to(src1, res_shape)
586-
src2 = dpt.broadcast_to(src2, res_shape)
585+
if src1.shape != res_shape:
586+
src1 = dpt.broadcast_to(src1, res_shape)
587+
if src2.shape != res_shape:
588+
src2 = dpt.broadcast_to(src2, res_shape)
587589
ht_binary_ev, binary_ev = self.binary_fn_(
588590
src1=src1, src2=src2, dst=out, sycl_queue=exec_q
589591
)
@@ -628,8 +630,8 @@ def __call__(self, o1, o2, out=None, order="K"):
628630
f"Output array of type {res_dt} is needed,"
629631
f"got {out.dtype}"
630632
)
631-
632-
src1 = dpt.broadcast_to(src1, res_shape)
633+
if src1.shape != res_shape:
634+
src1 = dpt.broadcast_to(src1, res_shape)
633635
buf2 = dpt.broadcast_to(buf2, res_shape)
634636
ht_binary_ev, binary_ev = self.binary_fn_(
635637
src1=src1,
@@ -676,7 +678,8 @@ def __call__(self, o1, o2, out=None, order="K"):
676678
)
677679

678680
buf1 = dpt.broadcast_to(buf1, res_shape)
679-
src2 = dpt.broadcast_to(src2, res_shape)
681+
if src2.shape != res_shape:
682+
src2 = dpt.broadcast_to(src2, res_shape)
680683
ht_binary_ev, binary_ev = self.binary_fn_(
681684
src1=buf1,
682685
src2=src2,

0 commit comments

Comments
 (0)